diff --git a/internal/api/zone.go b/internal/api/zone.go index b517576..f28ac83 100644 --- a/internal/api/zone.go +++ b/internal/api/zone.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "os" + "path/filepath" "strings" "git.freecumextremist.com/grumbulon/pomme/internal" @@ -51,12 +52,14 @@ func ReceiveFile(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(file) if err != nil { - log.Fatalln(err) + APIError(w, r, genericResponseFields{"message": "internal server error", "status": http.StatusInternalServerError, "error": err.Error()}) + + return } ok := validateContentType(file) if !ok { - http.Error(w, "file must be text/plain", http.StatusUnsupportedMediaType) + APIError(w, r, genericResponseFields{"message": "file must be text/plain", "status": http.StatusUnsupportedMediaType}) return } @@ -268,38 +271,49 @@ func (zone *ZoneRequest) Save() error { c, err := internal.ReadConfig() if err != nil { logHandler(genericResponseFields{"error": err.Error(), "message": "no config file defined"}) + return fmt.Errorf("unable to parse directory: %w", err) } var path string = fmt.Sprintf("%s/%s/", c.ZoneDir, zone.RawFileName) + var tmpPath string = fmt.Sprintf("/tmp/tmpfile-%s-%s", zone.RawFileName, zone.User) - if os.MkdirAll(path, 0755); err != nil { + if err = os.MkdirAll(path, 0o750); err != nil { logHandler(genericResponseFields{"error": err.Error(), "message": "unable to make directory for zone files"}) + return fmt.Errorf("unable to make zone directory: %w", err) } - if _, err = os.Create(path + zone.RawFileName); err != nil { + if _, err = os.Create(filepath.Clean(path + zone.RawFileName)); err != nil { logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) } - f, err := os.Open(tmpPath) + f, err := os.Open(filepath.Clean(tmpPath)) if err != nil { logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) } - defer f.Close() + defer func() { + if err = f.Close(); err != nil { + logHandler(genericResponseFields{"message": "Error closing file", "error": err.Error()}) + } + }() b, err := io.ReadAll(f) if err != nil { logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) } - if os.WriteFile(path+zone.RawFileName, b, 0666); err != nil { + if err = os.WriteFile(path+zone.RawFileName, b, 0o600); err != nil { logHandler(genericResponseFields{"error": err.Error(), "message": "unable to save zonefile to directory"}) + return fmt.Errorf("unable to save zonefile to directory: %w", err) } diff --git a/internal/util/fs.go b/internal/util/fs.go index 6c04f8f..d2048e4 100644 --- a/internal/util/fs.go +++ b/internal/util/fs.go @@ -1,17 +1,20 @@ package util import ( + "errors" "fmt" "os" ) +var errEmptyFile = errors.New("will not save empty file to FS") + func MakeLocal(filename, username string, buf []byte) error { if _, err := os.Stat(fmt.Sprintf("/tmp/tmpfile-%s-%s", filename, username)); !os.IsNotExist(err) { return fmt.Errorf("file %s already exists: %w", filename, err) } if len(buf) == 0 { - return fmt.Errorf("will not save empty file: %s to FS", filename) + return errEmptyFile } f, err := os.Create("/tmp/tmpfile-" + filename + "-" + username) //nolint: gosec