diff --git a/cmd/watchdog/root.go b/cmd/watchdog/root.go index 7e8ff22..3fa4744 100644 --- a/cmd/watchdog/root.go +++ b/cmd/watchdog/root.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "strings" "syscall" @@ -167,6 +168,23 @@ func basicAuth(next http.Handler, username, password string) http.Handler { }) } +// Sanitizes a path for logging to prevent log injection attacks. Uses strconv.Quote +// to properly escape control characters and special bytes. +func sanitizePathForLog(path string) string { + escaped := strconv.Quote(path) + if len(escaped) >= 2 && escaped[0] == '"' && escaped[len(escaped)-1] == '"' { + escaped = escaped[1 : len(escaped)-1] + } + + // Limit path length to prevent log flooding + const maxLen = 200 + if len(escaped) > maxLen { + return escaped[:maxLen] + "..." + } + + return escaped +} + // Creates a file server that only serves whitelisted files. Blocks dotfiles, .git, .env, etc. // TODO: I need to hook this up to eris somehow so I can just forward the paths that are being // scanned despite not being on a whitelist. Would be a good way of detecting scrapers, maybe. @@ -179,7 +197,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha // Block directory listings if strings.HasSuffix(path, "/") { blockedRequests.WithLabelValues("directory_listing").Inc() - log.Printf("Blocked directory listing attempt: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked directory listing attempt: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } @@ -188,7 +206,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha for segment := range strings.SplitSeq(path, "/") { if strings.HasPrefix(segment, ".") { blockedRequests.WithLabelValues("dotfile").Inc() - log.Printf("Blocked dotfile access: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked dotfile access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } @@ -199,7 +217,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha strings.HasSuffix(lower, ".bak") || strings.HasSuffix(lower, "~") { blockedRequests.WithLabelValues("sensitive_file").Inc() - log.Printf("Blocked sensitive file access: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked sensitive file access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } @@ -209,7 +227,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha ext := strings.ToLower(filepath.Ext(path)) if ext != ".js" && ext != ".html" && ext != ".css" { blockedRequests.WithLabelValues("invalid_extension").Inc() - log.Printf("Blocked invalid extension: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked invalid extension: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } diff --git a/cmd/watchdog/root_test.go b/cmd/watchdog/root_test.go new file mode 100644 index 0000000..f434364 --- /dev/null +++ b/cmd/watchdog/root_test.go @@ -0,0 +1,116 @@ +package watchdog + +import ( + "strings" + "testing" +) + +func TestSanitizePathForLog(t *testing.T) { + tests := []struct { + name string + input string + want string + maxLen int // expected max length check + }{ + { + name: "normal path", + input: "/web/beacon.js", + want: "/web/beacon.js", + maxLen: 200, + }, + { + name: "path with newlines", + input: "/test\nmalicious", + want: `/test\nmalicious`, + maxLen: 200, + }, + { + name: "path with carriage return", + input: "/test\rmalicious", + want: `/test\rmalicious`, + maxLen: 200, + }, + { + name: "path with tabs", + input: "/test\tmalicious", + want: `/test\tmalicious`, + maxLen: 200, + }, + { + name: "path with null bytes", + input: "/test\x00malicious", + want: `/test\x00malicious`, + maxLen: 200, + }, + { + name: "path with quotes", + input: `/test"malicious`, + want: `/test\"malicious`, + maxLen: 200, + }, + { + name: "path with backslash", + input: `/test\malicious`, + want: `/test\\malicious`, + maxLen: 200, + }, + { + name: "control characters", + input: "/test\x01\x02\x1fmalicious", + want: `/test\x01\x02\x1fmalicious`, + maxLen: 200, + }, + { + name: "truncation at 200 chars", + input: "/" + strings.Repeat("a", 250), + want: "/" + strings.Repeat("a", 199) + "...", + maxLen: 203, // 200 chars + "..." = 203 + }, + { + name: "empty string", + input: "", + want: "", + maxLen: 200, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizePathForLog(tt.input) + if got != tt.want { + t.Errorf("sanitizePathForLog(%q) = %q, want %q", tt.input, got, tt.want) + } + if len(got) > tt.maxLen { + t.Errorf("sanitizePathForLog(%q) length = %d, exceeds max %d", tt.input, len(got), tt.maxLen) + } + }) + } +} + +func TestSanitizePathForLog_LogInjectionPrevention(t *testing.T) { + // Log injection attempts should be neutralized + maliciousPaths := []string{ + "/api\nINFO: fake log entry", + "/test\r\nERROR: fake error", + "/.git/config\x00", // null byte injection + } + + for _, path := range maliciousPaths { + sanitized := sanitizePathForLog(path) + // Check that newlines are escaped, not literal + if strings.Contains(sanitized, "\n") || strings.Contains(sanitized, "\r") { + t.Errorf("sanitizePathForLog(%q) contains literal newlines: %q", path, sanitized) + } + // Check that null bytes are escaped + if strings.Contains(sanitized, "\x00") { + t.Errorf("sanitizePathForLog(%q) contains literal null byte: %q", path, sanitized) + } + } +} + +func BenchmarkSanitizePathForLog(b *testing.B) { + path := "/test/path/with\nnewlines\rand\ttabs" + for b.Loop() { + _ = sanitizePathForLog(path) + } +} diff --git a/internal/api/handler.go b/internal/api/handler.go index 1b20c85..fe4298d 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -1,7 +1,10 @@ package api import ( - "math/rand" + "crypto/rand" + "encoding/hex" + "fmt" + mrand "math/rand" "net" "net/http" "strings" @@ -23,7 +26,7 @@ type IngestionHandler struct { refRegistry *normalize.ReferrerRegistry metricsAgg *aggregate.MetricsAggregator rateLimiter *ratelimit.TokenBucket - rng *rand.Rand + rng *mrand.Rand trustedNetworks []*net.IPNet // pre-parsed CIDR networks } @@ -82,12 +85,19 @@ func NewIngestionHandler( refRegistry: refRegistry, metricsAgg: metricsAgg, rateLimiter: limiter, - rng: rand.New(rand.NewSource(time.Now().UnixNano())), + rng: mrand.New(mrand.NewSource(time.Now().UnixNano())), trustedNetworks: trustedNetworks, } } func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Generate or extract request ID for tracing + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + w.Header().Set("X-Request-ID", requestID) + // Handle CORS preflight if r.Method == http.MethodOptions { h.handleCORS(w, r) @@ -131,8 +141,8 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Validate event via map lookup (also O(1)) - if err := event.ValidateWithMap(h.domainMap); err != nil { + // Validate event via map lookup + if err := event.Validate(h.domainMap); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } @@ -322,3 +332,13 @@ func (h *IngestionHandler) classifyDevice(width int, userAgent string) string { return "unknown" } + +// generateRequestID creates a unique request ID for tracing +func generateRequestID() string { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + // Fallback to timestamp if crypto/rand fails + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +}