pomme/internal/api/api.go
2023-01-20 18:29:33 -05:00

116 lines
3.2 KiB
Go

package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"git.freecumextremist.com/grumbulon/pomme/internal"
"git.freecumextremist.com/grumbulon/pomme/internal/db"
"github.com/go-chi/chi/v5"
"github.com/go-chi/httplog"
"github.com/go-chi/httprate"
"github.com/go-chi/jwtauth/v5"
)
type key int
const (
keyPrincipalContextID key = iota
)
// setDBMiddleware is the http Handler func for the GORM middleware with context.
func setDBMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
db := db.InitDb()
timeoutContext, cancelContext := context.WithTimeout(context.Background(), time.Second)
ctx := context.WithValue(r.Context(), keyPrincipalContextID, db.WithContext(timeoutContext))
defer cancelContext()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// handlers for very common errors.
func authFailed(w http.ResponseWriter, realm string) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Realm="%s"`, realm))
w.WriteHeader(http.StatusUnauthorized)
}
func internalServerError(w http.ResponseWriter, errMsg string) {
logger := httplog.NewLogger("Pomme", httplog.Options{
JSON: true,
})
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Add("Internal Server Error", errMsg)
w.WriteHeader(http.StatusInternalServerError)
logger.Error().Msg(errMsg)
}
// API subroute handler.
func API() http.Handler {
api := chi.NewRouter()
// Protected routes
api.Group(func(api chi.Router) {
api.Use(httprate.Limit(
10, // requests
10*time.Second, // per duration
httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint),
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
err := json.NewEncoder(w).Encode(
internal.Response{
HTTPResponse: http.StatusTooManyRequests,
Message: "API rate limit exceded",
})
if err != nil {
internalServerError(w, "internal server error")
return
}
}),
))
api.Use(jwtauth.Verifier(tokenAuth))
api.Use(jwtauth.Authenticator)
api.With(setDBMiddleware).Post("/upload", ReceiveFile)
api.With(setDBMiddleware).Post("/parse", ZoneFiles)
})
// Open routes
api.Group(func(api chi.Router) {
api.Use(httprate.Limit(
5, // requests
5*time.Second, // per duration
httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint),
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
err := json.NewEncoder(w).Encode(
internal.Response{
HTTPResponse: http.StatusTooManyRequests,
Message: "API rate limit exceded",
})
if err != nil {
internalServerError(w, "internal server error")
return
}
}),
))
api.Use(setDBMiddleware)
api.With(setDBMiddleware).Post("/create", NewUser)
api.With(setDBMiddleware).Post("/login", Login)
api.Post("/logout", Logout)
})
return api
}