From 5e8ba819bc0dd0f4c4d938e7c48685b04bb7d323 Mon Sep 17 00:00:00 2001 From: grumbulon Date: Mon, 30 Jan 2023 19:49:52 -0500 Subject: [PATCH] added read DB path from config, add DB paths to sample config, added render library for JSON responses, removed plaintext password from User struct, made error handler funcs return json and their calls to include http.Request, and made API tests use httptest server --- .gitignore | 1 + config.sample.yaml | 2 + go.mod | 2 + go.sum | 4 ++ internal/api/api.go | 43 ++++++++++++++++---- internal/api/api_test.go | 82 +++++++++++++++++++++++++++------------ internal/api/auth.go | 57 +++++++++++---------------- internal/api/users.go | 16 ++++---- internal/api/zone.go | 20 +++++----- internal/configuration.go | 3 +- internal/db/db.go | 4 +- 11 files changed, 146 insertions(+), 88 deletions(-) diff --git a/.gitignore b/.gitignore index d46b0b2..6c5e775 100644 --- a/.gitignore +++ b/.gitignore @@ -24,5 +24,6 @@ go.work pomme test.db test.sqlite +pomme.sqlite .dccache config.yaml \ No newline at end of file diff --git a/config.sample.yaml b/config.sample.yaml index eef7035..94d008c 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -1,3 +1,5 @@ server: example.com # does nothing yet hashingsecret: PasswordHashingSecret port: 3008 # port the server runs on +db: pomme.sqlite +testdb: test.sqlite \ No newline at end of file diff --git a/go.mod b/go.mod index a6b3308..2fe0bfb 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/KyleBanks/depth v1.2.1 // indirect + github.com/ajg/form v1.5.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect @@ -35,6 +36,7 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 // indirect github.com/glebarez/go-sqlite v1.20.0 // indirect github.com/go-chi/httplog v0.2.5 + github.com/go-chi/render v1.0.2 github.com/goccy/go-json v0.10.0 // indirect github.com/google/uuid v1.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/go.sum b/go.sum index b02d0dc..4017746 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls= github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E= +github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= @@ -29,6 +31,8 @@ github.com/go-chi/httprate v0.7.1 h1:d5kXARdms2PREQfU4pHvq44S6hJ1hPu4OXLeBKmCKWs github.com/go-chi/httprate v0.7.1/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A= github.com/go-chi/jwtauth/v5 v5.1.0 h1:wJyf2YZ/ohPvNJBwPOzZaQbyzwgMZZceE1m8FOzXLeA= github.com/go-chi/jwtauth/v5 v5.1.0/go.mod h1:MA93hc1au3tAQwCKry+fI4LqJ5MIVN4XSsglOo+lSc8= +github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg= +github.com/go-chi/render v1.0.2/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= diff --git a/internal/api/api.go b/internal/api/api.go index d1068f6..b8f3f73 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log" "net/http" "time" @@ -13,6 +14,8 @@ import ( "github.com/go-chi/httplog" "github.com/go-chi/httprate" "github.com/go-chi/jwtauth/v5" + "github.com/go-chi/render" + "gorm.io/gorm" ) type key int @@ -24,32 +27,56 @@ const ( // setDBMiddleware is the http Handler func for the GORM middleware with context. func setDBMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - db := db.InitDb() + var pommeDB *gorm.DB + c, err := internal.ReadConfig() + if err != nil { + log.Printf("No config file defined: %v", err) + } + + switch r.Header.Get("User-Agent") { + case "pomme-api-test-slave": + pommeDB = db.InitDb(c.TestDB) + default: + pommeDB = db.InitDb(c.DB) + } + timeoutContext, cancelContext := context.WithTimeout(context.Background(), time.Second) - ctx := context.WithValue(r.Context(), keyPrincipalContextID, db.WithContext(timeoutContext)) + ctx := context.WithValue(r.Context(), keyPrincipalContextID, pommeDB.WithContext(timeoutContext)) defer cancelContext() next.ServeHTTP(w, r.WithContext(ctx)) }) } // handlers for very common errors. -func authFailed(w http.ResponseWriter, realm string) { +func authFailed(w http.ResponseWriter, r *http.Request, realm string) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Realm="%s"`, realm)) - w.WriteHeader(http.StatusUnauthorized) + + resp := internal.Response{ + Message: fmt.Sprintf(`Login failed -- Realm="%s"`, realm), + HTTPResponse: http.StatusUnauthorized, + } + + render.JSON(w, r, resp) } -func internalServerError(w http.ResponseWriter, errMsg string) { +func internalServerError(w http.ResponseWriter, r *http.Request, errMsg string) { logger := httplog.NewLogger("Pomme", httplog.Options{ JSON: true, }) - w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Add("Internal Server Error", errMsg) w.WriteHeader(http.StatusInternalServerError) + resp := internal.Response{ + Message: errMsg, + HTTPResponse: http.StatusInternalServerError, + } + + render.JSON(w, r, resp) + logger.Error().Msg(errMsg) } @@ -72,7 +99,7 @@ func API() http.Handler { Message: "API rate limit exceded", }) if err != nil { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } @@ -100,7 +127,7 @@ func API() http.Handler { Message: "API rate limit exceded", }) if err != nil { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } diff --git a/internal/api/api_test.go b/internal/api/api_test.go index fd0bf2c..44aafcb 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -3,15 +3,18 @@ package api import ( "encoding/json" "io" - "log" "net/http" + "net/http/httptest" "net/url" "strings" "testing" + "time" "git.freecumextremist.com/grumbulon/pomme/internal" "git.freecumextremist.com/grumbulon/pomme/internal/db" + "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" ) type response struct { @@ -23,18 +26,59 @@ type response struct { type accountTest struct { username string password string + url string } -func TestInit(t *testing.T) { - tester := accountTest{ - username: autoUname(), - password: "merde", +func TestAPI(t *testing.T) { + config, err := internal.ReadConfig() + if err != nil { + panic(err) } + pomme := chi.NewRouter() + + pomme.Mount("/api", API()) + + s := &http.Server{ + ReadTimeout: 3 * time.Second, + WriteTimeout: 15 * time.Second, + Addr: ":" + config.Port, + Handler: pomme, + } + + ts := httptest.NewUnstartedServer(pomme) + ts.Config = s + ts.Start() + + defer ts.Close() + + tester := Init(ts.URL) + + c, err := internal.ReadConfig() + if err != nil { + assert.NotNil(t, err) + } + + db := db.InitDb(c.TestDB) + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(tester.password), bcrypt.DefaultCost) + if err != nil { + return + } + + db.Create(&internal.User{Username: tester.username, HashedPassword: string(hashedPassword)}) + tester.TestMakeAccount(t) tester.TestLogin(t) tester.TestLogout(t) - tester.CleanUpDb() +} + +func Init(url string) accountTest { + return accountTest{ + username: autoUname(), + password: "merde", + url: url, + } } func (a *accountTest) TestMakeAccount(t *testing.T) { @@ -48,9 +92,11 @@ func (a *accountTest) TestMakeAccount(t *testing.T) { form.Add("password", a.password) - if req, err := http.NewRequest(http.MethodPost, "http://localhost:3010/api/create", strings.NewReader(form.Encode())); err == nil { + if req, err := http.NewRequest(http.MethodPost, a.url+`api/create`, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("User-Agent", "pomme-api-test-slave") + resp, err := client.Do(req) if err != nil { assert.NotNil(t, err) @@ -63,7 +109,6 @@ func (a *accountTest) TestMakeAccount(t *testing.T) { assert.NotNil(t, err) } - log.Println(target) assert.Equal(t, http.StatusCreated, target.Status) } } @@ -79,8 +124,9 @@ func (a *accountTest) TestLogin(t *testing.T) { form.Add("password", a.password) - if req, err := http.NewRequest(http.MethodPost, "http://localhost:3010/api/login", strings.NewReader(form.Encode())); err == nil { + if req, err := http.NewRequest(http.MethodPost, a.url+`/api/login`, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("User-Agent", "pomme-api-test-slave") resp, err := client.Do(req) if err != nil { @@ -94,7 +140,6 @@ func (a *accountTest) TestLogin(t *testing.T) { assert.NotNil(t, err) } - log.Println(target) assert.Equal(t, http.StatusOK, target.Status) } } @@ -108,8 +153,9 @@ func (a *accountTest) TestLogout(t *testing.T) { form.Add("username", a.username) - if req, err := http.NewRequest(http.MethodPost, "http://localhost:3010/api/logout", strings.NewReader(form.Encode())); err == nil { + if req, err := http.NewRequest(http.MethodPost, a.url+`/api/logout`, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("User-Agent", "pomme-api-test-slave") resp, err := client.Do(req) if err != nil { @@ -123,20 +169,6 @@ func (a *accountTest) TestLogout(t *testing.T) { assert.NotNil(t, err) } - log.Println(target) assert.Equal(t, http.StatusOK, target.Status) } } - -// currently does not work. -func (a *accountTest) CleanUpDb() { - var user internal.User - - db := db.InitDb() - - db.Where("username = ?", a.username).First(&user) - - if user.Username != "" { - db.Delete(&user, user.ID) - } -} diff --git a/internal/api/auth.go b/internal/api/auth.go index 0b863f8..ee6fce1 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -1,11 +1,11 @@ package api import ( - "encoding/json" "net/http" "time" "git.freecumextremist.com/grumbulon/pomme/internal" + "github.com/go-chi/render" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) @@ -48,14 +48,20 @@ func Login(w http.ResponseWriter, r *http.Request) { var result internal.User if _, err := r.Cookie("jwt"); err == nil { - http.Error(w, "Logged in", http.StatusCreated) + w.Header().Set("Content-Type", "application/json") + + resp := internal.Response{ + Message: "Already logged in", + HTTPResponse: http.StatusOK, + } + render.JSON(w, r, resp) return } err := r.ParseForm() if err != nil { - internalServerError(w, "unable to parse request") + internalServerError(w, r, "unable to parse request") return } @@ -65,20 +71,20 @@ func Login(w http.ResponseWriter, r *http.Request) { password := r.Form.Get("password") if username == "" { - internalServerError(w, "no username provided") // this should prob be handled by the frontend + internalServerError(w, r, "no username provided") // this should prob be handled by the frontend return } if password == "" { - internalServerError(w, "no password provided") // this should prob be handled by the frontend + internalServerError(w, r, "no password provided") // this should prob be handled by the frontend return } db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - internalServerError(w, "DB connection failed") + internalServerError(w, r, "DB connection failed") return } @@ -86,7 +92,7 @@ func Login(w http.ResponseWriter, r *http.Request) { db.Where("username = ?", username).First(&result) if result.Username == "" { - authFailed(w, "login") + authFailed(w, r, "authentication") return } @@ -94,14 +100,14 @@ func Login(w http.ResponseWriter, r *http.Request) { err = bcrypt.CompareHashAndPassword([]byte(result.HashedPassword), []byte(password)) if err != nil { - authFailed(w, "login") + authFailed(w, r, "authentication") return } token, err := makeToken(username) if err != nil { - internalServerError(w, err.Error()) + internalServerError(w, r, err.Error()) return } @@ -116,20 +122,12 @@ func Login(w http.ResponseWriter, r *http.Request) { Name: "jwt", // Must be named "jwt" or else the token cannot be searched for by jwtauth.Verifier. Value: token, }) - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode( - internal.Response{ - Message: "Successfully logged in", - HTTPResponse: http.StatusOK, - }) - if err != nil { - internalServerError(w, "internal server error") - - return + resp := internal.Response{ + Message: "Successfully logged in", + HTTPResponse: http.StatusOK, } - - http.Redirect(w, r, "/", http.StatusSeeOther) + render.JSON(w, r, resp) } // Logout destroys a users JWT cookie. @@ -143,18 +141,9 @@ func Logout(w http.ResponseWriter, r *http.Request) { Value: "", }) - w.Header().Set("Content-Type", "application/json") - - err := json.NewEncoder(w).Encode( - internal.Response{ - Message: "Successfully logged out", - HTTPResponse: http.StatusOK, - }) - if err != nil { - internalServerError(w, "internal server error") - - return + resp := internal.Response{ + Message: "Successfully logged out", + HTTPResponse: http.StatusOK, } - - http.Redirect(w, r, "/", http.StatusSeeOther) + render.JSON(w, r, resp) } diff --git a/internal/api/users.go b/internal/api/users.go index f91ec5e..fbbe9c2 100644 --- a/internal/api/users.go +++ b/internal/api/users.go @@ -17,14 +17,14 @@ import ( func NewUser(w http.ResponseWriter, r *http.Request) { db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") } var result internal.User err := r.ParseForm() if err != nil { - internalServerError(w, "unable to parse request") + internalServerError(w, r, "unable to parse request") return } @@ -38,7 +38,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { password := r.Form.Get("password") if password == "" { - internalServerError(w, "no password provided") + internalServerError(w, r, "no password provided") return } @@ -46,14 +46,14 @@ func NewUser(w http.ResponseWriter, r *http.Request) { db.Where("username = ?", username).First(&result) if result.Username != "" { - internalServerError(w, "user already exists") + internalServerError(w, r, "user already exists") return } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - authFailed(w, "login") + authFailed(w, r, "login") return } @@ -62,7 +62,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { token, err := makeToken(username) if err != nil { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } @@ -79,7 +79,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { }) w.WriteHeader(http.StatusCreated) - w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode( internal.Response{ Username: username, @@ -88,7 +88,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { }) if err != nil { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } diff --git a/internal/api/zone.go b/internal/api/zone.go index e1512d4..dee1f8a 100644 --- a/internal/api/zone.go +++ b/internal/api/zone.go @@ -60,7 +60,7 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { file, header, err := r.FormFile("file") if err != nil { - internalServerError(w, fmt.Sprintf("file upload failed: %v", err)) + internalServerError(w, r, fmt.Sprintf("file upload failed: %v", err)) return } @@ -77,20 +77,20 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { name := strings.Split(header.Filename, ".") if _, err = io.Copy(&buf, file); err != nil { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } if err = util.MakeLocal(name[0], claims["username"].(string), buf); err != nil { - internalServerError(w, err.Error()) + internalServerError(w, r, err.Error()) return } db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } @@ -115,7 +115,7 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { }) if err != nil { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } @@ -147,7 +147,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - internalServerError(w, "unable to parse request") + internalServerError(w, r, "unable to parse request") return } @@ -155,14 +155,14 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { filename := r.Form.Get("filename") if filename == "" { - internalServerError(w, "no filename parsed") + internalServerError(w, r, "no filename parsed") return } db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } @@ -175,7 +175,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { }).First(&result) if result == (internal.ZoneRequest{}) { - internalServerError(w, "internal server error") + internalServerError(w, r, "internal server error") return } @@ -183,7 +183,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) { zoneFile := newZoneRequest(result.RawFileName, claims["username"].(string)) if err := zoneFile.Parse(); err != nil { - internalServerError(w, fmt.Sprintf("unable to parse zonefile: %v", err)) + internalServerError(w, r, fmt.Sprintf("unable to parse zonefile: %v", err)) return } diff --git a/internal/configuration.go b/internal/configuration.go index 189bd67..834794a 100644 --- a/internal/configuration.go +++ b/internal/configuration.go @@ -14,7 +14,6 @@ import ( type User struct { gorm.Model Username string - Password string HashedPassword string } @@ -44,6 +43,8 @@ type Config struct { Server string HashingSecret string Port string + DB string + TestDB string } var config Config diff --git a/internal/db/db.go b/internal/db/db.go index 7a6aa5d..5209b55 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -7,8 +7,8 @@ import ( ) // InitDb is the init function for the database. -func InitDb() *gorm.DB { - db, err := gorm.Open(sqlite.Open("test.sqlite"), &gorm.Config{}) +func InitDb(path string) *gorm.DB { + db, err := gorm.Open(sqlite.Open(path), &gorm.Config{}) if err != nil { panic("failed to connect database") }