diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 07b3aea..7f6119f 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -3,8 +3,21 @@ name: Build & Test with Go on: push: branches: [main] + paths: + - '**.go' + - 'go.mod' + - 'go.sum' + - '.github/workflows/go.yml' + - 'testdata/**' + pull_request: branches: [main] + paths: + - '**.go' + - 'go.mod' + - 'go.sum' + - '.github/workflows/go.yml' + - 'testdata/**' jobs: test: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0390525..60c97d6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -39,7 +39,7 @@ jobs: GOARCH: ${{ matrix.goarch }} - name: Upload artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: watchdog-${{ matrix.goos }}-${{ matrix.goarch }} path: watchdog-* @@ -50,17 +50,17 @@ jobs: permissions: contents: write steps: - - uses: actions/checkout@v6 - - name: Download all artifacts uses: actions/download-artifact@v8 with: path: artifacts + pattern: watchdog-* + merge-multiple: true - name: Create GitHub Release uses: softprops/action-gh-release@v2 with: - files: artifacts/**/*.tar.gz artifacts/**/*.zip + files: artifacts/watchdog-* generate_release_notes: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/cmd/watchdog/main.go b/cmd/watchdog/main.go index 221328a..fd81b9e 100644 --- a/cmd/watchdog/main.go +++ b/cmd/watchdog/main.go @@ -25,18 +25,23 @@ type versionInfo struct { BuildDate string `json:"buildDate"` } +func getVersionInfo() versionInfo { + data, err := os.ReadFile("version.json") + if err != nil { + return versionInfo{} + } + var v versionInfo + if err := json.Unmarshal(data, &v); err != nil { + return versionInfo{} + } + return v +} + func getVersion() string { if version != "" { return version } - data, err := os.ReadFile("version.json") - if err != nil { - return "dev" - } - var v versionInfo - if err := json.Unmarshal(data, &v); err != nil { - return "dev" - } + v := getVersionInfo() if v.Version != "" { return v.Version } @@ -47,14 +52,7 @@ func getCommit() string { if commit != "" { return commit } - data, err := os.ReadFile("version.json") - if err != nil { - return "none" - } - var v versionInfo - if err := json.Unmarshal(data, &v); err != nil { - return "none" - } + v := getVersionInfo() if v.Commit != "" { return v.Commit } diff --git a/cmd/watchdog/root.go b/cmd/watchdog/root.go index 7e8ff22..96fdbd4 100644 --- a/cmd/watchdog/root.go +++ b/cmd/watchdog/root.go @@ -9,8 +9,10 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "strings" "syscall" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -20,6 +22,7 @@ import ( "notashelf.dev/watchdog/internal/health" "notashelf.dev/watchdog/internal/limits" "notashelf.dev/watchdog/internal/normalize" + "notashelf.dev/watchdog/internal/ratelimit" ) func Run(cfg *config.Config) error { @@ -74,7 +77,7 @@ func Run(cfg *config.Config) error { // Setup routes mux := http.NewServeMux() - // Metrics endpoint with optional basic auth + // Metrics endpoint with optional basic auth and rate limiting metricsHandler := promhttp.HandlerFor(promRegistry, promhttp.HandlerOpts{ EnableOpenMetrics: true, }) @@ -87,6 +90,20 @@ func Run(cfg *config.Config) error { ) } + // Add rate limiting to metrics endpoint + if cfg.Limits.MaxMetricsPerMinute > 0 { + metricsRateLimiter := ratelimit.NewTokenBucket( + cfg.Limits.MaxMetricsPerMinute, + cfg.Limits.MaxMetricsPerMinute, + time.Minute, + ) + metricsHandler = rateLimitMiddleware(metricsHandler, metricsRateLimiter) + } + + // Add response size limit to metrics endpoint (10MB max) + const maxMetricsResponseSize = 10 * 1024 * 1024 // 10MB + metricsHandler = responseSizeLimitMiddleware(metricsHandler, maxMetricsResponseSize) + mux.Handle(cfg.Server.MetricsPath, metricsHandler) // Ingestion endpoint @@ -167,6 +184,69 @@ func basicAuth(next http.Handler, username, password string) http.Handler { }) } +// Wraps a handler with rate limiting +func rateLimitMiddleware(next http.Handler, limiter *ratelimit.TokenBucket) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !limiter.Allow() { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} + +// Wraps http.ResponseWriter to enforce max response size +type limitedResponseWriter struct { + http.ResponseWriter + maxSize int + written int + limitExceeded bool +} + +func (w *limitedResponseWriter) Write(p []byte) (int, error) { + if w.limitExceeded { + return 0, fmt.Errorf("response size limit exceeded") + } + + if w.written+len(p) > w.maxSize { + w.limitExceeded = true + w.Header().Set("X-Response-Truncated", "true") + http.Error(w.ResponseWriter, "Response size limit exceeded", http.StatusInternalServerError) + return 0, fmt.Errorf("response size limit exceeded: %d bytes", w.maxSize) + } + n, err := w.ResponseWriter.Write(p) + w.written += n + return n, err +} + +// Wraps a handler with response size limiting +func responseSizeLimitMiddleware(next http.Handler, maxSize int) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limited := &limitedResponseWriter{ + ResponseWriter: w, + maxSize: maxSize, + } + next.ServeHTTP(limited, r) + }) +} + +// Sanitizes a path for logging to prevent log injection attacks. Uses `strconv.Quote` +// to properly escape control characters and special bytes. +func sanitizePathForLog(path string) string { + escaped := strconv.Quote(path) + if len(escaped) >= 2 && escaped[0] == '"' && escaped[len(escaped)-1] == '"' { + escaped = escaped[1 : len(escaped)-1] + } + + // Limit path length to prevent log flooding + const maxLen = 200 + if len(escaped) > maxLen { + return escaped[:maxLen] + "..." + } + + return escaped +} + // Creates a file server that only serves whitelisted files. Blocks dotfiles, .git, .env, etc. // TODO: I need to hook this up to eris somehow so I can just forward the paths that are being // scanned despite not being on a whitelist. Would be a good way of detecting scrapers, maybe. @@ -179,7 +259,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha // Block directory listings if strings.HasSuffix(path, "/") { blockedRequests.WithLabelValues("directory_listing").Inc() - log.Printf("Blocked directory listing attempt: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked directory listing attempt: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } @@ -188,7 +268,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha for segment := range strings.SplitSeq(path, "/") { if strings.HasPrefix(segment, ".") { blockedRequests.WithLabelValues("dotfile").Inc() - log.Printf("Blocked dotfile access: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked dotfile access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } @@ -199,7 +279,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha strings.HasSuffix(lower, ".bak") || strings.HasSuffix(lower, "~") { blockedRequests.WithLabelValues("sensitive_file").Inc() - log.Printf("Blocked sensitive file access: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked sensitive file access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } @@ -209,7 +289,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha ext := strings.ToLower(filepath.Ext(path)) if ext != ".js" && ext != ".html" && ext != ".css" { blockedRequests.WithLabelValues("invalid_extension").Inc() - log.Printf("Blocked invalid extension: %s from %s", path, r.RemoteAddr) + log.Printf("Blocked invalid extension: %s from %s", sanitizePathForLog(path), r.RemoteAddr) http.NotFound(w, r) return } diff --git a/cmd/watchdog/root_test.go b/cmd/watchdog/root_test.go new file mode 100644 index 0000000..f434364 --- /dev/null +++ b/cmd/watchdog/root_test.go @@ -0,0 +1,116 @@ +package watchdog + +import ( + "strings" + "testing" +) + +func TestSanitizePathForLog(t *testing.T) { + tests := []struct { + name string + input string + want string + maxLen int // expected max length check + }{ + { + name: "normal path", + input: "/web/beacon.js", + want: "/web/beacon.js", + maxLen: 200, + }, + { + name: "path with newlines", + input: "/test\nmalicious", + want: `/test\nmalicious`, + maxLen: 200, + }, + { + name: "path with carriage return", + input: "/test\rmalicious", + want: `/test\rmalicious`, + maxLen: 200, + }, + { + name: "path with tabs", + input: "/test\tmalicious", + want: `/test\tmalicious`, + maxLen: 200, + }, + { + name: "path with null bytes", + input: "/test\x00malicious", + want: `/test\x00malicious`, + maxLen: 200, + }, + { + name: "path with quotes", + input: `/test"malicious`, + want: `/test\"malicious`, + maxLen: 200, + }, + { + name: "path with backslash", + input: `/test\malicious`, + want: `/test\\malicious`, + maxLen: 200, + }, + { + name: "control characters", + input: "/test\x01\x02\x1fmalicious", + want: `/test\x01\x02\x1fmalicious`, + maxLen: 200, + }, + { + name: "truncation at 200 chars", + input: "/" + strings.Repeat("a", 250), + want: "/" + strings.Repeat("a", 199) + "...", + maxLen: 203, // 200 chars + "..." = 203 + }, + { + name: "empty string", + input: "", + want: "", + maxLen: 200, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizePathForLog(tt.input) + if got != tt.want { + t.Errorf("sanitizePathForLog(%q) = %q, want %q", tt.input, got, tt.want) + } + if len(got) > tt.maxLen { + t.Errorf("sanitizePathForLog(%q) length = %d, exceeds max %d", tt.input, len(got), tt.maxLen) + } + }) + } +} + +func TestSanitizePathForLog_LogInjectionPrevention(t *testing.T) { + // Log injection attempts should be neutralized + maliciousPaths := []string{ + "/api\nINFO: fake log entry", + "/test\r\nERROR: fake error", + "/.git/config\x00", // null byte injection + } + + for _, path := range maliciousPaths { + sanitized := sanitizePathForLog(path) + // Check that newlines are escaped, not literal + if strings.Contains(sanitized, "\n") || strings.Contains(sanitized, "\r") { + t.Errorf("sanitizePathForLog(%q) contains literal newlines: %q", path, sanitized) + } + // Check that null bytes are escaped + if strings.Contains(sanitized, "\x00") { + t.Errorf("sanitizePathForLog(%q) contains literal null byte: %q", path, sanitized) + } + } +} + +func BenchmarkSanitizePathForLog(b *testing.B) { + path := "/test/path/with\nnewlines\rand\ttabs" + for b.Loop() { + _ = sanitizePathForLog(path) + } +} diff --git a/config.example.yaml b/config.example.yaml index c3ad980..87c06d8 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -63,6 +63,8 @@ limits: max_custom_events: 100 # Maximum events per minute (rate limiting, 0 = unlimited) max_events_per_minute: 10000 + # Maximum metrics endpoint requests per minute (rate limiting, 0 = unlimited, default: 30) + max_metrics_per_minute: 60 # Device classification breakpoints (screen width in pixels) device_breakpoints: @@ -85,6 +87,8 @@ security: - "*" # Or specific domains: ["https://example.com", "https://www.example.com"] # Basic authentication for /metrics endpoint + # Password can also be set via environment variable: + # - `$ export WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD=your-secret-password` metrics_auth: enabled: false username: "admin" diff --git a/docs/observability.md b/docs/observability.md index 1f3b9dd..4796d2b 100644 --- a/docs/observability.md +++ b/docs/observability.md @@ -6,12 +6,17 @@ Grafana. > [!IMPORTANT] > -> **Why you need Prometheus:** +> **Why you need a time-series database:** > > - Watchdog exposes _current state_ (counters, gauges) -> - Prometheus _scrapes periodically_ and _stores time-series data_ -> - Grafana _visualizes_ the historical data from Prometheus +> - A TSDB _scrapes periodically_ and _stores time-series data_ +> - Grafana _visualizes_ the historical data > - Grafana cannot directly scrape Prometheus `/metrics` endpoints +> +> **Compatible databases:** +> +> - [Prometheus](#prometheus-setup), +> - [VictoriaMetrics](#victoriametrics), or any Prometheus-compatible scraper ## Prometheus Setup @@ -124,7 +129,7 @@ For multiple Watchdog instances: datasources.settings.datasources = [{ name = "Prometheus"; type = "prometheus"; - url = "http://localhost:9090"; + url = "http://localhost:9090"; # Or "http://localhost:8428" for VictoriaMetrics isDefault = true; }]; }; @@ -228,7 +233,34 @@ sum by (instance) (rate(web_pageviews_total[5m])) ### VictoriaMetrics -Drop-in Prometheus replacement with better performance and compression: +VictoriaMetrics is a fast, cost-effective monitoring solution and time-series +database that is 100% compatible with Prometheus exposition format. Watchdog's +`/metrics` endpoint can be scraped directly by VictoriaMetrics without requiring +Prometheus. + +#### Direct Scraping (Recommended) + +VictoriaMetrics single-node mode can scrape Watchdog directly using standard +Prometheus scrape configuration: + +**Configuration file (`/etc/victoriametrics/scrape.yml`):** + +```yaml +scrape_configs: + - job_name: "watchdog" + static_configs: + - targets: ["localhost:8080"] + scrape_interval: 15s + metrics_path: /metrics +``` + +**Run VictoriaMetrics:** + +```bash +victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml +``` + +**NixOS configuration:** ```nix { @@ -236,15 +268,84 @@ Drop-in Prometheus replacement with better performance and compression: enable = true; listenAddress = ":8428"; retentionPeriod = "12month"; + + # Define scrape configs directly. 'prometheusConfig' is the configuration for + # Prometheus-style metrics endpoints, which Watchdog exports. + prometheusConfig = { + scrape_configs = [ + { + job_name = "watchdog"; + scrape_interval = "15s"; + static_configs = [{ + targets = [ "localhost:8080" ]; # replace the port + }]; + } + ]; + }; + }; +} +``` + +#### Using `vmagent` + +Alternatively, for distributed setups or when you need more advanced features +like relabeling, you may use `vmagent`: + +```nix +{ + services.vmagent = { + enable = true; + remoteWriteUrl = "http://localhost:8428/api/v1/write"; + + prometheusConfig = { + scrape_configs = [ + { + job_name = "watchdog"; + static_configs = [{ + targets = [ "localhost:8080" ]; + }]; + } + ]; + }; }; - # Configure Prometheus to remote-write to VictoriaMetrics + services.victoriametrics = { + enable = true; + listenAddress = ":8428"; + }; +} +``` + +#### Prometheus Remote Write + +If you are migrating from Prometheus, or if you need PromQL compatibility, or if +you just really like using Prometheus for some inexplicable reason you may keep +Prometheus but use VictoriaMetrics to remote-write. + +```nix +{ services.prometheus = { enable = true; + port = 9090; + + scrapeConfigs = [ + { + job_name = "watchdog"; + static_configs = [{ + targets = [ "localhost:8080" ]; + }]; + } + ]; + remoteWrite = [{ url = "http://localhost:8428/api/v1/write"; }]; }; + + services.victoriametrics = { + enable = true; + listenAddress = ":8428"; + }; } ``` @@ -274,10 +375,10 @@ metrics: ## Monitoring the Monitoring -Monitor Prometheus itself: +Monitor your scraper: ```promql -# Prometheus scrape success rate +# Scrape success rate up{job="watchdog"} # Scrape duration @@ -287,14 +388,9 @@ scrape_duration_seconds{job="watchdog"} time() - timestamp(up{job="watchdog"}) ``` -## Additional Recommendations +For VictoriaMetrics, you can also monitor ingestion stats: -1. **Retention**: Set `--storage.tsdb.retention.time=30d` or longer based on - disk space -2. **Backups**: Back up `/var/lib/prometheus` periodically (or whatever your - state directory is) -3. **Alerting**: Configure Prometheus alerting rules for critical metrics -4. **High Availability**: Run multiple Prometheus instances with identical - configs -5. **Remote Storage**: For long-term storage, use Thanos, Cortex, or - VictoriaMetrics +```bash +# VM internal metrics +curl http://localhost:8428/metrics | grep vm_rows_inserted_total +``` diff --git a/internal/aggregate/custom_events.go b/internal/aggregate/custom_events.go index cc0e92a..4593c72 100644 --- a/internal/aggregate/custom_events.go +++ b/internal/aggregate/custom_events.go @@ -54,3 +54,20 @@ func (r *CustomEventRegistry) OverflowCount() int { defer r.mu.RUnlock() return r.overflowCount } + +// Contains checks if an event name exists in the registry. +func (r *CustomEventRegistry) Contains(eventName string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, exists := r.events[eventName] + return exists +} + +// Count returns the number of unique events in the registry. +func (r *CustomEventRegistry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.events) +} diff --git a/internal/aggregate/metrics.go b/internal/aggregate/metrics.go index 14ef690..49c1047 100644 --- a/internal/aggregate/metrics.go +++ b/internal/aggregate/metrics.go @@ -162,7 +162,11 @@ func (m *MetricsAggregator) RecordPageview(path, country, device, referrer, doma labels := prometheus.Labels{"path": sanitizeLabel(path)} if m.cfg.Site.Collect.Country { - labels["country"] = sanitizeLabel(country) + if country == "" { + labels["country"] = "unknown" + } else { + labels["country"] = sanitizeLabel(country) + } } if m.cfg.Site.Collect.Device { diff --git a/internal/aggregate/uniques.go b/internal/aggregate/uniques.go index c606ecd..c3c0955 100644 --- a/internal/aggregate/uniques.go +++ b/internal/aggregate/uniques.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "os" + "strings" "sync" "time" @@ -17,15 +18,18 @@ type UniquesEstimator struct { hll *hyperloglog.Sketch salt string rotation string // "daily" or "hourly" + saltKey string // cached time key to avoid regeneration mu sync.Mutex } // Creates a new unique visitor estimator func NewUniquesEstimator(rotation string) *UniquesEstimator { + now := time.Now() return &UniquesEstimator{ hll: hyperloglog.New(), - salt: generateSalt(time.Now(), rotation), + salt: generateSalt(now, rotation), rotation: rotation, + saltKey: getSaltKey(now, rotation), } } @@ -36,11 +40,13 @@ func (u *UniquesEstimator) Add(ip, userAgent string) { defer u.mu.Unlock() // Check if we need to rotate to a new period - currentSalt := generateSalt(time.Now(), u.rotation) - if currentSalt != u.salt { + now := time.Now() + currentKey := getSaltKey(now, u.rotation) + if currentKey != u.saltKey { // Reset HLL for new period u.hll = hyperloglog.New() - u.salt = currentSalt + u.salt = generateSaltFromKey(currentKey) + u.saltKey = currentKey } // Hash visitor with salt to prevent cross-period tracking @@ -55,24 +61,36 @@ func (u *UniquesEstimator) Estimate() uint64 { return u.hll.Estimate() } -// Generates a deterministic salt based on the rotation mode -// Daily: same day = same salt, different day = different salt -// Hourly: same hour = same salt, different hour = different salt -func generateSalt(t time.Time, rotation string) string { - var key string +// Returns the time-based key for salt generation without hashing +func getSaltKey(t time.Time, rotation string) string { if rotation == "hourly" { - key = t.UTC().Format("2006-01-02T15") - } else { - key = t.UTC().Format("2006-01-02") + return t.UTC().Format("2006-01-02T15") } + return t.UTC().Format("2006-01-02") +} + +// Creates a salt from a pre-computed key +func generateSaltFromKey(key string) string { h := sha256.Sum256([]byte("watchdog-salt-" + key)) return hex.EncodeToString(h[:]) } +// Generates a deterministic salt based on the rotation mode +// Daily: same day = same salt, different day = different salt +// Hourly: same hour = same salt, different hour = different salt +func generateSalt(t time.Time, rotation string) string { + return generateSaltFromKey(getSaltKey(t, rotation)) +} + // Creates a privacy-preserving hash of visitor identity func hashVisitor(ip, userAgent, salt string) string { - combined := ip + "|" + userAgent + "|" + salt - h := sha256.Sum256([]byte(combined)) + var sb strings.Builder + sb.WriteString(ip) + sb.WriteString("|") + sb.WriteString(userAgent) + sb.WriteString("|") + sb.WriteString(salt) + h := sha256.Sum256([]byte(sb.String())) return hex.EncodeToString(h[:]) } @@ -122,16 +140,20 @@ func (u *UniquesEstimator) Load(path string) error { } savedSalt := string(parts[0]) - currentSalt := generateSalt(time.Now(), u.rotation) + now := time.Now() + currentKey := getSaltKey(now, u.rotation) + currentSalt := generateSaltFromKey(currentKey) // Only restore if it's the same period if savedSalt == currentSalt { u.salt = savedSalt + u.saltKey = currentKey return u.hll.UnmarshalBinary(parts[1]) } // Different period, start fresh u.hll = hyperloglog.New() u.salt = currentSalt + u.saltKey = currentKey return nil } diff --git a/internal/api/event.go b/internal/api/event.go index 7cb828f..7f71ce9 100644 --- a/internal/api/event.go +++ b/internal/api/event.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io" - "slices" "notashelf.dev/watchdog/internal/limits" ) @@ -40,40 +39,8 @@ func ParseEvent(body io.Reader) (*Event, error) { return &event, nil } -// Validate checks if the event is valid for the given domains -func (e *Event) Validate(allowedDomains []string) error { - if e.Domain == "" { - return fmt.Errorf("domain required") - } - - // Check if domain is in allowed list - allowed := slices.Contains(allowedDomains, e.Domain) - if !allowed { - return fmt.Errorf("domain not allowed") - } - - if e.Path == "" { - return fmt.Errorf("path required") - } - - if len(e.Path) > limits.MaxPathLen { - return fmt.Errorf("path too long") - } - - if len(e.Referrer) > limits.MaxRefLen { - return fmt.Errorf("referrer too long") - } - - // Validate screen width is in reasonable range - if e.Width < 0 || e.Width > limits.MaxWidth { - return fmt.Errorf("invalid width") - } - - return nil -} - -// ValidateWithMap checks if the event is valid using a domain map (O(1) lookup) -func (e *Event) ValidateWithMap(allowedDomains map[string]bool) error { +// Validate checks if the event is valid using a domain map +func (e *Event) Validate(allowedDomains map[string]bool) error { if e.Domain == "" { return fmt.Errorf("domain required") } diff --git a/internal/api/event_test.go b/internal/api/event_test.go index 0a075fe..f14d0e2 100644 --- a/internal/api/event_test.go +++ b/internal/api/event_test.go @@ -115,7 +115,7 @@ func TestValidateEvent(t *testing.T) { tests := []struct { name string event Event - domains []string + domains map[string]bool wantErr bool }{ { @@ -124,7 +124,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "/home", }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: false, }, { @@ -134,7 +134,7 @@ func TestValidateEvent(t *testing.T) { Path: "/signup", Event: "signup", }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: false, }, { @@ -143,7 +143,7 @@ func TestValidateEvent(t *testing.T) { Domain: "wrong.com", Path: "/home", }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: true, }, { @@ -152,7 +152,7 @@ func TestValidateEvent(t *testing.T) { Domain: "", Path: "/home", }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: true, }, { @@ -161,7 +161,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "", }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: true, }, { @@ -170,7 +170,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "/" + strings.Repeat("a", 3000), }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: true, }, { @@ -180,7 +180,7 @@ func TestValidateEvent(t *testing.T) { Path: "/home", Referrer: strings.Repeat("a", 3000), }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: true, }, { @@ -189,7 +189,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "/" + strings.Repeat("a", 2000), }, - domains: []string{"example.com"}, + domains: map[string]bool{"example.com": true}, wantErr: false, }, { @@ -198,7 +198,7 @@ func TestValidateEvent(t *testing.T) { Domain: "site1.com", Path: "/home", }, - domains: []string{"site1.com", "site2.com"}, + domains: map[string]bool{"site1.com": true, "site2.com": true}, wantErr: false, }, { @@ -207,7 +207,7 @@ func TestValidateEvent(t *testing.T) { Domain: "site2.com", Path: "/about", }, - domains: []string{"site1.com", "site2.com"}, + domains: map[string]bool{"site1.com": true, "site2.com": true}, wantErr: false, }, { @@ -216,7 +216,7 @@ func TestValidateEvent(t *testing.T) { Domain: "site3.com", Path: "/home", }, - domains: []string{"site1.com", "site2.com"}, + domains: map[string]bool{"site1.com": true, "site2.com": true}, wantErr: true, }, } @@ -231,24 +231,7 @@ func TestValidateEvent(t *testing.T) { } } -func BenchmarkValidate_SliceLookup(b *testing.B) { - // Simulate multi-site with 50 domains - domains := make([]string, 50) - for i := range 50 { - domains[i] = strings.Repeat("site", i) + ".com" - } - - event := Event{ - Domain: domains[49], // Worst case - last in list - Path: "/test", - } - - for b.Loop() { - _ = event.Validate(domains) - } -} - -func BenchmarkValidate_MapLookup(b *testing.B) { +func BenchmarkValidate(b *testing.B) { // Simulate multi-site with 50 domains domainMap := make(map[string]bool, 50) for i := range 50 { @@ -256,11 +239,11 @@ func BenchmarkValidate_MapLookup(b *testing.B) { } event := Event{ - Domain: strings.Repeat("site", 49) + ".com", // any position + Domain: strings.Repeat("site", 49) + ".com", Path: "/test", } for b.Loop() { - _ = event.ValidateWithMap(domainMap) + _ = event.Validate(domainMap) } } diff --git a/internal/api/handler.go b/internal/api/handler.go index 4cb3905..ff068da 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -1,10 +1,14 @@ package api import ( - "math/rand" + "crypto/rand" + "encoding/hex" + "fmt" + mrand "math/rand" "net" "net/http" "strings" + "time" "notashelf.dev/watchdog/internal/aggregate" "notashelf.dev/watchdog/internal/config" @@ -14,14 +18,16 @@ import ( // Handles incoming analytics events type IngestionHandler struct { - cfg *config.Config - domainMap map[string]bool // O(1) domain validation - pathNorm *normalize.PathNormalizer - pathRegistry *aggregate.PathRegistry - refRegistry *normalize.ReferrerRegistry - metricsAgg *aggregate.MetricsAggregator - rateLimiter *ratelimit.TokenBucket - rng *rand.Rand + cfg *config.Config + domainMap map[string]bool // O(1) domain validation + corsOriginMap map[string]bool // O(1) CORS origin validation + pathNorm *normalize.PathNormalizer + pathRegistry *aggregate.PathRegistry + refRegistry *normalize.ReferrerRegistry + metricsAgg *aggregate.MetricsAggregator + rateLimiter *ratelimit.TokenBucket + rng *mrand.Rand + trustedNetworks []*net.IPNet // pre-parsed CIDR networks } // Creates a new ingestion handler @@ -47,19 +53,51 @@ func NewIngestionHandler( domainMap[domain] = true } + // Build CORS origin map for O(1) lookup + corsOriginMap := make(map[string]bool, len(cfg.Security.CORS.AllowedOrigins)) + for _, origin := range cfg.Security.CORS.AllowedOrigins { + corsOriginMap[origin] = true + } + + // Pre-parse trusted proxy CIDRs to avoid re-parsing on each request + trustedNetworks := make([]*net.IPNet, 0, len(cfg.Security.TrustedProxies)) + for _, cidr := range cfg.Security.TrustedProxies { + if _, network, err := net.ParseCIDR(cidr); err == nil { + trustedNetworks = append(trustedNetworks, network) + } else if ip := net.ParseIP(cidr); ip != nil { + // Single IP - create a /32 or /128 network + var mask net.IPMask + if ip.To4() != nil { + mask = net.CIDRMask(32, 32) + } else { + mask = net.CIDRMask(128, 128) + } + trustedNetworks = append(trustedNetworks, &net.IPNet{IP: ip, Mask: mask}) + } + } + return &IngestionHandler{ - cfg: cfg, - domainMap: domainMap, - pathNorm: pathNorm, - pathRegistry: pathRegistry, - refRegistry: refRegistry, - metricsAgg: metricsAgg, - rateLimiter: limiter, - rng: rand.New(rand.NewSource(42)), + cfg: cfg, + domainMap: domainMap, + corsOriginMap: corsOriginMap, + pathNorm: pathNorm, + pathRegistry: pathRegistry, + refRegistry: refRegistry, + metricsAgg: metricsAgg, + rateLimiter: limiter, + rng: mrand.New(mrand.NewSource(time.Now().UnixNano())), + trustedNetworks: trustedNetworks, } } func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Generate or extract request ID for tracing + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + w.Header().Set("X-Request-ID", requestID) + // Handle CORS preflight if r.Method == http.MethodOptions { h.handleCORS(w, r) @@ -103,8 +141,8 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Validate event via map lookup (also O(1)) - if err := event.ValidateWithMap(h.domainMap); err != nil { + // Validate event via map lookup + if err := event.Validate(h.domainMap); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } @@ -181,13 +219,8 @@ func (h *IngestionHandler) handleCORS(w http.ResponseWriter, r *http.Request) { } // Check if origin is allowed - allowed := false - for _, allowedOrigin := range h.cfg.Security.CORS.AllowedOrigins { - if allowedOrigin == "*" || allowedOrigin == origin { - allowed = true - break - } - } + // This uses map so that it's O(1) + allowed := h.corsOriginMap["*"] || h.corsOriginMap[origin] if allowed { if origin == "*" { @@ -213,13 +246,8 @@ func (h *IngestionHandler) extractIP(r *http.Request) string { // Check if we should trust proxy headers trustProxy := false - if len(h.cfg.Security.TrustedProxies) > 0 { - for _, trustedCIDR := range h.cfg.Security.TrustedProxies { - if h.ipInCIDR(remoteIP, trustedCIDR) { - trustProxy = true - break - } - } + if len(h.trustedNetworks) > 0 { + trustProxy = h.ipInNetworks(remoteIP, h.trustedNetworks) } // If not trusting proxy, return direct IP @@ -233,42 +261,43 @@ func (h *IngestionHandler) extractIP(r *http.Request) string { ips := strings.Split(xff, ",") for i := len(ips) - 1; i >= 0; i-- { ip := strings.TrimSpace(ips[i]) - if !h.ipInCIDR(ip, "0.0.0.0/0") { - continue + if testIP := net.ParseIP(ip); testIP != nil { + // Only accept this IP if it's NOT from a trusted proxy + if !h.ipInNetworks(ip, h.trustedNetworks) { + return ip + } } - return ip } } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { - return xri + // Validate the IP format and ensure it's not from a trusted proxy + if testIP := net.ParseIP(xri); testIP != nil { + if !h.ipInNetworks(xri, h.trustedNetworks) { + return xri + } + } } // Fall back to RemoteAddr return remoteIP } -// Checks if an IP address is within a CIDR range -func (h *IngestionHandler) ipInCIDR(ip, cidr string) bool { - // Parse the IP address +// Checks if an IP address is within any of the trusted networks +func (h *IngestionHandler) ipInNetworks(ip string, networks []*net.IPNet) bool { testIP := net.ParseIP(ip) if testIP == nil { return false } - // Parse the CIDR - _, network, err := net.ParseCIDR(cidr) - if err != nil { - // If it's not a CIDR, try as a single IP - cidrIP := net.ParseIP(cidr) - if cidrIP == nil { - return false + for _, network := range networks { + if network.Contains(testIP) { + return true } - return testIP.Equal(cidrIP) } - return network.Contains(testIP) + return false } // Classifies device using both screen width and User-Agent parsing @@ -311,3 +340,17 @@ func (h *IngestionHandler) classifyDevice(width int, userAgent string) string { return "unknown" } + +// Creates a unique request ID for tracing. +// Uses 8 bytes (64 bits) of randomness which produces 16 hex characters. +// 2^64 possible IDs (~18 quintillion) provides sufficient uniqueness for +// request tracing while keeping IDs reasonably short in logs and headers. +func generateRequestID() string { + // 8 bytes = 64 bits = 16 hex chars = 2^64 possible IDs + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + // Fallback to timestamp if crypto/rand fails + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} diff --git a/internal/api/handler_test.go b/internal/api/handler_test.go index cee1dbb..0ba35d1 100644 --- a/internal/api/handler_test.go +++ b/internal/api/handler_test.go @@ -218,6 +218,107 @@ func newTestHandler(cfg *config.Config) *IngestionHandler { return NewIngestionHandler(cfg, pathNorm, pathRegistry, refRegistry, metricsAgg) } +func TestExtractIP(t *testing.T) { + cfg := &config.Config{ + Site: config.SiteConfig{ + Domains: []string{"example.com"}, + }, + Limits: config.LimitsConfig{ + MaxPaths: 100, + MaxSources: 50, + }, + Security: config.SecurityConfig{ + TrustedProxies: []string{"10.0.0.0/8", "192.168.1.1"}, + }, + } + h := newTestHandler(cfg) + + tests := []struct { + name string + remoteAddr string + headers map[string]string + want string + }{ + { + name: "direct connection no proxy", + remoteAddr: "192.168.1.100:12345", + headers: map[string]string{}, + want: "192.168.1.100", + }, + { + name: "trusted proxy with X-Forwarded-For", + remoteAddr: "10.0.0.1:12345", + headers: map[string]string{ + "X-Forwarded-For": "203.0.113.1, 10.0.0.5", + }, + want: "203.0.113.1", + }, + { + name: "trusted proxy with X-Real-IP", + remoteAddr: "10.0.0.1:12345", + headers: map[string]string{ + "X-Real-IP": "203.0.113.2", + }, + want: "203.0.113.2", + }, + { + name: "X-Real-IP from trusted network should be ignored", + remoteAddr: "10.0.0.1:12345", + headers: map[string]string{ + "X-Real-IP": "10.0.0.50", // trusted network, should fall back + }, + want: "10.0.0.1", // falls back to remoteAddr + }, + { + name: "X-Real-IP invalid IP should be ignored", + remoteAddr: "10.0.0.1:12345", + headers: map[string]string{ + "X-Real-IP": "not-an-ip", + }, + want: "10.0.0.1", // falls back + }, + { + name: "untrusted proxy X-Forwarded-For ignored", + remoteAddr: "203.0.113.50:12345", + headers: map[string]string{ + "X-Forwarded-For": "1.2.3.4", + }, + want: "203.0.113.50", // uses remoteAddr, ignores header + }, + { + name: "untrusted proxy X-Real-IP ignored", + remoteAddr: "203.0.113.50:12345", + headers: map[string]string{ + "X-Real-IP": "1.2.3.4", + }, + want: "203.0.113.50", // uses remoteAddr, ignores header + }, + { + name: "X-Forwarded-For all trusted falls back", + remoteAddr: "10.0.0.1:12345", + headers: map[string]string{ + "X-Forwarded-For": "10.0.0.2, 10.0.0.3", + }, + want: "10.0.0.1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/api/event", nil) + req.RemoteAddr = tt.remoteAddr + for k, v := range tt.headers { + req.Header.Set(k, v) + } + + got := h.extractIP(req) + if got != tt.want { + t.Errorf("extractIP() = %q, want %q", got, tt.want) + } + }) + } +} + func TestClassifyDevice_UA(t *testing.T) { cfg := &config.Config{ Limits: config.LimitsConfig{ diff --git a/internal/config/config.go b/internal/config/config.go index 595306e..158b58f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,11 +45,12 @@ type PathConfig struct { // Cardinality limits type LimitsConfig struct { - MaxPaths int `yaml:"max_paths"` - MaxEventsPerMinute int `yaml:"max_events_per_minute"` - MaxSources int `yaml:"max_sources"` - MaxCustomEvents int `yaml:"max_custom_events"` - DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"` + MaxPaths int `yaml:"max_paths"` + MaxEventsPerMinute int `yaml:"max_events_per_minute"` + MaxSources int `yaml:"max_sources"` + MaxCustomEvents int `yaml:"max_custom_events"` + DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"` + MaxMetricsPerMinute int `yaml:"max_metrics_per_minute"` // rate limit for metrics endpoint } // Device classification breakpoints @@ -72,10 +73,11 @@ type CORSConfig struct { } // Authentication for metrics endpoint +// Password can be set via environment variable: WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD type AuthConfig struct { Enabled bool `yaml:"enabled"` Username string `yaml:"username"` - Password string `yaml:"password"` + Password string `yaml:"password"` // can use env var WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD } // Server endpoints @@ -149,6 +151,10 @@ func (c *Config) Validate() error { c.Limits.MaxCustomEvents = 100 // Default } + if c.Limits.MaxMetricsPerMinute <= 0 { + c.Limits.MaxMetricsPerMinute = 30 // Default: 30 requests per minute + } + if c.Site.Path.MaxSegments < 0 { return fmt.Errorf("site.path.max_segments cannot be negative") } diff --git a/internal/normalize/path.go b/internal/normalize/path.go index 9836b67..6dadd4f 100644 --- a/internal/normalize/path.go +++ b/internal/normalize/path.go @@ -48,38 +48,36 @@ func (n *PathNormalizer) Normalize(path string) string { path = "/" + path } + // Process segments in-place to minimize allocations // Split into segments, first element is *always* empty for paths starting with '/' segments := strings.Split(path, "/") - if len(segments) > 0 && segments[0] == "" { - segments = segments[1:] - } - // Remove empty segments (from double slashes) - filtered := make([]string, 0, len(segments)) - for _, seg := range segments { - if seg != "" { - filtered = append(filtered, seg) + // Process segments in a single pass: remove empty, resolve . and .. + writeIdx := 0 + for i := 0; i < len(segments); i++ { + seg := segments[i] + + // Skip empty segments (from double slashes or leading /) + if seg == "" { + continue } - } - segments = filtered - // Resolve . and .. segments - resolved := make([]string, 0, len(segments)) - for _, seg := range segments { if seg == "." { // Skip current directory continue } else if seg == ".." { // Go up one level if possible - if len(resolved) > 0 { - resolved = resolved[:len(resolved)-1] + if writeIdx > 0 { + writeIdx-- } // If already at root, skip .. } else { - resolved = append(resolved, seg) + // Keep this segment + segments[writeIdx] = seg + writeIdx++ } } - segments = resolved + segments = segments[:writeIdx] // Collapse numeric segments if n.cfg.CollapseNumericSegments { diff --git a/internal/normalize/referrer.go b/internal/normalize/referrer.go index df0189d..1b43423 100644 --- a/internal/normalize/referrer.go +++ b/internal/normalize/referrer.go @@ -1,6 +1,7 @@ package normalize import ( + "net" "net/url" "strings" @@ -21,34 +22,20 @@ func isInternalHost(hostname string) bool { return true } - // Private IPv4 ranges (RFC1918) - if strings.HasPrefix(hostname, "10.") || - strings.HasPrefix(hostname, "192.168.") || - strings.HasPrefix(hostname, "172.16.") || - strings.HasPrefix(hostname, "172.17.") || - strings.HasPrefix(hostname, "172.18.") || - strings.HasPrefix(hostname, "172.19.") || - strings.HasPrefix(hostname, "172.20.") || - strings.HasPrefix(hostname, "172.21.") || - strings.HasPrefix(hostname, "172.22.") || - strings.HasPrefix(hostname, "172.23.") || - strings.HasPrefix(hostname, "172.24.") || - strings.HasPrefix(hostname, "172.25.") || - strings.HasPrefix(hostname, "172.26.") || - strings.HasPrefix(hostname, "172.27.") || - strings.HasPrefix(hostname, "172.28.") || - strings.HasPrefix(hostname, "172.29.") || - strings.HasPrefix(hostname, "172.30.") || - strings.HasPrefix(hostname, "172.31.") { - return true - } - - // IPv6 loopback and local - if strings.HasPrefix(hostname, "::1") || - strings.HasPrefix(hostname, "fe80::") || - strings.HasPrefix(hostname, "fc00::") || - strings.HasPrefix(hostname, "fd00::") { - return true + // Check if hostname is an IP address + if ip := net.ParseIP(hostname); ip != nil { + // Private IPv4 ranges (RFC1918) + if ip.IsPrivate() { + return true + } + // Additional localhost checks for IP formats + if ip.IsLoopback() { + return true + } + // Link-local addresses + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } } return false diff --git a/internal/normalize/referrer_registry.go b/internal/normalize/referrer_registry.go index daaf69c..4b7913c 100644 --- a/internal/normalize/referrer_registry.go +++ b/internal/normalize/referrer_registry.go @@ -56,3 +56,20 @@ func (r *ReferrerRegistry) OverflowCount() int { defer r.mu.RUnlock() return r.overflowCount } + +// Contains checks if a source exists in the registry. +func (r *ReferrerRegistry) Contains(source string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, exists := r.sources[source] + return exists +} + +// Count returns the number of unique sources in the registry. +func (r *ReferrerRegistry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.sources) +} diff --git a/internal/normalize/referrer_test.go b/internal/normalize/referrer_test.go index 04a8d0a..0fa2e21 100644 --- a/internal/normalize/referrer_test.go +++ b/internal/normalize/referrer_test.go @@ -108,6 +108,36 @@ func TestExtractReferrerDomain(t *testing.T) { siteDomain: "example.com", want: "internal", }, + { + name: "private IP 172.16.x (RFC1918)", + referrer: "http://172.16.0.1/page", + siteDomain: "example.com", + want: "internal", + }, + { + name: "private IP 172.31.x (RFC1918 upper bound)", + referrer: "http://172.31.255.1/page", + siteDomain: "example.com", + want: "internal", + }, + { + name: "private IP 172.20.x (middle of range)", + referrer: "http://172.20.50.100/page", + siteDomain: "example.com", + want: "internal", + }, + { + name: "public IP 172.15.x (just outside private range)", + referrer: "http://172.15.0.1/page", + siteDomain: "example.com", + want: "other", // not internal, but invalid TLD + }, + { + name: "public IP 172.32.x (just outside private range)", + referrer: "http://172.32.0.1/page", + siteDomain: "example.com", + want: "other", // not internal, but invalid TLD + }, } for _, tt := range tests { diff --git a/test/integration_test.go b/test/integration_test.go index 3bc7220..9121e43 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -195,6 +195,9 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) { MaxSources: 50, MaxCustomEvents: 10, }, + Server: config.ServerConfig{ + StatePath: "/tmp/watchdog-test.state", + }, } pathRegistry := aggregate.NewPathRegistry(cfg.Limits.MaxPaths) @@ -223,12 +226,12 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) { } // Verify state file was created - if _, err := os.Stat("/tmp/watchdog-hll.state"); os.IsNotExist(err) { + if _, err := os.Stat("/tmp/watchdog-test.state"); os.IsNotExist(err) { t.Error("HLL state file was not created") } // Cleanup - os.Remove("/tmp/watchdog-hll.state") + os.Remove("/tmp/watchdog-test.state") } func TestEndToEnd_InvalidRequests(t *testing.T) {