internal/api: resolve IPv6 handling; prevent XFF spoofing & add rate limiting
Signed-off-by: NotAShelf <raf@notashelf.dev> Change-Id: Ibe415a133bbc8bd533a21ed1ccd44cf36a6a6964
This commit is contained in:
parent
8187608b38
commit
7e1ef845e8
1 changed files with 148 additions and 25 deletions
|
|
@ -1,13 +1,15 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"math/rand"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"notashelf.dev/watchdog/internal/aggregate"
|
"notashelf.dev/watchdog/internal/aggregate"
|
||||||
"notashelf.dev/watchdog/internal/config"
|
"notashelf.dev/watchdog/internal/config"
|
||||||
"notashelf.dev/watchdog/internal/normalize"
|
"notashelf.dev/watchdog/internal/normalize"
|
||||||
|
"notashelf.dev/watchdog/internal/ratelimit"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handles incoming analytics events
|
// Handles incoming analytics events
|
||||||
|
|
@ -17,6 +19,8 @@ type IngestionHandler struct {
|
||||||
pathRegistry *aggregate.PathRegistry
|
pathRegistry *aggregate.PathRegistry
|
||||||
refRegistry *normalize.ReferrerRegistry
|
refRegistry *normalize.ReferrerRegistry
|
||||||
metricsAgg *aggregate.MetricsAggregator
|
metricsAgg *aggregate.MetricsAggregator
|
||||||
|
rateLimiter *ratelimit.TokenBucket
|
||||||
|
rng *rand.Rand
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new ingestion handler
|
// Creates a new ingestion handler
|
||||||
|
|
@ -27,33 +31,72 @@ func NewIngestionHandler(
|
||||||
refRegistry *normalize.ReferrerRegistry,
|
refRegistry *normalize.ReferrerRegistry,
|
||||||
metricsAgg *aggregate.MetricsAggregator,
|
metricsAgg *aggregate.MetricsAggregator,
|
||||||
) *IngestionHandler {
|
) *IngestionHandler {
|
||||||
|
var limiter *ratelimit.TokenBucket
|
||||||
|
if cfg.Limits.MaxEventsPerMinute > 0 {
|
||||||
|
limiter = ratelimit.NewTokenBucket(
|
||||||
|
cfg.Limits.MaxEventsPerMinute,
|
||||||
|
cfg.Limits.MaxEventsPerMinute,
|
||||||
|
60_000_000_000, // 1 minute in nanoseconds
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return &IngestionHandler{
|
return &IngestionHandler{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
pathNorm: pathNorm,
|
pathNorm: pathNorm,
|
||||||
pathRegistry: pathRegistry,
|
pathRegistry: pathRegistry,
|
||||||
refRegistry: refRegistry,
|
refRegistry: refRegistry,
|
||||||
metricsAgg: metricsAgg,
|
metricsAgg: metricsAgg,
|
||||||
|
rateLimiter: limiter,
|
||||||
|
rng: rand.New(rand.NewSource(42)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Handle CORS preflight
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
h.handleCORS(w, r)
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply CORS headers to actual request
|
||||||
|
h.handleCORS(w, r)
|
||||||
|
|
||||||
// Only accept POST requests
|
// Only accept POST requests
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit
|
||||||
|
if h.rateLimiter != nil && !h.rateLimiter.Allow() {
|
||||||
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply sampling
|
||||||
|
if h.cfg.Site.Sampling < 1.0 {
|
||||||
|
if h.rng.Float64() > h.cfg.Site.Sampling {
|
||||||
|
// Sampled out, return success but don't track
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check context cancellation
|
||||||
|
if r.Context().Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Parse event from request body
|
// Parse event from request body
|
||||||
event, err := ParseEvent(r.Body)
|
event, err := ParseEvent(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to parse event: %v", err)
|
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate event
|
// Validate event
|
||||||
if err := event.Validate(h.cfg.Site.Domain); err != nil {
|
if err := event.Validate(h.cfg.Site.Domain); err != nil {
|
||||||
log.Printf("Event validation failed: %v", err)
|
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -63,16 +106,14 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if !h.pathRegistry.Add(normalizedPath) {
|
if !h.pathRegistry.Add(normalizedPath) {
|
||||||
// Path was rejected due to cardinality limit
|
// Path was rejected due to cardinality limit
|
||||||
h.metricsAgg.RecordPathOverflow()
|
h.metricsAgg.RecordPathOverflow()
|
||||||
log.Printf("Path overflow: rejected %s", normalizedPath)
|
|
||||||
|
|
||||||
// Still return success to client
|
// Still return success to client
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract visitor identity for unique tracking
|
// Extract visitor identity for unique tracking
|
||||||
ip := extractIP(r)
|
ip := h.extractIP(r)
|
||||||
userAgent := r.Header.Get("User-Agent")
|
userAgent := r.Header.Get("User-Agent")
|
||||||
|
|
||||||
// Track unique visitor if salt rotation is enabled
|
// Track unique visitor if salt rotation is enabled
|
||||||
|
|
@ -88,14 +129,18 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// Device classification
|
// Device classification
|
||||||
if h.cfg.Site.Collect.Device {
|
if h.cfg.Site.Collect.Device {
|
||||||
device = classifyDevice(event.Width)
|
device = h.classifyDevice(event.Width)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Referrer classification
|
// Referrer classification
|
||||||
if h.cfg.Site.Collect.Referrer == "domain" {
|
if h.cfg.Site.Collect.Referrer == "domain" {
|
||||||
refDomain := normalize.ExtractReferrerDomain(event.Referrer, h.cfg.Site.Domain)
|
refDomain := normalize.ExtractReferrerDomain(event.Referrer, h.cfg.Site.Domain)
|
||||||
if refDomain != "" {
|
if refDomain != "" {
|
||||||
referrer = h.refRegistry.Add(refDomain)
|
accepted := h.refRegistry.Add(refDomain)
|
||||||
|
if accepted == "other" {
|
||||||
|
h.metricsAgg.RecordReferrerOverflow()
|
||||||
|
}
|
||||||
|
referrer = accepted
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,15 +155,74 @@ func (h *IngestionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractIP extracts the client IP from the request
|
// Adds CORS headers if enabled in config
|
||||||
// Checks X-Forwarded-For and X-Real-IP headers for proxied requests
|
func (h *IngestionHandler) handleCORS(w http.ResponseWriter, r *http.Request) {
|
||||||
func extractIP(r *http.Request) string {
|
if !h.cfg.Security.CORS.Enabled {
|
||||||
// Check X-Forwarded-For header (may contain multiple IPs)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if origin is allowed
|
||||||
|
allowed := false
|
||||||
|
for _, allowedOrigin := range h.cfg.Security.CORS.AllowedOrigins {
|
||||||
|
if allowedOrigin == "*" || allowedOrigin == origin {
|
||||||
|
allowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowed {
|
||||||
|
if origin == "*" {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||||
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts the client IP from the requests. Only trusts proxy headers if source
|
||||||
|
// IP is in trusted_proxies list
|
||||||
|
func (h *IngestionHandler) extractIP(r *http.Request) string {
|
||||||
|
// Get the direct connection IP
|
||||||
|
remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
// RemoteAddr might not have port (shouldn't happen, but handle it anyway)
|
||||||
|
remoteIP = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we should trust proxy headers
|
||||||
|
trustProxy := false
|
||||||
|
if len(h.cfg.Security.TrustedProxies) > 0 {
|
||||||
|
for _, trustedCIDR := range h.cfg.Security.TrustedProxies {
|
||||||
|
if h.ipInCIDR(remoteIP, trustedCIDR) {
|
||||||
|
trustProxy = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not trusting proxy, return direct IP
|
||||||
|
if !trustProxy {
|
||||||
|
return remoteIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check X-Forwarded-For header
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
// Take the first IP in the list
|
// Take the rightmost IP that's not from a trusted proxy
|
||||||
ips := strings.Split(xff, ",")
|
ips := strings.Split(xff, ",")
|
||||||
if len(ips) > 0 {
|
for i := len(ips) - 1; i >= 0; i-- {
|
||||||
return strings.TrimSpace(ips[0])
|
ip := strings.TrimSpace(ips[i])
|
||||||
|
if !h.ipInCIDR(ip, "0.0.0.0/0") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return ip
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,24 +232,43 @@ func extractIP(r *http.Request) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to RemoteAddr
|
// Fall back to RemoteAddr
|
||||||
ip := r.RemoteAddr
|
return remoteIP
|
||||||
// Strip port if present
|
|
||||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
||||||
ip = ip[:idx]
|
|
||||||
}
|
|
||||||
return ip
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Classifies screen width into device categories
|
// Checks if an IP address is within a CIDR range
|
||||||
func classifyDevice(width int) string {
|
func (h *IngestionHandler) ipInCIDR(ip, cidr string) bool {
|
||||||
// FIXME: probably not the best logic for this...
|
// Parse the IP address
|
||||||
|
testIP := net.ParseIP(ip)
|
||||||
|
if testIP == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the CIDR
|
||||||
|
_, network, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
// If it's not a CIDR, try as a single IP
|
||||||
|
cidrIP := net.ParseIP(cidr)
|
||||||
|
if cidrIP == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return testIP.Equal(cidrIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
return network.Contains(testIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Classifies screen width into device categories using configured breakpoints
|
||||||
|
// FIXME: we need a more robust mechanism for classifying devices. Breakpoints
|
||||||
|
// are the only ones I can think of *right now* but I'm positive there are better
|
||||||
|
// mechanisns. We'll get to this later.
|
||||||
|
func (h *IngestionHandler) classifyDevice(width int) string {
|
||||||
if width == 0 {
|
if width == 0 {
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
if width < 768 {
|
if width < h.cfg.Limits.DeviceBreakpoints.Mobile {
|
||||||
return "mobile"
|
return "mobile"
|
||||||
}
|
}
|
||||||
if width < 1024 {
|
if width < h.cfg.Limits.DeviceBreakpoints.Tablet {
|
||||||
return "tablet"
|
return "tablet"
|
||||||
}
|
}
|
||||||
return "desktop"
|
return "desktop"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue