diff --git a/main.go b/main.go index c190f17..955bf6a 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,8 @@ import ( "net/http" "os" "path/filepath" + "regexp" + "strings" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" @@ -20,8 +22,11 @@ type Config struct { UploadsDir string `json:"uploads_dir"` } -var config Config -var logger *logrus.Logger +var ( + config Config + logger *logrus.Logger + filenameRegex *regexp.Regexp +) func main() { // Initialize logger @@ -37,6 +42,9 @@ func main() { logger.Fatalf("Error loading config file: %s", err) } + // Initialize filename regular expression + filenameRegex = regexp.MustCompile(`^[a-zA-Z0-9-_\.]+$`) + // Initialize router router := chi.NewRouter() router.Use(middleware.Logger) @@ -47,6 +55,7 @@ func main() { // Start server address := fmt.Sprintf(":%s", config.ServicePort) logger.Infof("Starting CDN server on port %s...", config.ServicePort) + logger.Infof("Serving files from %s", config.UploadsDir) err = http.ListenAndServe(address, router) if err != nil { logger.Fatalf("Server error: %s", err) @@ -117,14 +126,29 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { } // Parse the uploaded file - file, header, err := r.FormFile("file") + err := r.ParseMultipartForm(32 << 20) // Max file size: 32MB if err != nil { logger.Error("Bad Request:", err) http.Error(w, "Bad Request", http.StatusBadRequest) return } + + file, handler, err := r.FormFile("file") + if err != nil { + logger.Error("No file provided in the request:", err) + http.Error(w, "No file provided in the request", http.StatusBadRequest) + return + } defer file.Close() + // Validate filename + filename := sanitizeFilename(handler.Filename) + if !isValidFilename(filename) { + logger.Errorf("Invalid filename: %s", handler.Filename) + http.Error(w, "Invalid filename", http.StatusBadRequest) + return + } + // Create the uploads directory if it doesn't exist err = os.MkdirAll(config.UploadsDir, os.ModePerm) if err != nil { @@ -134,7 +158,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { } // Create a new file in the uploads directory - dst, err := os.Create(filepath.Join(config.UploadsDir, header.Filename)) + dst, err := os.Create(filepath.Join(config.UploadsDir, filename)) if err != nil { logger.Error("Internal Server Error:", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -150,7 +174,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { return } - logger.Infof("File uploaded successfully: %s", header.Filename) + logger.Infof("File uploaded successfully: %s", filename) fmt.Fprintf(w, "File uploaded successfully!") } @@ -158,3 +182,11 @@ func checkAuthentication(r *http.Request) bool { username, password, ok := r.BasicAuth() return ok && username == config.Username && password == config.Password } + +func sanitizeFilename(filename string) string { + return strings.TrimSpace(filename) +} + +func isValidFilename(filename string) bool { + return filenameRegex.MatchString(filename) +}