added read DB path from config, add DB paths to sample config, added render library for JSON responses, removed plaintext password from User struct, made error handler funcs return json and their calls to include http.Request, and made API tests use httptest server

This commit is contained in:
grumbulon 2023-01-30 19:49:52 -05:00
parent 849f5d28fa
commit 5e8ba819bc
11 changed files with 146 additions and 88 deletions

1
.gitignore vendored
View file

@ -24,5 +24,6 @@ go.work
pomme
test.db
test.sqlite
pomme.sqlite
.dccache
config.yaml

View file

@ -1,3 +1,5 @@
server: example.com # does nothing yet
hashingsecret: PasswordHashingSecret
port: 3008 # port the server runs on
db: pomme.sqlite
testdb: test.sqlite

2
go.mod
View file

@ -15,6 +15,7 @@ require (
require (
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/ajg/form v1.5.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
@ -35,6 +36,7 @@ require (
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 // indirect
github.com/glebarez/go-sqlite v1.20.0 // indirect
github.com/go-chi/httplog v0.2.5
github.com/go-chi/render v1.0.2
github.com/goccy/go-json v0.10.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect

4
go.sum
View file

@ -2,6 +2,8 @@ github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls=
github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY=
@ -29,6 +31,8 @@ github.com/go-chi/httprate v0.7.1 h1:d5kXARdms2PREQfU4pHvq44S6hJ1hPu4OXLeBKmCKWs
github.com/go-chi/httprate v0.7.1/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A=
github.com/go-chi/jwtauth/v5 v5.1.0 h1:wJyf2YZ/ohPvNJBwPOzZaQbyzwgMZZceE1m8FOzXLeA=
github.com/go-chi/jwtauth/v5 v5.1.0/go.mod h1:MA93hc1au3tAQwCKry+fI4LqJ5MIVN4XSsglOo+lSc8=
github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg=
github.com/go-chi/render v1.0.2/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=

View file

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
@ -13,6 +14,8 @@ import (
"github.com/go-chi/httplog"
"github.com/go-chi/httprate"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render"
"gorm.io/gorm"
)
type key int
@ -24,32 +27,56 @@ const (
// setDBMiddleware is the http Handler func for the GORM middleware with context.
func setDBMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
db := db.InitDb()
var pommeDB *gorm.DB
c, err := internal.ReadConfig()
if err != nil {
log.Printf("No config file defined: %v", err)
}
switch r.Header.Get("User-Agent") {
case "pomme-api-test-slave":
pommeDB = db.InitDb(c.TestDB)
default:
pommeDB = db.InitDb(c.DB)
}
timeoutContext, cancelContext := context.WithTimeout(context.Background(), time.Second)
ctx := context.WithValue(r.Context(), keyPrincipalContextID, db.WithContext(timeoutContext))
ctx := context.WithValue(r.Context(), keyPrincipalContextID, pommeDB.WithContext(timeoutContext))
defer cancelContext()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// handlers for very common errors.
func authFailed(w http.ResponseWriter, realm string) {
func authFailed(w http.ResponseWriter, r *http.Request, realm string) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Realm="%s"`, realm))
w.WriteHeader(http.StatusUnauthorized)
resp := internal.Response{
Message: fmt.Sprintf(`Login failed -- Realm="%s"`, realm),
HTTPResponse: http.StatusUnauthorized,
}
render.JSON(w, r, resp)
}
func internalServerError(w http.ResponseWriter, errMsg string) {
func internalServerError(w http.ResponseWriter, r *http.Request, errMsg string) {
logger := httplog.NewLogger("Pomme", httplog.Options{
JSON: true,
})
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Add("Internal Server Error", errMsg)
w.WriteHeader(http.StatusInternalServerError)
resp := internal.Response{
Message: errMsg,
HTTPResponse: http.StatusInternalServerError,
}
render.JSON(w, r, resp)
logger.Error().Msg(errMsg)
}
@ -72,7 +99,7 @@ func API() http.Handler {
Message: "API rate limit exceded",
})
if err != nil {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}
@ -100,7 +127,7 @@ func API() http.Handler {
Message: "API rate limit exceded",
})
if err != nil {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}

View file

@ -3,15 +3,18 @@ package api
import (
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"git.freecumextremist.com/grumbulon/pomme/internal"
"git.freecumextremist.com/grumbulon/pomme/internal/db"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
type response struct {
@ -23,18 +26,59 @@ type response struct {
type accountTest struct {
username string
password string
url string
}
func TestInit(t *testing.T) {
tester := accountTest{
username: autoUname(),
password: "merde",
func TestAPI(t *testing.T) {
config, err := internal.ReadConfig()
if err != nil {
panic(err)
}
pomme := chi.NewRouter()
pomme.Mount("/api", API())
s := &http.Server{
ReadTimeout: 3 * time.Second,
WriteTimeout: 15 * time.Second,
Addr: ":" + config.Port,
Handler: pomme,
}
ts := httptest.NewUnstartedServer(pomme)
ts.Config = s
ts.Start()
defer ts.Close()
tester := Init(ts.URL)
c, err := internal.ReadConfig()
if err != nil {
assert.NotNil(t, err)
}
db := db.InitDb(c.TestDB)
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(tester.password), bcrypt.DefaultCost)
if err != nil {
return
}
db.Create(&internal.User{Username: tester.username, HashedPassword: string(hashedPassword)})
tester.TestMakeAccount(t)
tester.TestLogin(t)
tester.TestLogout(t)
tester.CleanUpDb()
}
func Init(url string) accountTest {
return accountTest{
username: autoUname(),
password: "merde",
url: url,
}
}
func (a *accountTest) TestMakeAccount(t *testing.T) {
@ -48,9 +92,11 @@ func (a *accountTest) TestMakeAccount(t *testing.T) {
form.Add("password", a.password)
if req, err := http.NewRequest(http.MethodPost, "http://localhost:3010/api/create", strings.NewReader(form.Encode())); err == nil {
if req, err := http.NewRequest(http.MethodPost, a.url+`api/create`, strings.NewReader(form.Encode())); err == nil {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("User-Agent", "pomme-api-test-slave")
resp, err := client.Do(req)
if err != nil {
assert.NotNil(t, err)
@ -63,7 +109,6 @@ func (a *accountTest) TestMakeAccount(t *testing.T) {
assert.NotNil(t, err)
}
log.Println(target)
assert.Equal(t, http.StatusCreated, target.Status)
}
}
@ -79,8 +124,9 @@ func (a *accountTest) TestLogin(t *testing.T) {
form.Add("password", a.password)
if req, err := http.NewRequest(http.MethodPost, "http://localhost:3010/api/login", strings.NewReader(form.Encode())); err == nil {
if req, err := http.NewRequest(http.MethodPost, a.url+`/api/login`, strings.NewReader(form.Encode())); err == nil {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("User-Agent", "pomme-api-test-slave")
resp, err := client.Do(req)
if err != nil {
@ -94,7 +140,6 @@ func (a *accountTest) TestLogin(t *testing.T) {
assert.NotNil(t, err)
}
log.Println(target)
assert.Equal(t, http.StatusOK, target.Status)
}
}
@ -108,8 +153,9 @@ func (a *accountTest) TestLogout(t *testing.T) {
form.Add("username", a.username)
if req, err := http.NewRequest(http.MethodPost, "http://localhost:3010/api/logout", strings.NewReader(form.Encode())); err == nil {
if req, err := http.NewRequest(http.MethodPost, a.url+`/api/logout`, strings.NewReader(form.Encode())); err == nil {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("User-Agent", "pomme-api-test-slave")
resp, err := client.Do(req)
if err != nil {
@ -123,20 +169,6 @@ func (a *accountTest) TestLogout(t *testing.T) {
assert.NotNil(t, err)
}
log.Println(target)
assert.Equal(t, http.StatusOK, target.Status)
}
}
// currently does not work.
func (a *accountTest) CleanUpDb() {
var user internal.User
db := db.InitDb()
db.Where("username = ?", a.username).First(&user)
if user.Username != "" {
db.Delete(&user, user.ID)
}
}

View file

@ -1,11 +1,11 @@
package api
import (
"encoding/json"
"net/http"
"time"
"git.freecumextremist.com/grumbulon/pomme/internal"
"github.com/go-chi/render"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
@ -48,14 +48,20 @@ func Login(w http.ResponseWriter, r *http.Request) {
var result internal.User
if _, err := r.Cookie("jwt"); err == nil {
http.Error(w, "Logged in", http.StatusCreated)
w.Header().Set("Content-Type", "application/json")
resp := internal.Response{
Message: "Already logged in",
HTTPResponse: http.StatusOK,
}
render.JSON(w, r, resp)
return
}
err := r.ParseForm()
if err != nil {
internalServerError(w, "unable to parse request")
internalServerError(w, r, "unable to parse request")
return
}
@ -65,20 +71,20 @@ func Login(w http.ResponseWriter, r *http.Request) {
password := r.Form.Get("password")
if username == "" {
internalServerError(w, "no username provided") // this should prob be handled by the frontend
internalServerError(w, r, "no username provided") // this should prob be handled by the frontend
return
}
if password == "" {
internalServerError(w, "no password provided") // this should prob be handled by the frontend
internalServerError(w, r, "no password provided") // this should prob be handled by the frontend
return
}
db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB)
if !ok {
internalServerError(w, "DB connection failed")
internalServerError(w, r, "DB connection failed")
return
}
@ -86,7 +92,7 @@ func Login(w http.ResponseWriter, r *http.Request) {
db.Where("username = ?", username).First(&result)
if result.Username == "" {
authFailed(w, "login")
authFailed(w, r, "authentication")
return
}
@ -94,14 +100,14 @@ func Login(w http.ResponseWriter, r *http.Request) {
err = bcrypt.CompareHashAndPassword([]byte(result.HashedPassword), []byte(password))
if err != nil {
authFailed(w, "login")
authFailed(w, r, "authentication")
return
}
token, err := makeToken(username)
if err != nil {
internalServerError(w, err.Error())
internalServerError(w, r, err.Error())
return
}
@ -116,20 +122,12 @@ func Login(w http.ResponseWriter, r *http.Request) {
Name: "jwt", // Must be named "jwt" or else the token cannot be searched for by jwtauth.Verifier.
Value: token,
})
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(
internal.Response{
Message: "Successfully logged in",
HTTPResponse: http.StatusOK,
})
if err != nil {
internalServerError(w, "internal server error")
return
resp := internal.Response{
Message: "Successfully logged in",
HTTPResponse: http.StatusOK,
}
http.Redirect(w, r, "/", http.StatusSeeOther)
render.JSON(w, r, resp)
}
// Logout destroys a users JWT cookie.
@ -143,18 +141,9 @@ func Logout(w http.ResponseWriter, r *http.Request) {
Value: "",
})
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(
internal.Response{
Message: "Successfully logged out",
HTTPResponse: http.StatusOK,
})
if err != nil {
internalServerError(w, "internal server error")
return
resp := internal.Response{
Message: "Successfully logged out",
HTTPResponse: http.StatusOK,
}
http.Redirect(w, r, "/", http.StatusSeeOther)
render.JSON(w, r, resp)
}

View file

@ -17,14 +17,14 @@ import (
func NewUser(w http.ResponseWriter, r *http.Request) {
db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB)
if !ok {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
}
var result internal.User
err := r.ParseForm()
if err != nil {
internalServerError(w, "unable to parse request")
internalServerError(w, r, "unable to parse request")
return
}
@ -38,7 +38,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
password := r.Form.Get("password")
if password == "" {
internalServerError(w, "no password provided")
internalServerError(w, r, "no password provided")
return
}
@ -46,14 +46,14 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
db.Where("username = ?", username).First(&result)
if result.Username != "" {
internalServerError(w, "user already exists")
internalServerError(w, r, "user already exists")
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
authFailed(w, "login")
authFailed(w, r, "login")
return
}
@ -62,7 +62,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
token, err := makeToken(username)
if err != nil {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}
@ -79,7 +79,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
})
w.WriteHeader(http.StatusCreated)
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(
internal.Response{
Username: username,
@ -88,7 +88,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
})
if err != nil {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}

View file

@ -60,7 +60,7 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) {
file, header, err := r.FormFile("file")
if err != nil {
internalServerError(w, fmt.Sprintf("file upload failed: %v", err))
internalServerError(w, r, fmt.Sprintf("file upload failed: %v", err))
return
}
@ -77,20 +77,20 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) {
name := strings.Split(header.Filename, ".")
if _, err = io.Copy(&buf, file); err != nil {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}
if err = util.MakeLocal(name[0], claims["username"].(string), buf); err != nil {
internalServerError(w, err.Error())
internalServerError(w, r, err.Error())
return
}
db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB)
if !ok {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}
@ -115,7 +115,7 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) {
})
if err != nil {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}
@ -147,7 +147,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
internalServerError(w, "unable to parse request")
internalServerError(w, r, "unable to parse request")
return
}
@ -155,14 +155,14 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) {
filename := r.Form.Get("filename")
if filename == "" {
internalServerError(w, "no filename parsed")
internalServerError(w, r, "no filename parsed")
return
}
db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB)
if !ok {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}
@ -175,7 +175,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) {
}).First(&result)
if result == (internal.ZoneRequest{}) {
internalServerError(w, "internal server error")
internalServerError(w, r, "internal server error")
return
}
@ -183,7 +183,7 @@ func ZoneFiles(w http.ResponseWriter, r *http.Request) {
zoneFile := newZoneRequest(result.RawFileName, claims["username"].(string))
if err := zoneFile.Parse(); err != nil {
internalServerError(w, fmt.Sprintf("unable to parse zonefile: %v", err))
internalServerError(w, r, fmt.Sprintf("unable to parse zonefile: %v", err))
return
}

View file

@ -14,7 +14,6 @@ import (
type User struct {
gorm.Model
Username string
Password string
HashedPassword string
}
@ -44,6 +43,8 @@ type Config struct {
Server string
HashingSecret string
Port string
DB string
TestDB string
}
var config Config

View file

@ -7,8 +7,8 @@ import (
)
// InitDb is the init function for the database.
func InitDb() *gorm.DB {
db, err := gorm.Open(sqlite.Open("test.sqlite"), &gorm.Config{})
func InitDb(path string) *gorm.DB {
db, err := gorm.Open(sqlite.Open(path), &gorm.Config{})
if err != nil {
panic("failed to connect database")
}