internal/api: better multi-sites support; validate events against allowed domains

Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: Iff1ced4966b4d42cfd6dfefb0cfd97696a6a6964
This commit is contained in:
raf 2026-03-01 14:27:20 +03:00
commit 18fe1a8234
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
10 changed files with 542 additions and 35 deletions

View file

@ -43,6 +43,10 @@ func NewMetricsAggregator(pathRegistry *PathRegistry, eventRegistry *CustomEvent
labels = append(labels, "referrer")
}
if cfg.Site.Collect.Domain {
labels = append(labels, "domain")
}
pageviews := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "web_pageviews_total",
@ -143,7 +147,7 @@ func sanitizeLabel(label string) string {
}
// Records a pageview with the configured dimensions
func (m *MetricsAggregator) RecordPageview(path, country, device, referrer string) {
func (m *MetricsAggregator) RecordPageview(path, country, device, referrer, domain string) {
// Build label values in the same order as label names
labels := prometheus.Labels{"path": sanitizeLabel(path)}
@ -159,6 +163,10 @@ func (m *MetricsAggregator) RecordPageview(path, country, device, referrer strin
labels["referrer"] = sanitizeLabel(referrer)
}
if m.cfg.Site.Collect.Domain {
labels["domain"] = sanitizeLabel(domain)
}
m.pageviews.With(labels).Inc()
}
@ -193,7 +201,7 @@ func (m *MetricsAggregator) AddUnique(ip, userAgent string) {
}
m.estimator.Add(ip, userAgent)
// Note: Gauge is updated in background goroutine, not here
// NOTE: Gauge is updated in background goroutine, not here
}
// Registers all metrics with the provided Prometheus registry

View file

@ -25,7 +25,7 @@ func TestMetricsAggregator_RecordPageview(t *testing.T) {
agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), cfg)
// Record pageview with all dimensions
agg.RecordPageview("/home", "US", "desktop", "google.com")
agg.RecordPageview("/home", "US", "desktop", "google.com", "")
// Verify metric was recorded
expected := `
@ -54,7 +54,7 @@ func TestMetricsAggregator_RecordPageview_MinimalDimensions(t *testing.T) {
agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), cfg)
// Record pageview with only path
agg.RecordPageview("/home", "", "", "")
agg.RecordPageview("/home", "", "", "", "")
// Verify metric was recorded
expected := `
@ -176,7 +176,7 @@ func TestMetricsAggregator_MustRegister(t *testing.T) {
agg.MustRegister(promRegistry)
// Record some metrics to ensure they show up
agg.RecordPageview("/test", "", "", "")
agg.RecordPageview("/test", "", "", "", "")
agg.RecordPathOverflow()
// Verify metrics can be gathered

View file

@ -39,14 +39,22 @@ func ParseEvent(body io.Reader) (*Event, error) {
return &event, nil
}
// Validate checks if the event is valid for the given domain
func (e *Event) Validate(expectedDomain string) error {
// Validate checks if the event is valid for the given domains
func (e *Event) Validate(allowedDomains []string) error {
if e.Domain == "" {
return fmt.Errorf("domain required")
}
if e.Domain != expectedDomain {
return fmt.Errorf("domain mismatch")
// Check if domain is in allowed list
allowed := false
for _, domain := range allowedDomains {
if e.Domain == domain {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("domain not allowed")
}
if e.Path == "" {

View file

@ -115,7 +115,7 @@ func TestValidateEvent(t *testing.T) {
tests := []struct {
name string
event Event
domain string
domains []string
wantErr bool
}{
{
@ -124,7 +124,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "/home",
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: false,
},
{
@ -134,7 +134,7 @@ func TestValidateEvent(t *testing.T) {
Path: "/signup",
Event: "signup",
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: false,
},
{
@ -143,7 +143,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "wrong.com",
Path: "/home",
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: true,
},
{
@ -152,7 +152,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "",
Path: "/home",
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: true,
},
{
@ -161,7 +161,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "",
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: true,
},
{
@ -170,7 +170,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "/" + strings.Repeat("a", 3000),
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: true,
},
{
@ -180,7 +180,7 @@ func TestValidateEvent(t *testing.T) {
Path: "/home",
Referrer: strings.Repeat("a", 3000),
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: true,
},
{
@ -189,14 +189,41 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "/" + strings.Repeat("a", 2000),
},
domain: "example.com",
domains: []string{"example.com"},
wantErr: false,
},
{
name: "multi-site valid domain 1",
event: Event{
Domain: "site1.com",
Path: "/home",
},
domains: []string{"site1.com", "site2.com"},
wantErr: false,
},
{
name: "multi-site valid domain 2",
event: Event{
Domain: "site2.com",
Path: "/about",
},
domains: []string{"site1.com", "site2.com"},
wantErr: false,
},
{
name: "multi-site invalid domain",
event: Event{
Domain: "site3.com",
Path: "/home",
},
domains: []string{"site1.com", "site2.com"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.event.Validate(tt.domain)
err := tt.event.Validate(tt.domains)
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}

View file

@ -96,7 +96,7 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Validate event
if err := event.Validate(h.cfg.Site.Domain); err != nil {
if err := event.Validate(h.cfg.Site.Domains); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
@ -134,7 +134,7 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Referrer classification
if h.cfg.Site.Collect.Referrer == "domain" {
refDomain := normalize.ExtractReferrerDomain(event.Referrer, h.cfg.Site.Domain)
refDomain := normalize.ExtractReferrerDomain(event.Referrer, event.Domain)
if refDomain != "" {
accepted := h.refRegistry.Add(refDomain)
if accepted == "other" {
@ -144,11 +144,17 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
// Domain tracking (if enabled for multi-site analytics)
var domain string
if h.cfg.Site.Collect.Domain {
domain = event.Domain
}
// FIXME: Country would be extracted from IP here. For now, we skip country extraction
// because I have neither the time nor the patience to look into it. Return later.
// Record pageview
h.metricsAgg.RecordPageview(normalizedPath, country, device, referrer)
h.metricsAgg.RecordPageview(normalizedPath, country, device, referrer, domain)
}
// Return success

View file

@ -15,7 +15,7 @@ import (
func TestIngestionHandler_Pageview(t *testing.T) {
cfg := config.Config{
Site: config.SiteConfig{
Domain: "example.com",
Domains: []string{"example.com"},
Collect: config.CollectConfig{
Pageviews: true,
Country: true,
@ -62,7 +62,7 @@ func TestIngestionHandler_Pageview(t *testing.T) {
func TestIngestionHandler_CustomEvent(t *testing.T) {
cfg := config.Config{
Site: config.SiteConfig{
Domain: "example.com",
Domains: []string{"example.com"},
Collect: config.CollectConfig{
Pageviews: true,
},
@ -99,7 +99,7 @@ func TestIngestionHandler_CustomEvent(t *testing.T) {
func TestIngestionHandler_WrongDomain(t *testing.T) {
cfg := config.Config{
Site: config.SiteConfig{
Domain: "example.com",
Domains: []string{"example.com"},
Collect: config.CollectConfig{
Pageviews: true,
},
@ -133,7 +133,7 @@ func TestIngestionHandler_WrongDomain(t *testing.T) {
func TestIngestionHandler_MethodNotAllowed(t *testing.T) {
cfg := config.Config{
Site: config.SiteConfig{
Domain: "example.com",
Domains: []string{"example.com"},
},
Limits: config.LimitsConfig{
MaxPaths: 100,
@ -160,7 +160,7 @@ func TestIngestionHandler_MethodNotAllowed(t *testing.T) {
func TestIngestionHandler_InvalidJSON(t *testing.T) {
cfg := config.Config{
Site: config.SiteConfig{
Domain: "example.com",
Domains: []string{"example.com"},
},
Limits: config.LimitsConfig{
MaxPaths: 100,
@ -190,7 +190,7 @@ func TestIngestionHandler_InvalidJSON(t *testing.T) {
func TestIngestionHandler_DeviceClassification(t *testing.T) {
cfg := config.Config{
Site: config.SiteConfig{
Domain: "example.com",
Domains: []string{"example.com"},
Collect: config.CollectConfig{
Pageviews: true,
Device: true,

View file

@ -17,7 +17,7 @@ type Config struct {
// Site-specific settings
type SiteConfig struct {
Domain string `yaml:"domain"`
Domains []string `yaml:"domains"` // list of allowed domains
SaltRotation string `yaml:"salt_rotation"`
Collect CollectConfig `yaml:"collect"`
CustomEvents []string `yaml:"custom_events"`
@ -31,6 +31,7 @@ type CollectConfig struct {
Country bool `yaml:"country"`
Device bool `yaml:"device"`
Referrer string `yaml:"referrer"`
Domain bool `yaml:"domain"` // track domain as metric dimension (for multi-site)
}
// Path normalization options
@ -106,8 +107,8 @@ func Load(path string) (*Config, error) {
// Check required fields and sets defaults
func (c *Config) Validate() error {
// Site validation
if c.Site.Domain == "" {
return fmt.Errorf("site.domain is required")
if len(c.Site.Domains) == 0 {
return fmt.Errorf("site.domains is required")
}
// Validate salt_rotation

View file

@ -10,8 +10,8 @@ func TestLoadConfig_ValidFile(t *testing.T) {
t.Fatalf("expected no error, got %v", err)
}
if cfg.Site.Domain != "example.com" {
t.Errorf("expected domain 'example.com', got '%s'", cfg.Site.Domain)
if len(cfg.Site.Domains) == 0 || cfg.Site.Domains[0] != "example.com" {
t.Errorf("expected domains to contain 'example.com', got %v", cfg.Site.Domains)
}
if cfg.Site.SaltRotation != "daily" {
@ -29,7 +29,7 @@ func TestLoadConfig_MissingFile(t *testing.T) {
func TestValidate_MaxPathsRequired(t *testing.T) {
cfg := &Config{
Site: SiteConfig{
Domain: "example.com",
Domains: []string{"example.com"},
SaltRotation: "daily",
},
Limits: LimitsConfig{