diff --git a/go.mod b/go.mod index 6029104..ff5eb22 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module notashelf.dev/watchdog go 1.25.5 require gopkg.in/yaml.v3 v3.0.1 + +require golang.org/x/net v0.51.0 diff --git a/go.sum b/go.sum index a62c313..00815c1 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/config/config.go b/internal/config/config.go index cddab60..499d545 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,7 @@ type PathConfig struct { type LimitsConfig struct { MaxPaths int `yaml:"max_paths"` MaxEventsPerMinute int `yaml:"max_events_per_minute"` + MaxSources int `yaml:"max_sources"` } // Server endpoints @@ -90,6 +91,11 @@ func (c *Config) Validate() error { return fmt.Errorf("limits.max_paths must be greater than 0") } + // Validate max_sources is positive + if c.Limits.MaxSources <= 0 { + return fmt.Errorf("limits.max_sources must be greater than 0") + } + // Set server defaults if not provided if c.Server.ListenAddr == "" { c.Server.ListenAddr = ":8080" diff --git a/internal/normalize/referrer.go b/internal/normalize/referrer.go new file mode 100644 index 0000000..a785463 --- /dev/null +++ b/internal/normalize/referrer.go @@ -0,0 +1,97 @@ +package normalize + +import ( + "net/url" + "strings" + + "golang.org/x/net/publicsuffix" +) + +// Returns true for localhost, loopback IPs, and private IPs. +func isInternalHost(hostname string) bool { + if hostname == "" { + return false + } + + // Localhost variants + if hostname == "localhost" || + strings.HasPrefix(hostname, "localhost.") || + strings.HasPrefix(hostname, "127.") || + hostname == "::1" { + return true + } + + // Private IPv4 ranges (RFC1918) + if strings.HasPrefix(hostname, "10.") || + strings.HasPrefix(hostname, "192.168.") || + strings.HasPrefix(hostname, "172.16.") || + strings.HasPrefix(hostname, "172.17.") || + strings.HasPrefix(hostname, "172.18.") || + strings.HasPrefix(hostname, "172.19.") || + strings.HasPrefix(hostname, "172.20.") || + strings.HasPrefix(hostname, "172.21.") || + strings.HasPrefix(hostname, "172.22.") || + strings.HasPrefix(hostname, "172.23.") || + strings.HasPrefix(hostname, "172.24.") || + strings.HasPrefix(hostname, "172.25.") || + strings.HasPrefix(hostname, "172.26.") || + strings.HasPrefix(hostname, "172.27.") || + strings.HasPrefix(hostname, "172.28.") || + strings.HasPrefix(hostname, "172.29.") || + strings.HasPrefix(hostname, "172.30.") || + strings.HasPrefix(hostname, "172.31.") { + return true + } + + // IPv6 loopback and local + if strings.HasPrefix(hostname, "::1") || + strings.HasPrefix(hostname, "fe80::") || + strings.HasPrefix(hostname, "fc00::") || + strings.HasPrefix(hostname, "fd00::") { + return true + } + + return false +} + +// Extracts the eTLD+1 domain from a referrer URL. +// Returns "direct" for empty or same-domain referrers. +// Returns empty string for invalid URLs. +func ExtractReferrerDomain(referrer, siteDomain string) string { + if referrer == "" { + return "direct" + } + + u, err := url.Parse(referrer) + if err != nil { + return "" + } + + hostname := strings.ToLower(u.Hostname()) + hostname = strings.TrimSuffix(hostname, ".") // remove trailing dot + if hostname == "" { + return "" + } + + // Check for internal/localhost traffic + if isInternalHost(hostname) { + return "internal" + } + + // Same domain check + siteDomainLower := strings.ToLower(siteDomain) + if hostname == siteDomainLower || strings.HasSuffix(hostname, "."+siteDomainLower) { + return "direct" + } + + // Extract eTLD+1 (effective top-level domain + 1 label); e.g. + // - "www.google.co.uk" -> "google.co.uk" + // - "news.ycombinator.com" -> "ycombinator.com" + eTLDPlus1, err := publicsuffix.EffectiveTLDPlusOne(hostname) + if err != nil { + // If public suffix lookup fails, use hostname as-is + return hostname + } + + return eTLDPlus1 +} diff --git a/internal/normalize/referrer_registry.go b/internal/normalize/referrer_registry.go new file mode 100644 index 0000000..b573309 --- /dev/null +++ b/internal/normalize/referrer_registry.go @@ -0,0 +1,48 @@ +package normalize + +import "sync" + +// Bounded set of observed referrer domains. +type ReferrerRegistry struct { + mu sync.RWMutex + sources map[string]struct{} + maxSources int + overflowCount int +} + +func NewReferrerRegistry(maxSources int) *ReferrerRegistry { + return &ReferrerRegistry{ + sources: make(map[string]struct{}, maxSources), + maxSources: maxSources, + } +} + +// Attempt to add a referrer source to the registry. Returns the source to use ("other" if rejected). +func (r *ReferrerRegistry) Add(source string) string { + if source == "direct" || source == "internal" { + return source + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Already exists + if _, exists := r.sources[source]; exists { + return source + } + + // Check limit + if len(r.sources) >= r.maxSources { + r.overflowCount++ + return "other" + } + + r.sources[source] = struct{}{} + return source +} + +func (r *ReferrerRegistry) OverflowCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + return r.overflowCount +} diff --git a/internal/normalize/referrer_test.go b/internal/normalize/referrer_test.go new file mode 100644 index 0000000..04a8d0a --- /dev/null +++ b/internal/normalize/referrer_test.go @@ -0,0 +1,200 @@ +package normalize + +import ( + "fmt" + "testing" +) + +func TestExtractReferrerDomain(t *testing.T) { + tests := []struct { + name string + referrer string + siteDomain string + want string + }{ + { + name: "empty referrer", + referrer: "", + siteDomain: "example.com", + want: "direct", + }, + { + name: "same domain", + referrer: "https://example.com/page", + siteDomain: "example.com", + want: "direct", + }, + { + name: "subdomain is direct", + referrer: "https://blog.example.com/post", + siteDomain: "example.com", + want: "direct", + }, + { + name: "google search", + referrer: "https://www.google.com/search?q=test", + siteDomain: "example.com", + want: "google.com", + }, + { + name: "google country domain", + referrer: "https://www.google.co.uk/search", + siteDomain: "example.com", + want: "google.co.uk", + }, + { + name: "hacker news", + referrer: "https://news.ycombinator.com/item?id=123", + siteDomain: "example.com", + want: "ycombinator.com", + }, + { + name: "twitter short link", + referrer: "https://t.co/abc123", + siteDomain: "example.com", + want: "t.co", + }, + { + name: "github", + referrer: "https://github.com/user/repo", + siteDomain: "example.com", + want: "github.com", + }, + { + name: "invalid url", + referrer: "not-a-url", + siteDomain: "example.com", + want: "", + }, + { + name: "case insensitive", + referrer: "https://WWW.GOOGLE.COM/search", + siteDomain: "EXAMPLE.COM", + want: "google.com", + }, + { + name: "trailing dot normalized", + referrer: "https://example.com./page", + siteDomain: "test.com", + want: "example.com", + }, + { + name: "localhost", + referrer: "http://localhost:8080/page", + siteDomain: "example.com", + want: "internal", + }, + { + name: "loopback IPv4", + referrer: "http://127.0.0.1/page", + siteDomain: "example.com", + want: "internal", + }, + { + name: "loopback IPv6", + referrer: "http://[::1]/page", + siteDomain: "example.com", + want: "internal", + }, + { + name: "private IP 192.168", + referrer: "http://192.168.1.1/page", + siteDomain: "example.com", + want: "internal", + }, + { + name: "private IP 10.x", + referrer: "http://10.0.0.1/page", + siteDomain: "example.com", + want: "internal", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractReferrerDomain(tt.referrer, tt.siteDomain) + if got != tt.want { + t.Errorf("ExtractReferrerDomain(%q, %q) = %q, want %q", + tt.referrer, tt.siteDomain, got, tt.want) + } + }) + } +} + +func TestReferrerRegistry(t *testing.T) { + registry := NewReferrerRegistry(3) + + // Add sources within limit + if got := registry.Add("google.com"); got != "google.com" { + t.Errorf("expected google.com, got %s", got) + } + if got := registry.Add("github.com"); got != "github.com" { + t.Errorf("expected github.com, got %s", got) + } + if got := registry.Add("reddit.com"); got != "reddit.com" { + t.Errorf("expected reddit.com, got %s", got) + } + + // Adding same source again should succeed + if got := registry.Add("google.com"); got != "google.com" { + t.Errorf("expected google.com, got %s", got) + } + + // Exceeding limit should return "other" + if got := registry.Add("twitter.com"); got != "other" { + t.Errorf("expected other, got %s", got) + } + + // Direct always works + if got := registry.Add("direct"); got != "direct" { + t.Errorf("expected direct, got %s", got) + } + + // Internal always works + if got := registry.Add("internal"); got != "internal" { + t.Errorf("expected internal, got %s", got) + } + + // Overflow count should be 1 + if registry.OverflowCount() != 1 { + t.Errorf("expected overflow count 1, got %d", registry.OverflowCount()) + } +} + +func TestReferrerRegistryConcurrentOverflow(t *testing.T) { + registry := NewReferrerRegistry(10) + + // Use channels to coordinate goroutines + const numGoroutines = 50 + const sourcesPerGoroutine = 5 + done := make(chan bool, numGoroutines) + + // Launch goroutines that race to add sources + for i := range numGoroutines { + go func(id int) { + for j := range sourcesPerGoroutine { + source := fmt.Sprintf("source-%d-%d.com", id, j) + registry.Add(source) + } + done <- true + }(i) + } + + // Wait for all goroutines + for range numGoroutines { + <-done + } + + // Registry should have exactly 10 sources (limit) + // Overflow should be (50 * 5) - 10 = 240 rejections + // But since same sources might be added multiple times, + // we just verify: overflow > 0 and total attempts tracked + if registry.OverflowCount() == 0 { + t.Error("expected some overflow with concurrent adds") + } + + // Verify adding more sources still returns "other" + if got := registry.Add("new-source.com"); got != "other" { + t.Errorf("expected other after overflow, got %s", got) + } +}