internal/aggregate: make shutdown context-aware proper goroutine sync

Adds `WaitGroup` to track background goroutine and make Shutdown respect
context deadlines

Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: Ia7f074725717f037412dacb93e34105b6a6a6964
This commit is contained in:
raf 2026-03-01 20:01:32 +03:00
commit 987ddd92cc
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
4 changed files with 205 additions and 9 deletions

View file

@ -2,7 +2,9 @@ package aggregate
import (
"context"
"fmt"
"regexp"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
@ -24,6 +26,7 @@ type MetricsAggregator struct {
dailyUniques prometheus.Gauge
estimator *UniquesEstimator
stopChan chan struct{}
wg sync.WaitGroup
}
// Creates a new metrics aggregator with dynamic labels based on config
@ -111,6 +114,7 @@ func NewMetricsAggregator(
// Start background goroutine to update HLL gauge periodically
if cfg.Site.SaltRotation != "" {
m.wg.Add(1)
go m.updateUniquesGauge()
}
@ -120,6 +124,7 @@ func NewMetricsAggregator(
// Background goroutine to update the unique visitors gauge every 10 seconds
// instead of on every request. This should help with performance.
func (m *MetricsAggregator) updateUniquesGauge() {
defer m.wg.Done()
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
@ -218,9 +223,33 @@ func (m *MetricsAggregator) MustRegister(reg prometheus.Registerer) {
reg.MustRegister(m.dailyUniques)
}
// LoadState restores HLL state from disk if it exists
func (m *MetricsAggregator) LoadState() error {
if m.cfg.Site.SaltRotation == "" {
return nil // State persistence not enabled
}
return m.estimator.Load(m.cfg.Server.StatePath)
}
// Shutdown performs graceful shutdown operations
func (m *MetricsAggregator) Shutdown(ctx context.Context) error {
// Signal goroutine to stop
m.Stop()
// Wait for goroutine to finish, respecting context deadline
done := make(chan struct{})
go func() {
m.wg.Wait()
close(done)
}()
select {
case <-done:
// Goroutine finished successfully
case <-ctx.Done():
return fmt.Errorf("shutdown timeout: %w", ctx.Err())
}
// Persist HLL state if configured
if m.cfg.Site.SaltRotation != "" {
return m.estimator.Save(m.cfg.Server.StatePath)

View file

@ -2,9 +2,11 @@ package aggregate
import (
"context"
"errors"
"os"
"strings"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
@ -278,3 +280,160 @@ func TestMetricsAggregator_Shutdown_DefaultStatePath(t *testing.T) {
t.Logf("Shutdown returned error (might be expected): %v", err)
}
}
func TestMetricsAggregator_Shutdown_RespectsContext(t *testing.T) {
registry := NewPathRegistry(100)
tmpDir := t.TempDir()
cfg := config.Config{
Site: config.SiteConfig{
SaltRotation: "daily",
Collect: config.CollectConfig{
Pageviews: true,
},
},
Server: config.ServerConfig{
StatePath: tmpDir + "/hll.state",
},
}
agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), &cfg)
// Create a context with very short timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
// Wait for context to expire
time.Sleep(10 * time.Millisecond)
// Shutdown should respect context timeout
err := agg.Shutdown(ctx)
if err == nil {
t.Error("expected context deadline exceeded error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf(
"expected context.DeadlineExceeded, got %v",
err,
)
}
}
func TestMetricsAggregator_Shutdown_WaitsForGoroutine(t *testing.T) {
registry := NewPathRegistry(100)
tmpDir := t.TempDir()
cfg := config.Config{
Site: config.SiteConfig{
SaltRotation: "daily",
Collect: config.CollectConfig{
Pageviews: true,
},
},
Server: config.ServerConfig{
StatePath: tmpDir + "/hll.state",
},
}
agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), &cfg)
// Give the goroutine time to start
time.Sleep(10 * time.Millisecond)
// Track if goroutine is still running
done := make(chan struct{})
go func() {
ctx := context.Background()
agg.Shutdown(ctx)
close(done)
}()
// Shutdown should complete quickly (goroutine should stop)
select {
case <-done:
// Success - shutdown completed
case <-time.After(1 * time.Second):
t.Fatal("Shutdown did not complete within timeout - goroutine not stopping")
}
}
func TestMetricsAggregator_LoadState(t *testing.T) {
tmpDir := t.TempDir()
statePath := tmpDir + "/hll.state"
cfg := config.Config{
Site: config.SiteConfig{
SaltRotation: "daily",
Collect: config.CollectConfig{
Pageviews: true,
},
},
Server: config.ServerConfig{
StatePath: statePath,
},
}
// Create first aggregator and add some visitors
registry1 := NewPathRegistry(100)
agg1 := NewMetricsAggregator(registry1, NewCustomEventRegistry(100), &cfg)
agg1.AddUnique("192.168.1.1", "Mozilla/5.0")
agg1.AddUnique("192.168.1.2", "Mozilla/5.0")
agg1.AddUnique("192.168.1.3", "Mozilla/5.0")
// Get estimate before shutdown
estimate1 := agg1.estimator.Estimate()
if estimate1 == 0 {
t.Fatal("expected non-zero estimate before shutdown")
}
// Shutdown to save state
ctx := context.Background()
if err := agg1.Shutdown(ctx); err != nil {
t.Fatalf("Shutdown failed: %v", err)
}
// Verify state file was created
if _, err := os.Stat(statePath); os.IsNotExist(err) {
t.Fatal("state file was not created")
}
// Create second aggregator and load state
registry2 := NewPathRegistry(100)
agg2 := NewMetricsAggregator(registry2, NewCustomEventRegistry(100), &cfg)
// Load should restore the state
if err := agg2.LoadState(); err != nil {
t.Fatalf("LoadState failed: %v", err)
}
// Estimate should match (approximately - HLL is probabilistic)
estimate2 := agg2.estimator.Estimate()
if estimate2 != estimate1 {
t.Errorf("expected estimate %d after load, got %d", estimate1, estimate2)
}
}
func TestMetricsAggregator_LoadState_NoFile(t *testing.T) {
tmpDir := t.TempDir()
statePath := tmpDir + "/nonexistent.state"
cfg := config.Config{
Site: config.SiteConfig{
SaltRotation: "daily",
Collect: config.CollectConfig{
Pageviews: true,
},
},
Server: config.ServerConfig{
StatePath: statePath,
},
}
registry := NewPathRegistry(100)
agg := NewMetricsAggregator(registry, NewCustomEventRegistry(100), &cfg)
// LoadState should not error if file doesn't exist (first run)
if err := agg.LoadState(); err != nil {
t.Errorf("LoadState should not error on missing file, got: %v", err)
}
}

View file

@ -12,14 +12,14 @@ import (
"github.com/axiomhq/hyperloglog"
)
// UniquesEstimator tracks unique visitors using HyperLogLog with daily salt rotation
// Tracks unique visitors using HyperLogLog with daily salt rotation
type UniquesEstimator struct {
hll *hyperloglog.Sketch
currentDay string
mu sync.Mutex
}
// NewUniquesEstimator creates a new unique visitor estimator
// Creates a new unique visitor estimator
func NewUniquesEstimator() *UniquesEstimator {
return &UniquesEstimator{
hll: hyperloglog.New(),
@ -53,7 +53,7 @@ func (u *UniquesEstimator) Estimate() uint64 {
return u.hll.Estimate()
}
// dailySalt generates a deterministic salt based on the current date
// Cenerates a deterministic salt based on the current date
// Same day = same salt, different day = different salt
func dailySalt(t time.Time) string {
// Use UTC to ensure consistent rotation regardless of timezone
@ -62,21 +62,21 @@ func dailySalt(t time.Time) string {
return hex.EncodeToString(h[:])
}
// hashVisitor creates a privacy-preserving hash of visitor identity
// Creates a privacy-preserving hash of visitor identity
func hashVisitor(ip, userAgent, salt string) string {
combined := ip + "|" + userAgent + "|" + salt
h := sha256.Sum256([]byte(combined))
return hex.EncodeToString(h[:])
}
// CurrentSalt returns the current salt for testing
// Returns the current salt for testing
func (u *UniquesEstimator) CurrentSalt() string {
u.mu.Lock()
defer u.mu.Unlock()
return u.currentDay
}
// DailySalt is exported for testing
// Exported for testing
func DailySalt(t time.Time) string {
return dailySalt(t)
}
@ -99,7 +99,10 @@ func (u *UniquesEstimator) Save(path string) error {
func (u *UniquesEstimator) Load(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err // File not existing is OK (first run)
if os.IsNotExist(err) {
return nil // file not existing is OK (first run)
}
return err // other errors should be reported
}
u.mu.Lock()
@ -120,7 +123,7 @@ func (u *UniquesEstimator) Load(path string) error {
return u.hll.UnmarshalBinary(parts[1])
}
// Different day - start fresh
// Different day, start fresh
u.hll = hyperloglog.New()
u.currentDay = today
return nil