From 987ddd92cc1ea0836efd200224aa46e245476ec5 Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Sun, 1 Mar 2026 20:01:32 +0300 Subject: [PATCH] internal/aggregate: make shutdown context-aware proper goroutine sync Adds `WaitGroup` to track background goroutine and make Shutdown respect context deadlines Signed-off-by: NotAShelf Change-Id: Ia7f074725717f037412dacb93e34105b6a6a6964 --- cmd/watchdog/root.go | 7 +- internal/aggregate/metrics.go | 29 ++++++ internal/aggregate/metrics_test.go | 159 +++++++++++++++++++++++++++++ internal/aggregate/uniques.go | 19 ++-- 4 files changed, 205 insertions(+), 9 deletions(-) diff --git a/cmd/watchdog/root.go b/cmd/watchdog/root.go index 1cc89cf..e64e158 100644 --- a/cmd/watchdog/root.go +++ b/cmd/watchdog/root.go @@ -31,9 +31,14 @@ func Run(cfg *config.Config) error { eventRegistry := aggregate.NewCustomEventRegistry(cfg.Limits.MaxCustomEvents) metricsAgg := aggregate.NewMetricsAggregator(pathRegistry, eventRegistry, cfg) - // HLL state persistence is handled automatically if salt_rotation is configured + // Load HLL state from previous run if it exists if cfg.Site.SaltRotation != "" { log.Println("HLL state persistence enabled") + if err := metricsAgg.LoadState(); err != nil { + log.Printf("Could not load HLL state (might be first run): %v", err) + } else { + log.Println("HLL state restored from previous run") + } } // Register Prometheus metrics diff --git a/internal/aggregate/metrics.go b/internal/aggregate/metrics.go index 4472a40..a9ba3ad 100644 --- a/internal/aggregate/metrics.go +++ b/internal/aggregate/metrics.go @@ -2,7 +2,9 @@ package aggregate import ( "context" + "fmt" "regexp" + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -24,6 +26,7 @@ type MetricsAggregator struct { dailyUniques prometheus.Gauge estimator *UniquesEstimator stopChan chan struct{} + wg sync.WaitGroup } // Creates a new metrics aggregator with dynamic labels based on config @@ -111,6 +114,7 @@ func NewMetricsAggregator( // Start background goroutine to update HLL gauge periodically if cfg.Site.SaltRotation != "" { + m.wg.Add(1) go m.updateUniquesGauge() } @@ -120,6 +124,7 @@ func NewMetricsAggregator( // Background goroutine to update the unique visitors gauge every 10 seconds // instead of on every request. This should help with performance. func (m *MetricsAggregator) updateUniquesGauge() { + defer m.wg.Done() ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() @@ -218,9 +223,33 @@ func (m *MetricsAggregator) MustRegister(reg prometheus.Registerer) { reg.MustRegister(m.dailyUniques) } +// LoadState restores HLL state from disk if it exists +func (m *MetricsAggregator) LoadState() error { + if m.cfg.Site.SaltRotation == "" { + return nil // State persistence not enabled + } + return m.estimator.Load(m.cfg.Server.StatePath) +} + // Shutdown performs graceful shutdown operations func (m *MetricsAggregator) Shutdown(ctx context.Context) error { + // Signal goroutine to stop m.Stop() + + // Wait for goroutine to finish, respecting context deadline + done := make(chan struct{}) + go func() { + m.wg.Wait() + close(done) + }() + + select { + case <-done: + // Goroutine finished successfully + case <-ctx.Done(): + return fmt.Errorf("shutdown timeout: %w", ctx.Err()) + } + // Persist HLL state if configured if m.cfg.Site.SaltRotation != "" { return m.estimator.Save(m.cfg.Server.StatePath) diff --git a/internal/aggregate/metrics_test.go b/internal/aggregate/metrics_test.go index f04d725..5630d8e 100644 --- a/internal/aggregate/metrics_test.go +++ b/internal/aggregate/metrics_test.go @@ -2,9 +2,11 @@ package aggregate import ( "context" + "errors" "os" "strings" "testing" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" @@ -278,3 +280,160 @@ func TestMetricsAggregator_Shutdown_DefaultStatePath(t *testing.T) { t.Logf("Shutdown returned error (might be expected): %v", err) } } + +func TestMetricsAggregator_Shutdown_RespectsContext(t *testing.T) { + registry := NewPathRegistry(100) + tmpDir := t.TempDir() + + cfg := config.Config{ + Site: config.SiteConfig{ + SaltRotation: "daily", + Collect: config.CollectConfig{ + Pageviews: true, + }, + }, + Server: config.ServerConfig{ + StatePath: tmpDir + "/hll.state", + }, + } + + agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), &cfg) + + // Create a context with very short timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + // Wait for context to expire + time.Sleep(10 * time.Millisecond) + + // Shutdown should respect context timeout + err := agg.Shutdown(ctx) + if err == nil { + t.Error("expected context deadline exceeded error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf( + "expected context.DeadlineExceeded, got %v", + err, + ) + } +} + +func TestMetricsAggregator_Shutdown_WaitsForGoroutine(t *testing.T) { + registry := NewPathRegistry(100) + tmpDir := t.TempDir() + + cfg := config.Config{ + Site: config.SiteConfig{ + SaltRotation: "daily", + Collect: config.CollectConfig{ + Pageviews: true, + }, + }, + Server: config.ServerConfig{ + StatePath: tmpDir + "/hll.state", + }, + } + + agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), &cfg) + + // Give the goroutine time to start + time.Sleep(10 * time.Millisecond) + + // Track if goroutine is still running + done := make(chan struct{}) + go func() { + ctx := context.Background() + agg.Shutdown(ctx) + close(done) + }() + + // Shutdown should complete quickly (goroutine should stop) + select { + case <-done: + // Success - shutdown completed + case <-time.After(1 * time.Second): + t.Fatal("Shutdown did not complete within timeout - goroutine not stopping") + } +} + +func TestMetricsAggregator_LoadState(t *testing.T) { + tmpDir := t.TempDir() + statePath := tmpDir + "/hll.state" + + cfg := config.Config{ + Site: config.SiteConfig{ + SaltRotation: "daily", + Collect: config.CollectConfig{ + Pageviews: true, + }, + }, + Server: config.ServerConfig{ + StatePath: statePath, + }, + } + + // Create first aggregator and add some visitors + registry1 := NewPathRegistry(100) + agg1 := NewMetricsAggregator(registry1, NewCustomEventRegistry(100), &cfg) + agg1.AddUnique("192.168.1.1", "Mozilla/5.0") + agg1.AddUnique("192.168.1.2", "Mozilla/5.0") + agg1.AddUnique("192.168.1.3", "Mozilla/5.0") + + // Get estimate before shutdown + estimate1 := agg1.estimator.Estimate() + if estimate1 == 0 { + t.Fatal("expected non-zero estimate before shutdown") + } + + // Shutdown to save state + ctx := context.Background() + if err := agg1.Shutdown(ctx); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + // Verify state file was created + if _, err := os.Stat(statePath); os.IsNotExist(err) { + t.Fatal("state file was not created") + } + + // Create second aggregator and load state + registry2 := NewPathRegistry(100) + agg2 := NewMetricsAggregator(registry2, NewCustomEventRegistry(100), &cfg) + + // Load should restore the state + if err := agg2.LoadState(); err != nil { + t.Fatalf("LoadState failed: %v", err) + } + + // Estimate should match (approximately - HLL is probabilistic) + estimate2 := agg2.estimator.Estimate() + if estimate2 != estimate1 { + t.Errorf("expected estimate %d after load, got %d", estimate1, estimate2) + } +} + +func TestMetricsAggregator_LoadState_NoFile(t *testing.T) { + tmpDir := t.TempDir() + statePath := tmpDir + "/nonexistent.state" + + cfg := config.Config{ + Site: config.SiteConfig{ + SaltRotation: "daily", + Collect: config.CollectConfig{ + Pageviews: true, + }, + }, + Server: config.ServerConfig{ + StatePath: statePath, + }, + } + + registry := NewPathRegistry(100) + agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), &cfg) + + // LoadState should not error if file doesn't exist (first run) + if err := agg.LoadState(); err != nil { + t.Errorf("LoadState should not error on missing file, got: %v", err) + } +} diff --git a/internal/aggregate/uniques.go b/internal/aggregate/uniques.go index 74ec2da..1c7b6de 100644 --- a/internal/aggregate/uniques.go +++ b/internal/aggregate/uniques.go @@ -12,14 +12,14 @@ import ( "github.com/axiomhq/hyperloglog" ) -// UniquesEstimator tracks unique visitors using HyperLogLog with daily salt rotation +// Tracks unique visitors using HyperLogLog with daily salt rotation type UniquesEstimator struct { hll *hyperloglog.Sketch currentDay string mu sync.Mutex } -// NewUniquesEstimator creates a new unique visitor estimator +// Creates a new unique visitor estimator func NewUniquesEstimator() *UniquesEstimator { return &UniquesEstimator{ hll: hyperloglog.New(), @@ -53,7 +53,7 @@ func (u *UniquesEstimator) Estimate() uint64 { return u.hll.Estimate() } -// dailySalt generates a deterministic salt based on the current date +// Cenerates a deterministic salt based on the current date // Same day = same salt, different day = different salt func dailySalt(t time.Time) string { // Use UTC to ensure consistent rotation regardless of timezone @@ -62,21 +62,21 @@ func dailySalt(t time.Time) string { return hex.EncodeToString(h[:]) } -// hashVisitor creates a privacy-preserving hash of visitor identity +// Creates a privacy-preserving hash of visitor identity func hashVisitor(ip, userAgent, salt string) string { combined := ip + "|" + userAgent + "|" + salt h := sha256.Sum256([]byte(combined)) return hex.EncodeToString(h[:]) } -// CurrentSalt returns the current salt for testing +// Returns the current salt for testing func (u *UniquesEstimator) CurrentSalt() string { u.mu.Lock() defer u.mu.Unlock() return u.currentDay } -// DailySalt is exported for testing +// Exported for testing func DailySalt(t time.Time) string { return dailySalt(t) } @@ -99,7 +99,10 @@ func (u *UniquesEstimator) Save(path string) error { func (u *UniquesEstimator) Load(path string) error { data, err := os.ReadFile(path) if err != nil { - return err // File not existing is OK (first run) + if os.IsNotExist(err) { + return nil // file not existing is OK (first run) + } + return err // other errors should be reported } u.mu.Lock() @@ -120,7 +123,7 @@ func (u *UniquesEstimator) Load(path string) error { return u.hll.UnmarshalBinary(parts[1]) } - // Different day - start fresh + // Different day, start fresh u.hll = hyperloglog.New() u.currentDay = today return nil