diff --git a/internal/api/api.go b/internal/api/api.go index e0026ef..50cd295 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -35,6 +35,7 @@ func API() http.Handler { api.Use(jwtauth.Authenticator) api.With(setDBMiddleware).Post("/upload", ReceiveFile) api.With(setDBMiddleware).Post("/parse", ParseZoneFiles) + api.With(setDBMiddleware).Post("/save", SaveZoneFiles) }) // Open routes diff --git a/internal/api/helpers.go b/internal/api/helpers.go index cd26208..12e20f2 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -48,7 +48,7 @@ func setDBMiddleware(next http.Handler) http.Handler { }) } -func APIError[T map[string]any](w http.ResponseWriter, r *http.Request, v map[string]any) { +func APIError[T ~map[string]any](w http.ResponseWriter, r *http.Request, v T) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Content-Type", "application/json; charset=utf-8") diff --git a/internal/api/types.go b/internal/api/types.go index a56a1cc..7293855 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -6,8 +6,6 @@ const ( keyPrincipalContextID key = iota ) -type genericResponseFields map[string]any - type key int // ZoneRequest represents a Zone file request. @@ -28,3 +26,12 @@ type Zone struct { type GenericResponse[T map[string]any] struct { Response map[string]any `json:"response,omitempty"` } + +type genericResponseFields map[string]any + +type ndr interface { + Parse() error + Save() error +} + +var _ ndr = (*ZoneRequest)(nil) diff --git a/internal/api/zone.go b/internal/api/zone.go index 489d719..b517576 100644 --- a/internal/api/zone.go +++ b/internal/api/zone.go @@ -1,7 +1,6 @@ package api import ( - "bytes" "fmt" "io" "log" @@ -39,8 +38,6 @@ import ( func ReceiveFile(w http.ResponseWriter, r *http.Request) { _, claims, _ := jwtauth.FromContext(r.Context()) - var buf bytes.Buffer - r.Body = http.MaxBytesReader(w, r.Body, 1*1024*1024) // approx 1 mb max upload file, header, err := r.FormFile("file") @@ -52,6 +49,11 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { defer file.Close() //nolint: errcheck + b, err := io.ReadAll(file) + if err != nil { + log.Fatalln(err) + } + ok := validateContentType(file) if !ok { http.Error(w, "file must be text/plain", http.StatusUnsupportedMediaType) @@ -61,13 +63,7 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { name := strings.Split(header.Filename, ".") - if _, err = io.Copy(&buf, file); err != nil { - APIError(w, r, genericResponseFields{"message": "internal server error", "status": http.StatusInternalServerError, "error": err.Error()}) - - return - } - - if err = util.MakeLocal(name[0], claims["username"].(string), buf); err != nil { + if err = util.MakeLocal(name[0], claims["username"].(string), b); err != nil { APIError(w, r, genericResponseFields{"message": "internal server error", "status": http.StatusInternalServerError, "error": err.Error()}) return @@ -89,8 +85,6 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { }, }) - buf.Reset() - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) @@ -161,7 +155,7 @@ func ParseZoneFiles(w http.ResponseWriter, r *http.Request) { return } - zoneFile := newZoneRequest(result.RawFileName, claims["username"].(string)) + zoneFile := newNDSRequest(result.RawFileName, claims["username"].(string)) if err := zoneFile.Parse(); err != nil { APIError(w, r, genericResponseFields{"message": "Unable to parse zonefile", "status": http.StatusInternalServerError, "error": err.Error()}) @@ -180,7 +174,66 @@ func ParseZoneFiles(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, resp) } -func newZoneRequest(filename string, user string) *ZoneRequest { +func SaveZoneFiles(w http.ResponseWriter, r *http.Request) { + var result internal.ZoneRequest + + _, claims, _ := jwtauth.FromContext(r.Context()) + + err := r.ParseForm() + if err != nil { + APIError(w, r, genericResponseFields{"message": "internal server error", "status": http.StatusInternalServerError, "error": err.Error()}) + + return + } + + filename := r.Form.Get("filename") + + if filename == "" { + APIError(w, r, genericResponseFields{"message": "no filename provided", "status": http.StatusInternalServerError}) + + return + } + + db, ok := r.Context().Value(keyPrincipalContextID).(*gorm.DB) + if !ok { + APIError(w, r, genericResponseFields{"message": "internal server error", "status": http.StatusInternalServerError, "error": "unable to connect to DB"}) + + return + } + + db.Where(ZoneRequest{ + Zone: &Zone{ + RawFileName: filename, + }, + User: claims["username"].(string), + }).First(&result) + + if result == (internal.ZoneRequest{}) { + APIError(w, r, genericResponseFields{"message": "internal server error", "status": http.StatusInternalServerError}) + + return + } + + zoneFile := newNDSRequest(result.RawFileName, claims["username"].(string)) + + if err := zoneFile.Save(); err != nil { + APIError(w, r, genericResponseFields{"message": "Unable to save zonefile", "status": http.StatusInternalServerError, "error": err.Error()}) + + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + + w.WriteHeader(http.StatusCreated) + + resp := internal.Response{ + Message: "Successfully saved zonefile", + } + + render.JSON(w, r, resp) +} + +func newNDSRequest(filename string, user string) ndr { dat, err := os.ReadFile(fmt.Sprintf("/tmp/tmpfile-%s-%s", filename, user)) if err != nil { return &ZoneRequest{} @@ -211,6 +264,48 @@ func (zone *ZoneRequest) Parse() error { return nil } +func (zone *ZoneRequest) Save() error { + c, err := internal.ReadConfig() + if err != nil { + logHandler(genericResponseFields{"error": err.Error(), "message": "no config file defined"}) + return fmt.Errorf("unable to parse directory: %w", err) + } + + var path string = fmt.Sprintf("%s/%s/", c.ZoneDir, zone.RawFileName) + var tmpPath string = fmt.Sprintf("/tmp/tmpfile-%s-%s", zone.RawFileName, zone.User) + + if os.MkdirAll(path, 0755); err != nil { + logHandler(genericResponseFields{"error": err.Error(), "message": "unable to make directory for zone files"}) + return fmt.Errorf("unable to make zone directory: %w", err) + } + + if _, err = os.Create(path + zone.RawFileName); err != nil { + logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) + } + + f, err := os.Open(tmpPath) + if err != nil { + logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) + } + + defer f.Close() + + b, err := io.ReadAll(f) + if err != nil { + logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) + } + + if os.WriteFile(path+zone.RawFileName, b, 0666); err != nil { + logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) + } + + return nil +} + func validateContentType(file io.Reader) bool { bytes, err := io.ReadAll(file) if err != nil { diff --git a/internal/types.go b/internal/types.go index 5d049e8..97e99c9 100644 --- a/internal/types.go +++ b/internal/types.go @@ -35,6 +35,7 @@ type Config struct { Port string DB string TestDB string + ZoneDir string } // SwaggerGenericResponse[T] diff --git a/internal/util/fs.go b/internal/util/fs.go index c46b092..14e6403 100644 --- a/internal/util/fs.go +++ b/internal/util/fs.go @@ -1,18 +1,15 @@ package util import ( - "bytes" "fmt" "os" ) -func MakeLocal(filename, username string, buf bytes.Buffer) error { +func MakeLocal(filename, username string, buf []byte) error { if _, err := os.Stat(fmt.Sprintf("/tmp/tmpfile-%s-%s", filename, username)); !os.IsNotExist(err) { return fmt.Errorf("file %s already exists: %w", filename, err) } - defer buf.Reset() - f, err := os.Create("/tmp/tmpfile-" + filename + "-" + username) //nolint: gosec // this is set to nolint because I am doing everything os.CreateTemp does but here since I don't like some of the limitations if err != nil { @@ -26,7 +23,7 @@ func MakeLocal(filename, username string, buf bytes.Buffer) error { } }() - err = os.WriteFile(f.Name(), buf.Bytes(), 0o600) + err = os.WriteFile(f.Name(), buf, 0o600) if err != nil { return fmt.Errorf("failed to write file locally: %w", err)