diff --git a/cmd/watchdog/main.go b/cmd/watchdog/main.go index 221328a..fd81b9e 100644 --- a/cmd/watchdog/main.go +++ b/cmd/watchdog/main.go @@ -25,18 +25,23 @@ type versionInfo struct { BuildDate string `json:"buildDate"` } +func getVersionInfo() versionInfo { + data, err := os.ReadFile("version.json") + if err != nil { + return versionInfo{} + } + var v versionInfo + if err := json.Unmarshal(data, &v); err != nil { + return versionInfo{} + } + return v +} + func getVersion() string { if version != "" { return version } - data, err := os.ReadFile("version.json") - if err != nil { - return "dev" - } - var v versionInfo - if err := json.Unmarshal(data, &v); err != nil { - return "dev" - } + v := getVersionInfo() if v.Version != "" { return v.Version } @@ -47,14 +52,7 @@ func getCommit() string { if commit != "" { return commit } - data, err := os.ReadFile("version.json") - if err != nil { - return "none" - } - var v versionInfo - if err := json.Unmarshal(data, &v); err != nil { - return "none" - } + v := getVersionInfo() if v.Commit != "" { return v.Commit } diff --git a/internal/api/handler.go b/internal/api/handler.go index 4cb3905..851c0ea 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -5,6 +5,7 @@ import ( "net" "net/http" "strings" + "time" "notashelf.dev/watchdog/internal/aggregate" "notashelf.dev/watchdog/internal/config" @@ -14,14 +15,16 @@ import ( // Handles incoming analytics events type IngestionHandler struct { - cfg *config.Config - domainMap map[string]bool // O(1) domain validation - pathNorm *normalize.PathNormalizer - pathRegistry *aggregate.PathRegistry - refRegistry *normalize.ReferrerRegistry - metricsAgg *aggregate.MetricsAggregator - rateLimiter *ratelimit.TokenBucket - rng *rand.Rand + cfg *config.Config + domainMap map[string]bool // O(1) domain validation + corsOriginMap map[string]bool // O(1) CORS origin validation + pathNorm *normalize.PathNormalizer + pathRegistry *aggregate.PathRegistry + refRegistry *normalize.ReferrerRegistry + metricsAgg *aggregate.MetricsAggregator + rateLimiter *ratelimit.TokenBucket + rng *rand.Rand + trustedNetworks []*net.IPNet // Pre-parsed CIDR networks } // Creates a new ingestion handler @@ -47,15 +50,40 @@ func NewIngestionHandler( domainMap[domain] = true } + // Build CORS origin map for O(1) lookup + corsOriginMap := make(map[string]bool, len(cfg.Security.CORS.AllowedOrigins)) + for _, origin := range cfg.Security.CORS.AllowedOrigins { + corsOriginMap[origin] = true + } + + // Pre-parse trusted proxy CIDRs to avoid re-parsing on each request + trustedNetworks := make([]*net.IPNet, 0, len(cfg.Security.TrustedProxies)) + for _, cidr := range cfg.Security.TrustedProxies { + if _, network, err := net.ParseCIDR(cidr); err == nil { + trustedNetworks = append(trustedNetworks, network) + } else if ip := net.ParseIP(cidr); ip != nil { + // Single IP - create a /32 or /128 network + var mask net.IPMask + if ip.To4() != nil { + mask = net.CIDRMask(32, 32) + } else { + mask = net.CIDRMask(128, 128) + } + trustedNetworks = append(trustedNetworks, &net.IPNet{IP: ip, Mask: mask}) + } + } + return &IngestionHandler{ - cfg: cfg, - domainMap: domainMap, - pathNorm: pathNorm, - pathRegistry: pathRegistry, - refRegistry: refRegistry, - metricsAgg: metricsAgg, - rateLimiter: limiter, - rng: rand.New(rand.NewSource(42)), + cfg: cfg, + domainMap: domainMap, + corsOriginMap: corsOriginMap, + pathNorm: pathNorm, + pathRegistry: pathRegistry, + refRegistry: refRegistry, + metricsAgg: metricsAgg, + rateLimiter: limiter, + rng: rand.New(rand.NewSource(time.Now().UnixNano())), + trustedNetworks: trustedNetworks, } } @@ -213,13 +241,8 @@ func (h *IngestionHandler) extractIP(r *http.Request) string { // Check if we should trust proxy headers trustProxy := false - if len(h.cfg.Security.TrustedProxies) > 0 { - for _, trustedCIDR := range h.cfg.Security.TrustedProxies { - if h.ipInCIDR(remoteIP, trustedCIDR) { - trustProxy = true - break - } - } + if len(h.trustedNetworks) > 0 { + trustProxy = h.ipInNetworks(remoteIP, h.trustedNetworks) } // If not trusting proxy, return direct IP @@ -233,10 +256,9 @@ func (h *IngestionHandler) extractIP(r *http.Request) string { ips := strings.Split(xff, ",") for i := len(ips) - 1; i >= 0; i-- { ip := strings.TrimSpace(ips[i]) - if !h.ipInCIDR(ip, "0.0.0.0/0") { - continue + if testIP := net.ParseIP(ip); testIP != nil { + return ip } - return ip } } @@ -249,26 +271,20 @@ func (h *IngestionHandler) extractIP(r *http.Request) string { return remoteIP } -// Checks if an IP address is within a CIDR range -func (h *IngestionHandler) ipInCIDR(ip, cidr string) bool { - // Parse the IP address +// Checks if an IP address is within any of the trusted networks +func (h *IngestionHandler) ipInNetworks(ip string, networks []*net.IPNet) bool { testIP := net.ParseIP(ip) if testIP == nil { return false } - // Parse the CIDR - _, network, err := net.ParseCIDR(cidr) - if err != nil { - // If it's not a CIDR, try as a single IP - cidrIP := net.ParseIP(cidr) - if cidrIP == nil { - return false + for _, network := range networks { + if network.Contains(testIP) { + return true } - return testIP.Equal(cidrIP) } - return network.Contains(testIP) + return false } // Classifies device using both screen width and User-Agent parsing