mirror of
https://github.com/NotAShelf/watchdog.git
synced 2026-04-27 11:55:36 +00:00
Compare commits
21 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
3662adc61a |
|||
|
5fc6ef592f |
|||
|
c925cca321 |
|||
|
6fed378bb6 |
|||
|
d1d450ec11 |
|||
|
81e7168a61 |
|||
|
2cad815cd8 |
|||
|
9c8f91ef27 |
|||
|
ac24734e8f |
|||
|
fd3a832f7b |
|||
|
98611ca452 |
|||
|
42e1fe83c7 |
|||
|
ffa2af62be |
|||
|
d1181d38f0 |
|||
|
4189d14d65 |
|||
|
02c4f11619 |
|||
|
d2f28ded61 |
|||
|
0f38a062e9 |
|||
|
ad50debb62 |
|||
|
4cd0b3b0cf |
|||
|
83a7aac1c9 |
20 changed files with 712 additions and 227 deletions
13
.github/workflows/go.yml
vendored
13
.github/workflows/go.yml
vendored
|
|
@ -3,8 +3,21 @@ name: Build & Test with Go
|
|||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- '**.go'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/go.yml'
|
||||
- 'testdata/**'
|
||||
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- '**.go'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/go.yml'
|
||||
- 'testdata/**'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
|
|
|||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
|
|
@ -39,7 +39,7 @@ jobs:
|
|||
GOARCH: ${{ matrix.goarch }}
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: watchdog-${{ matrix.goos }}-${{ matrix.goarch }}
|
||||
path: watchdog-*
|
||||
|
|
@ -50,17 +50,17 @@ jobs:
|
|||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v8
|
||||
with:
|
||||
path: artifacts
|
||||
pattern: watchdog-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: artifacts/**/*.tar.gz artifacts/**/*.zip
|
||||
files: artifacts/watchdog-*
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
|
|
@ -25,18 +25,23 @@ type versionInfo struct {
|
|||
BuildDate string `json:"buildDate"`
|
||||
}
|
||||
|
||||
func getVersionInfo() versionInfo {
|
||||
data, err := os.ReadFile("version.json")
|
||||
if err != nil {
|
||||
return versionInfo{}
|
||||
}
|
||||
var v versionInfo
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return versionInfo{}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func getVersion() string {
|
||||
if version != "" {
|
||||
return version
|
||||
}
|
||||
data, err := os.ReadFile("version.json")
|
||||
if err != nil {
|
||||
return "dev"
|
||||
}
|
||||
var v versionInfo
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return "dev"
|
||||
}
|
||||
v := getVersionInfo()
|
||||
if v.Version != "" {
|
||||
return v.Version
|
||||
}
|
||||
|
|
@ -47,14 +52,7 @@ func getCommit() string {
|
|||
if commit != "" {
|
||||
return commit
|
||||
}
|
||||
data, err := os.ReadFile("version.json")
|
||||
if err != nil {
|
||||
return "none"
|
||||
}
|
||||
var v versionInfo
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return "none"
|
||||
}
|
||||
v := getVersionInfo()
|
||||
if v.Commit != "" {
|
||||
return v.Commit
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,8 +9,10 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
|
@ -20,6 +22,7 @@ import (
|
|||
"notashelf.dev/watchdog/internal/health"
|
||||
"notashelf.dev/watchdog/internal/limits"
|
||||
"notashelf.dev/watchdog/internal/normalize"
|
||||
"notashelf.dev/watchdog/internal/ratelimit"
|
||||
)
|
||||
|
||||
func Run(cfg *config.Config) error {
|
||||
|
|
@ -74,7 +77,7 @@ func Run(cfg *config.Config) error {
|
|||
// Setup routes
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Metrics endpoint with optional basic auth
|
||||
// Metrics endpoint with optional basic auth and rate limiting
|
||||
metricsHandler := promhttp.HandlerFor(promRegistry, promhttp.HandlerOpts{
|
||||
EnableOpenMetrics: true,
|
||||
})
|
||||
|
|
@ -87,6 +90,20 @@ func Run(cfg *config.Config) error {
|
|||
)
|
||||
}
|
||||
|
||||
// Add rate limiting to metrics endpoint
|
||||
if cfg.Limits.MaxMetricsPerMinute > 0 {
|
||||
metricsRateLimiter := ratelimit.NewTokenBucket(
|
||||
cfg.Limits.MaxMetricsPerMinute,
|
||||
cfg.Limits.MaxMetricsPerMinute,
|
||||
time.Minute,
|
||||
)
|
||||
metricsHandler = rateLimitMiddleware(metricsHandler, metricsRateLimiter)
|
||||
}
|
||||
|
||||
// Add response size limit to metrics endpoint (10MB max)
|
||||
const maxMetricsResponseSize = 10 * 1024 * 1024 // 10MB
|
||||
metricsHandler = responseSizeLimitMiddleware(metricsHandler, maxMetricsResponseSize)
|
||||
|
||||
mux.Handle(cfg.Server.MetricsPath, metricsHandler)
|
||||
|
||||
// Ingestion endpoint
|
||||
|
|
@ -167,6 +184,69 @@ func basicAuth(next http.Handler, username, password string) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
// Wraps a handler with rate limiting
|
||||
func rateLimitMiddleware(next http.Handler, limiter *ratelimit.TokenBucket) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Wraps http.ResponseWriter to enforce max response size
|
||||
type limitedResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
maxSize int
|
||||
written int
|
||||
limitExceeded bool
|
||||
}
|
||||
|
||||
func (w *limitedResponseWriter) Write(p []byte) (int, error) {
|
||||
if w.limitExceeded {
|
||||
return 0, fmt.Errorf("response size limit exceeded")
|
||||
}
|
||||
|
||||
if w.written+len(p) > w.maxSize {
|
||||
w.limitExceeded = true
|
||||
w.Header().Set("X-Response-Truncated", "true")
|
||||
http.Error(w.ResponseWriter, "Response size limit exceeded", http.StatusInternalServerError)
|
||||
return 0, fmt.Errorf("response size limit exceeded: %d bytes", w.maxSize)
|
||||
}
|
||||
n, err := w.ResponseWriter.Write(p)
|
||||
w.written += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Wraps a handler with response size limiting
|
||||
func responseSizeLimitMiddleware(next http.Handler, maxSize int) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
limited := &limitedResponseWriter{
|
||||
ResponseWriter: w,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
next.ServeHTTP(limited, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Sanitizes a path for logging to prevent log injection attacks. Uses `strconv.Quote`
|
||||
// to properly escape control characters and special bytes.
|
||||
func sanitizePathForLog(path string) string {
|
||||
escaped := strconv.Quote(path)
|
||||
if len(escaped) >= 2 && escaped[0] == '"' && escaped[len(escaped)-1] == '"' {
|
||||
escaped = escaped[1 : len(escaped)-1]
|
||||
}
|
||||
|
||||
// Limit path length to prevent log flooding
|
||||
const maxLen = 200
|
||||
if len(escaped) > maxLen {
|
||||
return escaped[:maxLen] + "..."
|
||||
}
|
||||
|
||||
return escaped
|
||||
}
|
||||
|
||||
// Creates a file server that only serves whitelisted files. Blocks dotfiles, .git, .env, etc.
|
||||
// TODO: I need to hook this up to eris somehow so I can just forward the paths that are being
|
||||
// scanned despite not being on a whitelist. Would be a good way of detecting scrapers, maybe.
|
||||
|
|
@ -179,7 +259,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
|
|||
// Block directory listings
|
||||
if strings.HasSuffix(path, "/") {
|
||||
blockedRequests.WithLabelValues("directory_listing").Inc()
|
||||
log.Printf("Blocked directory listing attempt: %s from %s", path, r.RemoteAddr)
|
||||
log.Printf("Blocked directory listing attempt: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
|
@ -188,7 +268,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
|
|||
for segment := range strings.SplitSeq(path, "/") {
|
||||
if strings.HasPrefix(segment, ".") {
|
||||
blockedRequests.WithLabelValues("dotfile").Inc()
|
||||
log.Printf("Blocked dotfile access: %s from %s", path, r.RemoteAddr)
|
||||
log.Printf("Blocked dotfile access: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
|
@ -199,7 +279,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
|
|||
strings.HasSuffix(lower, ".bak") ||
|
||||
strings.HasSuffix(lower, "~") {
|
||||
blockedRequests.WithLabelValues("sensitive_file").Inc()
|
||||
log.Printf("Blocked sensitive file access: %s from %s", path, r.RemoteAddr)
|
||||
log.Printf("Blocked sensitive file access: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
|
@ -209,7 +289,7 @@ func safeFileServer(root string, blockedRequests *prometheus.CounterVec) http.Ha
|
|||
ext := strings.ToLower(filepath.Ext(path))
|
||||
if ext != ".js" && ext != ".html" && ext != ".css" {
|
||||
blockedRequests.WithLabelValues("invalid_extension").Inc()
|
||||
log.Printf("Blocked invalid extension: %s from %s", path, r.RemoteAddr)
|
||||
log.Printf("Blocked invalid extension: %s from %s", sanitizePathForLog(path), r.RemoteAddr)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
116
cmd/watchdog/root_test.go
Normal file
116
cmd/watchdog/root_test.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package watchdog
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizePathForLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
maxLen int // expected max length check
|
||||
}{
|
||||
{
|
||||
name: "normal path",
|
||||
input: "/web/beacon.js",
|
||||
want: "/web/beacon.js",
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "path with newlines",
|
||||
input: "/test\nmalicious",
|
||||
want: `/test\nmalicious`,
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "path with carriage return",
|
||||
input: "/test\rmalicious",
|
||||
want: `/test\rmalicious`,
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "path with tabs",
|
||||
input: "/test\tmalicious",
|
||||
want: `/test\tmalicious`,
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "path with null bytes",
|
||||
input: "/test\x00malicious",
|
||||
want: `/test\x00malicious`,
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "path with quotes",
|
||||
input: `/test"malicious`,
|
||||
want: `/test\"malicious`,
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "path with backslash",
|
||||
input: `/test\malicious`,
|
||||
want: `/test\\malicious`,
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "control characters",
|
||||
input: "/test\x01\x02\x1fmalicious",
|
||||
want: `/test\x01\x02\x1fmalicious`,
|
||||
maxLen: 200,
|
||||
},
|
||||
{
|
||||
name: "truncation at 200 chars",
|
||||
input: "/" + strings.Repeat("a", 250),
|
||||
want: "/" + strings.Repeat("a", 199) + "...",
|
||||
maxLen: 203, // 200 chars + "..." = 203
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: "",
|
||||
maxLen: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := sanitizePathForLog(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("sanitizePathForLog(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
if len(got) > tt.maxLen {
|
||||
t.Errorf("sanitizePathForLog(%q) length = %d, exceeds max %d", tt.input, len(got), tt.maxLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizePathForLog_LogInjectionPrevention(t *testing.T) {
|
||||
// Log injection attempts should be neutralized
|
||||
maliciousPaths := []string{
|
||||
"/api\nINFO: fake log entry",
|
||||
"/test\r\nERROR: fake error",
|
||||
"/.git/config\x00", // null byte injection
|
||||
}
|
||||
|
||||
for _, path := range maliciousPaths {
|
||||
sanitized := sanitizePathForLog(path)
|
||||
// Check that newlines are escaped, not literal
|
||||
if strings.Contains(sanitized, "\n") || strings.Contains(sanitized, "\r") {
|
||||
t.Errorf("sanitizePathForLog(%q) contains literal newlines: %q", path, sanitized)
|
||||
}
|
||||
// Check that null bytes are escaped
|
||||
if strings.Contains(sanitized, "\x00") {
|
||||
t.Errorf("sanitizePathForLog(%q) contains literal null byte: %q", path, sanitized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSanitizePathForLog(b *testing.B) {
|
||||
path := "/test/path/with\nnewlines\rand\ttabs"
|
||||
for b.Loop() {
|
||||
_ = sanitizePathForLog(path)
|
||||
}
|
||||
}
|
||||
|
|
@ -63,6 +63,8 @@ limits:
|
|||
max_custom_events: 100
|
||||
# Maximum events per minute (rate limiting, 0 = unlimited)
|
||||
max_events_per_minute: 10000
|
||||
# Maximum metrics endpoint requests per minute (rate limiting, 0 = unlimited, default: 30)
|
||||
max_metrics_per_minute: 60
|
||||
|
||||
# Device classification breakpoints (screen width in pixels)
|
||||
device_breakpoints:
|
||||
|
|
@ -85,6 +87,8 @@ security:
|
|||
- "*" # Or specific domains: ["https://example.com", "https://www.example.com"]
|
||||
|
||||
# Basic authentication for /metrics endpoint
|
||||
# Password can also be set via environment variable:
|
||||
# - `$ export WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD=your-secret-password`
|
||||
metrics_auth:
|
||||
enabled: false
|
||||
username: "admin"
|
||||
|
|
|
|||
|
|
@ -6,12 +6,17 @@ Grafana.
|
|||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> **Why you need Prometheus:**
|
||||
> **Why you need a time-series database:**
|
||||
>
|
||||
> - Watchdog exposes _current state_ (counters, gauges)
|
||||
> - Prometheus _scrapes periodically_ and _stores time-series data_
|
||||
> - Grafana _visualizes_ the historical data from Prometheus
|
||||
> - A TSDB _scrapes periodically_ and _stores time-series data_
|
||||
> - Grafana _visualizes_ the historical data
|
||||
> - Grafana cannot directly scrape Prometheus `/metrics` endpoints
|
||||
>
|
||||
> **Compatible databases:**
|
||||
>
|
||||
> - [Prometheus](#prometheus-setup),
|
||||
> - [VictoriaMetrics](#victoriametrics), or any Prometheus-compatible scraper
|
||||
|
||||
## Prometheus Setup
|
||||
|
||||
|
|
@ -124,7 +129,7 @@ For multiple Watchdog instances:
|
|||
datasources.settings.datasources = [{
|
||||
name = "Prometheus";
|
||||
type = "prometheus";
|
||||
url = "http://localhost:9090";
|
||||
url = "http://localhost:9090"; # Or "http://localhost:8428" for VictoriaMetrics
|
||||
isDefault = true;
|
||||
}];
|
||||
};
|
||||
|
|
@ -228,7 +233,34 @@ sum by (instance) (rate(web_pageviews_total[5m]))
|
|||
|
||||
### VictoriaMetrics
|
||||
|
||||
Drop-in Prometheus replacement with better performance and compression:
|
||||
VictoriaMetrics is a fast, cost-effective monitoring solution and time-series
|
||||
database that is 100% compatible with Prometheus exposition format. Watchdog's
|
||||
`/metrics` endpoint can be scraped directly by VictoriaMetrics without requiring
|
||||
Prometheus.
|
||||
|
||||
#### Direct Scraping (Recommended)
|
||||
|
||||
VictoriaMetrics single-node mode can scrape Watchdog directly using standard
|
||||
Prometheus scrape configuration:
|
||||
|
||||
**Configuration file (`/etc/victoriametrics/scrape.yml`):**
|
||||
|
||||
```yaml
|
||||
scrape_configs:
|
||||
- job_name: "watchdog"
|
||||
static_configs:
|
||||
- targets: ["localhost:8080"]
|
||||
scrape_interval: 15s
|
||||
metrics_path: /metrics
|
||||
```
|
||||
|
||||
**Run VictoriaMetrics:**
|
||||
|
||||
```bash
|
||||
victoria-metrics -promscrape.config=/etc/victoriametrics/scrape.yml
|
||||
```
|
||||
|
||||
**NixOS configuration:**
|
||||
|
||||
```nix
|
||||
{
|
||||
|
|
@ -236,15 +268,84 @@ Drop-in Prometheus replacement with better performance and compression:
|
|||
enable = true;
|
||||
listenAddress = ":8428";
|
||||
retentionPeriod = "12month";
|
||||
|
||||
# Define scrape configs directly. 'prometheusConfig' is the configuration for
|
||||
# Prometheus-style metrics endpoints, which Watchdog exports.
|
||||
prometheusConfig = {
|
||||
scrape_configs = [
|
||||
{
|
||||
job_name = "watchdog";
|
||||
scrape_interval = "15s";
|
||||
static_configs = [{
|
||||
targets = [ "localhost:8080" ]; # replace the port
|
||||
}];
|
||||
}
|
||||
];
|
||||
};
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
#### Using `vmagent`
|
||||
|
||||
Alternatively, for distributed setups or when you need more advanced features
|
||||
like relabeling, you may use `vmagent`:
|
||||
|
||||
```nix
|
||||
{
|
||||
services.vmagent = {
|
||||
enable = true;
|
||||
remoteWriteUrl = "http://localhost:8428/api/v1/write";
|
||||
|
||||
prometheusConfig = {
|
||||
scrape_configs = [
|
||||
{
|
||||
job_name = "watchdog";
|
||||
static_configs = [{
|
||||
targets = [ "localhost:8080" ];
|
||||
}];
|
||||
}
|
||||
];
|
||||
};
|
||||
};
|
||||
|
||||
# Configure Prometheus to remote-write to VictoriaMetrics
|
||||
services.victoriametrics = {
|
||||
enable = true;
|
||||
listenAddress = ":8428";
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
#### Prometheus Remote Write
|
||||
|
||||
If you are migrating from Prometheus, or if you need PromQL compatibility, or if
|
||||
you just really like using Prometheus for some inexplicable reason you may keep
|
||||
Prometheus but use VictoriaMetrics to remote-write.
|
||||
|
||||
```nix
|
||||
{
|
||||
services.prometheus = {
|
||||
enable = true;
|
||||
port = 9090;
|
||||
|
||||
scrapeConfigs = [
|
||||
{
|
||||
job_name = "watchdog";
|
||||
static_configs = [{
|
||||
targets = [ "localhost:8080" ];
|
||||
}];
|
||||
}
|
||||
];
|
||||
|
||||
remoteWrite = [{
|
||||
url = "http://localhost:8428/api/v1/write";
|
||||
}];
|
||||
};
|
||||
|
||||
services.victoriametrics = {
|
||||
enable = true;
|
||||
listenAddress = ":8428";
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -274,10 +375,10 @@ metrics:
|
|||
|
||||
## Monitoring the Monitoring
|
||||
|
||||
Monitor Prometheus itself:
|
||||
Monitor your scraper:
|
||||
|
||||
```promql
|
||||
# Prometheus scrape success rate
|
||||
# Scrape success rate
|
||||
up{job="watchdog"}
|
||||
|
||||
# Scrape duration
|
||||
|
|
@ -287,14 +388,9 @@ scrape_duration_seconds{job="watchdog"}
|
|||
time() - timestamp(up{job="watchdog"})
|
||||
```
|
||||
|
||||
## Additional Recommendations
|
||||
For VictoriaMetrics, you can also monitor ingestion stats:
|
||||
|
||||
1. **Retention**: Set `--storage.tsdb.retention.time=30d` or longer based on
|
||||
disk space
|
||||
2. **Backups**: Back up `/var/lib/prometheus` periodically (or whatever your
|
||||
state directory is)
|
||||
3. **Alerting**: Configure Prometheus alerting rules for critical metrics
|
||||
4. **High Availability**: Run multiple Prometheus instances with identical
|
||||
configs
|
||||
5. **Remote Storage**: For long-term storage, use Thanos, Cortex, or
|
||||
VictoriaMetrics
|
||||
```bash
|
||||
# VM internal metrics
|
||||
curl http://localhost:8428/metrics | grep vm_rows_inserted_total
|
||||
```
|
||||
|
|
|
|||
|
|
@ -54,3 +54,20 @@ func (r *CustomEventRegistry) OverflowCount() int {
|
|||
defer r.mu.RUnlock()
|
||||
return r.overflowCount
|
||||
}
|
||||
|
||||
// Contains checks if an event name exists in the registry.
|
||||
func (r *CustomEventRegistry) Contains(eventName string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, exists := r.events[eventName]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Count returns the number of unique events in the registry.
|
||||
func (r *CustomEventRegistry) Count() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.events)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -162,8 +162,12 @@ func (m *MetricsAggregator) RecordPageview(path, country, device, referrer, doma
|
|||
labels := prometheus.Labels{"path": sanitizeLabel(path)}
|
||||
|
||||
if m.cfg.Site.Collect.Country {
|
||||
if country == "" {
|
||||
labels["country"] = "unknown"
|
||||
} else {
|
||||
labels["country"] = sanitizeLabel(country)
|
||||
}
|
||||
}
|
||||
|
||||
if m.cfg.Site.Collect.Device {
|
||||
labels["device"] = sanitizeLabel(device)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -17,15 +18,18 @@ type UniquesEstimator struct {
|
|||
hll *hyperloglog.Sketch
|
||||
salt string
|
||||
rotation string // "daily" or "hourly"
|
||||
saltKey string // cached time key to avoid regeneration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Creates a new unique visitor estimator
|
||||
func NewUniquesEstimator(rotation string) *UniquesEstimator {
|
||||
now := time.Now()
|
||||
return &UniquesEstimator{
|
||||
hll: hyperloglog.New(),
|
||||
salt: generateSalt(time.Now(), rotation),
|
||||
salt: generateSalt(now, rotation),
|
||||
rotation: rotation,
|
||||
saltKey: getSaltKey(now, rotation),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -36,11 +40,13 @@ func (u *UniquesEstimator) Add(ip, userAgent string) {
|
|||
defer u.mu.Unlock()
|
||||
|
||||
// Check if we need to rotate to a new period
|
||||
currentSalt := generateSalt(time.Now(), u.rotation)
|
||||
if currentSalt != u.salt {
|
||||
now := time.Now()
|
||||
currentKey := getSaltKey(now, u.rotation)
|
||||
if currentKey != u.saltKey {
|
||||
// Reset HLL for new period
|
||||
u.hll = hyperloglog.New()
|
||||
u.salt = currentSalt
|
||||
u.salt = generateSaltFromKey(currentKey)
|
||||
u.saltKey = currentKey
|
||||
}
|
||||
|
||||
// Hash visitor with salt to prevent cross-period tracking
|
||||
|
|
@ -55,24 +61,36 @@ func (u *UniquesEstimator) Estimate() uint64 {
|
|||
return u.hll.Estimate()
|
||||
}
|
||||
|
||||
// Generates a deterministic salt based on the rotation mode
|
||||
// Daily: same day = same salt, different day = different salt
|
||||
// Hourly: same hour = same salt, different hour = different salt
|
||||
func generateSalt(t time.Time, rotation string) string {
|
||||
var key string
|
||||
// Returns the time-based key for salt generation without hashing
|
||||
func getSaltKey(t time.Time, rotation string) string {
|
||||
if rotation == "hourly" {
|
||||
key = t.UTC().Format("2006-01-02T15")
|
||||
} else {
|
||||
key = t.UTC().Format("2006-01-02")
|
||||
return t.UTC().Format("2006-01-02T15")
|
||||
}
|
||||
return t.UTC().Format("2006-01-02")
|
||||
}
|
||||
|
||||
// Creates a salt from a pre-computed key
|
||||
func generateSaltFromKey(key string) string {
|
||||
h := sha256.Sum256([]byte("watchdog-salt-" + key))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// Generates a deterministic salt based on the rotation mode
|
||||
// Daily: same day = same salt, different day = different salt
|
||||
// Hourly: same hour = same salt, different hour = different salt
|
||||
func generateSalt(t time.Time, rotation string) string {
|
||||
return generateSaltFromKey(getSaltKey(t, rotation))
|
||||
}
|
||||
|
||||
// Creates a privacy-preserving hash of visitor identity
|
||||
func hashVisitor(ip, userAgent, salt string) string {
|
||||
combined := ip + "|" + userAgent + "|" + salt
|
||||
h := sha256.Sum256([]byte(combined))
|
||||
var sb strings.Builder
|
||||
sb.WriteString(ip)
|
||||
sb.WriteString("|")
|
||||
sb.WriteString(userAgent)
|
||||
sb.WriteString("|")
|
||||
sb.WriteString(salt)
|
||||
h := sha256.Sum256([]byte(sb.String()))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
|
|
@ -122,16 +140,20 @@ func (u *UniquesEstimator) Load(path string) error {
|
|||
}
|
||||
|
||||
savedSalt := string(parts[0])
|
||||
currentSalt := generateSalt(time.Now(), u.rotation)
|
||||
now := time.Now()
|
||||
currentKey := getSaltKey(now, u.rotation)
|
||||
currentSalt := generateSaltFromKey(currentKey)
|
||||
|
||||
// Only restore if it's the same period
|
||||
if savedSalt == currentSalt {
|
||||
u.salt = savedSalt
|
||||
u.saltKey = currentKey
|
||||
return u.hll.UnmarshalBinary(parts[1])
|
||||
}
|
||||
|
||||
// Different period, start fresh
|
||||
u.hll = hyperloglog.New()
|
||||
u.salt = currentSalt
|
||||
u.saltKey = currentKey
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
|
||||
"notashelf.dev/watchdog/internal/limits"
|
||||
)
|
||||
|
|
@ -40,40 +39,8 @@ func ParseEvent(body io.Reader) (*Event, error) {
|
|||
return &event, nil
|
||||
}
|
||||
|
||||
// Validate checks if the event is valid for the given domains
|
||||
func (e *Event) Validate(allowedDomains []string) error {
|
||||
if e.Domain == "" {
|
||||
return fmt.Errorf("domain required")
|
||||
}
|
||||
|
||||
// Check if domain is in allowed list
|
||||
allowed := slices.Contains(allowedDomains, e.Domain)
|
||||
if !allowed {
|
||||
return fmt.Errorf("domain not allowed")
|
||||
}
|
||||
|
||||
if e.Path == "" {
|
||||
return fmt.Errorf("path required")
|
||||
}
|
||||
|
||||
if len(e.Path) > limits.MaxPathLen {
|
||||
return fmt.Errorf("path too long")
|
||||
}
|
||||
|
||||
if len(e.Referrer) > limits.MaxRefLen {
|
||||
return fmt.Errorf("referrer too long")
|
||||
}
|
||||
|
||||
// Validate screen width is in reasonable range
|
||||
if e.Width < 0 || e.Width > limits.MaxWidth {
|
||||
return fmt.Errorf("invalid width")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateWithMap checks if the event is valid using a domain map (O(1) lookup)
|
||||
func (e *Event) ValidateWithMap(allowedDomains map[string]bool) error {
|
||||
// Validate checks if the event is valid using a domain map
|
||||
func (e *Event) Validate(allowedDomains map[string]bool) error {
|
||||
if e.Domain == "" {
|
||||
return fmt.Errorf("domain required")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
event Event
|
||||
domains []string
|
||||
domains map[string]bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
|
|
@ -124,7 +124,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "example.com",
|
||||
Path: "/home",
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
|
|
@ -134,7 +134,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Path: "/signup",
|
||||
Event: "signup",
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
|
|
@ -143,7 +143,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "wrong.com",
|
||||
Path: "/home",
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
|
@ -152,7 +152,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "",
|
||||
Path: "/home",
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
|
@ -161,7 +161,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "example.com",
|
||||
Path: "",
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
|
@ -170,7 +170,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "example.com",
|
||||
Path: "/" + strings.Repeat("a", 3000),
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
|
@ -180,7 +180,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Path: "/home",
|
||||
Referrer: strings.Repeat("a", 3000),
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
|
@ -189,7 +189,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "example.com",
|
||||
Path: "/" + strings.Repeat("a", 2000),
|
||||
},
|
||||
domains: []string{"example.com"},
|
||||
domains: map[string]bool{"example.com": true},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
|
|
@ -198,7 +198,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "site1.com",
|
||||
Path: "/home",
|
||||
},
|
||||
domains: []string{"site1.com", "site2.com"},
|
||||
domains: map[string]bool{"site1.com": true, "site2.com": true},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
|
|
@ -207,7 +207,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "site2.com",
|
||||
Path: "/about",
|
||||
},
|
||||
domains: []string{"site1.com", "site2.com"},
|
||||
domains: map[string]bool{"site1.com": true, "site2.com": true},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
|
|
@ -216,7 +216,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
Domain: "site3.com",
|
||||
Path: "/home",
|
||||
},
|
||||
domains: []string{"site1.com", "site2.com"},
|
||||
domains: map[string]bool{"site1.com": true, "site2.com": true},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
|
@ -231,24 +231,7 @@ func TestValidateEvent(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidate_SliceLookup(b *testing.B) {
|
||||
// Simulate multi-site with 50 domains
|
||||
domains := make([]string, 50)
|
||||
for i := range 50 {
|
||||
domains[i] = strings.Repeat("site", i) + ".com"
|
||||
}
|
||||
|
||||
event := Event{
|
||||
Domain: domains[49], // Worst case - last in list
|
||||
Path: "/test",
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
_ = event.Validate(domains)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidate_MapLookup(b *testing.B) {
|
||||
func BenchmarkValidate(b *testing.B) {
|
||||
// Simulate multi-site with 50 domains
|
||||
domainMap := make(map[string]bool, 50)
|
||||
for i := range 50 {
|
||||
|
|
@ -256,11 +239,11 @@ func BenchmarkValidate_MapLookup(b *testing.B) {
|
|||
}
|
||||
|
||||
event := Event{
|
||||
Domain: strings.Repeat("site", 49) + ".com", // any position
|
||||
Domain: strings.Repeat("site", 49) + ".com",
|
||||
Path: "/test",
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
_ = event.ValidateWithMap(domainMap)
|
||||
_ = event.Validate(domainMap)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"notashelf.dev/watchdog/internal/aggregate"
|
||||
"notashelf.dev/watchdog/internal/config"
|
||||
|
|
@ -16,12 +20,14 @@ import (
|
|||
type IngestionHandler struct {
|
||||
cfg *config.Config
|
||||
domainMap map[string]bool // O(1) domain validation
|
||||
corsOriginMap map[string]bool // O(1) CORS origin validation
|
||||
pathNorm *normalize.PathNormalizer
|
||||
pathRegistry *aggregate.PathRegistry
|
||||
refRegistry *normalize.ReferrerRegistry
|
||||
metricsAgg *aggregate.MetricsAggregator
|
||||
rateLimiter *ratelimit.TokenBucket
|
||||
rng *rand.Rand
|
||||
rng *mrand.Rand
|
||||
trustedNetworks []*net.IPNet // pre-parsed CIDR networks
|
||||
}
|
||||
|
||||
// Creates a new ingestion handler
|
||||
|
|
@ -47,19 +53,51 @@ func NewIngestionHandler(
|
|||
domainMap[domain] = true
|
||||
}
|
||||
|
||||
// Build CORS origin map for O(1) lookup
|
||||
corsOriginMap := make(map[string]bool, len(cfg.Security.CORS.AllowedOrigins))
|
||||
for _, origin := range cfg.Security.CORS.AllowedOrigins {
|
||||
corsOriginMap[origin] = true
|
||||
}
|
||||
|
||||
// Pre-parse trusted proxy CIDRs to avoid re-parsing on each request
|
||||
trustedNetworks := make([]*net.IPNet, 0, len(cfg.Security.TrustedProxies))
|
||||
for _, cidr := range cfg.Security.TrustedProxies {
|
||||
if _, network, err := net.ParseCIDR(cidr); err == nil {
|
||||
trustedNetworks = append(trustedNetworks, network)
|
||||
} else if ip := net.ParseIP(cidr); ip != nil {
|
||||
// Single IP - create a /32 or /128 network
|
||||
var mask net.IPMask
|
||||
if ip.To4() != nil {
|
||||
mask = net.CIDRMask(32, 32)
|
||||
} else {
|
||||
mask = net.CIDRMask(128, 128)
|
||||
}
|
||||
trustedNetworks = append(trustedNetworks, &net.IPNet{IP: ip, Mask: mask})
|
||||
}
|
||||
}
|
||||
|
||||
return &IngestionHandler{
|
||||
cfg: cfg,
|
||||
domainMap: domainMap,
|
||||
corsOriginMap: corsOriginMap,
|
||||
pathNorm: pathNorm,
|
||||
pathRegistry: pathRegistry,
|
||||
refRegistry: refRegistry,
|
||||
metricsAgg: metricsAgg,
|
||||
rateLimiter: limiter,
|
||||
rng: rand.New(rand.NewSource(42)),
|
||||
rng: mrand.New(mrand.NewSource(time.Now().UnixNano())),
|
||||
trustedNetworks: trustedNetworks,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate or extract request ID for tracing
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
// Handle CORS preflight
|
||||
if r.Method == http.MethodOptions {
|
||||
h.handleCORS(w, r)
|
||||
|
|
@ -103,8 +141,8 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Validate event via map lookup (also O(1))
|
||||
if err := event.ValidateWithMap(h.domainMap); err != nil {
|
||||
// Validate event via map lookup
|
||||
if err := event.Validate(h.domainMap); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
|
@ -181,13 +219,8 @@ func (h *IngestionHandler) handleCORS(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Check if origin is allowed
|
||||
allowed := false
|
||||
for _, allowedOrigin := range h.cfg.Security.CORS.AllowedOrigins {
|
||||
if allowedOrigin == "*" || allowedOrigin == origin {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// This uses map so that it's O(1)
|
||||
allowed := h.corsOriginMap["*"] || h.corsOriginMap[origin]
|
||||
|
||||
if allowed {
|
||||
if origin == "*" {
|
||||
|
|
@ -213,13 +246,8 @@ func (h *IngestionHandler) extractIP(r *http.Request) string {
|
|||
|
||||
// Check if we should trust proxy headers
|
||||
trustProxy := false
|
||||
if len(h.cfg.Security.TrustedProxies) > 0 {
|
||||
for _, trustedCIDR := range h.cfg.Security.TrustedProxies {
|
||||
if h.ipInCIDR(remoteIP, trustedCIDR) {
|
||||
trustProxy = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(h.trustedNetworks) > 0 {
|
||||
trustProxy = h.ipInNetworks(remoteIP, h.trustedNetworks)
|
||||
}
|
||||
|
||||
// If not trusting proxy, return direct IP
|
||||
|
|
@ -233,42 +261,43 @@ func (h *IngestionHandler) extractIP(r *http.Request) string {
|
|||
ips := strings.Split(xff, ",")
|
||||
for i := len(ips) - 1; i >= 0; i-- {
|
||||
ip := strings.TrimSpace(ips[i])
|
||||
if !h.ipInCIDR(ip, "0.0.0.0/0") {
|
||||
continue
|
||||
}
|
||||
if testIP := net.ParseIP(ip); testIP != nil {
|
||||
// Only accept this IP if it's NOT from a trusted proxy
|
||||
if !h.ipInNetworks(ip, h.trustedNetworks) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
// Validate the IP format and ensure it's not from a trusted proxy
|
||||
if testIP := net.ParseIP(xri); testIP != nil {
|
||||
if !h.ipInNetworks(xri, h.trustedNetworks) {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Checks if an IP address is within a CIDR range
|
||||
func (h *IngestionHandler) ipInCIDR(ip, cidr string) bool {
|
||||
// Parse the IP address
|
||||
// Checks if an IP address is within any of the trusted networks
|
||||
func (h *IngestionHandler) ipInNetworks(ip string, networks []*net.IPNet) bool {
|
||||
testIP := net.ParseIP(ip)
|
||||
if testIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the CIDR
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
// If it's not a CIDR, try as a single IP
|
||||
cidrIP := net.ParseIP(cidr)
|
||||
if cidrIP == nil {
|
||||
return false
|
||||
for _, network := range networks {
|
||||
if network.Contains(testIP) {
|
||||
return true
|
||||
}
|
||||
return testIP.Equal(cidrIP)
|
||||
}
|
||||
|
||||
return network.Contains(testIP)
|
||||
return false
|
||||
}
|
||||
|
||||
// Classifies device using both screen width and User-Agent parsing
|
||||
|
|
@ -311,3 +340,17 @@ func (h *IngestionHandler) classifyDevice(width int, userAgent string) string {
|
|||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Creates a unique request ID for tracing.
|
||||
// Uses 8 bytes (64 bits) of randomness which produces 16 hex characters.
|
||||
// 2^64 possible IDs (~18 quintillion) provides sufficient uniqueness for
|
||||
// request tracing while keeping IDs reasonably short in logs and headers.
|
||||
func generateRequestID() string {
|
||||
// 8 bytes = 64 bits = 16 hex chars = 2^64 possible IDs
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to timestamp if crypto/rand fails
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -218,6 +218,107 @@ func newTestHandler(cfg *config.Config) *IngestionHandler {
|
|||
return NewIngestionHandler(cfg, pathNorm, pathRegistry, refRegistry, metricsAgg)
|
||||
}
|
||||
|
||||
func TestExtractIP(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Site: config.SiteConfig{
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
Limits: config.LimitsConfig{
|
||||
MaxPaths: 100,
|
||||
MaxSources: 50,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
TrustedProxies: []string{"10.0.0.0/8", "192.168.1.1"},
|
||||
},
|
||||
}
|
||||
h := newTestHandler(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
headers map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "direct connection no proxy",
|
||||
remoteAddr: "192.168.1.100:12345",
|
||||
headers: map[string]string{},
|
||||
want: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with X-Forwarded-For",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "203.0.113.1, 10.0.0.5",
|
||||
},
|
||||
want: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "trusted proxy with X-Real-IP",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Real-IP": "203.0.113.2",
|
||||
},
|
||||
want: "203.0.113.2",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP from trusted network should be ignored",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Real-IP": "10.0.0.50", // trusted network, should fall back
|
||||
},
|
||||
want: "10.0.0.1", // falls back to remoteAddr
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP invalid IP should be ignored",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Real-IP": "not-an-ip",
|
||||
},
|
||||
want: "10.0.0.1", // falls back
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy X-Forwarded-For ignored",
|
||||
remoteAddr: "203.0.113.50:12345",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "1.2.3.4",
|
||||
},
|
||||
want: "203.0.113.50", // uses remoteAddr, ignores header
|
||||
},
|
||||
{
|
||||
name: "untrusted proxy X-Real-IP ignored",
|
||||
remoteAddr: "203.0.113.50:12345",
|
||||
headers: map[string]string{
|
||||
"X-Real-IP": "1.2.3.4",
|
||||
},
|
||||
want: "203.0.113.50", // uses remoteAddr, ignores header
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For all trusted falls back",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "10.0.0.2, 10.0.0.3",
|
||||
},
|
||||
want: "10.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/event", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
for k, v := range tt.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
got := h.extractIP(req)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractIP() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyDevice_UA(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Limits: config.LimitsConfig{
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ type LimitsConfig struct {
|
|||
MaxSources int `yaml:"max_sources"`
|
||||
MaxCustomEvents int `yaml:"max_custom_events"`
|
||||
DeviceBreakpoints DeviceBreaks `yaml:"device_breakpoints"`
|
||||
MaxMetricsPerMinute int `yaml:"max_metrics_per_minute"` // rate limit for metrics endpoint
|
||||
}
|
||||
|
||||
// Device classification breakpoints
|
||||
|
|
@ -72,10 +73,11 @@ type CORSConfig struct {
|
|||
}
|
||||
|
||||
// Authentication for metrics endpoint
|
||||
// Password can be set via environment variable: WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD
|
||||
type AuthConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
Password string `yaml:"password"` // can use env var WATCHDOG_SECURITY_METRICS_AUTH_PASSWORD
|
||||
}
|
||||
|
||||
// Server endpoints
|
||||
|
|
@ -149,6 +151,10 @@ func (c *Config) Validate() error {
|
|||
c.Limits.MaxCustomEvents = 100 // Default
|
||||
}
|
||||
|
||||
if c.Limits.MaxMetricsPerMinute <= 0 {
|
||||
c.Limits.MaxMetricsPerMinute = 30 // Default: 30 requests per minute
|
||||
}
|
||||
|
||||
if c.Site.Path.MaxSegments < 0 {
|
||||
return fmt.Errorf("site.path.max_segments cannot be negative")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,38 +48,36 @@ func (n *PathNormalizer) Normalize(path string) string {
|
|||
path = "/" + path
|
||||
}
|
||||
|
||||
// Process segments in-place to minimize allocations
|
||||
// Split into segments, first element is *always* empty for paths starting with '/'
|
||||
segments := strings.Split(path, "/")
|
||||
if len(segments) > 0 && segments[0] == "" {
|
||||
segments = segments[1:]
|
||||
|
||||
// Process segments in a single pass: remove empty, resolve . and ..
|
||||
writeIdx := 0
|
||||
for i := 0; i < len(segments); i++ {
|
||||
seg := segments[i]
|
||||
|
||||
// Skip empty segments (from double slashes or leading /)
|
||||
if seg == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove empty segments (from double slashes)
|
||||
filtered := make([]string, 0, len(segments))
|
||||
for _, seg := range segments {
|
||||
if seg != "" {
|
||||
filtered = append(filtered, seg)
|
||||
}
|
||||
}
|
||||
segments = filtered
|
||||
|
||||
// Resolve . and .. segments
|
||||
resolved := make([]string, 0, len(segments))
|
||||
for _, seg := range segments {
|
||||
if seg == "." {
|
||||
// Skip current directory
|
||||
continue
|
||||
} else if seg == ".." {
|
||||
// Go up one level if possible
|
||||
if len(resolved) > 0 {
|
||||
resolved = resolved[:len(resolved)-1]
|
||||
if writeIdx > 0 {
|
||||
writeIdx--
|
||||
}
|
||||
// If already at root, skip ..
|
||||
} else {
|
||||
resolved = append(resolved, seg)
|
||||
// Keep this segment
|
||||
segments[writeIdx] = seg
|
||||
writeIdx++
|
||||
}
|
||||
}
|
||||
segments = resolved
|
||||
segments = segments[:writeIdx]
|
||||
|
||||
// Collapse numeric segments
|
||||
if n.cfg.CollapseNumericSegments {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package normalize
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
|
|
@ -21,35 +22,21 @@ func isInternalHost(hostname string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// Check if hostname is an IP address
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
// Private IPv4 ranges (RFC1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.16.") ||
|
||||
strings.HasPrefix(hostname, "172.17.") ||
|
||||
strings.HasPrefix(hostname, "172.18.") ||
|
||||
strings.HasPrefix(hostname, "172.19.") ||
|
||||
strings.HasPrefix(hostname, "172.20.") ||
|
||||
strings.HasPrefix(hostname, "172.21.") ||
|
||||
strings.HasPrefix(hostname, "172.22.") ||
|
||||
strings.HasPrefix(hostname, "172.23.") ||
|
||||
strings.HasPrefix(hostname, "172.24.") ||
|
||||
strings.HasPrefix(hostname, "172.25.") ||
|
||||
strings.HasPrefix(hostname, "172.26.") ||
|
||||
strings.HasPrefix(hostname, "172.27.") ||
|
||||
strings.HasPrefix(hostname, "172.28.") ||
|
||||
strings.HasPrefix(hostname, "172.29.") ||
|
||||
strings.HasPrefix(hostname, "172.30.") ||
|
||||
strings.HasPrefix(hostname, "172.31.") {
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
|
||||
// IPv6 loopback and local
|
||||
if strings.HasPrefix(hostname, "::1") ||
|
||||
strings.HasPrefix(hostname, "fe80::") ||
|
||||
strings.HasPrefix(hostname, "fc00::") ||
|
||||
strings.HasPrefix(hostname, "fd00::") {
|
||||
// Additional localhost checks for IP formats
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
// Link-local addresses
|
||||
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,3 +56,20 @@ func (r *ReferrerRegistry) OverflowCount() int {
|
|||
defer r.mu.RUnlock()
|
||||
return r.overflowCount
|
||||
}
|
||||
|
||||
// Contains checks if a source exists in the registry.
|
||||
func (r *ReferrerRegistry) Contains(source string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, exists := r.sources[source]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Count returns the number of unique sources in the registry.
|
||||
func (r *ReferrerRegistry) Count() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.sources)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,6 +108,36 @@ func TestExtractReferrerDomain(t *testing.T) {
|
|||
siteDomain: "example.com",
|
||||
want: "internal",
|
||||
},
|
||||
{
|
||||
name: "private IP 172.16.x (RFC1918)",
|
||||
referrer: "http://172.16.0.1/page",
|
||||
siteDomain: "example.com",
|
||||
want: "internal",
|
||||
},
|
||||
{
|
||||
name: "private IP 172.31.x (RFC1918 upper bound)",
|
||||
referrer: "http://172.31.255.1/page",
|
||||
siteDomain: "example.com",
|
||||
want: "internal",
|
||||
},
|
||||
{
|
||||
name: "private IP 172.20.x (middle of range)",
|
||||
referrer: "http://172.20.50.100/page",
|
||||
siteDomain: "example.com",
|
||||
want: "internal",
|
||||
},
|
||||
{
|
||||
name: "public IP 172.15.x (just outside private range)",
|
||||
referrer: "http://172.15.0.1/page",
|
||||
siteDomain: "example.com",
|
||||
want: "other", // not internal, but invalid TLD
|
||||
},
|
||||
{
|
||||
name: "public IP 172.32.x (just outside private range)",
|
||||
referrer: "http://172.32.0.1/page",
|
||||
siteDomain: "example.com",
|
||||
want: "other", // not internal, but invalid TLD
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
|
|||
|
|
@ -195,6 +195,9 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) {
|
|||
MaxSources: 50,
|
||||
MaxCustomEvents: 10,
|
||||
},
|
||||
Server: config.ServerConfig{
|
||||
StatePath: "/tmp/watchdog-test.state",
|
||||
},
|
||||
}
|
||||
|
||||
pathRegistry := aggregate.NewPathRegistry(cfg.Limits.MaxPaths)
|
||||
|
|
@ -223,12 +226,12 @@ func TestEndToEnd_GracefulShutdown(t *testing.T) {
|
|||
}
|
||||
|
||||
// Verify state file was created
|
||||
if _, err := os.Stat("/tmp/watchdog-hll.state"); os.IsNotExist(err) {
|
||||
if _, err := os.Stat("/tmp/watchdog-test.state"); os.IsNotExist(err) {
|
||||
t.Error("HLL state file was not created")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
os.Remove("/tmp/watchdog-hll.state")
|
||||
os.Remove("/tmp/watchdog-test.state")
|
||||
}
|
||||
|
||||
func TestEndToEnd_InvalidRequests(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue