input validation

This commit is contained in:
NotAShelf 2023-06-03 20:42:30 +03:00
parent 754dbf51be
commit 0031af0057
No known key found for this signature in database
GPG key ID: F0D14CCB5ED5AA22

42
main.go
View file

@ -7,6 +7,8 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strings"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
@ -20,8 +22,11 @@ type Config struct {
UploadsDir string `json:"uploads_dir"` UploadsDir string `json:"uploads_dir"`
} }
var config Config var (
var logger *logrus.Logger config Config
logger *logrus.Logger
filenameRegex *regexp.Regexp
)
func main() { func main() {
// Initialize logger // Initialize logger
@ -37,6 +42,9 @@ func main() {
logger.Fatalf("Error loading config file: %s", err) logger.Fatalf("Error loading config file: %s", err)
} }
// Initialize filename regular expression
filenameRegex = regexp.MustCompile(`^[a-zA-Z0-9-_\.]+$`)
// Initialize router // Initialize router
router := chi.NewRouter() router := chi.NewRouter()
router.Use(middleware.Logger) router.Use(middleware.Logger)
@ -47,6 +55,7 @@ func main() {
// Start server // Start server
address := fmt.Sprintf(":%s", config.ServicePort) address := fmt.Sprintf(":%s", config.ServicePort)
logger.Infof("Starting CDN server on port %s...", config.ServicePort) logger.Infof("Starting CDN server on port %s...", config.ServicePort)
logger.Infof("Serving files from %s", config.UploadsDir)
err = http.ListenAndServe(address, router) err = http.ListenAndServe(address, router)
if err != nil { if err != nil {
logger.Fatalf("Server error: %s", err) logger.Fatalf("Server error: %s", err)
@ -117,14 +126,29 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
} }
// Parse the uploaded file // Parse the uploaded file
file, header, err := r.FormFile("file") err := r.ParseMultipartForm(32 << 20) // Max file size: 32MB
if err != nil { if err != nil {
logger.Error("Bad Request:", err) logger.Error("Bad Request:", err)
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return
} }
file, handler, err := r.FormFile("file")
if err != nil {
logger.Error("No file provided in the request:", err)
http.Error(w, "No file provided in the request", http.StatusBadRequest)
return
}
defer file.Close() defer file.Close()
// Validate filename
filename := sanitizeFilename(handler.Filename)
if !isValidFilename(filename) {
logger.Errorf("Invalid filename: %s", handler.Filename)
http.Error(w, "Invalid filename", http.StatusBadRequest)
return
}
// Create the uploads directory if it doesn't exist // Create the uploads directory if it doesn't exist
err = os.MkdirAll(config.UploadsDir, os.ModePerm) err = os.MkdirAll(config.UploadsDir, os.ModePerm)
if err != nil { if err != nil {
@ -134,7 +158,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
} }
// Create a new file in the uploads directory // Create a new file in the uploads directory
dst, err := os.Create(filepath.Join(config.UploadsDir, header.Filename)) dst, err := os.Create(filepath.Join(config.UploadsDir, filename))
if err != nil { if err != nil {
logger.Error("Internal Server Error:", err) logger.Error("Internal Server Error:", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@ -150,7 +174,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
return return
} }
logger.Infof("File uploaded successfully: %s", header.Filename) logger.Infof("File uploaded successfully: %s", filename)
fmt.Fprintf(w, "File uploaded successfully!") fmt.Fprintf(w, "File uploaded successfully!")
} }
@ -158,3 +182,11 @@ func checkAuthentication(r *http.Request) bool {
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
return ok && username == config.Username && password == config.Password return ok && username == config.Username && password == config.Password
} }
func sanitizeFilename(filename string) string {
return strings.TrimSpace(filename)
}
func isValidFilename(filename string) bool {
return filenameRegex.MatchString(filename)
}