Compare commits

..

No commits in common. "main" and "v1.0.0" have entirely different histories.

20 changed files with 227 additions and 712 deletions

View file

@ -3,21 +3,8 @@ name: Build & Test with Go
on: on:
push: push:
branches: [main] branches: [main]
paths:
- '**.go'
- 'go.mod'
- 'go.sum'
- '.github/workflows/go.yml'
- 'testdata/**'
pull_request: pull_request:
branches: [main] branches: [main]
paths:
- '**.go'
- 'go.mod'
- 'go.sum'
- '.github/workflows/go.yml'
- 'testdata/**'
jobs: jobs:
test: test:

View file

@ -39,7 +39,7 @@ jobs:
GOARCH: ${{ matrix.goarch }} GOARCH: ${{ matrix.goarch }}
- name: Upload artifact - name: Upload artifact
uses: actions/upload-artifact@v7 uses: actions/upload-artifact@v4
with: with:
name: watchdog-${{ matrix.goos }}-${{ matrix.goarch }} name: watchdog-${{ matrix.goos }}-${{ matrix.goarch }}
path: watchdog-* path: watchdog-*
@ -50,17 +50,17 @@ jobs:
permissions: permissions:
contents: write contents: write
steps: steps:
- uses: actions/checkout@v6
- name: Download all artifacts - name: Download all artifacts
uses: actions/download-artifact@v8 uses: actions/download-artifact@v8
with: with:
path: artifacts path: artifacts
pattern: watchdog-*
merge-multiple: true
- name: Create GitHub Release - name: Create GitHub Release
uses: softprops/action-gh-release@v2 uses: softprops/action-gh-release@v2
with: with:
files: artifacts/watchdog-* files: artifacts/**/*.tar.gz artifacts/**/*.zip
generate_release_notes: true generate_release_notes: true
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View file

@ -25,23 +25,18 @@ type versionInfo struct {
BuildDate string `json:"buildDate"` BuildDate string `json:"buildDate"`
} }
func getVersionInfo() versionInfo {
data, err := os.ReadFile("version.json")
if err != nil {
return versionInfo{}
}
var v versionInfo
if err := json.Unmarshal(data, &v); err != nil {
return versionInfo{}
}
return v
}
func getVersion() string { func getVersion() string {
if version != "" { if version != "" {
return version return version
} }
v := getVersionInfo() data, err := os.ReadFile("version.json")
if err != nil {
return "dev"
}
var v versionInfo
if err := json.Unmarshal(data, &v); err != nil {
return "dev"
}
if v.Version != "" { if v.Version != "" {
return v.Version return v.Version
} }
@ -52,7 +47,14 @@ func getCommit() string {
if commit != "" { if commit != "" {
return commit return commit
} }
v := getVersionInfo() data, err := os.ReadFile("version.json")
if err != nil {
return "none"
}
var v versionInfo
if err := json.Unmarshal(data, &v); err != nil {
return "none"
}
if v.Commit != "" { if v.Commit != "" {
return v.Commit return v.Commit
} }

View file

@ -9,10 +9,8 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
@ -22,7 +20,6 @@ import (
"notashelf.dev/watchdog/internal/health" "notashelf.dev/watchdog/internal/health"
"notashelf.dev/watchdog/internal/limits" "notashelf.dev/watchdog/internal/limits"
"notashelf.dev/watchdog/internal/normalize" "notashelf.dev/watchdog/internal/normalize"
"notashelf.dev/watchdog/internal/ratelimit"
) )
func Run(cfg *config.Config) error { func Run(cfg *config.Config) error {
@ -77,7 +74,7 @@ func Run(cfg *config.Config) error {
// Setup routes // Setup routes
mux := http.NewServeMux() mux := http.NewServeMux()
// Metrics endpoint with optional basic auth and rate limiting // Metrics endpoint with optional basic auth
metricsHandler := promhttp.HandlerFor(promRegistry, promhttp.HandlerOpts{ metricsHandler := promhttp.HandlerFor(promRegistry, promhttp.HandlerOpts{
EnableOpenMetrics: true, EnableOpenMetrics: true,
}) })
@ -90,20 +87,6 @@ func Run(cfg *config.Config) error {
) )
} }
// Add rate limiting to metrics endpoint
if cfg.Limits.MaxMetricsPerMinute > 0 {
metricsRateLimiter := ratelimit.NewTokenBucket(
cfg.Limits.MaxMetricsPerMinute,
cfg.Limits.MaxMetricsPerMinute,
time.Minute,
)
metricsHandler = rateLimitMiddleware(metricsHandler, metricsRateLimiter)
}
// Add response size limit to metrics endpoint (10MB max)
const maxMetricsResponseSize = 10 * 1024 * 1024 // 10MB
metricsHandler = responseSizeLimitMiddleware(metricsHandler, maxMetricsResponseSize)
mux.Handle(cfg.Server.MetricsPath, metricsHandler) mux.Handle(cfg.Server.MetricsPath, metricsHandler)
// Ingestion endpoint // Ingestion endpoint
@ -184,69 +167,6 @@ func basicAuth(next http.Handler, username, password string) http.Handler {
}) })
} }
// Wraps a handler with rate limiting
func rateLimitMiddleware(next http.Handler, limiter *ratelimit.TokenBucket) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// Wraps http.ResponseWriter to enforce max response size
type limitedResponseWriter struct {
http.ResponseWriter
maxSize int
written int
limitExceeded bool
}
func (w *limitedResponseWriter) Write(p []byte) (int, error) {
if w.limitExceeded {
return 0, fmt.Errorf("response size limit exceeded")
}
if w.written+len(p) > w.maxSize {
w.limitExceeded = true
w.Header().Set("X-Response-Truncated", "true")
http.Error(w.ResponseWriter, "Response size limit exceeded", http.StatusInternalServerError)
return 0, fmt.Errorf("response size limit exceeded: %d bytes", w.maxSize)
}
n, err := w.ResponseWriter.Write(p)
w.written += n
return n, err
}
// Wraps a handler with response size limiting
func responseSizeLimitMiddleware(next http.Handler, maxSize int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limited := &limitedResponseWriter{
ResponseWriter: w,
maxSize: maxSize,
}
next.ServeHTTP(limited, r)
})
}
// Sanitizes a path for logging to prevent log injection attacks. Uses `strconv.Quote`
// to properly escape control characters and special bytes.
func sanitizePathForLog(path string) string {
escaped := strconv.Quote(path)
if len(escaped) >= 2 && escaped[0] == '"' && escaped[len(escaped)-1] == '"' {
escaped = escaped[1 : len(escaped)-1]
}
// Limit path length to prevent log flooding
const maxLen = 200
if len(escaped) > maxLen {
return escaped[:maxLen] + "..."
}
return escaped
}
// Creates a file server that only serves whitelisted files. Blocks dotfiles, .git, .env, etc. // Creates a file server that only serves whitelisted files. Blocks dotfiles, .git, .env, etc.
// TODO: I need to hook this up to eris somehow so I can just forward the paths that are being // TODO: I need to hook this up to eris somehow so I can just forward the paths that are being
// scanned despite not being on a whitelist. Would be a good way of detecting scrapers, maybe. // scanned despite not being on a whitelist. Would be a good way of detecting scrapers, maybe.
@ -259,7 +179,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
// Block directory listings // Block directory listings
if strings.HasSuffix(path, "/") { if strings.HasSuffix(path, "/") {
blockedRequests.WithLabelValues("directory_listing").Inc() blockedRequests.WithLabelValues("directory_listing").Inc()
log.Printf("Blocked directory listing attempt: %s from %s", sanitizePathForLog(path), r.RemoteAddr) log.Printf("Blocked directory listing attempt: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
@ -268,7 +188,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
for segment := range strings.SplitSeq(path, "/") { for segment := range strings.SplitSeq(path, "/") {
if strings.HasPrefix(segment, ".") { if strings.HasPrefix(segment, ".") {
blockedRequests.WithLabelValues("dotfile").Inc() blockedRequests.WithLabelValues("dotfile").Inc()
log.Printf("Blocked dotfile access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) log.Printf("Blocked dotfile access: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
@ -279,7 +199,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
strings.HasSuffix(lower, ".bak") || strings.HasSuffix(lower, ".bak") ||
strings.HasSuffix(lower, "~") { strings.HasSuffix(lower, "~") {
blockedRequests.WithLabelValues("sensitive_file").Inc() blockedRequests.WithLabelValues("sensitive_file").Inc()
log.Printf("Blocked sensitive file access: %s from %s", sanitizePathForLog(path), r.RemoteAddr) log.Printf("Blocked sensitive file access: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
@ -289,7 +209,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
ext := strings.ToLower(filepath.Ext(path)) ext := strings.ToLower(filepath.Ext(path))
if ext != ".js" && ext != ".html" && ext != ".css" { if ext != ".js" && ext != ".html" && ext != ".css" {
blockedRequests.WithLabelValues("invalid_extension").Inc() blockedRequests.WithLabelValues("invalid_extension").Inc()
log.Printf("Blocked invalid extension: %s from %s", sanitizePathForLog(path), r.RemoteAddr) log.Printf("Blocked invalid extension: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r) http.NotFound(w, r)
return return
} }

View file

@ -1,116 +0,0 @@
package watchdog
import (
"strings"
"testing"
)
func TestSanitizePathForLog(t *testing.T) {
tests := []struct {
name string
input string
want string
maxLen int // expected max length check
}{
{
name: "normal path",
input: "/web/beacon.js",
want: "/web/beacon.js",
maxLen: 200,
},
{
name: "path with newlines",
input: "/test\nmalicious",
want: `/test\nmalicious`,
maxLen: 200,
},
{
name: "path with carriage return",
input: "/test\rmalicious",
want: `/test\rmalicious`,
maxLen: 200,
},
{
name: "path with tabs",
input: "/test\tmalicious",
want: `/test\tmalicious`,
maxLen: 200,
},
{
name: "path with null bytes",
input: "/test\x00malicious",
want: `/test\x00malicious`,
maxLen: 200,
},
{
name: "path with quotes",
input: `/test"malicious`,
want: `/test\"malicious`,
maxLen: 200,
},
{
name: "path with backslash",
input: `/test\malicious`,
want: `/test\\malicious`,
maxLen: 200,
},
{
name: "control characters",
input: "/test\x01\x02\x1fmalicious",
want: `/test\x01\x02\x1fmalicious`,
maxLen: 200,
},
{
name: "truncation at 200 chars",
input: "/" + strings.Repeat("a", 250),
want: "/" + strings.Repeat("a", 199) + "...",
maxLen: 203, // 200 chars + "..." = 203
},
{
name: "empty string",
input: "",
want: "",
maxLen: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sanitizePathForLog(tt.input)
if got != tt.want {
t.Errorf("sanitizePathForLog(%q) = %q, want %q", tt.input, got, tt.want)
}
if len(got) > tt.maxLen {
t.Errorf("sanitizePathForLog(%q) length = %d, exceeds max %d", tt.input, len(got), tt.maxLen)
}
})
}
}
func TestSanitizePathForLog_LogInjectionPrevention(t *testing.T) {
// Log injection attempts should be neutralized
maliciousPaths := []string{
"/api\nINFO: fake log entry",
"/test\r\nERROR: fake error",
"/.git/config\x00", // null byte injection
}
for _, path := range maliciousPaths {
sanitized := sanitizePathForLog(path)
// Check that newlines are escaped, not literal
if strings.Contains(sanitized, "\n") || strings.Contains(sanitized, "\r") {
t.Errorf("sanitizePathForLog(%q) contains literal newlines: %q", path, sanitized)
}
// Check that null bytes are escaped
if strings.Contains(sanitized, "\x00") {
t.Errorf("sanitizePathForLog(%q) contains literal null byte: %q", path, sanitized)
}
}
}
func BenchmarkSanitizePathForLog(b *testing.B) {
path := "/test/path/with\nnewlines\rand\ttabs"
for b.Loop() {
_ = sanitizePathForLog(path)
}
}

View file

@ -63,8 +63,6 @@ limits:
max_custom_events: 100 max_custom_events: 100
# Maximum events per minute (rate limiting, 0 = unlimited) # Maximum events per minute (rate limiting, 0 = unlimited)
max_events_per_minute: 10000 max_events_per_minute: 10000
# Maximum metrics endpoint requests per minute (rate limiting, 0 = unlimited, default: 30)
max_metrics_per_minute: 60
# Device classification breakpoints (screen width in pixels) # Device classification breakpoints (screen width in pixels)
device_breakpoints: device_breakpoints:
@ -87,8 +85,6 @@ security:
- "*" # Or specific domains: ["https://example.com", "https://www.example.com"] - "*" # Or specific domains: ["https://example.com", "https://www.example.com"]
# Basic authentication for /metrics endpoint # Basic authentication for /metrics endpoint
# Password can also be set via environment variable:
# - `$ export WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD=your-secret-password`
metrics_auth: metrics_auth:
enabled: false enabled: false
username: "admin" username: "admin"

View file

@ -6,17 +6,12 @@ Grafana.
> [!IMPORTANT] > [!IMPORTANT]
> >
> **Why you need a time-series database:** > **Why you need Prometheus:**
> >
> - Watchdog exposes _current state_ (counters, gauges) > - Watchdog exposes _current state_ (counters, gauges)
> - A TSDB _scrapes periodically_ and _stores time-series data_ > - Prometheus _scrapes periodically_ and _stores time-series data_
> - Grafana _visualizes_ the historical data > - Grafana _visualizes_ the historical data from Prometheus
> - Grafana cannot directly scrape Prometheus `/metrics` endpoints > - Grafana cannot directly scrape Prometheus `/metrics` endpoints
>
> **Compatible databases:**
>
> - [Prometheus](#prometheus-setup),
> - [VictoriaMetrics](#victoriametrics), or any Prometheus-compatible scraper
## Prometheus Setup ## Prometheus Setup
@ -129,7 +124,7 @@ For multiple Watchdog instances:
datasources.settings.datasources = [{ datasources.settings.datasources = [{
name = "Prometheus"; name = "Prometheus";
type = "prometheus"; type = "prometheus";
url = "http://localhost:9090"; # Or "http://localhost:8428" for VictoriaMetrics url = "http://localhost:9090";
isDefault = true; isDefault = true;
}]; }];
}; };
@ -233,34 +228,7 @@ sum by (instance) (rate(web_pageviews_total[5m]))
### VictoriaMetrics ### VictoriaMetrics
VictoriaMetrics is a fast, cost-effective monitoring solution and time-series Drop-in Prometheus replacement with better performance and compression:
database that is 100% compatible with Prometheus exposition format. Watchdog's
`/metrics` endpoint can be scraped directly by VictoriaMetrics without requiring
Prometheus.
#### Direct Scraping (Recommended)
VictoriaMetrics single-node mode can scrape Watchdog directly using standard
Prometheus scrape configuration:
**Configuration file (`/etc/victoriametrics/scrape.yml`):**
```yaml
scrape_configs:
- job_name: "watchdog"
static_configs:
- targets: ["localhost:8080"]
scrape_interval: 15s
metrics_path: /metrics
```
**Run VictoriaMetrics:**
```bash
victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml
```
**NixOS configuration:**
```nix ```nix
{ {
@ -268,84 +236,15 @@ victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml
enable = true; enable = true;
listenAddress = ":8428"; listenAddress = ":8428";
retentionPeriod = "12month"; retentionPeriod = "12month";
# Define scrape configs directly. 'prometheusConfig' is the configuration for
# Prometheus-style metrics endpoints, which Watchdog exports.
prometheusConfig = {
scrape_configs = [
{
job_name = "watchdog";
scrape_interval = "15s";
static_configs = [{
targets = [ "localhost:8080" ]; # replace the port
}];
}
];
};
};
}
```
#### Using `vmagent`
Alternatively, for distributed setups or when you need more advanced features
like relabeling, you may use `vmagent`:
```nix
{
services.vmagent = {
enable = true;
remoteWriteUrl = "http://localhost:8428/api/v1/write";
prometheusConfig = {
scrape_configs = [
{
job_name = "watchdog";
static_configs = [{
targets = [ "localhost:8080" ];
}];
}
];
};
}; };
services.victoriametrics = { # Configure Prometheus to remote-write to VictoriaMetrics
enable = true;
listenAddress = ":8428";
};
}
```
#### Prometheus Remote Write
If you are migrating from Prometheus, or if you need PromQL compatibility, or if
you just really like using Prometheus for some inexplicable reason you may keep
Prometheus but use VictoriaMetrics to remote-write.
```nix
{
services.prometheus = { services.prometheus = {
enable = true; enable = true;
port = 9090;
scrapeConfigs = [
{
job_name = "watchdog";
static_configs = [{
targets = [ "localhost:8080" ];
}];
}
];
remoteWrite = [{ remoteWrite = [{
url = "http://localhost:8428/api/v1/write"; url = "http://localhost:8428/api/v1/write";
}]; }];
}; };
services.victoriametrics = {
enable = true;
listenAddress = ":8428";
};
} }
``` ```
@ -375,10 +274,10 @@ metrics:
## Monitoring the Monitoring ## Monitoring the Monitoring
Monitor your scraper: Monitor Prometheus itself:
```promql ```promql
# Scrape success rate # Prometheus scrape success rate
up{job="watchdog"} up{job="watchdog"}
# Scrape duration # Scrape duration
@ -388,9 +287,14 @@ scrape_duration_seconds{job="watchdog"}
time() - timestamp(up{job="watchdog"}) time() - timestamp(up{job="watchdog"})
``` ```
For VictoriaMetrics, you can also monitor ingestion stats: ## Additional Recommendations
```bash 1. **Retention**: Set `--storage.tsdb.retention.time=30d` or longer based on
# VM internal metrics disk space
curl http://localhost:8428/metrics | grep vm_rows_inserted_total 2. **Backups**: Back up `/var/lib/prometheus` periodically (or whatever your
``` state directory is)
3. **Alerting**: Configure Prometheus alerting rules for critical metrics
4. **High Availability**: Run multiple Prometheus instances with identical
configs
5. **Remote Storage**: For long-term storage, use Thanos, Cortex, or
VictoriaMetrics

View file

@ -54,20 +54,3 @@ func (r *CustomEventRegistry) OverflowCount() int {
defer r.mu.RUnlock() defer r.mu.RUnlock()
return r.overflowCount return r.overflowCount
} }
// Contains checks if an event name exists in the registry.
func (r *CustomEventRegistry) Contains(eventName string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.events[eventName]
return exists
}
// Count returns the number of unique events in the registry.
func (r *CustomEventRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.events)
}

View file

@ -162,12 +162,8 @@ func (m *MetricsAggregator) RecordPageview(path, country, device, referrer, doma
labels := prometheus.Labels{"path": sanitizeLabel(path)} labels := prometheus.Labels{"path": sanitizeLabel(path)}
if m.cfg.Site.Collect.Country { if m.cfg.Site.Collect.Country {
if country == "" {
labels["country"] = "unknown"
} else {
labels["country"] = sanitizeLabel(country) labels["country"] = sanitizeLabel(country)
} }
}
if m.cfg.Site.Collect.Device { if m.cfg.Site.Collect.Device {
labels["device"] = sanitizeLabel(device) labels["device"] = sanitizeLabel(device)

View file

@ -6,7 +6,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"os" "os"
"strings"
"sync" "sync"
"time" "time"
@ -18,18 +17,15 @@ type UniquesEstimator struct {
hll *hyperloglog.Sketch hll *hyperloglog.Sketch
salt string salt string
rotation string // "daily" or "hourly" rotation string // "daily" or "hourly"
saltKey string // cached time key to avoid regeneration
mu sync.Mutex mu sync.Mutex
} }
// Creates a new unique visitor estimator // Creates a new unique visitor estimator
func NewUniquesEstimator(rotation string) *UniquesEstimator { func NewUniquesEstimator(rotation string) *UniquesEstimator {
now := time.Now()
return &UniquesEstimator{ return &UniquesEstimator{
hll: hyperloglog.New(), hll: hyperloglog.New(),
salt: generateSalt(now, rotation), salt: generateSalt(time.Now(), rotation),
rotation: rotation, rotation: rotation,
saltKey: getSaltKey(now, rotation),
} }
} }
@ -40,13 +36,11 @@ func (u *UniquesEstimator) Add(ip, userAgent string) {
defer u.mu.Unlock() defer u.mu.Unlock()
// Check if we need to rotate to a new period // Check if we need to rotate to a new period
now := time.Now() currentSalt := generateSalt(time.Now(), u.rotation)
currentKey := getSaltKey(now, u.rotation) if currentSalt != u.salt {
if currentKey != u.saltKey {
// Reset HLL for new period // Reset HLL for new period
u.hll = hyperloglog.New() u.hll = hyperloglog.New()
u.salt = generateSaltFromKey(currentKey) u.salt = currentSalt
u.saltKey = currentKey
} }
// Hash visitor with salt to prevent cross-period tracking // Hash visitor with salt to prevent cross-period tracking
@ -61,36 +55,24 @@ func (u *UniquesEstimator) Estimate() uint64 {
return u.hll.Estimate() return u.hll.Estimate()
} }
// Returns the time-based key for salt generation without hashing
func getSaltKey(t time.Time, rotation string) string {
if rotation == "hourly" {
return t.UTC().Format("2006-01-02T15")
}
return t.UTC().Format("2006-01-02")
}
// Creates a salt from a pre-computed key
func generateSaltFromKey(key string) string {
h := sha256.Sum256([]byte("watchdog-salt-" + key))
return hex.EncodeToString(h[:])
}
// Generates a deterministic salt based on the rotation mode // Generates a deterministic salt based on the rotation mode
// Daily: same day = same salt, different day = different salt // Daily: same day = same salt, different day = different salt
// Hourly: same hour = same salt, different hour = different salt // Hourly: same hour = same salt, different hour = different salt
func generateSalt(t time.Time, rotation string) string { func generateSalt(t time.Time, rotation string) string {
return generateSaltFromKey(getSaltKey(t, rotation)) var key string
if rotation == "hourly" {
key = t.UTC().Format("2006-01-02T15")
} else {
key = t.UTC().Format("2006-01-02")
}
h := sha256.Sum256([]byte("watchdog-salt-" + key))
return hex.EncodeToString(h[:])
} }
// Creates a privacy-preserving hash of visitor identity // Creates a privacy-preserving hash of visitor identity
func hashVisitor(ip, userAgent, salt string) string { func hashVisitor(ip, userAgent, salt string) string {
var sb strings.Builder combined := ip + "|" + userAgent + "|" + salt
sb.WriteString(ip) h := sha256.Sum256([]byte(combined))
sb.WriteString("|")
sb.WriteString(userAgent)
sb.WriteString("|")
sb.WriteString(salt)
h := sha256.Sum256([]byte(sb.String()))
return hex.EncodeToString(h[:]) return hex.EncodeToString(h[:])
} }
@ -140,20 +122,16 @@ func (u *UniquesEstimator) Load(path string) error {
} }
savedSalt := string(parts[0]) savedSalt := string(parts[0])
now := time.Now() currentSalt := generateSalt(time.Now(), u.rotation)
currentKey := getSaltKey(now, u.rotation)
currentSalt := generateSaltFromKey(currentKey)
// Only restore if it's the same period // Only restore if it's the same period
if savedSalt == currentSalt { if savedSalt == currentSalt {
u.salt = savedSalt u.salt = savedSalt
u.saltKey = currentKey
return u.hll.UnmarshalBinary(parts[1]) return u.hll.UnmarshalBinary(parts[1])
} }
// Different period, start fresh // Different period, start fresh
u.hll = hyperloglog.New() u.hll = hyperloglog.New()
u.salt = currentSalt u.salt = currentSalt
u.saltKey = currentKey
return nil return nil
} }

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"slices"
"notashelf.dev/watchdog/internal/limits" "notashelf.dev/watchdog/internal/limits"
) )
@ -39,8 +40,40 @@ func ParseEvent(body io.Reader) (*Event, error) {
return &event, nil return &event, nil
} }
// Validate checks if the event is valid using a domain map // Validate checks if the event is valid for the given domains
func (e *Event) Validate(allowedDomains map[string]bool) error { func (e *Event) Validate(allowedDomains []string) error {
if e.Domain == "" {
return fmt.Errorf("domain required")
}
// Check if domain is in allowed list
allowed := slices.Contains(allowedDomains, e.Domain)
if !allowed {
return fmt.Errorf("domain not allowed")
}
if e.Path == "" {
return fmt.Errorf("path required")
}
if len(e.Path) > limits.MaxPathLen {
return fmt.Errorf("path too long")
}
if len(e.Referrer) > limits.MaxRefLen {
return fmt.Errorf("referrer too long")
}
// Validate screen width is in reasonable range
if e.Width < 0 || e.Width > limits.MaxWidth {
return fmt.Errorf("invalid width")
}
return nil
}
// ValidateWithMap checks if the event is valid using a domain map (O(1) lookup)
func (e *Event) ValidateWithMap(allowedDomains map[string]bool) error {
if e.Domain == "" { if e.Domain == "" {
return fmt.Errorf("domain required") return fmt.Errorf("domain required")
} }

View file

@ -115,7 +115,7 @@ func TestValidateEvent(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
event Event event Event
domains map[string]bool domains []string
wantErr bool wantErr bool
}{ }{
{ {
@ -124,7 +124,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com", Domain: "example.com",
Path: "/home", Path: "/home",
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: false, wantErr: false,
}, },
{ {
@ -134,7 +134,7 @@ func TestValidateEvent(t *testing.T) {
Path: "/signup", Path: "/signup",
Event: "signup", Event: "signup",
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: false, wantErr: false,
}, },
{ {
@ -143,7 +143,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "wrong.com", Domain: "wrong.com",
Path: "/home", Path: "/home",
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: true, wantErr: true,
}, },
{ {
@ -152,7 +152,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "", Domain: "",
Path: "/home", Path: "/home",
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: true, wantErr: true,
}, },
{ {
@ -161,7 +161,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com", Domain: "example.com",
Path: "", Path: "",
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: true, wantErr: true,
}, },
{ {
@ -170,7 +170,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com", Domain: "example.com",
Path: "/" + strings.Repeat("a", 3000), Path: "/" + strings.Repeat("a", 3000),
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: true, wantErr: true,
}, },
{ {
@ -180,7 +180,7 @@ func TestValidateEvent(t *testing.T) {
Path: "/home", Path: "/home",
Referrer: strings.Repeat("a", 3000), Referrer: strings.Repeat("a", 3000),
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: true, wantErr: true,
}, },
{ {
@ -189,7 +189,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com", Domain: "example.com",
Path: "/" + strings.Repeat("a", 2000), Path: "/" + strings.Repeat("a", 2000),
}, },
domains: map[string]bool{"example.com": true}, domains: []string{"example.com"},
wantErr: false, wantErr: false,
}, },
{ {
@ -198,7 +198,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "site1.com", Domain: "site1.com",
Path: "/home", Path: "/home",
}, },
domains: map[string]bool{"site1.com": true, "site2.com": true}, domains: []string{"site1.com", "site2.com"},
wantErr: false, wantErr: false,
}, },
{ {
@ -207,7 +207,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "site2.com", Domain: "site2.com",
Path: "/about", Path: "/about",
}, },
domains: map[string]bool{"site1.com": true, "site2.com": true}, domains: []string{"site1.com", "site2.com"},
wantErr: false, wantErr: false,
}, },
{ {
@ -216,7 +216,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "site3.com", Domain: "site3.com",
Path: "/home", Path: "/home",
}, },
domains: map[string]bool{"site1.com": true, "site2.com": true}, domains: []string{"site1.com", "site2.com"},
wantErr: true, wantErr: true,
}, },
} }
@ -231,7 +231,24 @@ func TestValidateEvent(t *testing.T) {
} }
} }
func BenchmarkValidate(b *testing.B) { func BenchmarkValidate_SliceLookup(b *testing.B) {
// Simulate multi-site with 50 domains
domains := make([]string, 50)
for i := range 50 {
domains[i] = strings.Repeat("site", i) + ".com"
}
event := Event{
Domain: domains[49], // Worst case - last in list
Path: "/test",
}
for b.Loop() {
_ = event.Validate(domains)
}
}
func BenchmarkValidate_MapLookup(b *testing.B) {
// Simulate multi-site with 50 domains // Simulate multi-site with 50 domains
domainMap := make(map[string]bool, 50) domainMap := make(map[string]bool, 50)
for i := range 50 { for i := range 50 {
@ -239,11 +256,11 @@ func BenchmarkValidate(b *testing.B) {
} }
event := Event{ event := Event{
Domain: strings.Repeat("site", 49) + ".com", Domain: strings.Repeat("site", 49) + ".com", // any position
Path: "/test", Path: "/test",
} }
for b.Loop() { for b.Loop() {
_ = event.Validate(domainMap) _ = event.ValidateWithMap(domainMap)
} }
} }

View file

@ -1,14 +1,10 @@
package api package api
import ( import (
"crypto/rand" "math/rand"
"encoding/hex"
"fmt"
mrand "math/rand"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"time"
"notashelf.dev/watchdog/internal/aggregate" "notashelf.dev/watchdog/internal/aggregate"
"notashelf.dev/watchdog/internal/config" "notashelf.dev/watchdog/internal/config"
@ -20,14 +16,12 @@ import (
type IngestionHandler struct { type IngestionHandler struct {
cfg *config.Config cfg *config.Config
domainMap map[string]bool // O(1) domain validation domainMap map[string]bool // O(1) domain validation
corsOriginMap map[string]bool // O(1) CORS origin validation
pathNorm *normalize.PathNormalizer pathNorm *normalize.PathNormalizer
pathRegistry *aggregate.PathRegistry pathRegistry *aggregate.PathRegistry
refRegistry *normalize.ReferrerRegistry refRegistry *normalize.ReferrerRegistry
metricsAgg *aggregate.MetricsAggregator metricsAgg *aggregate.MetricsAggregator
rateLimiter *ratelimit.TokenBucket rateLimiter *ratelimit.TokenBucket
rng *mrand.Rand rng *rand.Rand
trustedNetworks []*net.IPNet // pre-parsed CIDR networks
} }
// Creates a new ingestion handler // Creates a new ingestion handler
@ -53,51 +47,19 @@ func NewIngestionHandler(
domainMap[domain] = true domainMap[domain] = true
} }
// Build CORS origin map for O(1) lookup
corsOriginMap := make(map[string]bool, len(cfg.Security.CORS.AllowedOrigins))
for _, origin := range cfg.Security.CORS.AllowedOrigins {
corsOriginMap[origin] = true
}
// Pre-parse trusted proxy CIDRs to avoid re-parsing on each request
trustedNetworks := make([]*net.IPNet, 0, len(cfg.Security.TrustedProxies))
for _, cidr := range cfg.Security.TrustedProxies {
if _, network, err := net.ParseCIDR(cidr); err == nil {
trustedNetworks = append(trustedNetworks, network)
} else if ip := net.ParseIP(cidr); ip != nil {
// Single IP - create a /32 or /128 network
var mask net.IPMask
if ip.To4() != nil {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
trustedNetworks = append(trustedNetworks, &net.IPNet{IP: ip, Mask: mask})
}
}
return &IngestionHandler{ return &IngestionHandler{
cfg: cfg, cfg: cfg,
domainMap: domainMap, domainMap: domainMap,
corsOriginMap: corsOriginMap,
pathNorm: pathNorm, pathNorm: pathNorm,
pathRegistry: pathRegistry, pathRegistry: pathRegistry,
refRegistry: refRegistry, refRegistry: refRegistry,
metricsAgg: metricsAgg, metricsAgg: metricsAgg,
rateLimiter: limiter, rateLimiter: limiter,
rng: mrand.New(mrand.NewSource(time.Now().UnixNano())), rng: rand.New(rand.NewSource(42)),
trustedNetworks: trustedNetworks,
} }
} }
func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Generate or extract request ID for tracing
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
w.Header().Set("X-Request-ID", requestID)
// Handle CORS preflight // Handle CORS preflight
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {
h.handleCORS(w, r) h.handleCORS(w, r)
@ -141,8 +103,8 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// Validate event via map lookup // Validate event via map lookup (also O(1))
if err := event.Validate(h.domainMap); err != nil { if err := event.ValidateWithMap(h.domainMap); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
return return
} }
@ -219,8 +181,13 @@ func (h *IngestionHandler) handleCORS(w http.ResponseWriter, r *http.Request) {
} }
// Check if origin is allowed // Check if origin is allowed
// This uses map so that it's O(1) allowed := false
allowed := h.corsOriginMap["*"] || h.corsOriginMap[origin] for _, allowedOrigin := range h.cfg.Security.CORS.AllowedOrigins {
if allowedOrigin == "*" || allowedOrigin == origin {
allowed = true
break
}
}
if allowed { if allowed {
if origin == "*" { if origin == "*" {
@ -246,8 +213,13 @@ func (h *IngestionHandler) extractIP(r *http.Request) string {
// Check if we should trust proxy headers // Check if we should trust proxy headers
trustProxy := false trustProxy := false
if len(h.trustedNetworks) > 0 { if len(h.cfg.Security.TrustedProxies) > 0 {
trustProxy = h.ipInNetworks(remoteIP, h.trustedNetworks) for _, trustedCIDR := range h.cfg.Security.TrustedProxies {
if h.ipInCIDR(remoteIP, trustedCIDR) {
trustProxy = true
break
}
}
} }
// If not trusting proxy, return direct IP // If not trusting proxy, return direct IP
@ -261,43 +233,42 @@ func (h *IngestionHandler) extractIP(r *http.Request) string {
ips := strings.Split(xff, ",") ips := strings.Split(xff, ",")
for i := len(ips) - 1; i >= 0; i-- { for i := len(ips) - 1; i >= 0; i-- {
ip := strings.TrimSpace(ips[i]) ip := strings.TrimSpace(ips[i])
if testIP := net.ParseIP(ip); testIP != nil { if !h.ipInCIDR(ip, "0.0.0.0/0") {
// Only accept this IP if it's NOT from a trusted proxy continue
if !h.ipInNetworks(ip, h.trustedNetworks) { }
return ip return ip
} }
} }
}
}
// Check X-Real-IP header // Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" { if xri := r.Header.Get("X-Real-IP"); xri != "" {
// Validate the IP format and ensure it's not from a trusted proxy
if testIP := net.ParseIP(xri); testIP != nil {
if !h.ipInNetworks(xri, h.trustedNetworks) {
return xri return xri
} }
}
}
// Fall back to RemoteAddr // Fall back to RemoteAddr
return remoteIP return remoteIP
} }
// Checks if an IP address is within any of the trusted networks // Checks if an IP address is within a CIDR range
func (h *IngestionHandler) ipInNetworks(ip string, networks []*net.IPNet) bool { func (h *IngestionHandler) ipInCIDR(ip, cidr string) bool {
// Parse the IP address
testIP := net.ParseIP(ip) testIP := net.ParseIP(ip)
if testIP == nil { if testIP == nil {
return false return false
} }
for _, network := range networks { // Parse the CIDR
if network.Contains(testIP) { _, network, err := net.ParseCIDR(cidr)
return true if err != nil {
// If it's not a CIDR, try as a single IP
cidrIP := net.ParseIP(cidr)
if cidrIP == nil {
return false
} }
return testIP.Equal(cidrIP)
} }
return false return network.Contains(testIP)
} }
// Classifies device using both screen width and User-Agent parsing // Classifies device using both screen width and User-Agent parsing
@ -340,17 +311,3 @@ func (h *IngestionHandler) classifyDevice(width int, userAgent string) string {
return "unknown" return "unknown"
} }
// Creates a unique request ID for tracing.
// Uses 8 bytes (64 bits) of randomness which produces 16 hex characters.
// 2^64 possible IDs (~18 quintillion) provides sufficient uniqueness for
// request tracing while keeping IDs reasonably short in logs and headers.
func generateRequestID() string {
// 8 bytes = 64 bits = 16 hex chars = 2^64 possible IDs
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to timestamp if crypto/rand fails
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}

View file

@ -218,107 +218,6 @@ func newTestHandler(cfg *config.Config) *IngestionHandler {
return NewIngestionHandler(cfg, pathNorm, pathRegistry, refRegistry, metricsAgg) return NewIngestionHandler(cfg, pathNorm, pathRegistry, refRegistry, metricsAgg)
} }
func TestExtractIP(t *testing.T) {
cfg := &config.Config{
Site: config.SiteConfig{
Domains: []string{"example.com"},
},
Limits: config.LimitsConfig{
MaxPaths: 100,
MaxSources: 50,
},
Security: config.SecurityConfig{
TrustedProxies: []string{"10.0.0.0/8", "192.168.1.1"},
},
}
h := newTestHandler(cfg)
tests := []struct {
name string
remoteAddr string
headers map[string]string
want string
}{
{
name: "direct connection no proxy",
remoteAddr: "192.168.1.100:12345",
headers: map[string]string{},
want: "192.168.1.100",
},
{
name: "trusted proxy with X-Forwarded-For",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Forwarded-For": "203.0.113.1, 10.0.0.5",
},
want: "203.0.113.1",
},
{
name: "trusted proxy with X-Real-IP",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Real-IP": "203.0.113.2",
},
want: "203.0.113.2",
},
{
name: "X-Real-IP from trusted network should be ignored",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Real-IP": "10.0.0.50", // trusted network, should fall back
},
want: "10.0.0.1", // falls back to remoteAddr
},
{
name: "X-Real-IP invalid IP should be ignored",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Real-IP": "not-an-ip",
},
want: "10.0.0.1", // falls back
},
{
name: "untrusted proxy X-Forwarded-For ignored",
remoteAddr: "203.0.113.50:12345",
headers: map[string]string{
"X-Forwarded-For": "1.2.3.4",
},
want: "203.0.113.50", // uses remoteAddr, ignores header
},
{
name: "untrusted proxy X-Real-IP ignored",
remoteAddr: "203.0.113.50:12345",
headers: map[string]string{
"X-Real-IP": "1.2.3.4",
},
want: "203.0.113.50", // uses remoteAddr, ignores header
},
{
name: "X-Forwarded-For all trusted falls back",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Forwarded-For": "10.0.0.2, 10.0.0.3",
},
want: "10.0.0.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/event", nil)
req.RemoteAddr = tt.remoteAddr
for k, v := range tt.headers {
req.Header.Set(k, v)
}
got := h.extractIP(req)
if got != tt.want {
t.Errorf("extractIP() = %q, want %q", got, tt.want)
}
})
}
}
func TestClassifyDevice_UA(t *testing.T) { func TestClassifyDevice_UA(t *testing.T) {
cfg := &config.Config{ cfg := &config.Config{
Limits: config.LimitsConfig{ Limits: config.LimitsConfig{

View file

@ -50,7 +50,6 @@ type LimitsConfig struct {
MaxSources int `yaml:"max_sources"` MaxSources int `yaml:"max_sources"`
MaxCustomEvents int `yaml:"max_custom_events"` MaxCustomEvents int `yaml:"max_custom_events"`
DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"` DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"`
MaxMetricsPerMinute int `yaml:"max_metrics_per_minute"` // rate limit for metrics endpoint
} }
// Device classification breakpoints // Device classification breakpoints
@ -73,11 +72,10 @@ type CORSConfig struct {
} }
// Authentication for metrics endpoint // Authentication for metrics endpoint
// Password can be set via environment variable: WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD
type AuthConfig struct { type AuthConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
Username string `yaml:"username"` Username string `yaml:"username"`
Password string `yaml:"password"` // can use env var WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD Password string `yaml:"password"`
} }
// Server endpoints // Server endpoints
@ -151,10 +149,6 @@ func (c *Config) Validate() error {
c.Limits.MaxCustomEvents = 100 // Default c.Limits.MaxCustomEvents = 100 // Default
} }
if c.Limits.MaxMetricsPerMinute <= 0 {
c.Limits.MaxMetricsPerMinute = 30 // Default: 30 requests per minute
}
if c.Site.Path.MaxSegments < 0 { if c.Site.Path.MaxSegments < 0 {
return fmt.Errorf("site.path.max_segments cannot be negative") return fmt.Errorf("site.path.max_segments cannot be negative")
} }

View file

@ -48,36 +48,38 @@ func (n *PathNormalizer) Normalize(path string) string {
path = "/" + path path = "/" + path
} }
// Process segments in-place to minimize allocations
// Split into segments, first element is *always* empty for paths starting with '/' // Split into segments, first element is *always* empty for paths starting with '/'
segments := strings.Split(path, "/") segments := strings.Split(path, "/")
if len(segments) > 0 && segments[0] == "" {
// Process segments in a single pass: remove empty, resolve . and .. segments = segments[1:]
writeIdx := 0
for i := 0; i < len(segments); i++ {
seg := segments[i]
// Skip empty segments (from double slashes or leading /)
if seg == "" {
continue
} }
// Remove empty segments (from double slashes)
filtered := make([]string, 0, len(segments))
for _, seg := range segments {
if seg != "" {
filtered = append(filtered, seg)
}
}
segments = filtered
// Resolve . and .. segments
resolved := make([]string, 0, len(segments))
for _, seg := range segments {
if seg == "." { if seg == "." {
// Skip current directory // Skip current directory
continue continue
} else if seg == ".." { } else if seg == ".." {
// Go up one level if possible // Go up one level if possible
if writeIdx > 0 { if len(resolved) > 0 {
writeIdx-- resolved = resolved[:len(resolved)-1]
} }
// If already at root, skip .. // If already at root, skip ..
} else { } else {
// Keep this segment resolved = append(resolved, seg)
segments[writeIdx] = seg
writeIdx++
} }
} }
segments = segments[:writeIdx] segments = resolved
// Collapse numeric segments // Collapse numeric segments
if n.cfg.CollapseNumericSegments { if n.cfg.CollapseNumericSegments {

View file

@ -1,7 +1,6 @@
package normalize package normalize
import ( import (
"net"
"net/url" "net/url"
"strings" "strings"
@ -22,21 +21,35 @@ func isInternalHost(hostname string) bool {
return true return true
} }
// Check if hostname is an IP address
if ip := net.ParseIP(hostname); ip != nil {
// Private IPv4 ranges (RFC1918) // Private IPv4 ranges (RFC1918)
if ip.IsPrivate() { if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.16.") ||
strings.HasPrefix(hostname, "172.17.") ||
strings.HasPrefix(hostname, "172.18.") ||
strings.HasPrefix(hostname, "172.19.") ||
strings.HasPrefix(hostname, "172.20.") ||
strings.HasPrefix(hostname, "172.21.") ||
strings.HasPrefix(hostname, "172.22.") ||
strings.HasPrefix(hostname, "172.23.") ||
strings.HasPrefix(hostname, "172.24.") ||
strings.HasPrefix(hostname, "172.25.") ||
strings.HasPrefix(hostname, "172.26.") ||
strings.HasPrefix(hostname, "172.27.") ||
strings.HasPrefix(hostname, "172.28.") ||
strings.HasPrefix(hostname, "172.29.") ||
strings.HasPrefix(hostname, "172.30.") ||
strings.HasPrefix(hostname, "172.31.") {
return true return true
} }
// Additional localhost checks for IP formats
if ip.IsLoopback() { // IPv6 loopback and local
if strings.HasPrefix(hostname, "::1") ||
strings.HasPrefix(hostname, "fe80::") ||
strings.HasPrefix(hostname, "fc00::") ||
strings.HasPrefix(hostname, "fd00::") {
return true return true
} }
// Link-local addresses
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
}
return false return false
} }

View file

@ -56,20 +56,3 @@ func (r *ReferrerRegistry) OverflowCount() int {
defer r.mu.RUnlock() defer r.mu.RUnlock()
return r.overflowCount return r.overflowCount
} }
// Contains checks if a source exists in the registry.
func (r *ReferrerRegistry) Contains(source string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.sources[source]
return exists
}
// Count returns the number of unique sources in the registry.
func (r *ReferrerRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.sources)
}

View file

@ -108,36 +108,6 @@ func TestExtractReferrerDomain(t *testing.T) {
siteDomain: "example.com", siteDomain: "example.com",
want: "internal", want: "internal",
}, },
{
name: "private IP 172.16.x (RFC1918)",
referrer: "http://172.16.0.1/page",
siteDomain: "example.com",
want: "internal",
},
{
name: "private IP 172.31.x (RFC1918 upper bound)",
referrer: "http://172.31.255.1/page",
siteDomain: "example.com",
want: "internal",
},
{
name: "private IP 172.20.x (middle of range)",
referrer: "http://172.20.50.100/page",
siteDomain: "example.com",
want: "internal",
},
{
name: "public IP 172.15.x (just outside private range)",
referrer: "http://172.15.0.1/page",
siteDomain: "example.com",
want: "other", // not internal, but invalid TLD
},
{
name: "public IP 172.32.x (just outside private range)",
referrer: "http://172.32.0.1/page",
siteDomain: "example.com",
want: "other", // not internal, but invalid TLD
},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -195,9 +195,6 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) {
MaxSources: 50, MaxSources: 50,
MaxCustomEvents: 10, MaxCustomEvents: 10,
}, },
Server: config.ServerConfig{
StatePath: "/tmp/watchdog-test.state",
},
} }
pathRegistry := aggregate.NewPathRegistry(cfg.Limits.MaxPaths) pathRegistry := aggregate.NewPathRegistry(cfg.Limits.MaxPaths)
@ -226,12 +223,12 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) {
} }
// Verify state file was created // Verify state file was created
if _, err := os.Stat("/tmp/watchdog-test.state"); os.IsNotExist(err) { if _, err := os.Stat("/tmp/watchdog-hll.state"); os.IsNotExist(err) {
t.Error("HLL state file was not created") t.Error("HLL state file was not created")
} }
// Cleanup // Cleanup
os.Remove("/tmp/watchdog-test.state") os.Remove("/tmp/watchdog-hll.state")
} }
func TestEndToEnd_InvalidRequests(t *testing.T) { func TestEndToEnd_InvalidRequests(t *testing.T) {