package api import ( "context" "encoding/json" "fmt" "log" "net/http" "time" "git.freecumextremist.com/grumbulon/pomme/internal" "git.freecumextremist.com/grumbulon/pomme/internal/db" "github.com/go-chi/chi/v5" "github.com/go-chi/httplog" "github.com/go-chi/httprate" "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" "gorm.io/gorm" ) type key int const ( keyPrincipalContextID key = iota ) // setDBMiddleware is the http Handler func for the GORM middleware with context. func setDBMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var pommeDB *gorm.DB c, err := internal.ReadConfig() if err != nil { log.Printf("No config file defined: %v", err) } switch r.Header.Get("User-Agent") { case "pomme-api-test-slave": pommeDB = db.InitDb(c.TestDB) default: pommeDB = db.InitDb(c.DB) } timeoutContext, cancelContext := context.WithTimeout(context.Background(), time.Second) ctx := context.WithValue(r.Context(), keyPrincipalContextID, pommeDB.WithContext(timeoutContext)) defer cancelContext() next.ServeHTTP(w, r.WithContext(ctx)) }) } // handlers for very common errors. func authFailed(w http.ResponseWriter, r *http.Request, realm string) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Realm="%s"`, realm)) w.WriteHeader(http.StatusUnauthorized) resp := internal.Response{ Message: fmt.Sprintf(`Login failed -- Realm="%s"`, realm), HTTPResponse: http.StatusUnauthorized, } render.JSON(w, r, resp) } func internalServerError(w http.ResponseWriter, r *http.Request, errMsg string) { logger := httplog.NewLogger("Pomme", httplog.Options{ JSON: true, }) w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Add("Internal Server Error", errMsg) w.WriteHeader(http.StatusInternalServerError) resp := internal.Response{ Message: errMsg, HTTPResponse: http.StatusInternalServerError, } render.JSON(w, r, resp) logger.Error().Msg(errMsg) } // API subroute handler. func API() http.Handler { api := chi.NewRouter() // Protected routes api.Group(func(api chi.Router) { api.Use(httprate.Limit( 10, // requests 10*time.Second, // per duration httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint), httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) err := json.NewEncoder(w).Encode( internal.Response{ HTTPResponse: http.StatusTooManyRequests, Message: "API rate limit exceded", }) if err != nil { internalServerError(w, r, "internal server error") return } }), )) api.Use(jwtauth.Verifier(tokenAuth)) api.Use(jwtauth.Authenticator) api.With(setDBMiddleware).Post("/upload", ReceiveFile) api.With(setDBMiddleware).Post("/parse", ParseZoneFiles) }) // Open routes api.Group(func(api chi.Router) { api.Use(httprate.Limit( 5, // requests 5*time.Second, // per duration httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint), httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) err := json.NewEncoder(w).Encode( internal.Response{ HTTPResponse: http.StatusTooManyRequests, Message: "API rate limit exceded", }) if err != nil { internalServerError(w, r, "internal server error") return } }), )) api.Use(setDBMiddleware) api.With(setDBMiddleware).Post("/create", NewUser) api.With(setDBMiddleware).Post("/login", Login) api.Post("/logout", Logout) }) return api }