diff --git a/internal/api/api.go b/internal/api/api.go index a5fa1e9..86e3fe8 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -51,7 +51,7 @@ func API() http.Handler { render.JSON(w, r, resp) }), )) - api.Use(setDBMiddleware) + api.With(setDBMiddleware).Post("/create", NewUser) api.With(setDBMiddleware).Post("/login", Login) api.Post("/logout", Logout) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 02b5dc1..5f87f59 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -21,6 +21,7 @@ import ( "github.com/go-chi/jwtauth/v5" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) type response struct { @@ -37,7 +38,7 @@ type accountTest struct { } func makeTestToken(username string) (tokenString string, err error) { - claim := map[string]interface{}{"username": username, "admin": false} + claim := map[string]any{"username": username, "admin": false} jwtauth.SetExpiry(claim, time.Now().Add(time.Minute)) @@ -72,13 +73,10 @@ func TestAPI(t *testing.T) { defer ts.Close() tester := Init(ts.URL) + // test mode + mode = "test" + db, err, ok := db.InitDb(config.DB, mode) - c, err := internal.ReadConfig() - if err != nil { - assert.NotNil(t, err) - } - - db, err, ok := db.InitDb(c.TestDB) if err != nil && !ok { assert.NotNil(t, err) } @@ -94,11 +92,14 @@ func TestAPI(t *testing.T) { tester.TestLogin(t) tester.TestLogout(t) tester.TestUpload(t) + tester.CleanUp(db) } func Init(url string) accountTest { user := autoUname() + user += "-testUser" + token, err := makeTestToken(user) if err != nil { return accountTest{} @@ -128,8 +129,6 @@ func (a *accountTest) TestMakeAccount(t *testing.T) { if req, err := http.NewRequest(http.MethodPost, a.url+`api/create`, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("User-Agent", "pomme-api-test-slave") - resp, err := client.Do(req) if err != nil { assert.NotNil(t, err) @@ -159,7 +158,6 @@ func (a *accountTest) TestLogin(t *testing.T) { if req, err := http.NewRequest(http.MethodPost, a.url+`/api/login`, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("User-Agent", "pomme-api-test-slave") resp, err := client.Do(req) if err != nil { @@ -188,7 +186,6 @@ func (a *accountTest) TestLogout(t *testing.T) { if req, err := http.NewRequest(http.MethodPost, a.url+`/api/logout`, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("User-Agent", "pomme-api-test-slave") resp, err := client.Do(req) if err != nil { @@ -297,7 +294,6 @@ func (a *accountTest) TestUpload(t *testing.T) { } req.Header.Add("Authorization", `Bearer:`+a.token) - req.Header.Add("User-Agent", "pomme-api-test-slave") resp, err := client.Do(req) if err != nil { @@ -316,3 +312,22 @@ func (a *accountTest) TestUpload(t *testing.T) { }) } } + +func (a *accountTest) CleanUp(db *gorm.DB) { + var ( + user internal.User + req internal.ZoneRequest + ) + + db.Where("username = ?", a.username).Delete(&user) + + db.Where("user = ?", a.username).Delete(&req) + + if err := os.Remove("pomme-test.sqlite"); err != nil { + l := newResponder(Response[any]{ + Message: "unable to clean up test DB", + Err: err.Error(), + }) + l.writeLogEntry() + } +} diff --git a/internal/api/auth.go b/internal/api/auth.go index 3f11f04..0a4015b 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -29,14 +29,12 @@ func Login(w http.ResponseWriter, r *http.Request) { var result internal.User if _, err := r.Cookie("jwt"); err == nil { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - - w.WriteHeader(http.StatusOK) - - resp := internal.Response{ - Message: "Already logged in", - } - render.JSON(w, r, resp) + logger := newResponder(Response[any]{ + Message: "already logged in", + Status: http.StatusOK, + }) + logger.apiError(w, r) + logger.writeLogEntry() return } diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 45280f6..750a2f4 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -13,6 +13,8 @@ import ( "gorm.io/gorm" ) +var mode string + // setDBMiddleware is the http Handler func for the GORM middleware with context. func setDBMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -20,6 +22,7 @@ func setDBMiddleware(next http.Handler) http.Handler { pommeDB *gorm.DB ok bool ) + c, err := internal.ReadConfig() if err != nil { logger := newResponder(Response[any]{ @@ -29,12 +32,7 @@ func setDBMiddleware(next http.Handler) http.Handler { logger.writeLogEntry() } - switch r.Header.Get("User-Agent") { - case "pomme-api-test-slave": - pommeDB, err, ok = db.InitDb(c.TestDB) - default: - pommeDB, err, ok = db.InitDb(c.DB) - } + pommeDB, err, ok = db.InitDb(c.DB, mode) if err != nil && !ok { logger := newResponder(Response[any]{ diff --git a/internal/api/users.go b/internal/api/users.go index 44d644b..fb12155 100644 --- a/internal/api/users.go +++ b/internal/api/users.go @@ -15,6 +15,8 @@ import ( // NewUser takes a POST request and user form and creates a user in the database. func NewUser(w http.ResponseWriter, r *http.Request) { + var result internal.User + db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) if !ok { logger := newResponder(Response[any]{ @@ -26,8 +28,6 @@ func NewUser(w http.ResponseWriter, r *http.Request) { return } - var result internal.User - err := r.ParseForm() if err != nil { logger := newResponder(Response[any]{ @@ -55,6 +55,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { Status: http.StatusInternalServerError, }) logger.apiError(w, r) + logger.writeLogEntry() return } @@ -67,6 +68,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) { Status: http.StatusInternalServerError, }) logger.apiError(w, r) + logger.writeLogEntry() return } diff --git a/internal/api/zone.go b/internal/api/zone.go index 907ca89..68906e7 100644 --- a/internal/api/zone.go +++ b/internal/api/zone.go @@ -37,6 +37,8 @@ import ( // // @Router /api/upload [post] func ReceiveFile(w http.ResponseWriter, r *http.Request) { + var result internal.User + _, claims, _ := jwtauth.FromContext(r.Context()) r.Body = http.MaxBytesReader(w, r.Body, 1*1024*1024) // approx 1 mb max upload @@ -108,6 +110,19 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { return } + // check if request is coming from user not in the DB but has a valid JWT + db.Where("username = ?", claims["username"].(string)).First(&result) + + if result.Username == "" { + logger := newResponder(Response[any]{ + Message: "user does not exist", + Status: http.StatusInternalServerError, + }) + logger.apiError(w, r) + + return + } + db.Create( &ZoneRequest{ User: claims["username"].(string), diff --git a/internal/db/db.go b/internal/db/db.go index b4aba7a..7a68dd2 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -9,8 +9,13 @@ import ( ) // InitDb is the init function for the database. -func InitDb(path string) (db *gorm.DB, err error, ok bool) { +func InitDb(path, mode string) (db *gorm.DB, err error, ok bool) { ok = true + + if mode == "test" { + path = "pomme-test.sqlite" + } + db, err = gorm.Open(sqlite.Open(path), &gorm.Config{}) if err != nil {