diff --git a/internal/cache/db.go b/internal/cache/db.go index 7c793a7..556b047 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -3,6 +3,7 @@ package cache import ( "database/sql" "fmt" + "strings" "time" _ "modernc.org/sqlite" @@ -20,6 +21,7 @@ type RouteEntry struct { TTL time.Time NarHash string NarSize uint64 + NarURL string // narinfo URL field, e.g. "nar/1wwh37...nar.xz" } // Returns true if the entry exists and hasn't expired. @@ -86,14 +88,31 @@ func migrate(db *sql.DB) error { ); CREATE INDEX IF NOT EXISTS idx_negative_expires ON negative_cache(expires_at); `) + if err != nil { + return err + } + // Add nar_url column if it does not exist yet (ALTER TABLE does not support + // IF NOT EXISTS in SQLite, so we ignore the "duplicate column" error). + if _, err := db.Exec(`ALTER TABLE routes ADD COLUMN nar_url TEXT DEFAULT ''`); err != nil { + if !isDuplicateColumn(err) { + return err + } + } + _, err = db.Exec(`CREATE INDEX IF NOT EXISTS idx_routes_nar_url ON routes(nar_url)`) return err } +// Returns true when err is a SQLite "duplicate column name" error produced by +// ALTER TABLE ADD COLUMN on a column that already exists. +func isDuplicateColumn(err error) bool { + return err != nil && strings.Contains(err.Error(), "duplicate column name") +} + // Returns the route for storePath, or nil if not found. func (d *DB) GetRoute(storePath string) (*RouteEntry, error) { row := d.db.QueryRow(` SELECT store_path, upstream_url, latency_ms, latency_ema, - query_count, failure_count, last_verified, ttl, nar_hash, nar_size + query_count, failure_count, last_verified, ttl, nar_hash, nar_size, nar_url FROM routes WHERE store_path = ?`, storePath) var e RouteEntry @@ -101,7 +120,32 @@ func (d *DB) GetRoute(storePath string) (*RouteEntry, error) { err := row.Scan( &e.StorePath, &e.UpstreamURL, &e.LatencyMs, &e.LatencyEMA, &e.QueryCount, &e.FailureCount, &lastVerifiedUnix, &ttlUnix, - &e.NarHash, &e.NarSize, + &e.NarHash, &e.NarSize, &e.NarURL, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + e.LastVerified = time.Unix(lastVerifiedUnix, 0).UTC() + e.TTL = time.Unix(ttlUnix, 0).UTC() + return &e, nil +} + +// Returns the route whose narinfo URL matches narURL, or nil if not found / expired. +func (d *DB) GetRouteByNarURL(narURL string) (*RouteEntry, error) { + row := d.db.QueryRow(` + SELECT store_path, upstream_url, latency_ms, latency_ema, + query_count, failure_count, last_verified, ttl, nar_hash, nar_size, nar_url + FROM routes WHERE nar_url = ? AND ttl > ?`, narURL, time.Now().Unix()) + + var e RouteEntry + var lastVerifiedUnix, ttlUnix int64 + err := row.Scan( + &e.StorePath, &e.UpstreamURL, &e.LatencyMs, &e.LatencyEMA, + &e.QueryCount, &e.FailureCount, &lastVerifiedUnix, &ttlUnix, + &e.NarHash, &e.NarSize, &e.NarURL, ) if err == sql.ErrNoRows { return nil, nil @@ -119,8 +163,8 @@ func (d *DB) SetRoute(entry *RouteEntry) error { _, err := d.db.Exec(` INSERT INTO routes (store_path, upstream_url, latency_ms, latency_ema, - query_count, failure_count, last_verified, ttl, nar_hash, nar_size) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + query_count, failure_count, last_verified, ttl, nar_hash, nar_size, nar_url) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(store_path) DO UPDATE SET upstream_url = excluded.upstream_url, latency_ms = excluded.latency_ms, @@ -130,12 +174,13 @@ func (d *DB) SetRoute(entry *RouteEntry) error { last_verified = excluded.last_verified, ttl = excluded.ttl, nar_hash = excluded.nar_hash, - nar_size = excluded.nar_size`, + nar_size = excluded.nar_size, + nar_url = excluded.nar_url`, entry.StorePath, entry.UpstreamURL, entry.LatencyMs, entry.LatencyEMA, entry.QueryCount, entry.FailureCount, entry.LastVerified.Unix(), entry.TTL.Unix(), - entry.NarHash, entry.NarSize, + entry.NarHash, entry.NarSize, entry.NarURL, ) if err != nil { return err diff --git a/internal/cache/db_test.go b/internal/cache/db_test.go index 7501e5d..8ec1f99 100644 --- a/internal/cache/db_test.go +++ b/internal/cache/db_test.go @@ -240,6 +240,44 @@ func TestNegativeCacheExpiry(t *testing.T) { } } +func TestGetRouteByNarURL(t *testing.T) { + db, err := cache.Open(":memory:", 100) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + entry := &cache.RouteEntry{ + StorePath: "abc123", + UpstreamURL: "https://cache.nixos.org", + NarURL: "nar/abc123.nar.xz", + TTL: time.Now().Add(time.Hour), + } + if err := db.SetRoute(entry); err != nil { + t.Fatalf("SetRoute: %v", err) + } + + got, err := db.GetRouteByNarURL("nar/abc123.nar.xz") + if err != nil { + t.Fatalf("GetRouteByNarURL: %v", err) + } + if got == nil { + t.Fatal("expected non-nil entry") + } + if got.UpstreamURL != "https://cache.nixos.org" { + t.Errorf("UpstreamURL = %q", got.UpstreamURL) + } + + // Non-existent NarURL returns nil. + got2, err := db.GetRouteByNarURL("nar/nonexistent.nar.xz") + if err != nil { + t.Fatalf("GetRouteByNarURL for missing: %v", err) + } + if got2 != nil { + t.Error("expected nil for missing NarURL") + } +} + func TestLRUEviction(t *testing.T) { // Use maxEntries=3 to trigger eviction easily f, _ := os.CreateTemp("", "ncro-lru-*.db") diff --git a/internal/router/router.go b/internal/router/router.go index 741b2de..3678d6e 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -193,7 +193,7 @@ func (r *Router) race(storeHash string, candidates []string) (*Result, error) { metrics.UpstreamLatency.WithLabelValues(winner.url).Observe(winner.latencyMs / 1000) // Fetch narinfo body to parse metadata and forward to caller. - narInfoBytes, narHash, narSize := r.fetchNarInfo(winner.url, storeHash) + narInfoBytes, narURL, narHash, narSize := r.fetchNarInfo(winner.url, storeHash) health := r.prober.GetHealth(winner.url) ema := winner.latencyMs @@ -213,31 +213,32 @@ func (r *Router) race(storeHash string, candidates []string) (*Result, error) { TTL: now.Add(r.routeTTL), NarHash: narHash, NarSize: narSize, + NarURL: narURL, }) return &Result{URL: winner.url, LatencyMs: winner.latencyMs, NarInfoBytes: narInfoBytes}, nil } -// Fetches narinfo content from upstream, verifies its signature if a key is -// configured for that upstream, and returns (body, narHash, narSize). -// Returns (nil, "", 0) if the fetch fails or signature verification fails. -func (r *Router) fetchNarInfo(upstream, storeHash string) ([]byte, string, uint64) { +// Returns (body, narURL, narHash, narSize). narURL is the narinfo's URL field +// (e.g. "nar/1wwh37...nar.xz"), used for direct NAR routing. +// Returns (nil, "", "", 0) on fetch failure or signature verification failure. +func (r *Router) fetchNarInfo(upstream, storeHash string) ([]byte, string, string, uint64) { url := upstream + "/" + storeHash + ".narinfo" resp, err := r.client.Get(url) if err != nil { - return nil, "", 0 + return nil, "", "", 0 } defer resp.Body.Close() if resp.StatusCode != 200 { - return nil, "", 0 + return nil, "", "", 0 } body, err := io.ReadAll(resp.Body) if err != nil { - return nil, "", 0 + return nil, "", "", 0 } ni, err := narinfo.Parse(bytes.NewReader(body)) if err != nil { - return body, "", 0 + return body, "", "", 0 } r.mu.RLock() pubKeyStr := r.upstreamKeys[upstream] @@ -246,12 +247,12 @@ func (r *Router) fetchNarInfo(upstream, storeHash string) ([]byte, string, uint6 ok, err := ni.Verify(pubKeyStr) if err != nil { slog.Warn("narinfo: public key parse error", "upstream", upstream, "error", err) - return nil, "", 0 + return nil, "", "", 0 } if !ok { slog.Warn("narinfo: signature verification failed", "upstream", upstream, "store", storeHash) - return nil, "", 0 + return nil, "", "", 0 } } - return body, ni.NarHash, ni.NarSize + return body, ni.URL, ni.NarHash, ni.NarSize }