From a22ad3ec372278f9b98df73923c8658b7bd1e22c Mon Sep 17 00:00:00 2001 From: grumbulon Date: Thu, 19 Jan 2023 23:56:38 -0500 Subject: [PATCH] added a handler function for common stuff like 500 error, and changed the auth failure handler. Added error handling in a few places. Unexported setDBMiddleware handler --- internal/api/api.go | 28 +++++++++++++++++++--------- internal/api/auth.go | 23 ++++++++++++++--------- internal/api/jwt.go | 11 +++++------ internal/api/users.go | 19 ++++++++++++------- internal/api/zone.go | 16 ++++++++-------- 5 files changed, 58 insertions(+), 39 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index 3596af9..17513dd 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -17,8 +17,8 @@ const ( keyPrincipalContextID key = iota ) -// SetDBMiddleware is the http Handler func for the GORM middleware with context. -func SetDBMiddleware(next http.Handler) http.Handler { +// setDBMiddleware is the http Handler func for the GORM middleware with context. +func setDBMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { db := db.InitDb() timeoutContext, cancelContext := context.WithTimeout(context.Background(), time.Second) @@ -28,11 +28,21 @@ func SetDBMiddleware(next http.Handler) http.Handler { }) } -func basicAuthFailed(w http.ResponseWriter, realm string) { - w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) +// handlers for very common errors. +func authFailed(w http.ResponseWriter, realm string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Realm="%s"`, realm)) w.WriteHeader(http.StatusUnauthorized) } +func internalServerError(w http.ResponseWriter, errMsg string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Add("Internal Server Error", errMsg) + w.WriteHeader(http.StatusInternalServerError) +} + // API subroute handler. func API() http.Handler { api := chi.NewRouter() @@ -42,15 +52,15 @@ func API() http.Handler { api.Use(jwtauth.Verifier(tokenAuth)) api.Use(jwtauth.Authenticator) - api.With(SetDBMiddleware).Post("/upload", ReceiveFile) - api.With(SetDBMiddleware).Post("/parse", ZoneFiles) + api.With(setDBMiddleware).Post("/upload", ReceiveFile) + api.With(setDBMiddleware).Post("/parse", ZoneFiles) }) // Open routes api.Group(func(api chi.Router) { - api.Use(SetDBMiddleware) - api.With(SetDBMiddleware).Post("/create", NewUser) - api.With(SetDBMiddleware).Post("/login", Login) + api.Use(setDBMiddleware) + api.With(setDBMiddleware).Post("/create", NewUser) + api.With(setDBMiddleware).Post("/login", Login) api.Post("/logout", Logout) }) diff --git a/internal/api/auth.go b/internal/api/auth.go index 40bcc98..4f70b97 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -22,7 +22,7 @@ func Login(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - http.Error(w, "Unable to parse request", http.StatusInternalServerError) + internalServerError(w, "unable to parse request") return } @@ -32,20 +32,20 @@ func Login(w http.ResponseWriter, r *http.Request) { password := r.Form.Get("password") if username == "" { - http.Error(w, "No username provided", http.StatusInternalServerError) // this should prob be handled by the frontend + internalServerError(w, "no username provided") // this should prob be handled by the frontend return } if password == "" { - http.Error(w, "No password provided", http.StatusInternalServerError) // this should prob be handled by the frontend + internalServerError(w, "no password provided") // this should prob be handled by the frontend return } db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "DB connection failed") return } @@ -53,7 +53,7 @@ func Login(w http.ResponseWriter, r *http.Request) { db.Where("username = ?", username).First(&result) if result.Username == "" { - http.Error(w, "login failed", http.StatusUnauthorized) + authFailed(w, "login") return } @@ -61,12 +61,17 @@ func Login(w http.ResponseWriter, r *http.Request) { err = bcrypt.CompareHashAndPassword([]byte(result.HashedPassword), []byte(password)) if err != nil { - basicAuthFailed(w, "user") + authFailed(w, "login") return } - token := makeToken(username) + token, err := makeToken(username) + if err != nil { + internalServerError(w, err.Error()) + + return + } http.SetCookie(w, &http.Cookie{ HttpOnly: true, @@ -86,7 +91,7 @@ func Login(w http.ResponseWriter, r *http.Request) { }) if err != nil { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } @@ -113,7 +118,7 @@ func Logout(w http.ResponseWriter, r *http.Request) { HTTPResponse: 200, }) if err != nil { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } diff --git a/internal/api/jwt.go b/internal/api/jwt.go index fbe900a..abb5d11 100644 --- a/internal/api/jwt.go +++ b/internal/api/jwt.go @@ -1,7 +1,7 @@ package api import ( - "log" + "fmt" "git.freecumextremist.com/grumbulon/pomme/internal" "github.com/go-chi/jwtauth/v5" @@ -15,11 +15,10 @@ func init() { } } -func makeToken(username string) string { - _, tokenString, err := tokenAuth.Encode(map[string]interface{}{"username": username, "admin": "false"}) - if err != nil { - log.Fatalln(err) +func makeToken(username string) (tokenString string, err error) { + if _, tokenString, err = tokenAuth.Encode(map[string]interface{}{"username": username, "admin": "false"}); err == nil { + return } - return tokenString + return "", fmt.Errorf("unable to generate JWT: %w", err) } diff --git a/internal/api/users.go b/internal/api/users.go index 4281808..f91ec5e 100644 --- a/internal/api/users.go +++ b/internal/api/users.go @@ -17,14 +17,14 @@ import ( func NewUser(w http.ResponseWriter, r *http.Request) { db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") } var result internal.User err := r.ParseForm() if err != nil { - http.Error(w, "Unable to parse request", http.StatusInternalServerError) + internalServerError(w, "unable to parse request") return } @@ -38,7 +38,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { password := r.Form.Get("password") if password == "" { - http.Error(w, "No password entered", http.StatusInternalServerError) + internalServerError(w, "no password provided") return } @@ -46,21 +46,26 @@ func NewUser(w http.ResponseWriter, r *http.Request) { db.Where("username = ?", username).First(&result) if result.Username != "" { - http.Error(w, "User already exists", http.StatusInternalServerError) + internalServerError(w, "user already exists") return } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + authFailed(w, "login") return } db.Create(&internal.User{Username: username, HashedPassword: string(hashedPassword)}) - token := makeToken(username) + token, err := makeToken(username) + if err != nil { + internalServerError(w, "internal server error") + + return + } http.SetCookie(w, &http.Cookie{ HttpOnly: true, @@ -83,7 +88,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { }) if err != nil { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } diff --git a/internal/api/zone.go b/internal/api/zone.go index 0a65e67..2b8c8c5 100644 --- a/internal/api/zone.go +++ b/internal/api/zone.go @@ -50,20 +50,20 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { name := strings.Split(header.Filename, ".") if _, err = io.Copy(&buf, file); err != nil { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } if err = util.MakeLocal(name[0], claims["username"].(string), buf); err != nil { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } @@ -87,7 +87,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - http.Error(w, "Unable to parse request", http.StatusInternalServerError) + internalServerError(w, "unable to parse request") return } @@ -95,14 +95,14 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { filename := r.Form.Get("filename") if filename == "" { - http.Error(w, "No filename parsed", http.StatusInternalServerError) + internalServerError(w, "no filename parsed") return } db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - http.Error(w, "internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } @@ -115,7 +115,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { }).First(&result) if result == (internal.ZoneRequest{}) { - http.Error(w, "Internal server error", http.StatusInternalServerError) + internalServerError(w, "internal server error") return } @@ -123,7 +123,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { zoneFile := newZoneRequest(result.RawFileName, claims["username"].(string)) if err := zoneFile.Parse(); err != nil { - http.Error(w, "Unable to parse zonefile", http.StatusInternalServerError) + internalServerError(w, fmt.Sprintf("unable to parse zonefile: %v", err)) return }