Compare commits

..

No commits in common. "main" and "v1.0.0" have entirely different histories.

20 changed files with 227 additions and 712 deletions

View file

@ -3,21 +3,8 @@ name: Build & Test with Go
on:
push:
branches: [main]
paths:
- '**.go'
- 'go.mod'
- 'go.sum'
- '.github/workflows/go.yml'
- 'testdata/**'
pull_request:
branches: [main]
paths:
- '**.go'
- 'go.mod'
- 'go.sum'
- '.github/workflows/go.yml'
- 'testdata/**'
jobs:
test:

View file

@ -39,7 +39,7 @@ jobs:
GOARCH: ${{ matrix.goarch }}
- name: Upload artifact
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v4
with:
name: watchdog-${{ matrix.goos }}-${{ matrix.goarch }}
path: watchdog-*
@ -50,17 +50,17 @@ jobs:
permissions:
contents: write
steps:
- uses: actions/checkout@v6
- name: Download all artifacts
uses: actions/download-artifact@v8
with:
path: artifacts
pattern: watchdog-*
merge-multiple: true
- name: Create GitHub Release
uses: softprops/action-gh-release@v2
with:
files: artifacts/watchdog-*
files: artifacts/**/*.tar.gz artifacts/**/*.zip
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View file

@ -25,23 +25,18 @@ type versionInfo struct {
BuildDate string `json:"buildDate"`
}
func getVersionInfo() versionInfo {
data, err := os.ReadFile("version.json")
if err != nil {
return versionInfo{}
}
var v versionInfo
if err := json.Unmarshal(data, &v); err != nil {
return versionInfo{}
}
return v
}
func getVersion() string {
if version != "" {
return version
}
v := getVersionInfo()
data, err := os.ReadFile("version.json")
if err != nil {
return "dev"
}
var v versionInfo
if err := json.Unmarshal(data, &v); err != nil {
return "dev"
}
if v.Version != "" {
return v.Version
}
@ -52,7 +47,14 @@ func getCommit() string {
if commit != "" {
return commit
}
v := getVersionInfo()
data, err := os.ReadFile("version.json")
if err != nil {
return "none"
}
var v versionInfo
if err := json.Unmarshal(data, &v); err != nil {
return "none"
}
if v.Commit != "" {
return v.Commit
}

View file

@ -9,10 +9,8 @@ import (
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
@ -22,7 +20,6 @@ import (
"notashelf.dev/watchdog/internal/health"
"notashelf.dev/watchdog/internal/limits"
"notashelf.dev/watchdog/internal/normalize"
"notashelf.dev/watchdog/internal/ratelimit"
)
func Run(cfg *config.Config) error {
@ -77,7 +74,7 @@ func Run(cfg *config.Config) error {
// Setup routes
mux := http.NewServeMux()
// Metrics endpoint with optional basic auth and rate limiting
// Metrics endpoint with optional basic auth
metricsHandler := promhttp.HandlerFor(promRegistry, promhttp.HandlerOpts{
EnableOpenMetrics: true,
})
@ -90,20 +87,6 @@ func Run(cfg *config.Config) error {
)
}
// Add rate limiting to metrics endpoint
if cfg.Limits.MaxMetricsPerMinute > 0 {
metricsRateLimiter := ratelimit.NewTokenBucket(
cfg.Limits.MaxMetricsPerMinute,
cfg.Limits.MaxMetricsPerMinute,
time.Minute,
)
metricsHandler = rateLimitMiddleware(metricsHandler, metricsRateLimiter)
}
// Add response size limit to metrics endpoint (10MB max)
const maxMetricsResponseSize = 10 * 1024 * 1024 // 10MB
metricsHandler = responseSizeLimitMiddleware(metricsHandler, maxMetricsResponseSize)
mux.Handle(cfg.Server.MetricsPath, metricsHandler)
// Ingestion endpoint
@ -184,69 +167,6 @@ func basicAuth(next http.Handler, username, password string) http.Handler {
})
}
// Wraps a handler with rate limiting
func rateLimitMiddleware(next http.Handler, limiter *ratelimit.TokenBucket) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// Wraps http.ResponseWriter to enforce max response size
type limitedResponseWriter struct {
http.ResponseWriter
maxSize int
written int
limitExceeded bool
}
func (w *limitedResponseWriter) Write(p []byte) (int, error) {
if w.limitExceeded {
return 0, fmt.Errorf("response size limit exceeded")
}
if w.written+len(p) > w.maxSize {
w.limitExceeded = true
w.Header().Set("X-Response-Truncated", "true")
http.Error(w.ResponseWriter, "Response size limit exceeded", http.StatusInternalServerError)
return 0, fmt.Errorf("response size limit exceeded: %d bytes", w.maxSize)
}
n, err := w.ResponseWriter.Write(p)
w.written += n
return n, err
}
// Wraps a handler with response size limiting
func responseSizeLimitMiddleware(next http.Handler, maxSize int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limited := &limitedResponseWriter{
ResponseWriter: w,
maxSize: maxSize,
}
next.ServeHTTP(limited, r)
})
}
// Sanitizes a path for logging to prevent log injection attacks. Uses `strconv.Quote`
// to properly escape control characters and special bytes.
func sanitizePathForLog(path string) string {
escaped := strconv.Quote(path)
if len(escaped) >= 2 && escaped[0] == '"' && escaped[len(escaped)-1] == '"' {
escaped = escaped[1 : len(escaped)-1]
}
// Limit path length to prevent log flooding
const maxLen = 200
if len(escaped) > maxLen {
return escaped[:maxLen] + "..."
}
return escaped
}
// Creates a file server that only serves whitelisted files. Blocks dotfiles, .git, .env, etc.
// TODO: I need to hook this up to eris somehow so I can just forward the paths that are being
// scanned despite not being on a whitelist. Would be a good way of detecting scrapers, maybe.
@ -259,7 +179,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
// Block directory listings
if strings.HasSuffix(path, "/") {
blockedRequests.WithLabelValues("directory_listing").Inc()
log.Printf("Blocked directory listing attempt: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
log.Printf("Blocked directory listing attempt: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r)
return
}
@ -268,7 +188,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
for segment := range strings.SplitSeq(path, "/") {
if strings.HasPrefix(segment, ".") {
blockedRequests.WithLabelValues("dotfile").Inc()
log.Printf("Blocked dotfile access: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
log.Printf("Blocked dotfile access: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r)
return
}
@ -279,7 +199,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
strings.HasSuffix(lower, ".bak") ||
strings.HasSuffix(lower, "~") {
blockedRequests.WithLabelValues("sensitive_file").Inc()
log.Printf("Blocked sensitive file access: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
log.Printf("Blocked sensitive file access: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r)
return
}
@ -289,7 +209,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
ext := strings.ToLower(filepath.Ext(path))
if ext != ".js" && ext != ".html" && ext != ".css" {
blockedRequests.WithLabelValues("invalid_extension").Inc()
log.Printf("Blocked invalid extension: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
log.Printf("Blocked invalid extension: %s from %s", path, r.RemoteAddr)
http.NotFound(w, r)
return
}

View file

@ -1,116 +0,0 @@
package watchdog
import (
"strings"
"testing"
)
func TestSanitizePathForLog(t *testing.T) {
tests := []struct {
name string
input string
want string
maxLen int // expected max length check
}{
{
name: "normal path",
input: "/web/beacon.js",
want: "/web/beacon.js",
maxLen: 200,
},
{
name: "path with newlines",
input: "/test\nmalicious",
want: `/test\nmalicious`,
maxLen: 200,
},
{
name: "path with carriage return",
input: "/test\rmalicious",
want: `/test\rmalicious`,
maxLen: 200,
},
{
name: "path with tabs",
input: "/test\tmalicious",
want: `/test\tmalicious`,
maxLen: 200,
},
{
name: "path with null bytes",
input: "/test\x00malicious",
want: `/test\x00malicious`,
maxLen: 200,
},
{
name: "path with quotes",
input: `/test"malicious`,
want: `/test\"malicious`,
maxLen: 200,
},
{
name: "path with backslash",
input: `/test\malicious`,
want: `/test\\malicious`,
maxLen: 200,
},
{
name: "control characters",
input: "/test\x01\x02\x1fmalicious",
want: `/test\x01\x02\x1fmalicious`,
maxLen: 200,
},
{
name: "truncation at 200 chars",
input: "/" + strings.Repeat("a", 250),
want: "/" + strings.Repeat("a", 199) + "...",
maxLen: 203, // 200 chars + "..." = 203
},
{
name: "empty string",
input: "",
want: "",
maxLen: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sanitizePathForLog(tt.input)
if got != tt.want {
t.Errorf("sanitizePathForLog(%q) = %q, want %q", tt.input, got, tt.want)
}
if len(got) > tt.maxLen {
t.Errorf("sanitizePathForLog(%q) length = %d, exceeds max %d", tt.input, len(got), tt.maxLen)
}
})
}
}
func TestSanitizePathForLog_LogInjectionPrevention(t *testing.T) {
// Log injection attempts should be neutralized
maliciousPaths := []string{
"/api\nINFO: fake log entry",
"/test\r\nERROR: fake error",
"/.git/config\x00", // null byte injection
}
for _, path := range maliciousPaths {
sanitized := sanitizePathForLog(path)
// Check that newlines are escaped, not literal
if strings.Contains(sanitized, "\n") || strings.Contains(sanitized, "\r") {
t.Errorf("sanitizePathForLog(%q) contains literal newlines: %q", path, sanitized)
}
// Check that null bytes are escaped
if strings.Contains(sanitized, "\x00") {
t.Errorf("sanitizePathForLog(%q) contains literal null byte: %q", path, sanitized)
}
}
}
func BenchmarkSanitizePathForLog(b *testing.B) {
path := "/test/path/with\nnewlines\rand\ttabs"
for b.Loop() {
_ = sanitizePathForLog(path)
}
}

View file

@ -63,8 +63,6 @@ limits:
max_custom_events: 100
# Maximum events per minute (rate limiting, 0 = unlimited)
max_events_per_minute: 10000
# Maximum metrics endpoint requests per minute (rate limiting, 0 = unlimited, default: 30)
max_metrics_per_minute: 60
# Device classification breakpoints (screen width in pixels)
device_breakpoints:
@ -87,8 +85,6 @@ security:
- "*" # Or specific domains: ["https://example.com", "https://www.example.com"]
# Basic authentication for /metrics endpoint
# Password can also be set via environment variable:
# - `$ export WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD=your-secret-password`
metrics_auth:
enabled: false
username: "admin"

View file

@ -6,17 +6,12 @@ Grafana.
> [!IMPORTANT]
>
> **Why you need a time-series database:**
> **Why you need Prometheus:**
>
> - Watchdog exposes _current state_ (counters, gauges)
> - A TSDB _scrapes periodically_ and _stores time-series data_
> - Grafana _visualizes_ the historical data
> - Prometheus _scrapes periodically_ and _stores time-series data_
> - Grafana _visualizes_ the historical data from Prometheus
> - Grafana cannot directly scrape Prometheus `/metrics` endpoints
>
> **Compatible databases:**
>
> - [Prometheus](#prometheus-setup),
> - [VictoriaMetrics](#victoriametrics), or any Prometheus-compatible scraper
## Prometheus Setup
@ -129,7 +124,7 @@ For multiple Watchdog instances:
datasources.settings.datasources = [{
name = "Prometheus";
type = "prometheus";
url = "http://localhost:9090"; # Or "http://localhost:8428" for VictoriaMetrics
url = "http://localhost:9090";
isDefault = true;
}];
};
@ -233,34 +228,7 @@ sum by (instance) (rate(web_pageviews_total[5m]))
### VictoriaMetrics
VictoriaMetrics is a fast, cost-effective monitoring solution and time-series
database that is 100% compatible with Prometheus exposition format. Watchdog's
`/metrics` endpoint can be scraped directly by VictoriaMetrics without requiring
Prometheus.
#### Direct Scraping (Recommended)
VictoriaMetrics single-node mode can scrape Watchdog directly using standard
Prometheus scrape configuration:
**Configuration file (`/etc/victoriametrics/scrape.yml`):**
```yaml
scrape_configs:
- job_name: "watchdog"
static_configs:
- targets: ["localhost:8080"]
scrape_interval: 15s
metrics_path: /metrics
```
**Run VictoriaMetrics:**
```bash
victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml
```
**NixOS configuration:**
Drop-in Prometheus replacement with better performance and compression:
```nix
{
@ -268,84 +236,15 @@ victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml
enable = true;
listenAddress = ":8428";
retentionPeriod = "12month";
# Define scrape configs directly. 'prometheusConfig' is the configuration for
# Prometheus-style metrics endpoints, which Watchdog exports.
prometheusConfig = {
scrape_configs = [
{
job_name = "watchdog";
scrape_interval = "15s";
static_configs = [{
targets = [ "localhost:8080" ]; # replace the port
}];
}
];
};
};
}
```
#### Using `vmagent`
Alternatively, for distributed setups or when you need more advanced features
like relabeling, you may use `vmagent`:
```nix
{
services.vmagent = {
enable = true;
remoteWriteUrl = "http://localhost:8428/api/v1/write";
prometheusConfig = {
scrape_configs = [
{
job_name = "watchdog";
static_configs = [{
targets = [ "localhost:8080" ];
}];
}
];
};
};
services.victoriametrics = {
enable = true;
listenAddress = ":8428";
};
}
```
#### Prometheus Remote Write
If you are migrating from Prometheus, or if you need PromQL compatibility, or if
you just really like using Prometheus for some inexplicable reason you may keep
Prometheus but use VictoriaMetrics to remote-write.
```nix
{
# Configure Prometheus to remote-write to VictoriaMetrics
services.prometheus = {
enable = true;
port = 9090;
scrapeConfigs = [
{
job_name = "watchdog";
static_configs = [{
targets = [ "localhost:8080" ];
}];
}
];
remoteWrite = [{
url = "http://localhost:8428/api/v1/write";
}];
};
services.victoriametrics = {
enable = true;
listenAddress = ":8428";
};
}
```
@ -375,10 +274,10 @@ metrics:
## Monitoring the Monitoring
Monitor your scraper:
Monitor Prometheus itself:
```promql
# Scrape success rate
# Prometheus scrape success rate
up{job="watchdog"}
# Scrape duration
@ -388,9 +287,14 @@ scrape_duration_seconds{job="watchdog"}
time() - timestamp(up{job="watchdog"})
```
For VictoriaMetrics, you can also monitor ingestion stats:
## Additional Recommendations
```bash
# VM internal metrics
curl http://localhost:8428/metrics | grep vm_rows_inserted_total
```
1. **Retention**: Set `--storage.tsdb.retention.time=30d` or longer based on
disk space
2. **Backups**: Back up `/var/lib/prometheus` periodically (or whatever your
state directory is)
3. **Alerting**: Configure Prometheus alerting rules for critical metrics
4. **High Availability**: Run multiple Prometheus instances with identical
configs
5. **Remote Storage**: For long-term storage, use Thanos, Cortex, or
VictoriaMetrics

View file

@ -54,20 +54,3 @@ func (r *CustomEventRegistry) OverflowCount() int {
defer r.mu.RUnlock()
return r.overflowCount
}
// Contains checks if an event name exists in the registry.
func (r *CustomEventRegistry) Contains(eventName string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.events[eventName]
return exists
}
// Count returns the number of unique events in the registry.
func (r *CustomEventRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.events)
}

View file

@ -162,11 +162,7 @@ func (m *MetricsAggregator) RecordPageview(path, country, device, referrer, doma
labels := prometheus.Labels{"path": sanitizeLabel(path)}
if m.cfg.Site.Collect.Country {
if country == "" {
labels["country"] = "unknown"
} else {
labels["country"] = sanitizeLabel(country)
}
labels["country"] = sanitizeLabel(country)
}
if m.cfg.Site.Collect.Device {

View file

@ -6,7 +6,6 @@ import (
"encoding/hex"
"fmt"
"os"
"strings"
"sync"
"time"
@ -18,18 +17,15 @@ type UniquesEstimator struct {
hll *hyperloglog.Sketch
salt string
rotation string // "daily" or "hourly"
saltKey string // cached time key to avoid regeneration
mu sync.Mutex
}
// Creates a new unique visitor estimator
func NewUniquesEstimator(rotation string) *UniquesEstimator {
now := time.Now()
return &UniquesEstimator{
hll: hyperloglog.New(),
salt: generateSalt(now, rotation),
salt: generateSalt(time.Now(), rotation),
rotation: rotation,
saltKey: getSaltKey(now, rotation),
}
}
@ -40,13 +36,11 @@ func (u *UniquesEstimator) Add(ip, userAgent string) {
defer u.mu.Unlock()
// Check if we need to rotate to a new period
now := time.Now()
currentKey := getSaltKey(now, u.rotation)
if currentKey != u.saltKey {
currentSalt := generateSalt(time.Now(), u.rotation)
if currentSalt != u.salt {
// Reset HLL for new period
u.hll = hyperloglog.New()
u.salt = generateSaltFromKey(currentKey)
u.saltKey = currentKey
u.salt = currentSalt
}
// Hash visitor with salt to prevent cross-period tracking
@ -61,36 +55,24 @@ func (u *UniquesEstimator) Estimate() uint64 {
return u.hll.Estimate()
}
// Returns the time-based key for salt generation without hashing
func getSaltKey(t time.Time, rotation string) string {
if rotation == "hourly" {
return t.UTC().Format("2006-01-02T15")
}
return t.UTC().Format("2006-01-02")
}
// Creates a salt from a pre-computed key
func generateSaltFromKey(key string) string {
h := sha256.Sum256([]byte("watchdog-salt-" + key))
return hex.EncodeToString(h[:])
}
// Generates a deterministic salt based on the rotation mode
// Daily: same day = same salt, different day = different salt
// Hourly: same hour = same salt, different hour = different salt
func generateSalt(t time.Time, rotation string) string {
return generateSaltFromKey(getSaltKey(t, rotation))
var key string
if rotation == "hourly" {
key = t.UTC().Format("2006-01-02T15")
} else {
key = t.UTC().Format("2006-01-02")
}
h := sha256.Sum256([]byte("watchdog-salt-" + key))
return hex.EncodeToString(h[:])
}
// Creates a privacy-preserving hash of visitor identity
func hashVisitor(ip, userAgent, salt string) string {
var sb strings.Builder
sb.WriteString(ip)
sb.WriteString("|")
sb.WriteString(userAgent)
sb.WriteString("|")
sb.WriteString(salt)
h := sha256.Sum256([]byte(sb.String()))
combined := ip + "|" + userAgent + "|" + salt
h := sha256.Sum256([]byte(combined))
return hex.EncodeToString(h[:])
}
@ -140,20 +122,16 @@ func (u *UniquesEstimator) Load(path string) error {
}
savedSalt := string(parts[0])
now := time.Now()
currentKey := getSaltKey(now, u.rotation)
currentSalt := generateSaltFromKey(currentKey)
currentSalt := generateSalt(time.Now(), u.rotation)
// Only restore if it's the same period
if savedSalt == currentSalt {
u.salt = savedSalt
u.saltKey = currentKey
return u.hll.UnmarshalBinary(parts[1])
}
// Different period, start fresh
u.hll = hyperloglog.New()
u.salt = currentSalt
u.saltKey = currentKey
return nil
}

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"slices"
"notashelf.dev/watchdog/internal/limits"
)
@ -39,8 +40,40 @@ func ParseEvent(body io.Reader) (*Event, error) {
return &event, nil
}
// Validate checks if the event is valid using a domain map
func (e *Event) Validate(allowedDomains map[string]bool) error {
// Validate checks if the event is valid for the given domains
func (e *Event) Validate(allowedDomains []string) error {
if e.Domain == "" {
return fmt.Errorf("domain required")
}
// Check if domain is in allowed list
allowed := slices.Contains(allowedDomains, e.Domain)
if !allowed {
return fmt.Errorf("domain not allowed")
}
if e.Path == "" {
return fmt.Errorf("path required")
}
if len(e.Path) > limits.MaxPathLen {
return fmt.Errorf("path too long")
}
if len(e.Referrer) > limits.MaxRefLen {
return fmt.Errorf("referrer too long")
}
// Validate screen width is in reasonable range
if e.Width < 0 || e.Width > limits.MaxWidth {
return fmt.Errorf("invalid width")
}
return nil
}
// ValidateWithMap checks if the event is valid using a domain map (O(1) lookup)
func (e *Event) ValidateWithMap(allowedDomains map[string]bool) error {
if e.Domain == "" {
return fmt.Errorf("domain required")
}

View file

@ -115,7 +115,7 @@ func TestValidateEvent(t *testing.T) {
tests := []struct {
name string
event Event
domains map[string]bool
domains []string
wantErr bool
}{
{
@ -124,7 +124,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "/home",
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: false,
},
{
@ -134,7 +134,7 @@ func TestValidateEvent(t *testing.T) {
Path: "/signup",
Event: "signup",
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: false,
},
{
@ -143,7 +143,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "wrong.com",
Path: "/home",
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: true,
},
{
@ -152,7 +152,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "",
Path: "/home",
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: true,
},
{
@ -161,7 +161,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "",
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: true,
},
{
@ -170,7 +170,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "/" + strings.Repeat("a", 3000),
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: true,
},
{
@ -180,7 +180,7 @@ func TestValidateEvent(t *testing.T) {
Path: "/home",
Referrer: strings.Repeat("a", 3000),
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: true,
},
{
@ -189,7 +189,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "example.com",
Path: "/" + strings.Repeat("a", 2000),
},
domains: map[string]bool{"example.com": true},
domains: []string{"example.com"},
wantErr: false,
},
{
@ -198,7 +198,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "site1.com",
Path: "/home",
},
domains: map[string]bool{"site1.com": true, "site2.com": true},
domains: []string{"site1.com", "site2.com"},
wantErr: false,
},
{
@ -207,7 +207,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "site2.com",
Path: "/about",
},
domains: map[string]bool{"site1.com": true, "site2.com": true},
domains: []string{"site1.com", "site2.com"},
wantErr: false,
},
{
@ -216,7 +216,7 @@ func TestValidateEvent(t *testing.T) {
Domain: "site3.com",
Path: "/home",
},
domains: map[string]bool{"site1.com": true, "site2.com": true},
domains: []string{"site1.com", "site2.com"},
wantErr: true,
},
}
@ -231,7 +231,24 @@ func TestValidateEvent(t *testing.T) {
}
}
func BenchmarkValidate(b *testing.B) {
func BenchmarkValidate_SliceLookup(b *testing.B) {
// Simulate multi-site with 50 domains
domains := make([]string, 50)
for i := range 50 {
domains[i] = strings.Repeat("site", i) + ".com"
}
event := Event{
Domain: domains[49], // Worst case - last in list
Path: "/test",
}
for b.Loop() {
_ = event.Validate(domains)
}
}
func BenchmarkValidate_MapLookup(b *testing.B) {
// Simulate multi-site with 50 domains
domainMap := make(map[string]bool, 50)
for i := range 50 {
@ -239,11 +256,11 @@ func BenchmarkValidate(b *testing.B) {
}
event := Event{
Domain: strings.Repeat("site", 49) + ".com",
Domain: strings.Repeat("site", 49) + ".com", // any position
Path: "/test",
}
for b.Loop() {
_ = event.Validate(domainMap)
_ = event.ValidateWithMap(domainMap)
}
}

View file

@ -1,14 +1,10 @@
package api
import (
"crypto/rand"
"encoding/hex"
"fmt"
mrand "math/rand"
"math/rand"
"net"
"net/http"
"strings"
"time"
"notashelf.dev/watchdog/internal/aggregate"
"notashelf.dev/watchdog/internal/config"
@ -18,16 +14,14 @@ import (
// Handles incoming analytics events
type IngestionHandler struct {
cfg *config.Config
domainMap map[string]bool // O(1) domain validation
corsOriginMap map[string]bool // O(1) CORS origin validation
pathNorm *normalize.PathNormalizer
pathRegistry *aggregate.PathRegistry
refRegistry *normalize.ReferrerRegistry
metricsAgg *aggregate.MetricsAggregator
rateLimiter *ratelimit.TokenBucket
rng *mrand.Rand
trustedNetworks []*net.IPNet // pre-parsed CIDR networks
cfg *config.Config
domainMap map[string]bool // O(1) domain validation
pathNorm *normalize.PathNormalizer
pathRegistry *aggregate.PathRegistry
refRegistry *normalize.ReferrerRegistry
metricsAgg *aggregate.MetricsAggregator
rateLimiter *ratelimit.TokenBucket
rng *rand.Rand
}
// Creates a new ingestion handler
@ -53,51 +47,19 @@ func NewIngestionHandler(
domainMap[domain] = true
}
// Build CORS origin map for O(1) lookup
corsOriginMap := make(map[string]bool, len(cfg.Security.CORS.AllowedOrigins))
for _, origin := range cfg.Security.CORS.AllowedOrigins {
corsOriginMap[origin] = true
}
// Pre-parse trusted proxy CIDRs to avoid re-parsing on each request
trustedNetworks := make([]*net.IPNet, 0, len(cfg.Security.TrustedProxies))
for _, cidr := range cfg.Security.TrustedProxies {
if _, network, err := net.ParseCIDR(cidr); err == nil {
trustedNetworks = append(trustedNetworks, network)
} else if ip := net.ParseIP(cidr); ip != nil {
// Single IP - create a /32 or /128 network
var mask net.IPMask
if ip.To4() != nil {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
trustedNetworks = append(trustedNetworks, &net.IPNet{IP: ip, Mask: mask})
}
}
return &IngestionHandler{
cfg: cfg,
domainMap: domainMap,
corsOriginMap: corsOriginMap,
pathNorm: pathNorm,
pathRegistry: pathRegistry,
refRegistry: refRegistry,
metricsAgg: metricsAgg,
rateLimiter: limiter,
rng: mrand.New(mrand.NewSource(time.Now().UnixNano())),
trustedNetworks: trustedNetworks,
cfg: cfg,
domainMap: domainMap,
pathNorm: pathNorm,
pathRegistry: pathRegistry,
refRegistry: refRegistry,
metricsAgg: metricsAgg,
rateLimiter: limiter,
rng: rand.New(rand.NewSource(42)),
}
}
func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Generate or extract request ID for tracing
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
w.Header().Set("X-Request-ID", requestID)
// Handle CORS preflight
if r.Method == http.MethodOptions {
h.handleCORS(w, r)
@ -141,8 +103,8 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Validate event via map lookup
if err := event.Validate(h.domainMap); err != nil {
// Validate event via map lookup (also O(1))
if err := event.ValidateWithMap(h.domainMap); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
@ -219,8 +181,13 @@ func (h *IngestionHandler) handleCORS(w http.ResponseWriter, r *http.Request) {
}
// Check if origin is allowed
// This uses map so that it's O(1)
allowed := h.corsOriginMap["*"] || h.corsOriginMap[origin]
allowed := false
for _, allowedOrigin := range h.cfg.Security.CORS.AllowedOrigins {
if allowedOrigin == "*" || allowedOrigin == origin {
allowed = true
break
}
}
if allowed {
if origin == "*" {
@ -246,8 +213,13 @@ func (h *IngestionHandler) extractIP(r *http.Request) string {
// Check if we should trust proxy headers
trustProxy := false
if len(h.trustedNetworks) > 0 {
trustProxy = h.ipInNetworks(remoteIP, h.trustedNetworks)
if len(h.cfg.Security.TrustedProxies) > 0 {
for _, trustedCIDR := range h.cfg.Security.TrustedProxies {
if h.ipInCIDR(remoteIP, trustedCIDR) {
trustProxy = true
break
}
}
}
// If not trusting proxy, return direct IP
@ -261,43 +233,42 @@ func (h *IngestionHandler) extractIP(r *http.Request) string {
ips := strings.Split(xff, ",")
for i := len(ips) - 1; i >= 0; i-- {
ip := strings.TrimSpace(ips[i])
if testIP := net.ParseIP(ip); testIP != nil {
// Only accept this IP if it's NOT from a trusted proxy
if !h.ipInNetworks(ip, h.trustedNetworks) {
return ip
}
if !h.ipInCIDR(ip, "0.0.0.0/0") {
continue
}
return ip
}
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
// Validate the IP format and ensure it's not from a trusted proxy
if testIP := net.ParseIP(xri); testIP != nil {
if !h.ipInNetworks(xri, h.trustedNetworks) {
return xri
}
}
return xri
}
// Fall back to RemoteAddr
return remoteIP
}
// Checks if an IP address is within any of the trusted networks
func (h *IngestionHandler) ipInNetworks(ip string, networks []*net.IPNet) bool {
// Checks if an IP address is within a CIDR range
func (h *IngestionHandler) ipInCIDR(ip, cidr string) bool {
// Parse the IP address
testIP := net.ParseIP(ip)
if testIP == nil {
return false
}
for _, network := range networks {
if network.Contains(testIP) {
return true
// Parse the CIDR
_, network, err := net.ParseCIDR(cidr)
if err != nil {
// If it's not a CIDR, try as a single IP
cidrIP := net.ParseIP(cidr)
if cidrIP == nil {
return false
}
return testIP.Equal(cidrIP)
}
return false
return network.Contains(testIP)
}
// Classifies device using both screen width and User-Agent parsing
@ -340,17 +311,3 @@ func (h *IngestionHandler) classifyDevice(width int, userAgent string) string {
return "unknown"
}
// Creates a unique request ID for tracing.
// Uses 8 bytes (64 bits) of randomness which produces 16 hex characters.
// 2^64 possible IDs (~18 quintillion) provides sufficient uniqueness for
// request tracing while keeping IDs reasonably short in logs and headers.
func generateRequestID() string {
// 8 bytes = 64 bits = 16 hex chars = 2^64 possible IDs
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to timestamp if crypto/rand fails
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}

View file

@ -218,107 +218,6 @@ func newTestHandler(cfg *config.Config) *IngestionHandler {
return NewIngestionHandler(cfg, pathNorm, pathRegistry, refRegistry, metricsAgg)
}
func TestExtractIP(t *testing.T) {
cfg := &config.Config{
Site: config.SiteConfig{
Domains: []string{"example.com"},
},
Limits: config.LimitsConfig{
MaxPaths: 100,
MaxSources: 50,
},
Security: config.SecurityConfig{
TrustedProxies: []string{"10.0.0.0/8", "192.168.1.1"},
},
}
h := newTestHandler(cfg)
tests := []struct {
name string
remoteAddr string
headers map[string]string
want string
}{
{
name: "direct connection no proxy",
remoteAddr: "192.168.1.100:12345",
headers: map[string]string{},
want: "192.168.1.100",
},
{
name: "trusted proxy with X-Forwarded-For",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Forwarded-For": "203.0.113.1, 10.0.0.5",
},
want: "203.0.113.1",
},
{
name: "trusted proxy with X-Real-IP",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Real-IP": "203.0.113.2",
},
want: "203.0.113.2",
},
{
name: "X-Real-IP from trusted network should be ignored",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Real-IP": "10.0.0.50", // trusted network, should fall back
},
want: "10.0.0.1", // falls back to remoteAddr
},
{
name: "X-Real-IP invalid IP should be ignored",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Real-IP": "not-an-ip",
},
want: "10.0.0.1", // falls back
},
{
name: "untrusted proxy X-Forwarded-For ignored",
remoteAddr: "203.0.113.50:12345",
headers: map[string]string{
"X-Forwarded-For": "1.2.3.4",
},
want: "203.0.113.50", // uses remoteAddr, ignores header
},
{
name: "untrusted proxy X-Real-IP ignored",
remoteAddr: "203.0.113.50:12345",
headers: map[string]string{
"X-Real-IP": "1.2.3.4",
},
want: "203.0.113.50", // uses remoteAddr, ignores header
},
{
name: "X-Forwarded-For all trusted falls back",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Forwarded-For": "10.0.0.2, 10.0.0.3",
},
want: "10.0.0.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/event", nil)
req.RemoteAddr = tt.remoteAddr
for k, v := range tt.headers {
req.Header.Set(k, v)
}
got := h.extractIP(req)
if got != tt.want {
t.Errorf("extractIP() = %q, want %q", got, tt.want)
}
})
}
}
func TestClassifyDevice_UA(t *testing.T) {
cfg := &config.Config{
Limits: config.LimitsConfig{

View file

@ -45,12 +45,11 @@ type PathConfig struct {
// Cardinality limits
type LimitsConfig struct {
MaxPaths int `yaml:"max_paths"`
MaxEventsPerMinute int `yaml:"max_events_per_minute"`
MaxSources int `yaml:"max_sources"`
MaxCustomEvents int `yaml:"max_custom_events"`
DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"`
MaxMetricsPerMinute int `yaml:"max_metrics_per_minute"` // rate limit for metrics endpoint
MaxPaths int `yaml:"max_paths"`
MaxEventsPerMinute int `yaml:"max_events_per_minute"`
MaxSources int `yaml:"max_sources"`
MaxCustomEvents int `yaml:"max_custom_events"`
DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"`
}
// Device classification breakpoints
@ -73,11 +72,10 @@ type CORSConfig struct {
}
// Authentication for metrics endpoint
// Password can be set via environment variable: WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD
type AuthConfig struct {
Enabled bool `yaml:"enabled"`
Username string `yaml:"username"`
Password string `yaml:"password"` // can use env var WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD
Password string `yaml:"password"`
}
// Server endpoints
@ -151,10 +149,6 @@ func (c *Config) Validate() error {
c.Limits.MaxCustomEvents = 100 // Default
}
if c.Limits.MaxMetricsPerMinute <= 0 {
c.Limits.MaxMetricsPerMinute = 30 // Default: 30 requests per minute
}
if c.Site.Path.MaxSegments < 0 {
return fmt.Errorf("site.path.max_segments cannot be negative")
}

View file

@ -48,36 +48,38 @@ func (n *PathNormalizer) Normalize(path string) string {
path = "/" + path
}
// Process segments in-place to minimize allocations
// Split into segments, first element is *always* empty for paths starting with '/'
segments := strings.Split(path, "/")
if len(segments) > 0 && segments[0] == "" {
segments = segments[1:]
}
// Process segments in a single pass: remove empty, resolve . and ..
writeIdx := 0
for i := 0; i < len(segments); i++ {
seg := segments[i]
// Skip empty segments (from double slashes or leading /)
if seg == "" {
continue
// Remove empty segments (from double slashes)
filtered := make([]string, 0, len(segments))
for _, seg := range segments {
if seg != "" {
filtered = append(filtered, seg)
}
}
segments = filtered
// Resolve . and .. segments
resolved := make([]string, 0, len(segments))
for _, seg := range segments {
if seg == "." {
// Skip current directory
continue
} else if seg == ".." {
// Go up one level if possible
if writeIdx > 0 {
writeIdx--
if len(resolved) > 0 {
resolved = resolved[:len(resolved)-1]
}
// If already at root, skip ..
} else {
// Keep this segment
segments[writeIdx] = seg
writeIdx++
resolved = append(resolved, seg)
}
}
segments = segments[:writeIdx]
segments = resolved
// Collapse numeric segments
if n.cfg.CollapseNumericSegments {

View file

@ -1,7 +1,6 @@
package normalize
import (
"net"
"net/url"
"strings"
@ -22,20 +21,34 @@ func isInternalHost(hostname string) bool {
return true
}
// Check if hostname is an IP address
if ip := net.ParseIP(hostname); ip != nil {
// Private IPv4 ranges (RFC1918)
if ip.IsPrivate() {
return true
}
// Additional localhost checks for IP formats
if ip.IsLoopback() {
return true
}
// Link-local addresses
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
// Private IPv4 ranges (RFC1918)
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.16.") ||
strings.HasPrefix(hostname, "172.17.") ||
strings.HasPrefix(hostname, "172.18.") ||
strings.HasPrefix(hostname, "172.19.") ||
strings.HasPrefix(hostname, "172.20.") ||
strings.HasPrefix(hostname, "172.21.") ||
strings.HasPrefix(hostname, "172.22.") ||
strings.HasPrefix(hostname, "172.23.") ||
strings.HasPrefix(hostname, "172.24.") ||
strings.HasPrefix(hostname, "172.25.") ||
strings.HasPrefix(hostname, "172.26.") ||
strings.HasPrefix(hostname, "172.27.") ||
strings.HasPrefix(hostname, "172.28.") ||
strings.HasPrefix(hostname, "172.29.") ||
strings.HasPrefix(hostname, "172.30.") ||
strings.HasPrefix(hostname, "172.31.") {
return true
}
// IPv6 loopback and local
if strings.HasPrefix(hostname, "::1") ||
strings.HasPrefix(hostname, "fe80::") ||
strings.HasPrefix(hostname, "fc00::") ||
strings.HasPrefix(hostname, "fd00::") {
return true
}
return false

View file

@ -56,20 +56,3 @@ func (r *ReferrerRegistry) OverflowCount() int {
defer r.mu.RUnlock()
return r.overflowCount
}
// Contains checks if a source exists in the registry.
func (r *ReferrerRegistry) Contains(source string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.sources[source]
return exists
}
// Count returns the number of unique sources in the registry.
func (r *ReferrerRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.sources)
}

View file

@ -108,36 +108,6 @@ func TestExtractReferrerDomain(t *testing.T) {
siteDomain: "example.com",
want: "internal",
},
{
name: "private IP 172.16.x (RFC1918)",
referrer: "http://172.16.0.1/page",
siteDomain: "example.com",
want: "internal",
},
{
name: "private IP 172.31.x (RFC1918 upper bound)",
referrer: "http://172.31.255.1/page",
siteDomain: "example.com",
want: "internal",
},
{
name: "private IP 172.20.x (middle of range)",
referrer: "http://172.20.50.100/page",
siteDomain: "example.com",
want: "internal",
},
{
name: "public IP 172.15.x (just outside private range)",
referrer: "http://172.15.0.1/page",
siteDomain: "example.com",
want: "other", // not internal, but invalid TLD
},
{
name: "public IP 172.32.x (just outside private range)",
referrer: "http://172.32.0.1/page",
siteDomain: "example.com",
want: "other", // not internal, but invalid TLD
},
}
for _, tt := range tests {

View file

@ -195,9 +195,6 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) {
MaxSources: 50,
MaxCustomEvents: 10,
},
Server: config.ServerConfig{
StatePath: "/tmp/watchdog-test.state",
},
}
pathRegistry := aggregate.NewPathRegistry(cfg.Limits.MaxPaths)
@ -226,12 +223,12 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) {
}
// Verify state file was created
if _, err := os.Stat("/tmp/watchdog-test.state"); os.IsNotExist(err) {
if _, err := os.Stat("/tmp/watchdog-hll.state"); os.IsNotExist(err) {
t.Error("HLL state file was not created")
}
// Cleanup
os.Remove("/tmp/watchdog-test.state")
os.Remove("/tmp/watchdog-hll.state")
}
func TestEndToEnd_InvalidRequests(t *testing.T) {