diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 7f6119f..07b3aea 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -3,21 +3,8 @@ name: Build & Test with Go on: push: branches: [main] - paths: - - '**.go' - - 'go.mod' - - 'go.sum' - - '.github/workflows/go.yml' - - 'testdata/**' - pull_request: branches: [main] - paths: - - '**.go' - - 'go.mod' - - 'go.sum' - - '.github/workflows/go.yml' - - 'testdata/**' jobs: test: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 60c97d6..0390525 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -39,7 +39,7 @@ jobs: GOARCH: ${{ matrix.goarch }} - name: Upload artifact - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@v4 with: name: watchdog-${{ matrix.goos }}-${{ matrix.goarch }} path: watchdog-* @@ -50,17 +50,17 @@ jobs: permissions: contents: write steps: + - uses: actions/checkout@v6 + - name: Download all artifacts uses: actions/download-artifact@v8 with: path: artifacts - pattern: watchdog-* - merge-multiple: true - name: Create GitHub Release uses: softprops/action-gh-release@v2 with: - files: artifacts/watchdog-* + files: artifacts/**/*.tar.gz artifacts/**/*.zip generate_release_notes: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/cmd/watchdog/main.go b/cmd/watchdog/main.go index fd81b9e..221328a 100644 --- a/cmd/watchdog/main.go +++ b/cmd/watchdog/main.go @@ -25,23 +25,18 @@ type versionInfo struct { BuildDate string `json:"buildDate"` } -func getVersionInfo() versionInfo { - data, err := os.ReadFile("version.json") - if err != nil { - return versionInfo{} - } - var v versionInfo - if err := json.Unmarshal(data, &v); err != nil { - return versionInfo{} - } - return v -} - func getVersion() string { if version != "" { return version } - v := getVersionInfo() + data, err := os.ReadFile("version.json") + if err != nil { + return "dev" + } + var v versionInfo + if err := json.Unmarshal(data, &v); err != nil { + return "dev" + } if v.Version != "" { return v.Version } @@ -52,7 +47,14 @@ func getCommit() string { if commit != "" { return commit } - v := getVersionInfo() + data, err := os.ReadFile("version.json") + if err != nil { + return "none" + } + var v versionInfo + if err := json.Unmarshal(data, &v); err != nil { + return "none" + } if v.Commit != "" { return v.Commit } diff --git a/cmd/watchdog/root.go b/cmd/watchdog/root.go index 96fdbd4..7e8ff22 100644 --- a/cmd/watchdog/root.go +++ b/cmd/watchdog/root.go @@ -9,10 +9,8 @@ import ( "os" "os/signal" "path/filepath" - "strconv" "strings" "syscall" - "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -22,7 +20,6 @@ import ( "notashelf.dev/watchdog/internal/health" "notashelf.dev/watchdog/internal/limits" "notashelf.dev/watchdog/internal/normalize" - "notashelf.dev/watchdog/internal/ratelimit" ) func Run(cfg *config.Config) error { @@ -77,7 +74,7 @@ func Run(cfg *config.Config) error { // Setup routes mux := http.NewServeMux() - // Metrics endpoint with optional basic auth and rate limiting + // Metrics endpoint with optional basic auth metricsHandler := promhttp.HandlerFor(promRegistry, promhttp.HandlerOpts{ EnableOpenMetrics: true, }) @@ -90,20 +87,6 @@ func Run(cfg *config.Config) error { ) } - // Add rate limiting to metrics endpoint - if cfg.Limits.MaxMetricsPerMinute > 0 { - metricsRateLimiter := ratelimit.NewTokenBucket( - cfg.Limits.MaxMetricsPerMinute, - cfg.Limits.MaxMetricsPerMinute, - time.Minute, - ) - metricsHandler = rateLimitMiddleware(metricsHandler, metricsRateLimiter) - } - - // Add response size limit to metrics endpoint (10MB max) - const maxMetricsResponseSize = 10 * 1024 * 1024 // 10MB - metricsHandler = responseSizeLimitMiddleware(metricsHandler, maxMetricsResponseSize) - mux.Handle(cfg.Server.MetricsPath, metricsHandler) // Ingestion endpoint @@ -184,69 +167,6 @@ func basicAuth(next http.Handler, username, password string) http.Handler { }) } -// Wraps a handler with rate limiting -func rateLimitMiddleware(next http.Handler, limiter *ratelimit.TokenBucket) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) - return - } - next.ServeHTTP(w, r) - }) -} - -// Wraps http.ResponseWriter to enforce max response size -type limitedResponseWriter struct { - http.ResponseWriter - maxSize int - written int - limitExceeded bool -} - -func (w *limitedResponseWriter) Write(p []byte) (int, error) { - if w.limitExceeded { - return 0, fmt.Errorf("response size limit exceeded") - } - - if w.written+len(p) > w.maxSize { - w.limitExceeded = true - w.Header().Set("X-Response-Truncated", "true") - http.Error(w.ResponseWriter, "Response size limit exceeded", http.StatusInternalServerError) - return 0, fmt.Errorf("response size limit exceeded: %d bytes", w.maxSize) - } - n, err := w.ResponseWriter.Write(p) - w.written += n - return n, err -} - -// Wraps a handler with response size limiting -func responseSizeLimitMiddleware(next http.Handler, maxSize int) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - limited := &limitedResponseWriter{ - ResponseWriter: w, - maxSize: maxSize, - } - next.ServeHTTP(limited, r) - }) -} - -// Sanitizes a path for logging to prevent log injection attacks. Uses `strconv.Quote` -// to properly escape control characters and special bytes. -func sanitizePathForLog(path string) string { - escaped := strconv.Quote(path) - if len(escaped) >= 2 && escaped[0] == '"' && escaped[len(escaped)-1] == '"' { - escaped = escaped[1 : len(escaped)-1] - } - - // Limit path length to prevent log flooding - const maxLen = 200 - if len(escaped) > maxLen { - return escaped[:maxLen] + "..." - } - - return escaped -} - // Creates a file server that only serves whitelisted files. Blocks dotfiles, .git, .env, etc. // TODO: I need to hook this up to eris somehow so I can just forward the paths that are being // scanned despite not being on a whitelist. Would be a good way of detecting scrapers, maybe. @@ -259,7 +179,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha // Block directory listings if strings.HasSuffix(path, "/") { blockedRequests.WithLabelValues("directory_listing").Inc() - log.Printf("Blocked directory listing attempt: %s from %s", sanitizePathForLog(path), r.RemoteAddr) + log.Printf("Blocked directory listing attempt: %s from %s", path, r.RemoteAddr) http.NotFound(w, r) return } @@ -268,7 +188,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha for segment := range strings.SplitSeq(path, "/") { if strings.HasPrefix(segment, ".") { blockedRequests.WithLabelValues("dotfile").Inc() - log.Printf("Blocked dotfile access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) + log.Printf("Blocked dotfile access: %s from %s", path, r.RemoteAddr) http.NotFound(w, r) return } @@ -279,7 +199,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha strings.HasSuffix(lower, ".bak") || strings.HasSuffix(lower, "~") { blockedRequests.WithLabelValues("sensitive_file").Inc() - log.Printf("Blocked sensitive file access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) + log.Printf("Blocked sensitive file access: %s from %s", path, r.RemoteAddr) http.NotFound(w, r) return } @@ -289,7 +209,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha ext := strings.ToLower(filepath.Ext(path)) if ext != ".js" && ext != ".html" && ext != ".css" { blockedRequests.WithLabelValues("invalid_extension").Inc() - log.Printf("Blocked invalid extension: %s from %s", sanitizePathForLog(path), r.RemoteAddr) + log.Printf("Blocked invalid extension: %s from %s", path, r.RemoteAddr) http.NotFound(w, r) return } diff --git a/cmd/watchdog/root_test.go b/cmd/watchdog/root_test.go deleted file mode 100644 index f434364..0000000 --- a/cmd/watchdog/root_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package watchdog - -import ( - "strings" - "testing" -) - -func TestSanitizePathForLog(t *testing.T) { - tests := []struct { - name string - input string - want string - maxLen int // expected max length check - }{ - { - name: "normal path", - input: "/web/beacon.js", - want: "/web/beacon.js", - maxLen: 200, - }, - { - name: "path with newlines", - input: "/test\nmalicious", - want: `/test\nmalicious`, - maxLen: 200, - }, - { - name: "path with carriage return", - input: "/test\rmalicious", - want: `/test\rmalicious`, - maxLen: 200, - }, - { - name: "path with tabs", - input: "/test\tmalicious", - want: `/test\tmalicious`, - maxLen: 200, - }, - { - name: "path with null bytes", - input: "/test\x00malicious", - want: `/test\x00malicious`, - maxLen: 200, - }, - { - name: "path with quotes", - input: `/test"malicious`, - want: `/test\"malicious`, - maxLen: 200, - }, - { - name: "path with backslash", - input: `/test\malicious`, - want: `/test\\malicious`, - maxLen: 200, - }, - { - name: "control characters", - input: "/test\x01\x02\x1fmalicious", - want: `/test\x01\x02\x1fmalicious`, - maxLen: 200, - }, - { - name: "truncation at 200 chars", - input: "/" + strings.Repeat("a", 250), - want: "/" + strings.Repeat("a", 199) + "...", - maxLen: 203, // 200 chars + "..." = 203 - }, - { - name: "empty string", - input: "", - want: "", - maxLen: 200, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := sanitizePathForLog(tt.input) - if got != tt.want { - t.Errorf("sanitizePathForLog(%q) = %q, want %q", tt.input, got, tt.want) - } - if len(got) > tt.maxLen { - t.Errorf("sanitizePathForLog(%q) length = %d, exceeds max %d", tt.input, len(got), tt.maxLen) - } - }) - } -} - -func TestSanitizePathForLog_LogInjectionPrevention(t *testing.T) { - // Log injection attempts should be neutralized - maliciousPaths := []string{ - "/api\nINFO: fake log entry", - "/test\r\nERROR: fake error", - "/.git/config\x00", // null byte injection - } - - for _, path := range maliciousPaths { - sanitized := sanitizePathForLog(path) - // Check that newlines are escaped, not literal - if strings.Contains(sanitized, "\n") || strings.Contains(sanitized, "\r") { - t.Errorf("sanitizePathForLog(%q) contains literal newlines: %q", path, sanitized) - } - // Check that null bytes are escaped - if strings.Contains(sanitized, "\x00") { - t.Errorf("sanitizePathForLog(%q) contains literal null byte: %q", path, sanitized) - } - } -} - -func BenchmarkSanitizePathForLog(b *testing.B) { - path := "/test/path/with\nnewlines\rand\ttabs" - for b.Loop() { - _ = sanitizePathForLog(path) - } -} diff --git a/config.example.yaml b/config.example.yaml index 87c06d8..c3ad980 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -63,8 +63,6 @@ limits: max_custom_events: 100 # Maximum events per minute (rate limiting, 0 = unlimited) max_events_per_minute: 10000 - # Maximum metrics endpoint requests per minute (rate limiting, 0 = unlimited, default: 30) - max_metrics_per_minute: 60 # Device classification breakpoints (screen width in pixels) device_breakpoints: @@ -87,8 +85,6 @@ security: - "*" # Or specific domains: ["https://example.com", "https://www.example.com"] # Basic authentication for /metrics endpoint - # Password can also be set via environment variable: - # - `$ export WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD=your-secret-password` metrics_auth: enabled: false username: "admin" diff --git a/docs/observability.md b/docs/observability.md index 4796d2b..1f3b9dd 100644 --- a/docs/observability.md +++ b/docs/observability.md @@ -6,17 +6,12 @@ Grafana. > [!IMPORTANT] > -> **Why you need a time-series database:** +> **Why you need Prometheus:** > > - Watchdog exposes _current state_ (counters, gauges) -> - A TSDB _scrapes periodically_ and _stores time-series data_ -> - Grafana _visualizes_ the historical data +> - Prometheus _scrapes periodically_ and _stores time-series data_ +> - Grafana _visualizes_ the historical data from Prometheus > - Grafana cannot directly scrape Prometheus `/metrics` endpoints -> -> **Compatible databases:** -> -> - [Prometheus](#prometheus-setup), -> - [VictoriaMetrics](#victoriametrics), or any Prometheus-compatible scraper ## Prometheus Setup @@ -129,7 +124,7 @@ For multiple Watchdog instances: datasources.settings.datasources = [{ name = "Prometheus"; type = "prometheus"; - url = "http://localhost:9090"; # Or "http://localhost:8428" for VictoriaMetrics + url = "http://localhost:9090"; isDefault = true; }]; }; @@ -233,34 +228,7 @@ sum by (instance) (rate(web_pageviews_total[5m])) ### VictoriaMetrics -VictoriaMetrics is a fast, cost-effective monitoring solution and time-series -database that is 100% compatible with Prometheus exposition format. Watchdog's -`/metrics` endpoint can be scraped directly by VictoriaMetrics without requiring -Prometheus. - -#### Direct Scraping (Recommended) - -VictoriaMetrics single-node mode can scrape Watchdog directly using standard -Prometheus scrape configuration: - -**Configuration file (`/etc/victoriametrics/scrape.yml`):** - -```yaml -scrape_configs: - - job_name: "watchdog" - static_configs: - - targets: ["localhost:8080"] - scrape_interval: 15s - metrics_path: /metrics -``` - -**Run VictoriaMetrics:** - -```bash -victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml -``` - -**NixOS configuration:** +Drop-in Prometheus replacement with better performance and compression: ```nix { @@ -268,84 +236,15 @@ victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml enable = true; listenAddress = ":8428"; retentionPeriod = "12month"; - - # Define scrape configs directly. 'prometheusConfig' is the configuration for - # Prometheus-style metrics endpoints, which Watchdog exports. - prometheusConfig = { - scrape_configs = [ - { - job_name = "watchdog"; - scrape_interval = "15s"; - static_configs = [{ - targets = [ "localhost:8080" ]; # replace the port - }]; - } - ]; - }; - }; -} -``` - -#### Using `vmagent` - -Alternatively, for distributed setups or when you need more advanced features -like relabeling, you may use `vmagent`: - -```nix -{ - services.vmagent = { - enable = true; - remoteWriteUrl = "http://localhost:8428/api/v1/write"; - - prometheusConfig = { - scrape_configs = [ - { - job_name = "watchdog"; - static_configs = [{ - targets = [ "localhost:8080" ]; - }]; - } - ]; - }; }; - services.victoriametrics = { - enable = true; - listenAddress = ":8428"; - }; -} -``` - -#### Prometheus Remote Write - -If you are migrating from Prometheus, or if you need PromQL compatibility, or if -you just really like using Prometheus for some inexplicable reason you may keep -Prometheus but use VictoriaMetrics to remote-write. - -```nix -{ + # Configure Prometheus to remote-write to VictoriaMetrics services.prometheus = { enable = true; - port = 9090; - - scrapeConfigs = [ - { - job_name = "watchdog"; - static_configs = [{ - targets = [ "localhost:8080" ]; - }]; - } - ]; - remoteWrite = [{ url = "http://localhost:8428/api/v1/write"; }]; }; - - services.victoriametrics = { - enable = true; - listenAddress = ":8428"; - }; } ``` @@ -375,10 +274,10 @@ metrics: ## Monitoring the Monitoring -Monitor your scraper: +Monitor Prometheus itself: ```promql -# Scrape success rate +# Prometheus scrape success rate up{job="watchdog"} # Scrape duration @@ -388,9 +287,14 @@ scrape_duration_seconds{job="watchdog"} time() - timestamp(up{job="watchdog"}) ``` -For VictoriaMetrics, you can also monitor ingestion stats: +## Additional Recommendations -```bash -# VM internal metrics -curl http://localhost:8428/metrics | grep vm_rows_inserted_total -``` +1. **Retention**: Set `--storage.tsdb.retention.time=30d` or longer based on + disk space +2. **Backups**: Back up `/var/lib/prometheus` periodically (or whatever your + state directory is) +3. **Alerting**: Configure Prometheus alerting rules for critical metrics +4. **High Availability**: Run multiple Prometheus instances with identical + configs +5. **Remote Storage**: For long-term storage, use Thanos, Cortex, or + VictoriaMetrics diff --git a/internal/aggregate/custom_events.go b/internal/aggregate/custom_events.go index 4593c72..cc0e92a 100644 --- a/internal/aggregate/custom_events.go +++ b/internal/aggregate/custom_events.go @@ -54,20 +54,3 @@ func (r *CustomEventRegistry) OverflowCount() int { defer r.mu.RUnlock() return r.overflowCount } - -// Contains checks if an event name exists in the registry. -func (r *CustomEventRegistry) Contains(eventName string) bool { - r.mu.RLock() - defer r.mu.RUnlock() - - _, exists := r.events[eventName] - return exists -} - -// Count returns the number of unique events in the registry. -func (r *CustomEventRegistry) Count() int { - r.mu.RLock() - defer r.mu.RUnlock() - - return len(r.events) -} diff --git a/internal/aggregate/metrics.go b/internal/aggregate/metrics.go index 49c1047..14ef690 100644 --- a/internal/aggregate/metrics.go +++ b/internal/aggregate/metrics.go @@ -162,11 +162,7 @@ func (m *MetricsAggregator) RecordPageview(path, country, device, referrer, doma labels := prometheus.Labels{"path": sanitizeLabel(path)} if m.cfg.Site.Collect.Country { - if country == "" { - labels["country"] = "unknown" - } else { - labels["country"] = sanitizeLabel(country) - } + labels["country"] = sanitizeLabel(country) } if m.cfg.Site.Collect.Device { diff --git a/internal/aggregate/uniques.go b/internal/aggregate/uniques.go index c3c0955..c606ecd 100644 --- a/internal/aggregate/uniques.go +++ b/internal/aggregate/uniques.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "fmt" "os" - "strings" "sync" "time" @@ -18,18 +17,15 @@ type UniquesEstimator struct { hll *hyperloglog.Sketch salt string rotation string // "daily" or "hourly" - saltKey string // cached time key to avoid regeneration mu sync.Mutex } // Creates a new unique visitor estimator func NewUniquesEstimator(rotation string) *UniquesEstimator { - now := time.Now() return &UniquesEstimator{ hll: hyperloglog.New(), - salt: generateSalt(now, rotation), + salt: generateSalt(time.Now(), rotation), rotation: rotation, - saltKey: getSaltKey(now, rotation), } } @@ -40,13 +36,11 @@ func (u *UniquesEstimator) Add(ip, userAgent string) { defer u.mu.Unlock() // Check if we need to rotate to a new period - now := time.Now() - currentKey := getSaltKey(now, u.rotation) - if currentKey != u.saltKey { + currentSalt := generateSalt(time.Now(), u.rotation) + if currentSalt != u.salt { // Reset HLL for new period u.hll = hyperloglog.New() - u.salt = generateSaltFromKey(currentKey) - u.saltKey = currentKey + u.salt = currentSalt } // Hash visitor with salt to prevent cross-period tracking @@ -61,36 +55,24 @@ func (u *UniquesEstimator) Estimate() uint64 { return u.hll.Estimate() } -// Returns the time-based key for salt generation without hashing -func getSaltKey(t time.Time, rotation string) string { - if rotation == "hourly" { - return t.UTC().Format("2006-01-02T15") - } - return t.UTC().Format("2006-01-02") -} - -// Creates a salt from a pre-computed key -func generateSaltFromKey(key string) string { - h := sha256.Sum256([]byte("watchdog-salt-" + key)) - return hex.EncodeToString(h[:]) -} - // Generates a deterministic salt based on the rotation mode // Daily: same day = same salt, different day = different salt // Hourly: same hour = same salt, different hour = different salt func generateSalt(t time.Time, rotation string) string { - return generateSaltFromKey(getSaltKey(t, rotation)) + var key string + if rotation == "hourly" { + key = t.UTC().Format("2006-01-02T15") + } else { + key = t.UTC().Format("2006-01-02") + } + h := sha256.Sum256([]byte("watchdog-salt-" + key)) + return hex.EncodeToString(h[:]) } // Creates a privacy-preserving hash of visitor identity func hashVisitor(ip, userAgent, salt string) string { - var sb strings.Builder - sb.WriteString(ip) - sb.WriteString("|") - sb.WriteString(userAgent) - sb.WriteString("|") - sb.WriteString(salt) - h := sha256.Sum256([]byte(sb.String())) + combined := ip + "|" + userAgent + "|" + salt + h := sha256.Sum256([]byte(combined)) return hex.EncodeToString(h[:]) } @@ -140,20 +122,16 @@ func (u *UniquesEstimator) Load(path string) error { } savedSalt := string(parts[0]) - now := time.Now() - currentKey := getSaltKey(now, u.rotation) - currentSalt := generateSaltFromKey(currentKey) + currentSalt := generateSalt(time.Now(), u.rotation) // Only restore if it's the same period if savedSalt == currentSalt { u.salt = savedSalt - u.saltKey = currentKey return u.hll.UnmarshalBinary(parts[1]) } // Different period, start fresh u.hll = hyperloglog.New() u.salt = currentSalt - u.saltKey = currentKey return nil } diff --git a/internal/api/event.go b/internal/api/event.go index 7f71ce9..7cb828f 100644 --- a/internal/api/event.go +++ b/internal/api/event.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "slices" "notashelf.dev/watchdog/internal/limits" ) @@ -39,8 +40,40 @@ func ParseEvent(body io.Reader) (*Event, error) { return &event, nil } -// Validate checks if the event is valid using a domain map -func (e *Event) Validate(allowedDomains map[string]bool) error { +// Validate checks if the event is valid for the given domains +func (e *Event) Validate(allowedDomains []string) error { + if e.Domain == "" { + return fmt.Errorf("domain required") + } + + // Check if domain is in allowed list + allowed := slices.Contains(allowedDomains, e.Domain) + if !allowed { + return fmt.Errorf("domain not allowed") + } + + if e.Path == "" { + return fmt.Errorf("path required") + } + + if len(e.Path) > limits.MaxPathLen { + return fmt.Errorf("path too long") + } + + if len(e.Referrer) > limits.MaxRefLen { + return fmt.Errorf("referrer too long") + } + + // Validate screen width is in reasonable range + if e.Width < 0 || e.Width > limits.MaxWidth { + return fmt.Errorf("invalid width") + } + + return nil +} + +// ValidateWithMap checks if the event is valid using a domain map (O(1) lookup) +func (e *Event) ValidateWithMap(allowedDomains map[string]bool) error { if e.Domain == "" { return fmt.Errorf("domain required") } diff --git a/internal/api/event_test.go b/internal/api/event_test.go index f14d0e2..0a075fe 100644 --- a/internal/api/event_test.go +++ b/internal/api/event_test.go @@ -115,7 +115,7 @@ func TestValidateEvent(t *testing.T) { tests := []struct { name string event Event - domains map[string]bool + domains []string wantErr bool }{ { @@ -124,7 +124,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "/home", }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: false, }, { @@ -134,7 +134,7 @@ func TestValidateEvent(t *testing.T) { Path: "/signup", Event: "signup", }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: false, }, { @@ -143,7 +143,7 @@ func TestValidateEvent(t *testing.T) { Domain: "wrong.com", Path: "/home", }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: true, }, { @@ -152,7 +152,7 @@ func TestValidateEvent(t *testing.T) { Domain: "", Path: "/home", }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: true, }, { @@ -161,7 +161,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "", }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: true, }, { @@ -170,7 +170,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "/" + strings.Repeat("a", 3000), }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: true, }, { @@ -180,7 +180,7 @@ func TestValidateEvent(t *testing.T) { Path: "/home", Referrer: strings.Repeat("a", 3000), }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: true, }, { @@ -189,7 +189,7 @@ func TestValidateEvent(t *testing.T) { Domain: "example.com", Path: "/" + strings.Repeat("a", 2000), }, - domains: map[string]bool{"example.com": true}, + domains: []string{"example.com"}, wantErr: false, }, { @@ -198,7 +198,7 @@ func TestValidateEvent(t *testing.T) { Domain: "site1.com", Path: "/home", }, - domains: map[string]bool{"site1.com": true, "site2.com": true}, + domains: []string{"site1.com", "site2.com"}, wantErr: false, }, { @@ -207,7 +207,7 @@ func TestValidateEvent(t *testing.T) { Domain: "site2.com", Path: "/about", }, - domains: map[string]bool{"site1.com": true, "site2.com": true}, + domains: []string{"site1.com", "site2.com"}, wantErr: false, }, { @@ -216,7 +216,7 @@ func TestValidateEvent(t *testing.T) { Domain: "site3.com", Path: "/home", }, - domains: map[string]bool{"site1.com": true, "site2.com": true}, + domains: []string{"site1.com", "site2.com"}, wantErr: true, }, } @@ -231,7 +231,24 @@ func TestValidateEvent(t *testing.T) { } } -func BenchmarkValidate(b *testing.B) { +func BenchmarkValidate_SliceLookup(b *testing.B) { + // Simulate multi-site with 50 domains + domains := make([]string, 50) + for i := range 50 { + domains[i] = strings.Repeat("site", i) + ".com" + } + + event := Event{ + Domain: domains[49], // Worst case - last in list + Path: "/test", + } + + for b.Loop() { + _ = event.Validate(domains) + } +} + +func BenchmarkValidate_MapLookup(b *testing.B) { // Simulate multi-site with 50 domains domainMap := make(map[string]bool, 50) for i := range 50 { @@ -239,11 +256,11 @@ func BenchmarkValidate(b *testing.B) { } event := Event{ - Domain: strings.Repeat("site", 49) + ".com", + Domain: strings.Repeat("site", 49) + ".com", // any position Path: "/test", } for b.Loop() { - _ = event.Validate(domainMap) + _ = event.ValidateWithMap(domainMap) } } diff --git a/internal/api/handler.go b/internal/api/handler.go index ff068da..4cb3905 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -1,14 +1,10 @@ package api import ( - "crypto/rand" - "encoding/hex" - "fmt" - mrand "math/rand" + "math/rand" "net" "net/http" "strings" - "time" "notashelf.dev/watchdog/internal/aggregate" "notashelf.dev/watchdog/internal/config" @@ -18,16 +14,14 @@ import ( // Handles incoming analytics events type IngestionHandler struct { - cfg *config.Config - domainMap map[string]bool // O(1) domain validation - corsOriginMap map[string]bool // O(1) CORS origin validation - pathNorm *normalize.PathNormalizer - pathRegistry *aggregate.PathRegistry - refRegistry *normalize.ReferrerRegistry - metricsAgg *aggregate.MetricsAggregator - rateLimiter *ratelimit.TokenBucket - rng *mrand.Rand - trustedNetworks []*net.IPNet // pre-parsed CIDR networks + cfg *config.Config + domainMap map[string]bool // O(1) domain validation + pathNorm *normalize.PathNormalizer + pathRegistry *aggregate.PathRegistry + refRegistry *normalize.ReferrerRegistry + metricsAgg *aggregate.MetricsAggregator + rateLimiter *ratelimit.TokenBucket + rng *rand.Rand } // Creates a new ingestion handler @@ -53,51 +47,19 @@ func NewIngestionHandler( domainMap[domain] = true } - // Build CORS origin map for O(1) lookup - corsOriginMap := make(map[string]bool, len(cfg.Security.CORS.AllowedOrigins)) - for _, origin := range cfg.Security.CORS.AllowedOrigins { - corsOriginMap[origin] = true - } - - // Pre-parse trusted proxy CIDRs to avoid re-parsing on each request - trustedNetworks := make([]*net.IPNet, 0, len(cfg.Security.TrustedProxies)) - for _, cidr := range cfg.Security.TrustedProxies { - if _, network, err := net.ParseCIDR(cidr); err == nil { - trustedNetworks = append(trustedNetworks, network) - } else if ip := net.ParseIP(cidr); ip != nil { - // Single IP - create a /32 or /128 network - var mask net.IPMask - if ip.To4() != nil { - mask = net.CIDRMask(32, 32) - } else { - mask = net.CIDRMask(128, 128) - } - trustedNetworks = append(trustedNetworks, &net.IPNet{IP: ip, Mask: mask}) - } - } - return &IngestionHandler{ - cfg: cfg, - domainMap: domainMap, - corsOriginMap: corsOriginMap, - pathNorm: pathNorm, - pathRegistry: pathRegistry, - refRegistry: refRegistry, - metricsAgg: metricsAgg, - rateLimiter: limiter, - rng: mrand.New(mrand.NewSource(time.Now().UnixNano())), - trustedNetworks: trustedNetworks, + cfg: cfg, + domainMap: domainMap, + pathNorm: pathNorm, + pathRegistry: pathRegistry, + refRegistry: refRegistry, + metricsAgg: metricsAgg, + rateLimiter: limiter, + rng: rand.New(rand.NewSource(42)), } } func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Generate or extract request ID for tracing - requestID := r.Header.Get("X-Request-ID") - if requestID == "" { - requestID = generateRequestID() - } - w.Header().Set("X-Request-ID", requestID) - // Handle CORS preflight if r.Method == http.MethodOptions { h.handleCORS(w, r) @@ -141,8 +103,8 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Validate event via map lookup - if err := event.Validate(h.domainMap); err != nil { + // Validate event via map lookup (also O(1)) + if err := event.ValidateWithMap(h.domainMap); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } @@ -219,8 +181,13 @@ func (h *IngestionHandler) handleCORS(w http.ResponseWriter, r *http.Request) { } // Check if origin is allowed - // This uses map so that it's O(1) - allowed := h.corsOriginMap["*"] || h.corsOriginMap[origin] + allowed := false + for _, allowedOrigin := range h.cfg.Security.CORS.AllowedOrigins { + if allowedOrigin == "*" || allowedOrigin == origin { + allowed = true + break + } + } if allowed { if origin == "*" { @@ -246,8 +213,13 @@ func (h *IngestionHandler) extractIP(r *http.Request) string { // Check if we should trust proxy headers trustProxy := false - if len(h.trustedNetworks) > 0 { - trustProxy = h.ipInNetworks(remoteIP, h.trustedNetworks) + if len(h.cfg.Security.TrustedProxies) > 0 { + for _, trustedCIDR := range h.cfg.Security.TrustedProxies { + if h.ipInCIDR(remoteIP, trustedCIDR) { + trustProxy = true + break + } + } } // If not trusting proxy, return direct IP @@ -261,43 +233,42 @@ func (h *IngestionHandler) extractIP(r *http.Request) string { ips := strings.Split(xff, ",") for i := len(ips) - 1; i >= 0; i-- { ip := strings.TrimSpace(ips[i]) - if testIP := net.ParseIP(ip); testIP != nil { - // Only accept this IP if it's NOT from a trusted proxy - if !h.ipInNetworks(ip, h.trustedNetworks) { - return ip - } + if !h.ipInCIDR(ip, "0.0.0.0/0") { + continue } + return ip } } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { - // Validate the IP format and ensure it's not from a trusted proxy - if testIP := net.ParseIP(xri); testIP != nil { - if !h.ipInNetworks(xri, h.trustedNetworks) { - return xri - } - } + return xri } // Fall back to RemoteAddr return remoteIP } -// Checks if an IP address is within any of the trusted networks -func (h *IngestionHandler) ipInNetworks(ip string, networks []*net.IPNet) bool { +// Checks if an IP address is within a CIDR range +func (h *IngestionHandler) ipInCIDR(ip, cidr string) bool { + // Parse the IP address testIP := net.ParseIP(ip) if testIP == nil { return false } - for _, network := range networks { - if network.Contains(testIP) { - return true + // Parse the CIDR + _, network, err := net.ParseCIDR(cidr) + if err != nil { + // If it's not a CIDR, try as a single IP + cidrIP := net.ParseIP(cidr) + if cidrIP == nil { + return false } + return testIP.Equal(cidrIP) } - return false + return network.Contains(testIP) } // Classifies device using both screen width and User-Agent parsing @@ -340,17 +311,3 @@ func (h *IngestionHandler) classifyDevice(width int, userAgent string) string { return "unknown" } - -// Creates a unique request ID for tracing. -// Uses 8 bytes (64 bits) of randomness which produces 16 hex characters. -// 2^64 possible IDs (~18 quintillion) provides sufficient uniqueness for -// request tracing while keeping IDs reasonably short in logs and headers. -func generateRequestID() string { - // 8 bytes = 64 bits = 16 hex chars = 2^64 possible IDs - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - // Fallback to timestamp if crypto/rand fails - return fmt.Sprintf("%d", time.Now().UnixNano()) - } - return hex.EncodeToString(b) -} diff --git a/internal/api/handler_test.go b/internal/api/handler_test.go index 0ba35d1..cee1dbb 100644 --- a/internal/api/handler_test.go +++ b/internal/api/handler_test.go @@ -218,107 +218,6 @@ func newTestHandler(cfg *config.Config) *IngestionHandler { return NewIngestionHandler(cfg, pathNorm, pathRegistry, refRegistry, metricsAgg) } -func TestExtractIP(t *testing.T) { - cfg := &config.Config{ - Site: config.SiteConfig{ - Domains: []string{"example.com"}, - }, - Limits: config.LimitsConfig{ - MaxPaths: 100, - MaxSources: 50, - }, - Security: config.SecurityConfig{ - TrustedProxies: []string{"10.0.0.0/8", "192.168.1.1"}, - }, - } - h := newTestHandler(cfg) - - tests := []struct { - name string - remoteAddr string - headers map[string]string - want string - }{ - { - name: "direct connection no proxy", - remoteAddr: "192.168.1.100:12345", - headers: map[string]string{}, - want: "192.168.1.100", - }, - { - name: "trusted proxy with X-Forwarded-For", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{ - "X-Forwarded-For": "203.0.113.1, 10.0.0.5", - }, - want: "203.0.113.1", - }, - { - name: "trusted proxy with X-Real-IP", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{ - "X-Real-IP": "203.0.113.2", - }, - want: "203.0.113.2", - }, - { - name: "X-Real-IP from trusted network should be ignored", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{ - "X-Real-IP": "10.0.0.50", // trusted network, should fall back - }, - want: "10.0.0.1", // falls back to remoteAddr - }, - { - name: "X-Real-IP invalid IP should be ignored", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{ - "X-Real-IP": "not-an-ip", - }, - want: "10.0.0.1", // falls back - }, - { - name: "untrusted proxy X-Forwarded-For ignored", - remoteAddr: "203.0.113.50:12345", - headers: map[string]string{ - "X-Forwarded-For": "1.2.3.4", - }, - want: "203.0.113.50", // uses remoteAddr, ignores header - }, - { - name: "untrusted proxy X-Real-IP ignored", - remoteAddr: "203.0.113.50:12345", - headers: map[string]string{ - "X-Real-IP": "1.2.3.4", - }, - want: "203.0.113.50", // uses remoteAddr, ignores header - }, - { - name: "X-Forwarded-For all trusted falls back", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{ - "X-Forwarded-For": "10.0.0.2, 10.0.0.3", - }, - want: "10.0.0.1", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("POST", "/api/event", nil) - req.RemoteAddr = tt.remoteAddr - for k, v := range tt.headers { - req.Header.Set(k, v) - } - - got := h.extractIP(req) - if got != tt.want { - t.Errorf("extractIP() = %q, want %q", got, tt.want) - } - }) - } -} - func TestClassifyDevice_UA(t *testing.T) { cfg := &config.Config{ Limits: config.LimitsConfig{ diff --git a/internal/config/config.go b/internal/config/config.go index 158b58f..595306e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,12 +45,11 @@ type PathConfig struct { // Cardinality limits type LimitsConfig struct { - MaxPaths int `yaml:"max_paths"` - MaxEventsPerMinute int `yaml:"max_events_per_minute"` - MaxSources int `yaml:"max_sources"` - MaxCustomEvents int `yaml:"max_custom_events"` - DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"` - MaxMetricsPerMinute int `yaml:"max_metrics_per_minute"` // rate limit for metrics endpoint + MaxPaths int `yaml:"max_paths"` + MaxEventsPerMinute int `yaml:"max_events_per_minute"` + MaxSources int `yaml:"max_sources"` + MaxCustomEvents int `yaml:"max_custom_events"` + DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"` } // Device classification breakpoints @@ -73,11 +72,10 @@ type CORSConfig struct { } // Authentication for metrics endpoint -// Password can be set via environment variable: WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD type AuthConfig struct { Enabled bool `yaml:"enabled"` Username string `yaml:"username"` - Password string `yaml:"password"` // can use env var WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD + Password string `yaml:"password"` } // Server endpoints @@ -151,10 +149,6 @@ func (c *Config) Validate() error { c.Limits.MaxCustomEvents = 100 // Default } - if c.Limits.MaxMetricsPerMinute <= 0 { - c.Limits.MaxMetricsPerMinute = 30 // Default: 30 requests per minute - } - if c.Site.Path.MaxSegments < 0 { return fmt.Errorf("site.path.max_segments cannot be negative") } diff --git a/internal/normalize/path.go b/internal/normalize/path.go index 6dadd4f..9836b67 100644 --- a/internal/normalize/path.go +++ b/internal/normalize/path.go @@ -48,36 +48,38 @@ func (n *PathNormalizer) Normalize(path string) string { path = "/" + path } - // Process segments in-place to minimize allocations // Split into segments, first element is *always* empty for paths starting with '/' segments := strings.Split(path, "/") + if len(segments) > 0 && segments[0] == "" { + segments = segments[1:] + } - // Process segments in a single pass: remove empty, resolve . and .. - writeIdx := 0 - for i := 0; i < len(segments); i++ { - seg := segments[i] - - // Skip empty segments (from double slashes or leading /) - if seg == "" { - continue + // Remove empty segments (from double slashes) + filtered := make([]string, 0, len(segments)) + for _, seg := range segments { + if seg != "" { + filtered = append(filtered, seg) } + } + segments = filtered + // Resolve . and .. segments + resolved := make([]string, 0, len(segments)) + for _, seg := range segments { if seg == "." { // Skip current directory continue } else if seg == ".." { // Go up one level if possible - if writeIdx > 0 { - writeIdx-- + if len(resolved) > 0 { + resolved = resolved[:len(resolved)-1] } // If already at root, skip .. } else { - // Keep this segment - segments[writeIdx] = seg - writeIdx++ + resolved = append(resolved, seg) } } - segments = segments[:writeIdx] + segments = resolved // Collapse numeric segments if n.cfg.CollapseNumericSegments { diff --git a/internal/normalize/referrer.go b/internal/normalize/referrer.go index 1b43423..df0189d 100644 --- a/internal/normalize/referrer.go +++ b/internal/normalize/referrer.go @@ -1,7 +1,6 @@ package normalize import ( - "net" "net/url" "strings" @@ -22,20 +21,34 @@ func isInternalHost(hostname string) bool { return true } - // Check if hostname is an IP address - if ip := net.ParseIP(hostname); ip != nil { - // Private IPv4 ranges (RFC1918) - if ip.IsPrivate() { - return true - } - // Additional localhost checks for IP formats - if ip.IsLoopback() { - return true - } - // Link-local addresses - if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { - return true - } + // Private IPv4 ranges (RFC1918) + if strings.HasPrefix(hostname, "10.") || + strings.HasPrefix(hostname, "192.168.") || + strings.HasPrefix(hostname, "172.16.") || + strings.HasPrefix(hostname, "172.17.") || + strings.HasPrefix(hostname, "172.18.") || + strings.HasPrefix(hostname, "172.19.") || + strings.HasPrefix(hostname, "172.20.") || + strings.HasPrefix(hostname, "172.21.") || + strings.HasPrefix(hostname, "172.22.") || + strings.HasPrefix(hostname, "172.23.") || + strings.HasPrefix(hostname, "172.24.") || + strings.HasPrefix(hostname, "172.25.") || + strings.HasPrefix(hostname, "172.26.") || + strings.HasPrefix(hostname, "172.27.") || + strings.HasPrefix(hostname, "172.28.") || + strings.HasPrefix(hostname, "172.29.") || + strings.HasPrefix(hostname, "172.30.") || + strings.HasPrefix(hostname, "172.31.") { + return true + } + + // IPv6 loopback and local + if strings.HasPrefix(hostname, "::1") || + strings.HasPrefix(hostname, "fe80::") || + strings.HasPrefix(hostname, "fc00::") || + strings.HasPrefix(hostname, "fd00::") { + return true } return false diff --git a/internal/normalize/referrer_registry.go b/internal/normalize/referrer_registry.go index 4b7913c..daaf69c 100644 --- a/internal/normalize/referrer_registry.go +++ b/internal/normalize/referrer_registry.go @@ -56,20 +56,3 @@ func (r *ReferrerRegistry) OverflowCount() int { defer r.mu.RUnlock() return r.overflowCount } - -// Contains checks if a source exists in the registry. -func (r *ReferrerRegistry) Contains(source string) bool { - r.mu.RLock() - defer r.mu.RUnlock() - - _, exists := r.sources[source] - return exists -} - -// Count returns the number of unique sources in the registry. -func (r *ReferrerRegistry) Count() int { - r.mu.RLock() - defer r.mu.RUnlock() - - return len(r.sources) -} diff --git a/internal/normalize/referrer_test.go b/internal/normalize/referrer_test.go index 0fa2e21..04a8d0a 100644 --- a/internal/normalize/referrer_test.go +++ b/internal/normalize/referrer_test.go @@ -108,36 +108,6 @@ func TestExtractReferrerDomain(t *testing.T) { siteDomain: "example.com", want: "internal", }, - { - name: "private IP 172.16.x (RFC1918)", - referrer: "http://172.16.0.1/page", - siteDomain: "example.com", - want: "internal", - }, - { - name: "private IP 172.31.x (RFC1918 upper bound)", - referrer: "http://172.31.255.1/page", - siteDomain: "example.com", - want: "internal", - }, - { - name: "private IP 172.20.x (middle of range)", - referrer: "http://172.20.50.100/page", - siteDomain: "example.com", - want: "internal", - }, - { - name: "public IP 172.15.x (just outside private range)", - referrer: "http://172.15.0.1/page", - siteDomain: "example.com", - want: "other", // not internal, but invalid TLD - }, - { - name: "public IP 172.32.x (just outside private range)", - referrer: "http://172.32.0.1/page", - siteDomain: "example.com", - want: "other", // not internal, but invalid TLD - }, } for _, tt := range tests { diff --git a/test/integration_test.go b/test/integration_test.go index 9121e43..3bc7220 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -195,9 +195,6 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) { MaxSources: 50, MaxCustomEvents: 10, }, - Server: config.ServerConfig{ - StatePath: "/tmp/watchdog-test.state", - }, } pathRegistry := aggregate.NewPathRegistry(cfg.Limits.MaxPaths) @@ -226,12 +223,12 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) { } // Verify state file was created - if _, err := os.Stat("/tmp/watchdog-test.state"); os.IsNotExist(err) { + if _, err := os.Stat("/tmp/watchdog-hll.state"); os.IsNotExist(err) { t.Error("HLL state file was not created") } // Cleanup - os.Remove("/tmp/watchdog-test.state") + os.Remove("/tmp/watchdog-hll.state") } func TestEndToEnd_InvalidRequests(t *testing.T) {