package api import ( "bytes" "fmt" "mime/multipart" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strings" "testing" "time" "dns.froth.zone/pomme/internal" "dns.froth.zone/pomme/internal/db" "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) type accountTest struct { username string password string token string url string } func makeTestToken(username string) (tokenString string, err error) { claim := map[string]any{"username": username, "admin": false} jwtauth.SetExpiry(claim, time.Now().Add(time.Minute)) if _, tokenString, err = tokenAuth.Encode(claim); err == nil { return } return "", fmt.Errorf("unable to generate JWT: %w", err) } func TestAPI(t *testing.T) { config, err := internal.ReadConfig() if err != nil { assert.NotNil(t, err) } pomme := chi.NewRouter() pomme.Mount("/api", API()) s := &http.Server{ ReadTimeout: 3 * time.Second, WriteTimeout: 15 * time.Second, Addr: ":" + config.Port, Handler: pomme, } ts := httptest.NewUnstartedServer(pomme) ts.Config = s ts.Start() defer ts.Close() tester := Init(ts.URL) // test mode mode = "test" db, err, ok := db.InitDb(config.DB, mode) if err != nil && !ok { assert.NotNil(t, err) } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(tester.password), bcrypt.DefaultCost) if err != nil { assert.NotNil(t, err) } db.Create(&internal.User{Username: tester.username, HashedPassword: string(hashedPassword)}) tester.TestOpenRoutes(t) tester.TestProtectedRoutes(t) tester.TestRateLimit(t) tester.CleanUp(db) } func Init(url string) accountTest { user := autoUname() user += "-testUser" token, err := makeTestToken(user) if err != nil { return accountTest{} } test := accountTest{ username: user, password: "merde", url: url, token: token, } return test } type expectedValues struct { response int } func (a *accountTest) TestOpenRoutes(t *testing.T) { testCases := []struct { name string route string token string password string username string expected expectedValues }{ { name: "Should fail to login with bad username", route: "login", username: "a.username", password: a.password, expected: expectedValues{ response: http.StatusUnauthorized, }, }, { name: "Should fail to login with bad password", route: "login", username: a.username, password: "a.password", expected: expectedValues{ response: http.StatusUnauthorized, }, }, { name: "Should fail to login with bad password or username", route: "login", username: "apple", password: "a.password", expected: expectedValues{ response: http.StatusUnauthorized, }, }, { name: "Should fail to login with empty form", route: "login", expected: expectedValues{ response: http.StatusInternalServerError, }, }, { name: "Should login successfully", route: "login", username: a.username, password: a.password, expected: expectedValues{ response: http.StatusOK, }, }, { name: "Should fail to create account with empty form", route: "create", expected: expectedValues{ response: http.StatusInternalServerError, }, }, { name: "Should create account with empty username", route: "create", password: "asdf", expected: expectedValues{ response: http.StatusCreated, }, }, { name: "Should fail to create account with empty password", route: "create", expected: expectedValues{ response: http.StatusInternalServerError, }, }, { name: "Should log out successfully", route: "logout", username: a.username, expected: expectedValues{ response: http.StatusOK, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { client := http.Client{} form := url.Values{} form.Add("username", tc.username) form.Add("password", tc.password) if req, err := http.NewRequest(http.MethodPost, a.url+`/api/`+tc.route, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, err := client.Do(req) if err != nil { assert.NotNil(t, err) } assert.Equal(t, tc.expected.response, resp.StatusCode) } }) } } func (a *accountTest) TestRateLimit(t *testing.T) { testCases := []struct { name string route string password string username string expected expectedValues }{ { name: "Should rate limit open routes", route: "login", expected: expectedValues{ response: http.StatusTooManyRequests, }, }, { name: "Should rate limit private routes", route: "upload", expected: expectedValues{ response: http.StatusTooManyRequests, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { client := http.Client{} form := url.Values{} if req, err := http.NewRequest(http.MethodPost, a.url+`/api/`+tc.route, strings.NewReader(form.Encode())); err == nil { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") for i := 0; i < 100; i++ { resp, err := client.Do(req) if err != nil { assert.NotNil(t, err) } if resp.StatusCode == tc.expected.response { assert.Equal(t, tc.expected.response, resp.StatusCode) } } } }) } } func (a *accountTest) TestProtectedRoutes(t *testing.T) { testCases := []struct { name string contentType string route string fileContents []byte expected expectedValues }{ { name: "Should fail to upload an empty file", contentType: "audio/aac", route: "upload", expected: expectedValues{ response: http.StatusInternalServerError, }, }, { name: "Should upload a valid file", contentType: "multipart/form-data", fileContents: []byte{}, route: "upload", expected: expectedValues{ response: http.StatusCreated, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var ( f *os.File err error buf = new(bytes.Buffer) w = multipart.NewWriter(buf) ) f, err = os.CreateTemp(".", "zonefile") if err != nil { assert.NotNil(t, err) } if tc.name == "Should upload a valid file" { if err = os.WriteFile(f.Name(), zonebytes, 0o600); err != nil { assert.NotNil(t, err) } } defer os.Remove(f.Name()) //nolint: errcheck part, err := w.CreateFormFile("file", filepath.Base(f.Name())) if err != nil { assert.NotNil(t, err) } b, err := os.ReadFile(f.Name()) if err != nil { assert.NotNil(t, err) } _, err = part.Write(b) if err != nil { assert.NotNil(t, err) } err = w.Close() if err != nil { assert.NotNil(t, err) } client := http.Client{} if req, err := http.NewRequest(http.MethodPost, a.url+`/api/`+tc.route, buf); err == nil { if tc.name == "Should fail to upload an empty file" { req.Header.Add("Content-Type", tc.contentType) } else { req.Header.Add("Content-Type", w.FormDataContentType()) } req.Header.Add("Authorization", `Bearer:`+a.token) resp, err := client.Do(req) if err != nil { assert.NotNil(t, err) } assert.Equal(t, tc.expected.response, resp.StatusCode) } }) } } func (a *accountTest) CleanUp(db *gorm.DB) { if err := os.Remove("pomme-test.sqlite"); err != nil { l := newResponder(Response[any]{ Message: "unable to clean up test DB", Err: err.Error(), }) l.writeLogEntry() } }