pomme/internal/api/api_test.go

360 lines
7.5 KiB
Go

package api
import (
"bytes"
"fmt"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
"dns.froth.zone/pomme/internal"
"dns.froth.zone/pomme/internal/db"
"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type accountTest struct {
username string
password string
token string
url string
}
func makeTestToken(username string) (tokenString string, err error) {
claim := map[string]any{"username": username, "admin": false}
jwtauth.SetExpiry(claim, time.Now().Add(time.Minute))
if _, tokenString, err = tokenAuth.Encode(claim); err == nil {
return
}
return "", fmt.Errorf("unable to generate JWT: %w", err)
}
func TestAPI(t *testing.T) {
config, err := internal.ReadConfig()
if err != nil {
assert.NotNil(t, err)
}
pomme := chi.NewRouter()
pomme.Mount("/api", API())
s := &http.Server{
ReadTimeout: 3 * time.Second,
WriteTimeout: 15 * time.Second,
Addr: ":" + config.Port,
Handler: pomme,
}
ts := httptest.NewUnstartedServer(pomme)
ts.Config = s
ts.Start()
defer ts.Close()
tester := Init(ts.URL)
// test mode
mode = "test"
db, err, _ := db.InitDb(config.DB, mode)
assert.Nil(t, err)
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(tester.password), bcrypt.DefaultCost)
assert.Nil(t, err)
db.Create(&internal.User{Username: tester.username, HashedPassword: string(hashedPassword)})
tester.TestOpenRoutes(t)
tester.TestProtectedRoutes(t)
tester.TestRateLimit(t)
tester.CleanUp(db)
}
func Init(url string) accountTest {
user := autoUname()
user += "-testUser"
token, err := makeTestToken(user)
if err != nil {
return accountTest{}
}
test := accountTest{
username: user,
password: "merde",
url: url,
token: token,
}
return test
}
type expectedValues struct {
response int
}
func (a *accountTest) TestOpenRoutes(t *testing.T) {
testCases := []struct {
name string
route string
token string
password string
username string
expected expectedValues
}{
{
name: "Should fail to login with bad username",
route: "login",
username: "a.username",
password: a.password,
expected: expectedValues{
response: http.StatusUnauthorized,
},
},
{
name: "Should fail to login with bad password",
route: "login",
username: a.username,
password: "a.password",
expected: expectedValues{
response: http.StatusUnauthorized,
},
},
{
name: "Should fail to login with bad password or username",
route: "login",
username: "apple",
password: "a.password",
expected: expectedValues{
response: http.StatusUnauthorized,
},
},
{
name: "Should fail to login with empty form",
route: "login",
expected: expectedValues{
response: http.StatusInternalServerError,
},
},
{
name: "Should login successfully",
route: "login",
username: a.username,
password: a.password,
expected: expectedValues{
response: http.StatusOK,
},
},
{
name: "Should fail to create account with empty form",
route: "create",
expected: expectedValues{
response: http.StatusInternalServerError,
},
},
{
name: "Should create account with empty username",
route: "create",
password: "asdf",
expected: expectedValues{
response: http.StatusCreated,
},
},
{
name: "Should fail to create account with empty password",
route: "create",
expected: expectedValues{
response: http.StatusInternalServerError,
},
},
{
name: "Should log out successfully",
route: "logout",
username: a.username,
expected: expectedValues{
response: http.StatusOK,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client := http.Client{}
form := url.Values{}
form.Add("username", tc.username)
form.Add("password", tc.password)
if req, err := http.NewRequest(http.MethodPost, a.url+`/api/`+tc.route, strings.NewReader(form.Encode())); err == nil {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := client.Do(req)
if err != nil {
assert.NotNil(t, err)
}
assert.Equal(t, tc.expected.response, resp.StatusCode)
}
})
}
}
func (a *accountTest) TestRateLimit(t *testing.T) {
testCases := []struct {
name string
route string
password string
username string
expected expectedValues
}{
{
name: "Should rate limit open routes",
route: "login",
expected: expectedValues{
response: http.StatusTooManyRequests,
},
},
{
name: "Should rate limit private routes",
route: "upload",
expected: expectedValues{
response: http.StatusTooManyRequests,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client := http.Client{}
form := url.Values{}
if req, err := http.NewRequest(http.MethodPost, a.url+`/api/`+tc.route, strings.NewReader(form.Encode())); err == nil {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
for i := 0; i < 100; i++ {
resp, err := client.Do(req)
if err != nil {
assert.NotNil(t, err)
}
if resp.StatusCode == tc.expected.response {
assert.Equal(t, tc.expected.response, resp.StatusCode)
}
}
}
})
}
}
func (a *accountTest) TestProtectedRoutes(t *testing.T) {
testCases := []struct {
name string
contentType string
route string
fileContents []byte
expected expectedValues
}{
{
name: "Should fail to upload an empty file",
contentType: "audio/aac",
route: "upload",
expected: expectedValues{
response: http.StatusInternalServerError,
},
},
{
name: "Should upload a valid file",
contentType: "multipart/form-data",
fileContents: []byte{},
route: "upload",
expected: expectedValues{
response: http.StatusCreated,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var (
f *os.File
err error
buf = new(bytes.Buffer)
w = multipart.NewWriter(buf)
)
f, err = os.CreateTemp(".", "zonefile")
if err != nil {
assert.NotNil(t, err)
}
if tc.name == "Should upload a valid file" {
if err = os.WriteFile(f.Name(), zonebytes, 0o600); err != nil {
assert.NotNil(t, err)
}
}
defer os.Remove(f.Name()) //nolint: errcheck
part, err := w.CreateFormFile("file", filepath.Base(f.Name()))
if err != nil {
assert.NotNil(t, err)
}
b, err := os.ReadFile(f.Name())
if err != nil {
assert.NotNil(t, err)
}
_, err = part.Write(b)
if err != nil {
assert.NotNil(t, err)
}
err = w.Close()
if err != nil {
assert.NotNil(t, err)
}
client := http.Client{}
if req, err := http.NewRequest(http.MethodPost, a.url+`/api/`+tc.route, buf); err == nil {
if tc.name == "Should fail to upload an empty file" {
req.Header.Add("Content-Type", tc.contentType)
} else {
req.Header.Add("Content-Type", w.FormDataContentType())
}
req.Header.Add("Authorization", `Bearer:`+a.token)
resp, err := client.Do(req)
if err != nil {
assert.NotNil(t, err)
}
assert.Equal(t, tc.expected.response, resp.StatusCode)
}
})
}
}
func (a *accountTest) CleanUp(db *gorm.DB) {
if err := os.Remove("pomme-test.sqlite"); err != nil {
l := newResponder(Response[any]{
Message: "unable to clean up test DB",
Err: err.Error(),
})
l.writeLogEntry()
}
}