treewide: rewrite everything in Rust

Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: I786da853078e1013bb8f463ed9e9869c6a6a6964
This commit is contained in:
raf 2026-05-11 12:08:49 +03:00
commit ea96477830
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
43 changed files with 5993 additions and 4594 deletions

1
.envrc
View file

@ -1,2 +1 @@
use flake
export CGO_ENABLED=0

1
.gitignore vendored
View file

@ -1,2 +1,3 @@
# Build output
/ncro
/target

26
.rustfmt.toml Normal file
View file

@ -0,0 +1,26 @@
condense_wildcard_suffixes = true
doc_comment_code_block_width = 80
edition = "2024" # Keep in sync with Cargo.toml.
enum_discrim_align_threshold = 60
force_explicit_abi = false
force_multiline_blocks = true
format_code_in_doc_comments = true
format_macro_matchers = true
format_strings = true
group_imports = "StdExternalCrate"
hex_literal_case = "Upper"
imports_granularity = "Crate"
imports_layout = "HorizontalVertical"
inline_attribute_width = 60
match_block_trailing_comma = true
max_width = 80
newline_style = "Unix"
normalize_comments = true
normalize_doc_attributes = true
overflow_delimited_expr = true
struct_field_align_threshold = 60
tab_spaces = 2
unstable_features = true
use_field_init_shorthand = true
use_try_shorthand = true
wrap_comments = true

3331
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

130
Cargo.toml Normal file
View file

@ -0,0 +1,130 @@
[package]
name = "ncro"
version = "1.0.0"
edition = "2024"
license = "MIT"
[dependencies]
anyhow = "1.0.102"
async-trait = "0.1.89"
axum = { version = "0.8.9", features = ["macros"] }
base64 = "0.22.1"
bytes = "1.11.1"
clap = { version = "4.6.1", features = ["derive", "env"] }
chrono = { version = "0.4.44", features = ["serde"] }
ed25519-dalek = { version = "2.2.0", features = ["rand_core"] }
futures-util = "0.3.32"
hex = "0.4.3"
http = "1.4.0"
http-body-util = "0.1.3"
humantime-serde = "1.1.1"
mdns-sd = "0.15.2"
prometheus = "0.14.0"
rand = "0.8.6"
reqwest = { version = "0.12.28", default-features = false, features = [
"rustls-tls",
"stream",
] }
rmp-serde = "1.3.1"
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149"
serde_yaml = "0.9.34"
sqlx = { version = "0.8.6", default-features = false, features = [
"runtime-tokio-rustls",
"sqlite",
"macros",
"migrate",
"chrono",
] }
thiserror = "2.0.18"
tokio = { version = "1.52.3", features = [
"macros",
"rt-multi-thread",
"signal",
"time",
"net",
"fs",
] }
tokio-util = { version = "0.7.18", features = ["io"] }
tower-http = { version = "0.6.10", features = ["trace"] }
tracing = "0.1.44"
tracing-subscriber = { version = "0.3.23", features = ["env-filter", "json"] }
url = "2.5.8"
[dev-dependencies]
tempfile = "3.27.0"
tower = { version = "0.5.3", features = ["util"] }
# See:
# <https://doc.rust-lang.org/rustc/lints/listing/allowed-by-default.html>
[lints.clippy]
cargo = { level = "warn", priority = -1 }
complexity = { level = "warn", priority = -1 }
nursery = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
perf = { level = "warn", priority = -1 }
style = { level = "warn", priority = -1 }
# The lint groups above enable some less-than-desirable rules, we should manually
# enable those to keep our sanity.
absolute_paths = "allow"
arbitrary_source_item_ordering = "allow"
clone_on_ref_ptr = "warn"
dbg_macro = "warn"
empty_drop = "warn"
empty_structs_with_brackets = "warn"
exit = "warn"
filetype_is_file = "warn"
get_unwrap = "warn"
implicit_return = "allow"
infinite_loop = "warn"
map_with_unused_argument_over_ranges = "warn"
missing_docs_in_private_items = "allow"
multiple_crate_versions = "allow" # :(
non_ascii_literal = "allow"
non_std_lazy_statics = "warn"
pathbuf_init_then_push = "warn"
pattern_type_mismatch = "allow"
question_mark_used = "allow"
rc_buffer = "warn"
rc_mutex = "warn"
rest_pat_in_fully_bound_structs = "warn"
similar_names = "allow"
single_call_fn = "allow"
std_instead_of_core = "allow"
too_long_first_doc_paragraph = "allow"
too_many_lines = "allow"
undocumented_unsafe_blocks = "warn"
unnecessary_safety_comment = "warn"
unused_result_ok = "warn"
unused_trait_names = "allow"
# False positive:
# clippy's build script check doesn't recognize workspace-inherited metadata
# which means in our current workspace layout, we get pranked by Clippy.
cargo_common_metadata = "allow"
# In the honor of a recent Cloudflare regression
panic = "deny"
unwrap_used = "deny"
# Less dangerous, but we'd like to know
# Those must be opt-in, and are fine ONLY in tests and examples. We *can* panic
# in NDG (the binary crate), but it should be very deliberate
expect_used = "warn"
print_stderr = "warn"
print_stdout = "warn"
todo = "warn"
unimplemented = "warn"
unreachable = "warn"
[profile.dev]
debug = true
opt-level = 0
[profile.release]
codegen-units = 1
lto = true
opt-level = "z"
panic = "abort"
strip = "symbols"

View file

@ -195,10 +195,10 @@ Prometheus metrics are available at `/metrics`.
# With Nix (recommended)
$ nix build
# With Go directly
$ go build ./cmd/ncro/
# With Cargo directly
$ cargo build --release
# Development shell
$ nix develop
$ go test ./...
$ cargo test
```

View file

@ -1,5 +0,0 @@
package main
func main() {
Execute()
}

View file

@ -1,256 +0,0 @@
package main
import (
"context"
"crypto/ed25519"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/config"
"notashelf.dev/ncro/internal/discovery"
"notashelf.dev/ncro/internal/mesh"
"notashelf.dev/ncro/internal/metrics"
"notashelf.dev/ncro/internal/prober"
"notashelf.dev/ncro/internal/router"
"notashelf.dev/ncro/internal/server"
)
// Injected at build time via -ldflags "-X main.version=<ver>".
var version = "dev"
// Execute is the entrypoint called by main.
func Execute() {
if err := newRootCmd().Execute(); err != nil {
os.Exit(1)
}
}
func newRootCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "ncro",
Short: "Nix Cache Route Optimizer",
Version: version,
SilenceUsage: true,
RunE: runServer,
}
cmd.Flags().StringP("config", "c", "", "path to config YAML file (env: NCRO_CONFIG)")
_ = viper.BindPFlag("config", cmd.Flags().Lookup("config"))
viper.SetEnvPrefix("NCRO")
viper.AutomaticEnv()
return cmd
}
func runServer(_ *cobra.Command, _ []string) error {
cfg, err := config.Load(viper.GetString("config"))
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if err := cfg.Validate(); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
level := slog.LevelInfo
switch cfg.Logging.Level {
case "debug":
level = slog.LevelDebug
case "warn":
level = slog.LevelWarn
case "error":
level = slog.LevelError
}
var handler slog.Handler
if cfg.Logging.Format == "text" {
handler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
} else {
handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level})
}
slog.SetDefault(slog.New(handler))
metrics.Register(prometheus.DefaultRegisterer)
db, err := cache.Open(cfg.Cache.DBPath, cfg.Cache.MaxEntries)
if err != nil {
return fmt.Errorf("open database: %w", err)
}
defer db.Close()
expireDone := make(chan struct{})
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-expireDone:
return
case <-ticker.C:
if err := db.ExpireOldRoutes(); err != nil {
slog.Warn("expire routes error", "error", err)
}
if err := db.ExpireNegatives(); err != nil {
slog.Warn("expire negatives error", "error", err)
}
if count, err := db.RouteCount(); err == nil {
metrics.RouteEntries.Set(float64(count))
}
}
}
}()
p := prober.New(cfg.Cache.LatencyAlpha)
p.InitUpstreams(cfg.Upstreams)
if rows, err := db.LoadAllHealth(); err == nil {
for _, row := range rows {
p.Seed(row.URL, row.EMALatency, row.ConsecutiveFails, int64(row.TotalQueries))
}
} else {
slog.Warn("failed to load persisted health data", "error", err)
}
p.SetHealthPersistence(func(url string, ema float64, cf uint32, tq uint64) {
if err := db.SaveHealth(url, ema, int(cf), int64(tq)); err != nil {
slog.Warn("failed to save health", "url", url, "error", err)
}
})
for _, u := range cfg.Upstreams {
go p.ProbeUpstream(u.URL)
}
probeDone := make(chan struct{})
go p.RunProbeLoop(30*time.Second, probeDone)
// Setup mDNS discovery if enabled
var discoveryMgr *discovery.Discovery
if cfg.Discovery.Enabled {
discoveryMgr, err = discovery.New(cfg.Discovery)
if err != nil {
return fmt.Errorf("create discovery manager: %w", err)
}
discoveryMgr.SetCallbacks(
func(url string, priority int) {
slog.Info("adding discovered upstream", "url", url)
p.AddUpstream(url, priority)
},
func(url string) {
slog.Info("removing discovered upstream", "url", url)
p.RemoveUpstream(url)
},
)
slog.Info("mDNS discovery enabled", "service", cfg.Discovery.ServiceName)
}
r := router.New(db, p, cfg.Cache.TTL.Duration, 5*time.Second, cfg.Cache.NegativeTTL.Duration)
for _, u := range cfg.Upstreams {
if u.PublicKey != "" {
if err := r.SetUpstreamKey(u.URL, u.PublicKey); err != nil {
return fmt.Errorf("invalid upstream public key for %s: %w", u.URL, err)
}
slog.Info("narinfo signature verification enabled", "upstream", u.URL)
}
}
var gossipDone chan struct{}
if cfg.Mesh.Enabled {
store := mesh.NewRouteStore()
node, err := mesh.NewNode(cfg.Mesh.PrivateKeyPath, store)
if err != nil {
return fmt.Errorf("create mesh node: %w", err)
}
slog.Info("mesh node identity", "node_id", node.ID(),
"public_key", hex.EncodeToString(node.PublicKey()))
allowedKeys := make([]ed25519.PublicKey, 0, len(cfg.Mesh.Peers))
for _, peer := range cfg.Mesh.Peers {
if peer.PublicKey != "" {
b, _ := hex.DecodeString(peer.PublicKey)
allowedKeys = append(allowedKeys, ed25519.PublicKey(b))
}
}
if err := mesh.ListenAndServe(cfg.Mesh.BindAddr, store, allowedKeys...); err != nil {
return fmt.Errorf("start mesh listener: %w", err)
}
peerAddrs := make([]string, len(cfg.Mesh.Peers))
for i, p := range cfg.Mesh.Peers {
peerAddrs[i] = p.Addr
}
gossipDone = make(chan struct{})
go mesh.RunGossipLoop(node, db, peerAddrs, cfg.Mesh.GossipInterval.Duration, gossipDone)
slog.Info("mesh enabled", "addr", cfg.Mesh.BindAddr, "peers", len(cfg.Mesh.Peers))
}
// Start mDNS discovery in background
discoveryDone := make(chan struct{})
var discoveryCancel context.CancelFunc
if discoveryMgr != nil {
var ctx context.Context
ctx, discoveryCancel = context.WithCancel(context.Background())
go func() {
if err := discoveryMgr.Start(ctx); err != nil {
slog.Error("discovery error", "error", err)
}
}()
go func() {
<-discoveryDone
discoveryCancel()
discoveryMgr.Stop()
}()
}
srv := &http.Server{
Addr: cfg.Server.Listen,
Handler: server.New(r, p, db, cfg.Upstreams, cfg.Server.CachePriority),
ReadTimeout: cfg.Server.ReadTimeout.Duration,
WriteTimeout: cfg.Server.WriteTimeout.Duration,
}
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
serverErr := make(chan error, 1)
go func() {
slog.Info("ncro listening", "addr", cfg.Server.Listen,
"upstreams", len(cfg.Upstreams), "version", version)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
serverErr <- err
}
close(serverErr)
}()
select {
case <-stop:
slog.Info("shutting down")
case err := <-serverErr:
return fmt.Errorf("server: %w", err)
}
close(expireDone)
close(probeDone)
if gossipDone != nil {
close(gossipDone)
}
close(discoveryDone)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
slog.Warn("shutdown error", "error", err)
}
return nil
}

View file

@ -16,8 +16,16 @@ cache:
db_path: "/var/lib/ncro/routes.db"
max_entries: 100000
ttl: 1h
negative_ttl: 10m
latency_alpha: 0.3
discovery:
enabled: false
service_name: "_nix-serve._tcp"
domain: "local"
discovery_time: 5s
priority: 20
mesh:
enabled: false
bind_addr: "0.0.0.0:7946"

51
go.mod
View file

@ -1,51 +0,0 @@
module notashelf.dev/ncro
go 1.25.7
require (
github.com/grandcat/zeroconf v1.0.0
github.com/prometheus/client_golang v1.23.2
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/vmihailenco/msgpack/v5 v5.4.1
golang.org/x/sync v0.20.0
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.50.0
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.10.1 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-isatty v0.0.22 // indirect
github.com/miekg/dns v1.1.72 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/procfs v0.20.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.12.0 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.yaml.in/yaml/v2 v2.4.4 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.54.0 // indirect
golang.org/x/sys v0.44.0 // indirect
golang.org/x/text v0.37.0 // indirect
golang.org/x/tools v0.45.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
modernc.org/libc v1.72.3 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
)

153
go.sum
View file

@ -1,153 +0,0 @@
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE=
github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ=
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.50.0 h1:eMowQSWLK0MeiQTdmz3lqoF5dqclujdlIKeJA11+7oM=
modernc.org/sqlite v1.50.0/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

313
internal/cache/db.go vendored
View file

@ -1,313 +0,0 @@
package cache
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"strings"
"time"
_ "modernc.org/sqlite"
)
// Core routing decision persisted per store path.
type RouteEntry struct {
StorePath string
UpstreamURL string
LatencyMs float64
LatencyEMA float64
LastVerified time.Time
QueryCount uint32
FailureCount uint32
TTL time.Time
NarHash string
NarSize uint64
NarURL string // narinfo URL field, e.g. "nar/1wwh37...nar.xz"
}
// Returns true if the entry exists and hasn't expired.
func (r *RouteEntry) IsValid() bool {
return r != nil && time.Now().Before(r.TTL)
}
// SQLite-backed store for route persistence.
type DB struct {
db *sql.DB
maxEntries int
}
// Opens or creates the SQLite database at path with WAL mode.
// Creates parent directories as needed (unless path is ":memory:").
func Open(path string, maxEntries int) (*DB, error) {
if path != ":memory:" {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return nil, fmt.Errorf("create db dir: %w", err)
}
}
db, err := sql.Open("sqlite", path+"?_journal=WAL&_busy_timeout=5000")
if err != nil {
return nil, fmt.Errorf("open sqlite: %w", err)
}
db.SetMaxOpenConns(1) // SQLite WAL allows 1 writer
if err := migrate(db); err != nil {
db.Close()
return nil, fmt.Errorf("migrate: %w", err)
}
return &DB{db: db, maxEntries: maxEntries}, nil
}
// Closes the database.
func (d *DB) Close() error {
return d.db.Close()
}
func migrate(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS routes (
store_path TEXT PRIMARY KEY,
upstream_url TEXT NOT NULL,
latency_ms REAL DEFAULT 0,
latency_ema REAL DEFAULT 0,
query_count INTEGER DEFAULT 1,
failure_count INTEGER DEFAULT 0,
last_verified INTEGER DEFAULT 0,
ttl INTEGER NOT NULL,
nar_hash TEXT DEFAULT '',
nar_size INTEGER DEFAULT 0,
created_at INTEGER DEFAULT (strftime('%s', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_routes_ttl ON routes(ttl);
CREATE INDEX IF NOT EXISTS idx_routes_last_verified ON routes(last_verified);
CREATE TABLE IF NOT EXISTS upstream_health (
url TEXT PRIMARY KEY,
ema_latency REAL DEFAULT 0,
last_probe INTEGER DEFAULT 0,
consecutive_fails INTEGER DEFAULT 0,
total_queries INTEGER DEFAULT 0,
success_rate REAL DEFAULT 1.0
);
CREATE TABLE IF NOT EXISTS negative_cache (
store_path TEXT PRIMARY KEY,
expires_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_negative_expires ON negative_cache(expires_at);
`)
if err != nil {
return err
}
// Add nar_url column if it does not exist yet (ALTER TABLE does not support
// IF NOT EXISTS in SQLite, so we ignore the "duplicate column" error).
if _, err := db.Exec(`ALTER TABLE routes ADD COLUMN nar_url TEXT DEFAULT ''`); err != nil {
if !isDuplicateColumn(err) {
return err
}
}
_, err = db.Exec(`CREATE INDEX IF NOT EXISTS idx_routes_nar_url ON routes(nar_url)`)
return err
}
// Returns true when err is a SQLite "duplicate column name" error produced by
// ALTER TABLE ADD COLUMN on a column that already exists.
func isDuplicateColumn(err error) bool {
return err != nil && strings.Contains(err.Error(), "duplicate column name")
}
// Returns the route for storePath, or nil if not found.
func (d *DB) GetRoute(storePath string) (*RouteEntry, error) {
row := d.db.QueryRow(`
SELECT store_path, upstream_url, latency_ms, latency_ema,
query_count, failure_count, last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE store_path = ?`, storePath)
var e RouteEntry
var lastVerifiedUnix, ttlUnix int64
err := row.Scan(
&e.StorePath, &e.UpstreamURL, &e.LatencyMs, &e.LatencyEMA,
&e.QueryCount, &e.FailureCount, &lastVerifiedUnix, &ttlUnix,
&e.NarHash, &e.NarSize, &e.NarURL,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
e.LastVerified = time.Unix(lastVerifiedUnix, 0).UTC()
e.TTL = time.Unix(ttlUnix, 0).UTC()
return &e, nil
}
// Returns the route whose narinfo URL matches narURL, or nil if not found / expired.
func (d *DB) GetRouteByNarURL(narURL string) (*RouteEntry, error) {
row := d.db.QueryRow(`
SELECT store_path, upstream_url, latency_ms, latency_ema,
query_count, failure_count, last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE nar_url = ? AND ttl > ?`, narURL, time.Now().Unix())
var e RouteEntry
var lastVerifiedUnix, ttlUnix int64
err := row.Scan(
&e.StorePath, &e.UpstreamURL, &e.LatencyMs, &e.LatencyEMA,
&e.QueryCount, &e.FailureCount, &lastVerifiedUnix, &ttlUnix,
&e.NarHash, &e.NarSize, &e.NarURL,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
e.LastVerified = time.Unix(lastVerifiedUnix, 0).UTC()
e.TTL = time.Unix(ttlUnix, 0).UTC()
return &e, nil
}
// Inserts or updates a route entry.
func (d *DB) SetRoute(entry *RouteEntry) error {
_, err := d.db.Exec(`
INSERT INTO routes
(store_path, upstream_url, latency_ms, latency_ema,
query_count, failure_count, last_verified, ttl, nar_hash, nar_size, nar_url)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(store_path) DO UPDATE SET
upstream_url = excluded.upstream_url,
latency_ms = excluded.latency_ms,
latency_ema = excluded.latency_ema,
query_count = excluded.query_count,
failure_count = excluded.failure_count,
last_verified = excluded.last_verified,
ttl = excluded.ttl,
nar_hash = excluded.nar_hash,
nar_size = excluded.nar_size,
nar_url = excluded.nar_url`,
entry.StorePath, entry.UpstreamURL,
entry.LatencyMs, entry.LatencyEMA,
entry.QueryCount, entry.FailureCount,
entry.LastVerified.Unix(), entry.TTL.Unix(),
entry.NarHash, entry.NarSize, entry.NarURL,
)
if err != nil {
return err
}
return d.evictIfNeeded()
}
// Deletes routes whose TTL has passed.
func (d *DB) ExpireOldRoutes() error {
_, err := d.db.Exec(`DELETE FROM routes WHERE ttl < ?`, time.Now().Unix())
return err
}
// Returns up to n non-expired routes ordered by most-recently-verified.
func (d *DB) ListRecentRoutes(n int) ([]RouteEntry, error) {
rows, err := d.db.Query(`
SELECT store_path, upstream_url, latency_ema, last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE ttl > ? ORDER BY last_verified DESC LIMIT ?`,
time.Now().Unix(), n)
if err != nil {
return nil, err
}
defer rows.Close()
var result []RouteEntry
for rows.Next() {
var e RouteEntry
var lastVerifiedUnix, ttlUnix int64
if err := rows.Scan(
&e.StorePath, &e.UpstreamURL, &e.LatencyEMA,
&lastVerifiedUnix, &ttlUnix, &e.NarHash, &e.NarSize, &e.NarURL,
); err != nil {
return nil, err
}
e.LastVerified = time.Unix(lastVerifiedUnix, 0).UTC()
e.TTL = time.Unix(ttlUnix, 0).UTC()
result = append(result, e)
}
return result, rows.Err()
}
// Returns the total number of stored routes.
func (d *DB) RouteCount() (int, error) {
var count int
err := d.db.QueryRow(`SELECT COUNT(*) FROM routes`).Scan(&count)
return count, err
}
// Records a negative cache entry for storePath with the given TTL.
func (d *DB) SetNegative(storePath string, ttl time.Duration) error {
_, err := d.db.Exec(
`INSERT INTO negative_cache (store_path, expires_at) VALUES (?, ?)
ON CONFLICT(store_path) DO UPDATE SET expires_at = excluded.expires_at`,
storePath, time.Now().Add(ttl).Unix(),
)
return err
}
// Returns true if a non-expired negative entry exists for storePath.
func (d *DB) IsNegative(storePath string) (bool, error) {
var exists bool
err := d.db.QueryRow(
`SELECT EXISTS(SELECT 1 FROM negative_cache WHERE store_path = ? AND expires_at > ?)`,
storePath, time.Now().Unix(),
).Scan(&exists)
return exists, err
}
// Deletes expired negative cache entries.
func (d *DB) ExpireNegatives() error {
_, err := d.db.Exec(`DELETE FROM negative_cache WHERE expires_at < ?`, time.Now().Unix())
return err
}
// Persisted snapshot of one upstream's health metrics.
type HealthRow struct {
URL string
EMALatency float64
ConsecutiveFails int
TotalQueries int64
}
// Upserts the health metrics for the given upstream URL.
func (d *DB) SaveHealth(url string, ema float64, consecutiveFails int, totalQueries int64) error {
_, err := d.db.Exec(`
INSERT INTO upstream_health (url, ema_latency, consecutive_fails, total_queries)
VALUES (?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
ema_latency = excluded.ema_latency,
consecutive_fails = excluded.consecutive_fails,
total_queries = excluded.total_queries`,
url, ema, consecutiveFails, totalQueries,
)
return err
}
// Returns all rows from the upstream_health table.
func (d *DB) LoadAllHealth() ([]HealthRow, error) {
rows, err := d.db.Query(`SELECT url, ema_latency, consecutive_fails, total_queries FROM upstream_health`)
if err != nil {
return nil, err
}
defer rows.Close()
var result []HealthRow
for rows.Next() {
var r HealthRow
if err := rows.Scan(&r.URL, &r.EMALatency, &r.ConsecutiveFails, &r.TotalQueries); err != nil {
return nil, err
}
result = append(result, r)
}
return result, rows.Err()
}
// Deletes the oldest routes (by last_verified) when over capacity.
func (d *DB) evictIfNeeded() error {
_, err := d.db.Exec(`
DELETE FROM routes WHERE store_path IN (
SELECT store_path FROM routes ORDER BY last_verified ASC
LIMIT MAX(0, (SELECT COUNT(*) FROM routes) - ?)
)`, d.maxEntries)
return err
}

View file

@ -1,324 +0,0 @@
package cache_test
import (
"os"
"testing"
"time"
"notashelf.dev/ncro/internal/cache"
)
func newTestDB(t *testing.T) *cache.DB {
t.Helper()
f, err := os.CreateTemp("", "ncro-test-*.db")
if err != nil {
t.Fatal(err)
}
f.Close()
t.Cleanup(func() { os.Remove(f.Name()) })
db, err := cache.Open(f.Name(), 1000)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })
return db
}
func TestGetSetRoute(t *testing.T) {
db := newTestDB(t)
entry := &cache.RouteEntry{
StorePath: "abc123xyz-hello-2.12",
UpstreamURL: "https://cache.nixos.org",
LatencyMs: 12.5,
LatencyEMA: 12.5,
LastVerified: time.Now().UTC().Truncate(time.Second),
QueryCount: 1,
TTL: time.Now().Add(time.Hour).UTC().Truncate(time.Second),
}
if err := db.SetRoute(entry); err != nil {
t.Fatalf("SetRoute: %v", err)
}
got, err := db.GetRoute("abc123xyz-hello-2.12")
if err != nil {
t.Fatalf("GetRoute: %v", err)
}
if got == nil {
t.Fatal("GetRoute returned nil")
}
if got.UpstreamURL != entry.UpstreamURL {
t.Errorf("upstream = %q, want %q", got.UpstreamURL, entry.UpstreamURL)
}
if got.QueryCount != 1 {
t.Errorf("query_count = %d, want 1", got.QueryCount)
}
}
func TestGetRouteNotFound(t *testing.T) {
db := newTestDB(t)
got, err := db.GetRoute("nonexistent")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Errorf("expected nil, got %+v", got)
}
}
func TestSetRouteUpsert(t *testing.T) {
db := newTestDB(t)
entry := &cache.RouteEntry{
StorePath: "abc123-pkg",
UpstreamURL: "https://cache.nixos.org",
LatencyMs: 20.0,
LatencyEMA: 20.0,
QueryCount: 1,
TTL: time.Now().Add(time.Hour),
}
db.SetRoute(entry)
entry.LatencyEMA = 18.0
entry.QueryCount = 2
if err := db.SetRoute(entry); err != nil {
t.Fatalf("upsert: %v", err)
}
got, _ := db.GetRoute("abc123-pkg")
if got.LatencyEMA != 18.0 {
t.Errorf("ema = %f, want 18.0", got.LatencyEMA)
}
if got.QueryCount != 2 {
t.Errorf("query_count = %d, want 2", got.QueryCount)
}
}
func TestExpireOldRoutes(t *testing.T) {
db := newTestDB(t)
// Insert expired route
expired := &cache.RouteEntry{
StorePath: "expired-pkg",
UpstreamURL: "https://cache.nixos.org",
TTL: time.Now().Add(-time.Minute), // already expired
}
db.SetRoute(expired)
// Insert valid route
valid := &cache.RouteEntry{
StorePath: "valid-pkg",
UpstreamURL: "https://cache.nixos.org",
TTL: time.Now().Add(time.Hour),
}
db.SetRoute(valid)
if err := db.ExpireOldRoutes(); err != nil {
t.Fatalf("ExpireOldRoutes: %v", err)
}
got, _ := db.GetRoute("expired-pkg")
if got != nil {
t.Error("expired route should have been deleted")
}
got2, _ := db.GetRoute("valid-pkg")
if got2 == nil {
t.Error("valid route should still exist")
}
}
func TestRouteEntryIsValidExpired(t *testing.T) {
expired := &cache.RouteEntry{TTL: time.Now().Add(-time.Minute)}
if expired.IsValid() {
t.Error("expired entry should not be valid")
}
}
func TestRouteEntryIsValidFuture(t *testing.T) {
valid := &cache.RouteEntry{TTL: time.Now().Add(time.Hour)}
if !valid.IsValid() {
t.Error("future-TTL entry should be valid")
}
}
func TestDBOpenCreatesSchema(t *testing.T) {
db := newTestDB(t)
// RouteCount works only if schema was created.
count, err := db.RouteCount()
if err != nil {
t.Fatalf("RouteCount after fresh open: %v", err)
}
if count != 0 {
t.Errorf("expected 0 routes in fresh DB, got %d", count)
}
}
func TestRouteCountAfterExpiry(t *testing.T) {
db := newTestDB(t)
for i := range 3 {
ttl := time.Now().Add(-time.Minute) // all expired
db.SetRoute(&cache.RouteEntry{
StorePath: "pkg-" + string(rune('a'+i)),
UpstreamURL: "https://cache.nixos.org",
TTL: ttl,
})
}
before, _ := db.RouteCount()
if err := db.ExpireOldRoutes(); err != nil {
t.Fatal(err)
}
after, _ := db.RouteCount()
if after >= before {
t.Errorf("count did not decrease after expiry: before=%d after=%d", before, after)
}
if after != 0 {
t.Errorf("expected 0 routes after expiring all, got %d", after)
}
}
func TestNegativeCacheSetAndCheck(t *testing.T) {
db, err := cache.Open(":memory:", 100)
if err != nil {
t.Fatal(err)
}
defer db.Close()
neg, err := db.IsNegative("missing-path")
if err != nil {
t.Fatalf("IsNegative: %v", err)
}
if neg {
t.Error("expected false for unknown path")
}
if err := db.SetNegative("missing-path", 10*time.Minute); err != nil {
t.Fatalf("SetNegative: %v", err)
}
neg, err = db.IsNegative("missing-path")
if err != nil {
t.Fatalf("IsNegative after set: %v", err)
}
if !neg {
t.Error("expected true after SetNegative")
}
}
func TestNegativeCacheExpiry(t *testing.T) {
db, err := cache.Open(":memory:", 100)
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Set with negative duration so it's already expired.
if err := db.SetNegative("expires-now", -time.Second); err != nil {
t.Fatalf("SetNegative: %v", err)
}
// IsNegative must filter expired entries via the inline SQL predicate,
// even before ExpireNegatives cleans them up.
neg, err := db.IsNegative("expires-now")
if err != nil {
t.Fatalf("IsNegative for expired entry: %v", err)
}
if neg {
t.Error("IsNegative should return false for an already-expired entry (SQL time predicate)")
}
// Janitor cleanup should also work.
if err := db.ExpireNegatives(); err != nil {
t.Fatalf("ExpireNegatives: %v", err)
}
neg, _ = db.IsNegative("expires-now")
if neg {
t.Error("expired negative should not be returned after ExpireNegatives")
}
}
func TestGetRouteByNarURL(t *testing.T) {
db, err := cache.Open(":memory:", 100)
if err != nil {
t.Fatal(err)
}
defer db.Close()
entry := &cache.RouteEntry{
StorePath: "abc123",
UpstreamURL: "https://cache.nixos.org",
NarURL: "nar/abc123.nar.xz",
TTL: time.Now().Add(time.Hour),
}
if err := db.SetRoute(entry); err != nil {
t.Fatalf("SetRoute: %v", err)
}
got, err := db.GetRouteByNarURL("nar/abc123.nar.xz")
if err != nil {
t.Fatalf("GetRouteByNarURL: %v", err)
}
if got == nil {
t.Fatal("expected non-nil entry")
}
if got.UpstreamURL != "https://cache.nixos.org" {
t.Errorf("UpstreamURL = %q", got.UpstreamURL)
}
// Non-existent NarURL returns nil.
got2, err := db.GetRouteByNarURL("nar/nonexistent.nar.xz")
if err != nil {
t.Fatalf("GetRouteByNarURL for missing: %v", err)
}
if got2 != nil {
t.Error("expected nil for missing NarURL")
}
// Expired entry must not be returned (tests the AND ttl > ? predicate).
expired := &cache.RouteEntry{
StorePath: "abc456",
UpstreamURL: "https://cache.nixos.org",
NarURL: "nar/abc456.nar.xz",
TTL: time.Now().Add(-time.Hour), // already in the past
}
if err := db.SetRoute(expired); err != nil {
t.Fatalf("SetRoute expired: %v", err)
}
got3, err := db.GetRouteByNarURL("nar/abc456.nar.xz")
if err != nil {
t.Fatalf("GetRouteByNarURL for expired: %v", err)
}
if got3 != nil {
t.Error("GetRouteByNarURL should return nil for an expired entry")
}
}
func TestLRUEviction(t *testing.T) {
// Use maxEntries=3 to trigger eviction easily
f, _ := os.CreateTemp("", "ncro-lru-*.db")
f.Close()
defer os.Remove(f.Name())
db, _ := cache.Open(f.Name(), 3)
defer db.Close()
for i := range 4 {
db.SetRoute(&cache.RouteEntry{
StorePath: "pkg-" + string(rune('a'+i)),
UpstreamURL: "https://cache.nixos.org",
LastVerified: time.Now().Add(time.Duration(i) * time.Second),
TTL: time.Now().Add(time.Hour),
})
}
count, err := db.RouteCount()
if err != nil {
t.Fatal(err)
}
if count > 3 {
t.Errorf("expected count <= 3 after LRU eviction, got %d", count)
}
}

View file

@ -1,220 +0,0 @@
package config
import (
"encoding/hex"
"fmt"
"net/url"
"os"
"strings"
"time"
"gopkg.in/yaml.v3"
)
// Wrapper around time.Duration supporting YAML duration strings ("30s", "1h").
// yaml.v3 cannot unmarshal duration strings directly into time.Duration (int64).
type Duration struct {
time.Duration
}
func (d *Duration) UnmarshalYAML(value *yaml.Node) error {
var s string
if err := value.Decode(&s); err != nil {
// Try decoding as a raw int64 (nanoseconds) as fallback.
var ns int64
if err2 := value.Decode(&ns); err2 != nil {
return fmt.Errorf("cannot unmarshal duration (tried string: %v): %w", err, err2)
}
d.Duration = time.Duration(ns)
return nil
}
parsed, err := time.ParseDuration(s)
if err != nil {
return fmt.Errorf("invalid duration %q: %w", s, err)
}
d.Duration = parsed
return nil
}
type UpstreamConfig struct {
URL string `yaml:"url"`
Priority int `yaml:"priority"`
PublicKey string `yaml:"public_key"` // Nix signing key "name:base64(key)"
}
type ServerConfig struct {
Listen string `yaml:"listen"`
ReadTimeout Duration `yaml:"read_timeout"`
WriteTimeout Duration `yaml:"write_timeout"`
CachePriority int `yaml:"cache_priority"`
}
type CacheConfig struct {
DBPath string `yaml:"db_path"`
MaxEntries int `yaml:"max_entries"`
TTL Duration `yaml:"ttl"`
NegativeTTL Duration `yaml:"negative_ttl"`
LatencyAlpha float64 `yaml:"latency_alpha"`
}
// Mesh peer with its ed25519 public key for gossip message verification.
type PeerConfig struct {
Addr string `yaml:"addr"`
PublicKey string `yaml:"public_key"` // hex-encoded ed25519 public key (32 bytes)
}
type MeshConfig struct {
Enabled bool `yaml:"enabled"`
BindAddr string `yaml:"bind_addr"`
Peers []PeerConfig `yaml:"peers"`
PrivateKeyPath string `yaml:"private_key"`
GossipInterval Duration `yaml:"gossip_interval"`
}
// Controls mDNS/DNS-SD based dynamic upstream discovery.
type DiscoveryConfig struct {
Enabled bool `yaml:"enabled"`
ServiceName string `yaml:"service_name"`
Domain string `yaml:"domain"`
DiscoveryTime Duration `yaml:"discovery_time"`
Priority int `yaml:"priority"`
}
type LoggingConfig struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}
type Config struct {
Server ServerConfig `yaml:"server"`
Upstreams []UpstreamConfig `yaml:"upstreams"`
Cache CacheConfig `yaml:"cache"`
Mesh MeshConfig `yaml:"mesh"`
Discovery DiscoveryConfig `yaml:"discovery"`
Logging LoggingConfig `yaml:"logging"`
}
func defaults() Config {
return Config{
Server: ServerConfig{
Listen: ":8080",
ReadTimeout: Duration{30 * time.Second},
WriteTimeout: Duration{30 * time.Second},
CachePriority: 30,
},
Upstreams: []UpstreamConfig{
{URL: "https://cache.nixos.org", Priority: 10},
},
Cache: CacheConfig{
DBPath: "/var/lib/ncro/routes.db",
MaxEntries: 100000,
TTL: Duration{time.Hour},
NegativeTTL: Duration{10 * time.Minute},
LatencyAlpha: 0.3,
},
Mesh: MeshConfig{
BindAddr: "0.0.0.0:7946",
GossipInterval: Duration{30 * time.Second},
},
Discovery: DiscoveryConfig{
ServiceName: "_nix-serve._tcp",
Domain: "local",
DiscoveryTime: Duration{5 * time.Second},
Priority: 20,
},
Logging: LoggingConfig{
Level: "info",
Format: "json",
},
}
}
// Validates config fields. Call after Load.
func (c *Config) Validate() error {
if len(c.Upstreams) == 0 {
return fmt.Errorf("at least one upstream is required")
}
for i, u := range c.Upstreams {
if u.URL == "" {
return fmt.Errorf("upstream[%d]: URL is empty", i)
}
if _, err := url.ParseRequestURI(u.URL); err != nil {
return fmt.Errorf("upstream[%d]: invalid URL %q: %w", i, u.URL, err)
}
if u.PublicKey != "" && !strings.Contains(u.PublicKey, ":") {
return fmt.Errorf("upstream[%d]: public_key must be in 'name:base64(key)' Nix format", i)
}
}
if c.Server.Listen == "" {
return fmt.Errorf("server.listen is empty")
}
if c.Server.CachePriority < 1 {
return fmt.Errorf("server.cache_priority must be >= 1, got %d", c.Server.CachePriority)
}
if c.Cache.LatencyAlpha <= 0 || c.Cache.LatencyAlpha >= 1 {
return fmt.Errorf("cache.latency_alpha must be between 0 and 1 exclusive, got %f", c.Cache.LatencyAlpha)
}
if c.Cache.TTL.Duration <= 0 {
return fmt.Errorf("cache.ttl must be positive")
}
if c.Cache.NegativeTTL.Duration <= 0 {
return fmt.Errorf("cache.negative_ttl must be positive")
}
if c.Cache.MaxEntries <= 0 {
return fmt.Errorf("cache.max_entries must be positive")
}
if c.Mesh.Enabled && len(c.Mesh.Peers) == 0 {
return fmt.Errorf("mesh.enabled is true but no peers configured")
}
for i, peer := range c.Mesh.Peers {
if peer.Addr == "" {
return fmt.Errorf("mesh.peers[%d]: addr is empty", i)
}
if peer.PublicKey != "" {
b, err := hex.DecodeString(peer.PublicKey)
if err != nil || len(b) != 32 {
return fmt.Errorf("mesh.peers[%d]: public_key must be a hex-encoded 32-byte ed25519 key", i)
}
}
}
if c.Discovery.Enabled {
if c.Discovery.ServiceName == "" {
return fmt.Errorf("discovery.service_name is required when discovery is enabled")
}
if c.Discovery.Domain == "" {
return fmt.Errorf("discovery.domain is required when discovery is enabled")
}
if c.Discovery.DiscoveryTime.Duration <= 0 {
return fmt.Errorf("discovery.discovery_time must be positive")
}
}
return nil
}
// Loads config from file (if non-empty) and applies env overrides.
func Load(path string) (*Config, error) {
cfg := defaults()
if path != "" {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
}
// Env overrides
if v := os.Getenv("NCRO_LISTEN"); v != "" {
cfg.Server.Listen = v
}
if v := os.Getenv("NCRO_DB_PATH"); v != "" {
cfg.Cache.DBPath = v
}
if v := os.Getenv("NCRO_LOG_LEVEL"); v != "" {
cfg.Logging.Level = v
}
return &cfg, nil
}

View file

@ -1,213 +0,0 @@
package config_test
import (
"os"
"testing"
"time"
"notashelf.dev/ncro/internal/config"
)
func TestLoadDefaults(t *testing.T) {
cfg, err := config.Load("")
if err != nil {
t.Fatalf("Load(\"\") error: %v", err)
}
if cfg.Server.Listen != ":8080" {
t.Errorf("default listen = %q, want :8080", cfg.Server.Listen)
}
if len(cfg.Upstreams) == 0 {
t.Error("expected at least one default upstream")
}
if cfg.Cache.MaxEntries != 100000 {
t.Errorf("default max_entries = %d, want 100000", cfg.Cache.MaxEntries)
}
}
func TestLoadFromYAML(t *testing.T) {
yamlContent := `
server:
listen: ":9090"
upstreams:
- url: "https://cache.nixos.org"
priority: 10
cache:
db_path: "/tmp/test.db"
max_entries: 500
`
f, _ := os.CreateTemp("", "ncro-*.yaml")
defer os.Remove(f.Name())
f.WriteString(yamlContent)
f.Close()
cfg, err := config.Load(f.Name())
if err != nil {
t.Fatalf("Load error: %v", err)
}
if cfg.Server.Listen != ":9090" {
t.Errorf("listen = %q, want :9090", cfg.Server.Listen)
}
if cfg.Cache.MaxEntries != 500 {
t.Errorf("max_entries = %d, want 500", cfg.Cache.MaxEntries)
}
}
func TestEnvOverride(t *testing.T) {
t.Setenv("NCRO_LISTEN", ":1234")
cfg, err := config.Load("")
if err != nil {
t.Fatalf("Load error: %v", err)
}
if cfg.Server.Listen != ":1234" {
t.Errorf("env override listen = %q, want :1234", cfg.Server.Listen)
}
}
func TestDurationParsing(t *testing.T) {
yamlContent := `
server:
listen: ":8080"
read_timeout: 30s
write_timeout: 1m
cache:
ttl: 2h
mesh:
gossip_interval: 45s
`
f, _ := os.CreateTemp("", "ncro-dur-*.yaml")
defer os.Remove(f.Name())
f.WriteString(yamlContent)
f.Close()
cfg, err := config.Load(f.Name())
if err != nil {
t.Fatalf("Load error: %v", err)
}
if cfg.Server.ReadTimeout.Duration != 30*time.Second {
t.Errorf("read_timeout = %v, want 30s", cfg.Server.ReadTimeout.Duration)
}
if cfg.Server.WriteTimeout.Duration != time.Minute {
t.Errorf("write_timeout = %v, want 1m", cfg.Server.WriteTimeout.Duration)
}
if cfg.Cache.TTL.Duration != 2*time.Hour {
t.Errorf("ttl = %v, want 2h", cfg.Cache.TTL.Duration)
}
if cfg.Mesh.GossipInterval.Duration != 45*time.Second {
t.Errorf("gossip_interval = %v, want 45s", cfg.Mesh.GossipInterval.Duration)
}
}
func TestValidateValid(t *testing.T) {
cfg, _ := config.Load("")
if err := cfg.Validate(); err != nil {
t.Errorf("default config should be valid: %v", err)
}
}
func TestValidateNoUpstreams(t *testing.T) {
cfg, _ := config.Load("")
cfg.Upstreams = nil
if err := cfg.Validate(); err == nil {
t.Error("expected error for no upstreams")
}
}
func TestValidateBadURL(t *testing.T) {
cfg, _ := config.Load("")
cfg.Upstreams = []config.UpstreamConfig{{URL: "not-a-url"}}
if err := cfg.Validate(); err == nil {
t.Error("expected error for invalid URL")
}
}
func TestValidateBadAlpha(t *testing.T) {
cfg, _ := config.Load("")
cfg.Cache.LatencyAlpha = 0
if err := cfg.Validate(); err == nil {
t.Error("expected error for alpha=0")
}
cfg.Cache.LatencyAlpha = 1
if err := cfg.Validate(); err == nil {
t.Error("expected error for alpha=1")
}
}
func TestValidateZeroTTL(t *testing.T) {
cfg, _ := config.Load("")
cfg.Cache.TTL = config.Duration{}
if err := cfg.Validate(); err == nil {
t.Error("expected error for zero TTL")
}
}
func TestValidateNegativeTTL(t *testing.T) {
cfg, _ := config.Load("")
cfg.Cache.NegativeTTL = config.Duration{}
if err := cfg.Validate(); err == nil {
t.Error("expected error for zero negative_ttl")
}
}
func TestValidateMeshEnabledNoPeers(t *testing.T) {
cfg, _ := config.Load("")
cfg.Mesh.Enabled = true
cfg.Mesh.Peers = nil
if err := cfg.Validate(); err == nil {
t.Error("expected error for mesh enabled without peers")
}
}
func TestValidateMeshBadPeerKey(t *testing.T) {
cfg, _ := config.Load("")
cfg.Mesh.Enabled = true
cfg.Mesh.Peers = []config.PeerConfig{
{Addr: "127.0.0.1:7946", PublicKey: "not-hex!"},
}
if err := cfg.Validate(); err == nil {
t.Error("expected error for invalid mesh peer public key")
}
}
func TestValidateUpstreamBadPublicKey(t *testing.T) {
cfg, _ := config.Load("")
cfg.Upstreams = []config.UpstreamConfig{
{URL: "https://cache.nixos.org", PublicKey: "no-colon-here"},
}
if err := cfg.Validate(); err == nil {
t.Error("expected error for upstream public_key missing ':'")
}
}
func TestCachePriorityDefault(t *testing.T) {
cfg, err := config.Load("")
if err != nil {
t.Fatal(err)
}
if cfg.Server.CachePriority != 30 {
t.Errorf("default CachePriority = %d, want 30", cfg.Server.CachePriority)
}
}
func TestCachePriorityValidation(t *testing.T) {
cfg, _ := config.Load("")
cfg.Server.CachePriority = 0
if err := cfg.Validate(); err == nil {
t.Error("expected error for CachePriority = 0")
}
}
func TestInvalidDuration(t *testing.T) {
yamlContent := `
server:
read_timeout: "bananas"
`
f, _ := os.CreateTemp("", "ncro-bad-*.yaml")
defer os.Remove(f.Name())
f.WriteString(yamlContent)
f.Close()
_, err := config.Load(f.Name())
if err == nil {
t.Error("expected error for invalid duration string, got nil")
}
}

View file

@ -1,218 +0,0 @@
package discovery
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"time"
"github.com/grandcat/zeroconf"
"notashelf.dev/ncro/internal/config"
)
// Tracks discovered nix-serve instances and maintains the upstream list.
type Discovery struct {
cfg config.DiscoveryConfig
resolver *zeroconf.Resolver
discovered map[string]*discoveredPeer
mu sync.RWMutex
stopCh chan struct{}
stopOnce sync.Once
waitGroup sync.WaitGroup
onAddUpstream func(url string, priority int)
onRemoveUpstream func(url string)
}
type discoveredPeer struct {
url string
lastSeen time.Time
}
// Creates a new Discovery manager.
func New(cfg config.DiscoveryConfig) (*Discovery, error) {
resolver, err := zeroconf.NewResolver(nil)
if err != nil {
return nil, fmt.Errorf("create zeroconf resolver: %w", err)
}
return &Discovery{
cfg: cfg,
resolver: resolver,
discovered: make(map[string]*discoveredPeer),
stopCh: make(chan struct{}),
}, nil
}
// Sets callbacks for upstream addition/removal. These are invoked when peers
// are discovered or leave the network.
func (d *Discovery) SetCallbacks(
add func(url string, priority int),
remove func(url string),
) {
d.mu.Lock()
defer d.mu.Unlock()
d.onAddUpstream = add
d.onRemoveUpstream = remove
}
// Starts browsing for services on the local network. Blocks until the context
// is cancelled or Stop is called.
func (d *Discovery) Start(ctx context.Context) error {
entries := make(chan *zeroconf.ServiceEntry)
d.waitGroup.Add(1)
go d.handleEntries(ctx, entries)
d.waitGroup.Add(1)
go d.maintainPeers(ctx)
if err := d.resolver.Browse(ctx, d.cfg.ServiceName, d.cfg.Domain, entries); err != nil {
close(entries)
d.stopOnce.Do(func() { close(d.stopCh) })
d.waitGroup.Wait()
return fmt.Errorf("browse services: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case <-d.stopCh:
return nil
}
}
// Stops the discovery process.
func (d *Discovery) Stop() {
d.stopOnce.Do(func() { close(d.stopCh) })
d.waitGroup.Wait()
}
// Processes discovered service entries.
func (d *Discovery) handleEntries(ctx context.Context, entries chan *zeroconf.ServiceEntry) {
defer d.waitGroup.Done()
for {
select {
case <-ctx.Done():
return
case <-d.stopCh:
return
case entry, ok := <-entries:
if !ok {
return
}
d.handleEntry(ctx, entry)
}
}
}
// Handles a single service entry.
func (d *Discovery) handleEntry(_ context.Context, entry *zeroconf.ServiceEntry) {
if len(entry.AddrIPv4) == 0 && len(entry.AddrIPv6) == 0 {
slog.Debug("discovered service has no addresses", "instance", entry.Instance)
return
}
var addr string
if len(entry.AddrIPv4) > 0 {
addr = entry.AddrIPv4[0].String()
} else {
addr = entry.AddrIPv6[0].String()
}
url := "http://" + net.JoinHostPort(addr, fmt.Sprintf("%d", entry.Port))
key := fmt.Sprintf("%s@%s", entry.Instance, entry.HostName)
d.mu.Lock()
defer d.mu.Unlock()
// Check if we already know this peer
if _, exists := d.discovered[key]; exists {
d.discovered[key].lastSeen = time.Now()
return
}
// New peer discovered
slog.Info("discovered nix-serve instance", "instance", entry.Instance, "url", url)
d.discovered[key] = &discoveredPeer{
url: url,
lastSeen: time.Now(),
}
// Notify callback if set
if d.onAddUpstream != nil {
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in add upstream callback", "recover", r)
}
}()
d.onAddUpstream(url, d.cfg.Priority)
}()
}
}
// Removes peers that haven't been seen within the TTL period.
func (d *Discovery) maintainPeers(ctx context.Context) {
defer d.waitGroup.Done()
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-d.stopCh:
return
case <-ticker.C:
d.cleanupPeers()
}
}
}
// Cleans up stale peer entries.
func (d *Discovery) cleanupPeers() {
d.mu.Lock()
defer d.mu.Unlock()
now := time.Now()
// TTL is the discovery response time; peers should re-announce periodically.
// Use 3x TTL as the expiration window.
expiration := d.cfg.DiscoveryTime.Duration * 3
if expiration == 0 {
expiration = 30 * time.Second
}
for key, peer := range d.discovered {
if now.Sub(peer.lastSeen) > expiration {
slog.Info("removing stale peer", "url", peer.url)
delete(d.discovered, key)
if d.onRemoveUpstream != nil {
go func(url string) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in remove upstream callback", "recover", r)
}
}()
d.onRemoveUpstream(url)
}(peer.url)
}
}
}
}
// Returns a list of currently discovered peer URLs.
func (d *Discovery) DiscoveredPeers() []string {
d.mu.RLock()
defer d.mu.RUnlock()
peers := make([]string, 0, len(d.discovered))
for _, peer := range d.discovered {
peers = append(peers, peer.url)
}
return peers
}

View file

@ -1,152 +0,0 @@
package mesh
import (
"bytes"
"crypto/ed25519"
"fmt"
"log/slog"
"net"
"time"
"github.com/vmihailenco/msgpack/v5"
"notashelf.dev/ncro/internal/cache"
)
const (
maxPacketSize = 65536 // UDP max payload
headerSize = ed25519.PublicKeySize + ed25519.SignatureSize // 32 + 64 = 96
)
// Wire format: [32-byte sender pubkey][64-byte ed25519 sig][msgpack body]
func encodePacket(node *Node, msg Message) ([]byte, error) {
body, sig, err := node.Sign(msg)
if err != nil {
return nil, err
}
pkt := make([]byte, headerSize+len(body))
copy(pkt[:ed25519.PublicKeySize], node.PublicKey())
copy(pkt[ed25519.PublicKeySize:headerSize], sig)
copy(pkt[headerSize:], body)
return pkt, nil
}
func decodePacket(pkt []byte) (pubKey ed25519.PublicKey, sig, body []byte, msg Message, err error) {
if len(pkt) < headerSize {
return nil, nil, nil, Message{}, fmt.Errorf("packet too short: %d bytes", len(pkt))
}
pubKey = ed25519.PublicKey(pkt[:ed25519.PublicKeySize])
sig = pkt[ed25519.PublicKeySize:headerSize]
body = pkt[headerSize:]
if err := msgpack.Unmarshal(body, &msg); err != nil {
return nil, nil, nil, Message{}, fmt.Errorf("unmarshal: %w", err)
}
return pubKey, sig, body, msg, nil
}
// Starts a UDP listener at addr. All messages are signature-verified.
// When allowedKeys is non-empty, messages from unlisted senders are dropped.
// Pass no keys (or an empty list) to accept messages from any sender.
func ListenAndServe(addr string, store *RouteStore, allowedKeys ...ed25519.PublicKey) error {
conn, err := net.ListenPacket("udp", addr)
if err != nil {
return err
}
go func() {
defer conn.Close()
buf := make([]byte, maxPacketSize)
for {
n, src, err := conn.ReadFrom(buf)
if err != nil {
return
}
pubKey, sig, body, msg, err := decodePacket(buf[:n])
if err != nil {
slog.Warn("mesh: malformed packet", "src", src, "error", err)
continue
}
if len(allowedKeys) > 0 {
allowed := false
for _, k := range allowedKeys {
if bytes.Equal(k, pubKey) {
allowed = true
break
}
}
if !allowed {
slog.Warn("mesh: rejecting packet from unknown sender", "src", src)
continue
}
}
if err := Verify(pubKey, body, sig); err != nil {
slog.Warn("mesh: signature verification failed", "src", src, "error", err)
continue
}
if msg.Type == MsgAnnounce && len(msg.Routes) > 0 {
store.Merge(msg.Routes)
slog.Debug("mesh: merged peer routes", "node", msg.NodeID, "src", src, "count", len(msg.Routes))
}
}
}()
return nil
}
// Sends an MsgAnnounce carrying routes to a single peer address.
func Announce(peerAddr string, node *Node, routes []cache.RouteEntry) error {
msg := Message{
Type: MsgAnnounce,
NodeID: node.ID(),
Timestamp: time.Now().UnixNano(),
Routes: routes,
}
pkt, err := encodePacket(node, msg)
if err != nil {
return err
}
addr, err := net.ResolveUDPAddr("udp", peerAddr)
if err != nil {
return err
}
conn, err := net.DialUDP("udp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Write(pkt)
return err
}
// RouteSource retrieves routes to gossip.
type RouteSource interface {
ListRecentRoutes(n int) ([]cache.RouteEntry, error)
}
// Announces our top routes to each peer on interval. Blocks until stop is closed.
func RunGossipLoop(node *Node, src RouteSource, peers []string, interval time.Duration, stop <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-ticker.C:
routes, err := src.ListRecentRoutes(100)
if err != nil {
slog.Warn("mesh: failed to list routes for gossip", "error", err)
continue
}
if len(routes) == 0 {
continue
}
for _, peer := range peers {
go func(p string) {
if err := Announce(p, node, routes); err != nil {
slog.Warn("mesh: announce failed", "peer", p, "error", err)
}
}(peer)
}
slog.Debug("mesh: announced routes to peers", "routes", len(routes), "peers", len(peers))
}
}
}

View file

@ -1,117 +0,0 @@
package mesh_test
import (
"net"
"testing"
"time"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/mesh"
)
func freeUDPAddr(t *testing.T) string {
t.Helper()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := conn.LocalAddr().String()
conn.Close()
return addr
}
func TestAnnounceAndReceive(t *testing.T) {
store := mesh.NewRouteStore()
node, err := mesh.NewNode("", store)
if err != nil {
t.Fatal(err)
}
addr := freeUDPAddr(t)
// Allow messages from our own node (its public key is the only allowed key).
if err := mesh.ListenAndServe(addr, store, node.PublicKey()); err != nil {
t.Fatalf("ListenAndServe: %v", err)
}
routes := []cache.RouteEntry{
{
StorePath: "test-pkg-abc",
UpstreamURL: "https://cache.nixos.org",
LatencyEMA: 25,
TTL: time.Now().Add(time.Hour),
},
}
if err := mesh.Announce(addr, node, routes); err != nil {
t.Fatalf("Announce: %v", err)
}
time.Sleep(50 * time.Millisecond)
entry := store.Get("test-pkg-abc")
if entry == nil {
t.Fatal("route not merged into store after announce")
}
if entry.UpstreamURL != "https://cache.nixos.org" {
t.Errorf("UpstreamURL = %q", entry.UpstreamURL)
}
}
func TestRejectUnknownSender(t *testing.T) {
store := mesh.NewRouteStore()
// Listener node, this'll reject messages not from trusted
trusted, err := mesh.NewNode("", nil)
if err != nil {
t.Fatal(err)
}
// Untrusted sender
untrusted, err := mesh.NewNode("", nil)
if err != nil {
t.Fatal(err)
}
addr := freeUDPAddr(t)
// Only allow trusted node's key.
if err := mesh.ListenAndServe(addr, store, trusted.PublicKey()); err != nil {
t.Fatalf("ListenAndServe: %v", err)
}
routes := []cache.RouteEntry{
{StorePath: "untrusted-pkg", UpstreamURL: "https://evil.example.com",
TTL: time.Now().Add(time.Hour)},
}
mesh.Announce(addr, untrusted, routes)
time.Sleep(50 * time.Millisecond)
if entry := store.Get("untrusted-pkg"); entry != nil {
t.Error("route from untrusted sender should have been rejected")
}
}
func TestRejectTamperedMessage(t *testing.T) {
// This is covered by TestVerifyFailsOnTamper the mesh tests on the crypto level.
// Here we verify the full pipeline rejects a re-signed-but-tampered body.
store := mesh.NewRouteStore()
node, err := mesh.NewNode("", store)
if err != nil {
t.Fatal(err)
}
addr := freeUDPAddr(t)
if err := mesh.ListenAndServe(addr, store, node.PublicKey()); err != nil {
t.Fatalf("ListenAndServe: %v", err)
}
// Send a valid message first to confirm it works.
routes := []cache.RouteEntry{
{StorePath: "legit-pkg", UpstreamURL: "https://cache.nixos.org",
TTL: time.Now().Add(time.Hour)},
}
mesh.Announce(addr, node, routes)
time.Sleep(50 * time.Millisecond)
if store.Get("legit-pkg") == nil {
t.Fatal("valid message should have been accepted")
}
}

View file

@ -1,152 +0,0 @@
package mesh
import (
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"os"
"sync"
"time"
"github.com/vmihailenco/msgpack/v5"
"notashelf.dev/ncro/internal/cache"
)
// Gossip message types.
type MsgType uint8
const (
MsgAnnounce MsgType = 1
)
// Wire format for gossip messages.
type Message struct {
Type MsgType
NodeID string
Timestamp int64
Routes []cache.RouteEntry
}
// Cryptographic identity of an ncro node.
type Node struct {
pubKey ed25519.PublicKey
privKey ed25519.PrivateKey
store *RouteStore
}
// Loads or generates an ed25519 keypair from keyPath.
// Pass "" for an ephemeral in-memory key.
func NewNode(keyPath string, store *RouteStore) (*Node, error) {
if store == nil {
store = NewRouteStore()
}
if keyPath == "" {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
return &Node{pubKey: pub, privKey: priv, store: store}, nil
}
data, err := os.ReadFile(keyPath)
if err == nil && len(data) == ed25519.PrivateKeySize {
priv := ed25519.PrivateKey(data)
return &Node{pubKey: priv.Public().(ed25519.PublicKey), privKey: priv, store: store}, nil
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
if err := os.WriteFile(keyPath, priv, 0600); err != nil {
return nil, fmt.Errorf("write key: %w", err)
}
return &Node{pubKey: pub, privKey: priv, store: store}, nil
}
// Returns the hex-encoded public key fingerprint.
func (n *Node) ID() string {
return hex.EncodeToString(n.pubKey[:8])
}
// Returns the node's public key.
func (n *Node) PublicKey() ed25519.PublicKey {
return n.pubKey
}
// Serializes msg with msgpack and signs it; returns (data, signature, error).
func (n *Node) Sign(msg Message) ([]byte, []byte, error) {
data, err := msgpack.Marshal(msg)
if err != nil {
return nil, nil, err
}
return data, ed25519.Sign(n.privKey, data), nil
}
// Checks that sig is a valid ed25519 signature of data under pubKey.
func Verify(pubKey ed25519.PublicKey, data, sig []byte) error {
if !ed25519.Verify(pubKey, data, sig) {
return errors.New("invalid signature")
}
return nil
}
// In-memory route table with merge-conflict resolution for gossip.
type RouteStore struct {
mu sync.RWMutex
routes map[string]*cache.RouteEntry
}
// Creates an empty RouteStore.
func NewRouteStore() *RouteStore {
return &RouteStore{routes: make(map[string]*cache.RouteEntry)}
}
// Applies incoming routes: lower latency wins; newer LastVerified wins on tie.
func (rs *RouteStore) Merge(incoming []cache.RouteEntry) {
rs.mu.Lock()
defer rs.mu.Unlock()
now := time.Now()
for _, r := range incoming {
r := r
if r.TTL.Before(now) {
continue
}
existing, ok := rs.routes[r.StorePath]
if !ok {
rs.routes[r.StorePath] = &r
continue
}
if r.LatencyEMA < existing.LatencyEMA {
rs.routes[r.StorePath] = &r
} else if r.LatencyEMA == existing.LatencyEMA && r.LastVerified.After(existing.LastVerified) {
rs.routes[r.StorePath] = &r
}
}
}
// Returns a copy of the stored route, or nil.
func (rs *RouteStore) Get(storePath string) *cache.RouteEntry {
rs.mu.RLock()
defer rs.mu.RUnlock()
e, ok := rs.routes[storePath]
if !ok {
return nil
}
cp := *e
return &cp
}
// Returns up to n routes for sync batching.
func (rs *RouteStore) Top(n int) []cache.RouteEntry {
rs.mu.RLock()
defer rs.mu.RUnlock()
result := make([]cache.RouteEntry, 0, min(n, len(rs.routes)))
for _, e := range rs.routes {
result = append(result, *e)
if len(result) >= n {
break
}
}
return result
}

View file

@ -1,75 +0,0 @@
package mesh_test
import (
"testing"
"time"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/mesh"
)
func TestSignVerify(t *testing.T) {
node, err := mesh.NewNode("", nil)
if err != nil {
t.Fatal(err)
}
msg := mesh.Message{
Type: mesh.MsgAnnounce,
NodeID: node.ID(),
Timestamp: time.Now().UnixNano(),
Routes: []cache.RouteEntry{{StorePath: "abc123", UpstreamURL: "https://cache.nixos.org"}},
}
data, sig, err := node.Sign(msg)
if err != nil {
t.Fatalf("Sign: %v", err)
}
if err := mesh.Verify(node.PublicKey(), data, sig); err != nil {
t.Errorf("Verify: %v", err)
}
}
func TestVerifyFailsOnTamper(t *testing.T) {
node, _ := mesh.NewNode("", nil)
msg := mesh.Message{Type: mesh.MsgAnnounce, NodeID: node.ID()}
data, sig, _ := node.Sign(msg)
data[0] ^= 0xFF
if err := mesh.Verify(node.PublicKey(), data, sig); err == nil {
t.Error("expected verification failure on tampered data")
}
}
func TestMergeLowerLatencyWins(t *testing.T) {
store := mesh.NewRouteStore()
store.Merge([]cache.RouteEntry{
{StorePath: "pkg-a", UpstreamURL: "https://slow.example.com", LatencyEMA: 200, TTL: time.Now().Add(time.Hour)},
})
store.Merge([]cache.RouteEntry{
{StorePath: "pkg-a", UpstreamURL: "https://fast.example.com", LatencyEMA: 10, TTL: time.Now().Add(time.Hour)},
})
entry := store.Get("pkg-a")
if entry == nil {
t.Fatal("entry is nil")
}
if entry.UpstreamURL != "https://fast.example.com" {
t.Errorf("expected fast upstream, got %q", entry.UpstreamURL)
}
}
func TestMergeNewerTimestampWinsOnTie(t *testing.T) {
store := mesh.NewRouteStore()
now := time.Now()
store.Merge([]cache.RouteEntry{
{StorePath: "pkg-b", UpstreamURL: "https://a.example.com", LatencyEMA: 50, LastVerified: now.Add(-time.Minute), TTL: time.Now().Add(time.Hour)},
})
store.Merge([]cache.RouteEntry{
{StorePath: "pkg-b", UpstreamURL: "https://b.example.com", LatencyEMA: 50, LastVerified: now, TTL: time.Now().Add(time.Hour)},
})
entry := store.Get("pkg-b")
if entry.UpstreamURL != "https://b.example.com" {
t.Errorf("expected newer upstream, got %q", entry.UpstreamURL)
}
}

View file

@ -1,61 +0,0 @@
package metrics
import "github.com/prometheus/client_golang/prometheus"
var (
// Narinfo requests served from the route cache.
NarinfoCacheHits = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ncro_narinfo_cache_hits_total",
Help: "Narinfo requests served from route cache.",
})
// Narinfo requests that required an upstream race.
NarinfoCacheMisses = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ncro_narinfo_cache_misses_total",
Help: "Narinfo requests requiring upstream race.",
})
// Narinfo requests by HTTP status code.
NarinfoRequests = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ncro_narinfo_requests_total",
Help: "Narinfo requests by status.",
}, []string{"status"})
// NAR streaming requests.
NARRequests = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ncro_nar_requests_total",
Help: "NAR streaming requests.",
})
// Times each upstream won the narinfo race.
UpstreamRaceWins = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ncro_upstream_race_wins_total",
Help: "Times each upstream won the narinfo race.",
}, []string{"upstream"})
// Current number of route entries in SQLite.
RouteEntries = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "ncro_route_entries",
Help: "Current number of route entries in SQLite.",
})
// Upstream narinfo race latency in seconds.
UpstreamLatency = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "ncro_upstream_latency_seconds",
Help: "Upstream narinfo race latency.",
Buckets: prometheus.DefBuckets,
}, []string{"upstream"})
)
// Registers all metrics with reg.
func Register(reg prometheus.Registerer) {
reg.MustRegister(
NarinfoCacheHits,
NarinfoCacheMisses,
NarinfoRequests,
NARRequests,
UpstreamRaceWins,
RouteEntries,
UpstreamLatency,
)
}

View file

@ -1,139 +0,0 @@
package narinfo
import (
"bufio"
"crypto/ed25519"
"encoding/base64"
"fmt"
"io"
"strconv"
"strings"
)
// Parsed representation of a Nix narinfo file.
type NarInfo struct {
StorePath string
URL string
Compression string
FileHash string
FileSize uint64
NarHash string
NarSize uint64
References []string
Deriver string
Sig []string
CA string
}
// ParsePublicKey parses a Nix public key in "name:base64(key)" format.
func ParsePublicKey(s string) (name string, key ed25519.PublicKey, err error) {
name, b64, ok := strings.Cut(s, ":")
if !ok || name == "" {
return "", nil, fmt.Errorf("invalid public key %q: missing ':'", s)
}
raw, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", nil, fmt.Errorf("invalid public key %q: %w", s, err)
}
if len(raw) != ed25519.PublicKeySize {
return "", nil, fmt.Errorf("invalid public key size %d, want %d", len(raw), ed25519.PublicKeySize)
}
return name, ed25519.PublicKey(raw), nil
}
// Fingerprint returns the canonical signing input for this narinfo.
// Format: 1;<storePath>;<narHash>;<narSize>;<comma-separated-full-ref-paths>
func (ni *NarInfo) Fingerprint() string {
refs := make([]string, len(ni.References))
for i, r := range ni.References {
if strings.HasPrefix(r, "/nix/store/") {
refs[i] = r
} else {
refs[i] = "/nix/store/" + r
}
}
return fmt.Sprintf("1;%s;%s;%d;%s",
ni.StorePath, ni.NarHash, ni.NarSize, strings.Join(refs, ","))
}
// Verify checks that at least one Sig line is a valid signature for pubKeyStr.
// pubKeyStr must be in "name:base64(key)" Nix format.
// Returns false (not an error) when no matching Sig line is found.
func (ni *NarInfo) Verify(pubKeyStr string) (bool, error) {
keyName, key, err := ParsePublicKey(pubKeyStr)
if err != nil {
return false, err
}
fp := []byte(ni.Fingerprint())
for _, sigLine := range ni.Sig {
name, b64, ok := strings.Cut(sigLine, ":")
if !ok || name != keyName {
continue
}
sig, err := base64.StdEncoding.DecodeString(b64)
if err != nil || len(sig) != ed25519.SignatureSize {
continue
}
if ed25519.Verify(key, fp, sig) {
return true, nil
}
}
return false, nil
}
// Parses a narinfo from r. Returns error on malformed input or missing StorePath.
func Parse(r io.Reader) (*NarInfo, error) {
ni := &NarInfo{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
k, v, ok := strings.Cut(line, ": ")
if !ok {
return nil, fmt.Errorf("malformed line: %q", line)
}
switch k {
case "StorePath":
ni.StorePath = v
case "URL":
ni.URL = v
case "Compression":
ni.Compression = v
case "FileHash":
ni.FileHash = v
case "FileSize":
n, err := strconv.ParseUint(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("FileSize: %w", err)
}
ni.FileSize = n
case "NarHash":
ni.NarHash = v
case "NarSize":
n, err := strconv.ParseUint(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("NarSize: %w", err)
}
ni.NarSize = n
case "References":
if v != "" {
ni.References = strings.Fields(v)
}
case "Deriver":
ni.Deriver = v
case "Sig":
ni.Sig = append(ni.Sig, v)
case "CA":
ni.CA = v
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
if ni.StorePath == "" {
return nil, fmt.Errorf("missing StorePath")
}
return ni, nil
}

View file

@ -1,318 +0,0 @@
package narinfo_test
import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"strings"
"testing"
"notashelf.dev/ncro/internal/narinfo"
)
var realWorldNarinfo = `StorePath: /nix/store/s66mzxpvicwklp6cpph4dc53k5l6bfhe-hello-2.12.1
URL: nar/1wwh37nhg4f5zhb2vsn1a81p3ixn69gkg5k6fvmw3nhcn19fg8xj.nar.xz
Compression: xz
FileHash: sha256:1wwh37nhg4f5zhb2vsn1a81p3ixn69gkg5k6fvmw3nhcn19fg8xj
FileSize: 50088
NarHash: sha256:04rrn5x6lnzrfkcy3bh7gf7x6hq3w1kap4wdss2n6n4s19pgbkr7
NarSize: 226512
References: s66mzxpvicwklp6cpph4dc53k5l6bfhe-hello-2.12.1 4nlgxhzzvsnr6bva0b9afnq8lbr9rk2b-glibc-2.38-23
Sig: cache.nixos.org-1:abc123+base64signature=
`
func TestParseRealWorld(t *testing.T) {
ni, err := narinfo.Parse(strings.NewReader(realWorldNarinfo))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if ni.StorePath != "/nix/store/s66mzxpvicwklp6cpph4dc53k5l6bfhe-hello-2.12.1" {
t.Errorf("StorePath = %q", ni.StorePath)
}
if ni.URL != "nar/1wwh37nhg4f5zhb2vsn1a81p3ixn69gkg5k6fvmw3nhcn19fg8xj.nar.xz" {
t.Errorf("URL = %q", ni.URL)
}
if ni.Compression != "xz" {
t.Errorf("Compression = %q, want xz", ni.Compression)
}
if ni.FileSize != 50088 {
t.Errorf("FileSize = %d, want 50088", ni.FileSize)
}
if ni.NarHash != "sha256:04rrn5x6lnzrfkcy3bh7gf7x6hq3w1kap4wdss2n6n4s19pgbkr7" {
t.Errorf("NarHash = %q", ni.NarHash)
}
if ni.NarSize != 226512 {
t.Errorf("NarSize = %d, want 226512", ni.NarSize)
}
if len(ni.References) != 2 {
t.Errorf("References len = %d, want 2", len(ni.References))
}
if len(ni.Sig) != 1 {
t.Errorf("Sig len = %d, want 1", len(ni.Sig))
}
}
func TestParseNoneCompression(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nURL: nar/abc.nar\nCompression: none\n"
ni, err := narinfo.Parse(strings.NewReader(input))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if ni.Compression != "none" {
t.Errorf("Compression = %q, want none", ni.Compression)
}
}
func TestParseMultipleReferences(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nReferences: pkg-a pkg-b pkg-c pkg-d\n"
ni, err := narinfo.Parse(strings.NewReader(input))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if len(ni.References) != 4 {
t.Errorf("References = %v, want 4 entries", ni.References)
}
}
func TestParseEmptyReferences(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nReferences: \n"
ni, err := narinfo.Parse(strings.NewReader(input))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if len(ni.References) != 0 {
t.Errorf("References = %v, want empty", ni.References)
}
}
func TestParseMultipleSigs(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nSig: key1:aaa=\nSig: key2:bbb=\n"
ni, err := narinfo.Parse(strings.NewReader(input))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if len(ni.Sig) != 2 {
t.Errorf("Sig len = %d, want 2", len(ni.Sig))
}
if ni.Sig[0] != "key1:aaa=" || ni.Sig[1] != "key2:bbb=" {
t.Errorf("Sig = %v", ni.Sig)
}
}
func TestParseMissingStorePath(t *testing.T) {
input := "URL: nar/abc.nar\nNarHash: sha256:abc\n"
_, err := narinfo.Parse(strings.NewReader(input))
if err == nil {
t.Error("expected error for missing StorePath")
}
}
func TestParseMalformedLine(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nbadline\n"
_, err := narinfo.Parse(strings.NewReader(input))
if err == nil {
t.Error("expected error for malformed line")
}
}
func TestParseNarSizeOverflow(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nNarSize: 18446744073709551615\n"
ni, err := narinfo.Parse(strings.NewReader(input))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if ni.NarSize != 18446744073709551615 {
t.Errorf("NarSize = %d", ni.NarSize)
}
}
func TestParseDeriverCA(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nDeriver: abc-drv\nCA: fixed:r:sha256:abc\n"
ni, err := narinfo.Parse(strings.NewReader(input))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if ni.Deriver != "abc-drv" {
t.Errorf("Deriver = %q", ni.Deriver)
}
if ni.CA != "fixed:r:sha256:abc" {
t.Errorf("CA = %q", ni.CA)
}
}
func TestParseIgnoresBlankLines(t *testing.T) {
input := "\n\nStorePath: /nix/store/abc-test\n\nNarHash: sha256:abc\n\n"
ni, err := narinfo.Parse(strings.NewReader(input))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if ni.StorePath == "" {
t.Error("StorePath should be set")
}
}
func TestParseInvalidNarSize(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nNarSize: not-a-number\n"
_, err := narinfo.Parse(strings.NewReader(input))
if err == nil {
t.Error("expected error for invalid NarSize")
}
}
func TestParseInvalidFileSize(t *testing.T) {
input := "StorePath: /nix/store/abc-test\nFileSize: not-a-number\n"
_, err := narinfo.Parse(strings.NewReader(input))
if err == nil {
t.Error("expected error for invalid FileSize")
}
}
// Fingerprint and signature verification
func TestFingerprint(t *testing.T) {
ni := &narinfo.NarInfo{
StorePath: "/nix/store/s66mzxpvicwklp6cpph4dc53k5l6bfhe-hello-2.12.1",
NarHash: "sha256:04rrn5x6lnzrfkcy3bh7gf7x6hq3w1kap4wdss2n6n4s19pgbkr7",
NarSize: 226512,
References: []string{"s66mzxpvicwklp6cpph4dc53k5l6bfhe-hello-2.12.1"},
}
fp := ni.Fingerprint()
want := "1;/nix/store/s66mzxpvicwklp6cpph4dc53k5l6bfhe-hello-2.12.1;" +
"sha256:04rrn5x6lnzrfkcy3bh7gf7x6hq3w1kap4wdss2n6n4s19pgbkr7;226512;" +
"/nix/store/s66mzxpvicwklp6cpph4dc53k5l6bfhe-hello-2.12.1"
if fp != want {
t.Errorf("Fingerprint() =\n%q\nwant\n%q", fp, want)
}
}
func TestFingerprintNoRefs(t *testing.T) {
ni := &narinfo.NarInfo{
StorePath: "/nix/store/abc-test",
NarHash: "sha256:abc",
NarSize: 1234,
}
fp := ni.Fingerprint()
if !strings.HasSuffix(fp, ";") {
t.Errorf("Fingerprint with no refs should end with ';', got: %q", fp)
}
}
func TestFingerprintRefsAlreadyPrefixed(t *testing.T) {
ni := &narinfo.NarInfo{
StorePath: "/nix/store/abc-test",
NarHash: "sha256:abc",
NarSize: 1234,
References: []string{"/nix/store/dep-pkg"}, // already prefixed
}
fp := ni.Fingerprint()
if strings.Contains(fp, "/nix/store//nix/store/") {
t.Errorf("Fingerprint double-prefixed refs: %q", fp)
}
}
func TestParsePublicKeyValid(t *testing.T) {
name, key, err := narinfo.ParsePublicKey("cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=")
if err != nil {
t.Fatalf("ParsePublicKey: %v", err)
}
if name != "cache.nixos.org-1" {
t.Errorf("name = %q", name)
}
if len(key) != ed25519.PublicKeySize {
t.Errorf("key len = %d, want %d", len(key), ed25519.PublicKeySize)
}
}
func TestParsePublicKeyMissingColon(t *testing.T) {
_, _, err := narinfo.ParsePublicKey("no-colon-here")
if err == nil {
t.Error("expected error for missing ':'")
}
}
func TestParsePublicKeyBadBase64(t *testing.T) {
_, _, err := narinfo.ParsePublicKey("name:!!!not-base64!!!")
if err == nil {
t.Error("expected error for invalid base64")
}
}
func TestParsePublicKeyWrongSize(t *testing.T) {
// 16 bytes encoded in base64 = 24 chars with padding
b16 := base64.StdEncoding.EncodeToString(make([]byte, 16))
_, _, err := narinfo.ParsePublicKey("name:" + b16)
if err == nil {
t.Error("expected error for wrong key size (16 bytes, not 32)")
}
}
// Generates a fresh ed25519 key, signs a narinfo fingerprint,
// embeds the signature, and verifies it. This covers the full sign/verify path.
func TestVerifyRoundtrip(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
ni := &narinfo.NarInfo{
StorePath: "/nix/store/abc123-test-pkg",
NarHash: "sha256:abcdef123456",
NarSize: 98765,
References: []string{"abc123-test-pkg"},
}
fp := ni.Fingerprint()
sig := ed25519.Sign(priv, []byte(fp))
pubKeyStr := "test-key-1:" + base64.StdEncoding.EncodeToString(pub)
ni.Sig = []string{"test-key-1:" + base64.StdEncoding.EncodeToString(sig)}
ok, err := ni.Verify(pubKeyStr)
if err != nil {
t.Fatalf("Verify error: %v", err)
}
if !ok {
t.Error("Verify returned false for valid signature")
}
}
func TestVerifyWrongKey(t *testing.T) {
_, priv, _ := ed25519.GenerateKey(rand.Reader)
wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) // different key
ni := &narinfo.NarInfo{
StorePath: "/nix/store/abc123-test-pkg",
NarHash: "sha256:abcdef",
NarSize: 1234,
}
fp := ni.Fingerprint()
sig := ed25519.Sign(priv, []byte(fp))
// Register wrong public key but correct key name
wrongKeyStr := "test-key-1:" + base64.StdEncoding.EncodeToString(wrongPub)
ni.Sig = []string{"test-key-1:" + base64.StdEncoding.EncodeToString(sig)}
ok, err := ni.Verify(wrongKeyStr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ok {
t.Error("Verify should return false for mismatched key")
}
}
func TestVerifyNoMatchingKeyName(t *testing.T) {
pub, _, _ := ed25519.GenerateKey(rand.Reader)
ni := &narinfo.NarInfo{
StorePath: "/nix/store/abc123-test-pkg",
NarHash: "sha256:abcdef",
NarSize: 1234,
}
ni.Sig = []string{"other-key-1:invalidsig="}
pubKeyStr := "my-key-1:" + base64.StdEncoding.EncodeToString(pub)
ok, err := ni.Verify(pubKeyStr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ok {
t.Error("Verify should return false when no Sig line matches key name")
}
}

View file

@ -1,267 +0,0 @@
package prober
import (
"math"
"net/http"
"sort"
"sync"
"time"
"notashelf.dev/ncro/internal/config"
)
// Upstream health status.
type Status int
const (
StatusActive Status = iota
StatusDegraded // 3+ consecutive failures
StatusDown // 10+ consecutive failures
)
func (s Status) String() string {
switch s {
case StatusActive:
return "ACTIVE"
case StatusDegraded:
return "DEGRADED"
default:
return "DOWN"
}
}
// In-memory metrics for one upstream.
type UpstreamHealth struct {
URL string
Priority int
EMALatency float64
LastProbe time.Time
ConsecutiveFails uint32
TotalQueries uint64
Status Status
}
// Tracks latency and health for a set of upstreams.
type Prober struct {
mu sync.RWMutex
alpha float64
table map[string]*UpstreamHealth
client *http.Client
persistHealth func(url string, ema float64, consecutiveFails uint32, totalQueries uint64)
}
// Creates a Prober with the given EMA alpha coefficient.
func New(alpha float64) *Prober {
return &Prober{
alpha: alpha,
table: make(map[string]*UpstreamHealth),
client: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// Seeds the prober with upstream configs (records priority, no measurements yet).
func (p *Prober) InitUpstreams(upstreams []config.UpstreamConfig) {
p.mu.Lock()
defer p.mu.Unlock()
for _, u := range upstreams {
if _, ok := p.table[u.URL]; !ok {
p.table[u.URL] = &UpstreamHealth{URL: u.URL, Priority: u.Priority, Status: StatusActive}
}
}
}
// Derives Status from the number of consecutive failures, matching the logic
// in RecordFailure.
func computeStatus(consecutiveFails uint32) Status {
switch {
case consecutiveFails >= 10:
return StatusDown
case consecutiveFails >= 3:
return StatusDegraded
default:
return StatusActive
}
}
// Seeds an upstream's health state from persisted data. Should be called
// after InitUpstreams to restore state from the previous run.
func (p *Prober) Seed(url string, emaLatency float64, consecutiveFails int, totalQueries int64) {
p.mu.Lock()
defer p.mu.Unlock()
h, ok := p.table[url]
if !ok {
return
}
h.EMALatency = emaLatency
h.TotalQueries = uint64(totalQueries)
h.ConsecutiveFails = uint32(consecutiveFails)
h.Status = computeStatus(uint32(consecutiveFails))
}
// Registers a callback invoked after each RecordLatency or RecordFailure call.
// The callback runs in a separate goroutine and must be safe for concurrent use.
func (p *Prober) SetHealthPersistence(fn func(url string, ema float64, consecutiveFails uint32, totalQueries uint64)) {
p.mu.Lock()
defer p.mu.Unlock()
p.persistHealth = fn
}
// Records a successful latency measurement and updates the EMA.
func (p *Prober) RecordLatency(url string, ms float64) {
p.mu.Lock()
defer p.mu.Unlock()
h, ok := p.table[url]
if !ok {
return
}
if h.TotalQueries == 0 {
h.EMALatency = ms
} else {
h.EMALatency = p.alpha*ms + (1-p.alpha)*h.EMALatency
}
h.ConsecutiveFails = 0
h.TotalQueries++
h.Status = StatusActive
h.LastProbe = time.Now()
if p.persistHealth != nil {
u, ema, cf, tq := h.URL, h.EMALatency, h.ConsecutiveFails, h.TotalQueries
fn := p.persistHealth
go fn(u, ema, cf, tq)
}
}
// Records a probe failure.
func (p *Prober) RecordFailure(url string) {
p.mu.Lock()
defer p.mu.Unlock()
h, ok := p.table[url]
if !ok {
return
}
h.ConsecutiveFails++
h.Status = computeStatus(h.ConsecutiveFails)
if p.persistHealth != nil {
u, ema, cf, tq := h.URL, h.EMALatency, h.ConsecutiveFails, h.TotalQueries
fn := p.persistHealth
go fn(u, ema, cf, tq)
}
}
// Returns a copy of the health entry for url, or nil if unknown.
func (p *Prober) GetHealth(url string) *UpstreamHealth {
p.mu.RLock()
defer p.mu.RUnlock()
h, ok := p.table[url]
if !ok {
return nil
}
cp := *h
return &cp
}
// Returns all known upstreams sorted by EMA latency ascending.
// DOWN upstreams are sorted last. Within 10% EMA difference, lower Priority wins.
func (p *Prober) SortedByLatency() []*UpstreamHealth {
p.mu.RLock()
defer p.mu.RUnlock()
result := make([]*UpstreamHealth, 0, len(p.table))
for _, h := range p.table {
cp := *h
result = append(result, &cp)
}
sort.Slice(result, func(i, j int) bool {
a, b := result[i], result[j]
aDown := a.Status == StatusDown
bDown := b.Status == StatusDown
if aDown != bDown {
return bDown // non-down first
}
// Within 10% latency difference: prefer lower priority number, then lower latency.
if b.EMALatency > 0 && math.Abs(a.EMALatency-b.EMALatency)/b.EMALatency < 0.10 {
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
}
return a.EMALatency < b.EMALatency
})
return result
}
// Performs a HEAD /nix-cache-info against url and updates health.
func (p *Prober) ProbeUpstream(url string) {
// Skip if URL is not in table. This prevents in-flight probes from
// resurrecting removed upstreams (race: RemoveUpstream called while
// ProbeUpstream is in flight).
p.mu.RLock()
_, exists := p.table[url]
p.mu.RUnlock()
if !exists {
// URL was removed or never added; do not resurrect.
return
}
start := time.Now()
resp, err := p.client.Head(url + "/nix-cache-info")
elapsed := float64(time.Since(start).Nanoseconds()) / 1e6
if err != nil || resp.StatusCode != 200 {
p.RecordFailure(url)
return
}
resp.Body.Close()
p.RecordLatency(url, elapsed)
}
// Probes all known upstreams on interval until stop is closed.
func (p *Prober) RunProbeLoop(interval time.Duration, stop <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-ticker.C:
p.mu.RLock()
urls := make([]string, 0, len(p.table))
for u := range p.table {
urls = append(urls, u)
}
p.mu.RUnlock()
for _, u := range urls {
go p.ProbeUpstream(u)
}
}
}
}
func (p *Prober) getOrCreate(url string) *UpstreamHealth {
h, ok := p.table[url]
if !ok {
h = &UpstreamHealth{URL: url, Status: StatusActive}
p.table[url] = h
}
return h
}
// Adds a new upstream dynamically (e.g., discovered via mDNS).
// Thread-safe. Triggers an immediate probe in the background.
func (p *Prober) AddUpstream(url string, priority int) {
p.mu.Lock()
defer p.mu.Unlock()
if _, exists := p.table[url]; exists {
return
}
p.table[url] = &UpstreamHealth{URL: url, Priority: priority, Status: StatusActive}
// Trigger an immediate probe in background
go p.ProbeUpstream(url)
}
// Removes an upstream from tracking (e.g., when a peer leaves the network).
// Thread-safe. No-op if upstream was not known.
func (p *Prober) RemoveUpstream(url string) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.table, url)
}

View file

@ -1,188 +0,0 @@
package prober_test
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"notashelf.dev/ncro/internal/config"
"notashelf.dev/ncro/internal/prober"
)
func TestEMACalculation(t *testing.T) {
p := prober.New(0.3)
p.AddUpstream("https://example.com", 1)
p.RecordLatency("https://example.com", 100)
p.RecordLatency("https://example.com", 50)
// EMA after 2 measurements: first=100, second = 0.3*50 + 0.7*100 = 85
health := p.GetHealth("https://example.com")
if health == nil {
t.Fatal("expected health entry")
}
if health.EMALatency < 84 || health.EMALatency > 86 {
t.Errorf("EMA = %.2f, want ~85", health.EMALatency)
}
}
func TestStatusProgression(t *testing.T) {
p := prober.New(0.3)
p.AddUpstream("https://example.com", 1)
p.RecordLatency("https://example.com", 10)
for range 3 {
p.RecordFailure("https://example.com")
}
h := p.GetHealth("https://example.com")
if h.Status != prober.StatusDegraded {
t.Errorf("status = %v, want Degraded after 3 failures", h.Status)
}
for range 7 {
p.RecordFailure("https://example.com")
}
h = p.GetHealth("https://example.com")
if h.Status != prober.StatusDown {
t.Errorf("status = %v, want Down after 10 failures", h.Status)
}
}
func TestRecoveryAfterSuccess(t *testing.T) {
p := prober.New(0.3)
p.AddUpstream("https://example.com", 1)
for range 10 {
p.RecordFailure("https://example.com")
}
p.RecordLatency("https://example.com", 20)
h := p.GetHealth("https://example.com")
if h.Status != prober.StatusActive {
t.Errorf("status = %v, want Active after recovery", h.Status)
}
if h.ConsecutiveFails != 0 {
t.Errorf("ConsecutiveFails = %d, want 0", h.ConsecutiveFails)
}
}
func TestSortedByLatency(t *testing.T) {
p := prober.New(0.3)
p.AddUpstream("https://slow.example.com", 1)
p.AddUpstream("https://fast.example.com", 1)
p.AddUpstream("https://medium.example.com", 1)
p.RecordLatency("https://slow.example.com", 200)
p.RecordLatency("https://fast.example.com", 10)
p.RecordLatency("https://medium.example.com", 50)
sorted := p.SortedByLatency()
if len(sorted) != 3 {
t.Fatalf("expected 3, got %d", len(sorted))
}
if sorted[0].URL != "https://fast.example.com" {
t.Errorf("first = %q, want fast", sorted[0].URL)
}
}
func TestProbeUpstream(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srv.Close()
p := prober.New(0.3)
p.AddUpstream(srv.URL, 0)
p.ProbeUpstream(srv.URL)
h := p.GetHealth(srv.URL)
if h == nil || h.Status != prober.StatusActive {
t.Errorf("expected Active after successful probe, got %v", h)
}
}
func TestSortedByLatencyWithPriority(t *testing.T) {
p := prober.New(0.3)
// Two upstreams with very similar latency; lower priority number should win.
p.AddUpstream("https://low-priority.example.com", 1)
p.AddUpstream("https://high-priority.example.com", 1)
p.RecordLatency("https://low-priority.example.com", 100)
p.RecordLatency("https://high-priority.example.com", 102) // within 10%
// Set priorities by calling InitUpstreams via RecordLatency (already seeded).
// We can't call InitUpstreams without config here, so test via SortedByLatency
// behavior: without priority, the 100ms one wins. With equal EMA and priority
// both zero (default), the lower-latency one still wins.
sorted := p.SortedByLatency()
if len(sorted) != 2 {
t.Fatalf("expected 2, got %d", len(sorted))
}
// The 100ms upstream should be first (lower latency wins when not within 10% tie).
// 100 vs 102: diff=2, 2/102=1.96% < 10%, so priority decides (both priority=0, tie --> latency).
// Actually 100 < 102 still wins on latency when priority is equal.
if sorted[0].EMALatency > sorted[1].EMALatency {
t.Errorf("expected lower latency first, got %.2f then %.2f", sorted[0].EMALatency, sorted[1].EMALatency)
}
}
func TestProbeUpstreamFailure(t *testing.T) {
p := prober.New(0.3)
p.AddUpstream("http://127.0.0.1:1", 0)
p.ProbeUpstream("http://127.0.0.1:1") // nothing listening, maybe except for Makima
h := p.GetHealth("http://127.0.0.1:1")
if h == nil || h.ConsecutiveFails == 0 {
t.Error("expected failure recorded")
}
}
func TestSeedRestoresStatus(t *testing.T) {
p := prober.New(0.3)
p.InitUpstreams([]config.UpstreamConfig{{URL: "https://down.example.com"}})
// Seed with 10 consecutive fails -> should be StatusDown
p.Seed("https://down.example.com", 200.0, 10, 50)
h := p.GetHealth("https://down.example.com")
if h == nil {
t.Fatal("expected health entry")
}
if h.Status != prober.StatusDown {
t.Errorf("Status = %v, want StatusDown", h.Status)
}
if h.EMALatency != 200.0 {
t.Errorf("EMALatency = %f, want 200.0", h.EMALatency)
}
}
func TestPersistenceCallbackFired(t *testing.T) {
p := prober.New(0.3)
p.InitUpstreams([]config.UpstreamConfig{{URL: "https://up.example.com"}})
var (
mu sync.Mutex
savedURL string
savedCF uint32
wg sync.WaitGroup
)
wg.Add(1)
p.SetHealthPersistence(func(url string, ema float64, consecutiveFails uint32, totalQueries uint64) {
mu.Lock()
savedURL = url
savedCF = consecutiveFails
mu.Unlock()
wg.Done()
})
p.RecordLatency("https://up.example.com", 50.0)
wg.Wait()
mu.Lock()
gotURL := savedURL
gotCF := savedCF
mu.Unlock()
if gotURL != "https://up.example.com" {
t.Errorf("savedURL = %q, want https://up.example.com", gotURL)
}
if gotCF != 0 {
t.Errorf("consecutiveFails = %d, want 0", gotCF)
}
}

View file

@ -1,258 +0,0 @@
package router
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"sync"
"time"
"golang.org/x/sync/singleflight"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/metrics"
"notashelf.dev/ncro/internal/narinfo"
"notashelf.dev/ncro/internal/prober"
)
// Returned when all upstreams were reached but none had the path.
var ErrNotFound = errors.New("not found in any upstream")
// Returned when all upstreams failed with network errors.
var ErrUpstreamUnavailable = errors.New("all upstreams unavailable")
// Result of a Resolve call.
type Result struct {
URL string
LatencyMs float64
CacheHit bool
NarInfoBytes []byte // raw narinfo response on cache miss; nil on cache hit
}
// Resolves store paths to the best upstream via cache lookup or parallel racing.
type Router struct {
db *cache.DB
prober *prober.Prober
routeTTL time.Duration
raceTimeout time.Duration
negativeTTL time.Duration
client *http.Client
mu sync.RWMutex
upstreamKeys map[string]string // upstream URL -> Nix public key string
group singleflight.Group
}
// Creates a Router.
func New(db *cache.DB, p *prober.Prober, routeTTL, raceTimeout, negativeTTL time.Duration) *Router {
return &Router{
db: db,
prober: p,
routeTTL: routeTTL,
raceTimeout: raceTimeout,
negativeTTL: negativeTTL,
client: &http.Client{Timeout: raceTimeout},
upstreamKeys: make(map[string]string),
}
}
// Registers a Nix public key for narinfo signature verification on a given upstream.
// pubKeyStr must be in "name:base64(key)" format (e.g. "cache.nixos.org-1:...").
func (r *Router) SetUpstreamKey(url, pubKeyStr string) error {
if _, _, err := narinfo.ParsePublicKey(pubKeyStr); err != nil {
return err
}
r.mu.Lock()
r.upstreamKeys[url] = pubKeyStr
r.mu.Unlock()
return nil
}
// Returns the best upstream for the given store hash.
// Checks the route cache first; on miss races the provided candidates.
func (r *Router) Resolve(storeHash string, candidates []string) (*Result, error) {
// Fast path: negative cache.
if neg, err := r.db.IsNegative(storeHash); err == nil && neg {
return nil, ErrNotFound
}
// Fast path: route cache hit.
entry, err := r.db.GetRoute(storeHash)
if err == nil && entry != nil && entry.IsValid() {
h := r.prober.GetHealth(entry.UpstreamURL)
if h == nil || h.Status == prober.StatusActive {
metrics.NarinfoCacheHits.Inc()
return &Result{
URL: entry.UpstreamURL,
LatencyMs: entry.LatencyEMA,
CacheHit: true,
}, nil
}
}
metrics.NarinfoCacheMisses.Inc()
v, raceErr, _ := r.group.Do(storeHash, func() (interface{}, error) {
result, err := r.race(storeHash, candidates)
if errors.Is(err, ErrNotFound) {
_ = r.db.SetNegative(storeHash, r.negativeTTL)
}
if err != nil {
return nil, err
}
return result, nil
})
if raceErr != nil {
return nil, raceErr
}
return v.(*Result), nil
}
type raceResult struct {
url string
latencyMs float64
}
func (r *Router) race(storeHash string, candidates []string) (*Result, error) {
if len(candidates) == 0 {
return nil, fmt.Errorf("no candidates for %q", storeHash)
}
ctx, cancel := context.WithTimeout(context.Background(), r.raceTimeout)
defer cancel()
ch := make(chan raceResult, len(candidates))
var (
wg sync.WaitGroup
mu sync.Mutex
netErrs int
notFounds int
)
for _, u := range candidates {
wg.Add(1)
go func(upstream string) {
defer wg.Done()
start := time.Now()
req, err := http.NewRequestWithContext(ctx, http.MethodHead,
upstream+"/"+storeHash+".narinfo", nil)
if err != nil {
slog.Warn("bad upstream URL in race", "upstream", upstream, "error", err)
mu.Lock()
netErrs++
mu.Unlock()
return
}
resp, err := r.client.Do(req)
if err != nil {
mu.Lock()
netErrs++
mu.Unlock()
return
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
mu.Lock()
notFounds++
mu.Unlock()
return
}
ms := float64(time.Since(start).Nanoseconds()) / 1e6
select {
case ch <- raceResult{url: upstream, latencyMs: ms}:
default:
}
}(u)
}
go func() {
wg.Wait()
close(ch)
}()
winner, ok := <-ch
if !ok {
mu.Lock()
ne, nf := netErrs, notFounds
mu.Unlock()
if ne > 0 && nf == 0 {
return nil, ErrUpstreamUnavailable
}
return nil, ErrNotFound
}
cancel()
for res := range ch {
if res.latencyMs < winner.latencyMs {
winner = res
}
}
metrics.UpstreamRaceWins.WithLabelValues(winner.url).Inc()
metrics.UpstreamLatency.WithLabelValues(winner.url).Observe(winner.latencyMs / 1000)
// Fetch narinfo body to parse metadata and forward to caller.
narInfoBytes, narURL, narHash, narSize := r.fetchNarInfo(winner.url, storeHash)
health := r.prober.GetHealth(winner.url)
ema := winner.latencyMs
if health != nil {
ema = 0.3*winner.latencyMs + 0.7*health.EMALatency
}
r.prober.RecordLatency(winner.url, winner.latencyMs)
now := time.Now()
_ = r.db.SetRoute(&cache.RouteEntry{
StorePath: storeHash,
UpstreamURL: winner.url,
LatencyMs: winner.latencyMs,
LatencyEMA: ema,
LastVerified: now,
QueryCount: 1,
TTL: now.Add(r.routeTTL),
NarHash: narHash,
NarSize: narSize,
NarURL: narURL,
})
return &Result{URL: winner.url, LatencyMs: winner.latencyMs, NarInfoBytes: narInfoBytes}, nil
}
// Returns (body, narURL, narHash, narSize). narURL is the narinfo's URL field
// (e.g. "nar/1wwh37...nar.xz"), used for direct NAR routing.
// Returns (nil, "", "", 0) on fetch failure or signature verification failure.
func (r *Router) fetchNarInfo(upstream, storeHash string) ([]byte, string, string, uint64) {
url := upstream + "/" + storeHash + ".narinfo"
resp, err := r.client.Get(url)
if err != nil {
return nil, "", "", 0
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, "", "", 0
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", "", 0
}
ni, err := narinfo.Parse(bytes.NewReader(body))
if err != nil {
return body, "", "", 0
}
r.mu.RLock()
pubKeyStr := r.upstreamKeys[upstream]
r.mu.RUnlock()
if pubKeyStr != "" {
ok, err := ni.Verify(pubKeyStr)
if err != nil {
slog.Warn("narinfo: public key parse error", "upstream", upstream, "error", err)
return nil, "", "", 0
}
if !ok {
slog.Warn("narinfo: signature verification failed", "upstream", upstream, "store", storeHash)
return nil, "", "", 0
}
}
return body, ni.URL, ni.NarHash, ni.NarSize
}

View file

@ -1,251 +0,0 @@
package router_test
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/config"
"notashelf.dev/ncro/internal/prober"
"notashelf.dev/ncro/internal/router"
)
func newTestRouter(t *testing.T, upstreams ...string) (*router.Router, func()) {
t.Helper()
f, _ := os.CreateTemp("", "ncro-router-*.db")
f.Close()
db, err := cache.Open(f.Name(), 1000)
if err != nil {
t.Fatal(err)
}
p := prober.New(0.3)
for _, u := range upstreams {
p.RecordLatency(u, 10)
}
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
return r, func() {
db.Close()
os.Remove(f.Name())
}
}
func TestRouteHit(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "StorePath: /nix/store/abc123-hello")
}))
defer srv.Close()
r, cleanup := newTestRouter(t, srv.URL)
defer cleanup()
result, err := r.Resolve("abc123", []string{srv.URL})
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if result.URL != srv.URL {
t.Errorf("url = %q, want %q", result.URL, srv.URL)
}
if result.LatencyMs <= 0 {
t.Error("expected positive latency")
}
}
func TestRouteRacePicksFastest(t *testing.T) {
fast := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer fast.Close()
slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(200)
}))
defer slow.Close()
r, cleanup := newTestRouter(t, fast.URL, slow.URL)
defer cleanup()
result, err := r.Resolve("somehash", []string{slow.URL, fast.URL})
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if result.URL != fast.URL {
t.Errorf("expected fast server to win, got %q", result.URL)
}
}
func TestRouteAllFail(t *testing.T) {
r, cleanup := newTestRouter(t)
defer cleanup()
_, err := r.Resolve("somehash", []string{"http://127.0.0.1:1"})
if err == nil {
t.Error("expected error when all upstreams fail")
}
}
func TestRouteAllNotFound(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
}))
defer srv.Close()
r, cleanup := newTestRouter(t, srv.URL)
defer cleanup()
_, err := r.Resolve("somehash", []string{srv.URL})
if !errors.Is(err, router.ErrNotFound) {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestRouteAllUnavailable(t *testing.T) {
r, cleanup := newTestRouter(t)
defer cleanup()
_, err := r.Resolve("somehash", []string{"http://127.0.0.1:1"})
if !errors.Is(err, router.ErrUpstreamUnavailable) {
t.Errorf("expected ErrUpstreamUnavailable, got %v", err)
}
}
func TestRaceWithMalformedURL(t *testing.T) {
r, cleanup := newTestRouter(t)
defer cleanup()
_, err := r.Resolve("somehash", []string{"://bad-url"})
if err == nil {
t.Error("expected error for malformed upstream URL")
}
}
func TestCacheHit(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srv.Close()
r, cleanup := newTestRouter(t, srv.URL)
defer cleanup()
r.Resolve("abc123", []string{srv.URL})
result, err := r.Resolve("abc123", []string{srv.URL})
if err != nil {
t.Fatalf("second Resolve: %v", err)
}
if !result.CacheHit {
t.Error("expected cache hit on second resolve")
}
}
func TestResolveWithDownUpstream(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srv.Close()
f, _ := os.CreateTemp("", "ncro-router-*.db")
f.Close()
db, _ := cache.Open(f.Name(), 1000)
defer db.Close()
defer os.Remove(f.Name())
p := prober.New(0.3)
p.RecordLatency(srv.URL, 10)
// Force the upstream to StatusDown
for range 10 {
p.RecordFailure(srv.URL)
}
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
// Router should still attempt the race (the race uses HEAD, not the prober status)
// The upstream is actually healthy (httptest), so the race should succeed.
result, err := r.Resolve("somehash", []string{srv.URL})
if err != nil {
t.Fatalf("Resolve with down-flagged upstream: %v", err)
}
if result.URL != srv.URL {
t.Errorf("url = %q", result.URL)
}
}
func TestNegativeCaching(t *testing.T) {
var raceCount int32
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&raceCount, 1)
w.WriteHeader(http.StatusNotFound)
}))
defer ts.Close()
db, err := cache.Open(":memory:", 1000)
if err != nil {
t.Fatal(err)
}
defer db.Close()
p := prober.New(0.3)
p.InitUpstreams([]config.UpstreamConfig{{URL: ts.URL}})
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
_, err = r.Resolve("not-on-any-upstream", []string{ts.URL})
if !errors.Is(err, router.ErrNotFound) {
t.Fatalf("first resolve: expected ErrNotFound, got %v", err)
}
count1 := atomic.LoadInt32(&raceCount)
_, err = r.Resolve("not-on-any-upstream", []string{ts.URL})
if !errors.Is(err, router.ErrNotFound) {
t.Fatalf("second resolve: expected ErrNotFound, got %v", err)
}
count2 := atomic.LoadInt32(&raceCount)
if count2 != count1 {
t.Errorf("second resolve hit upstream %d extra times, want 0 (should be negatively cached)", count2-count1)
}
}
func TestSingleflightDedup(t *testing.T) {
var headCount int32
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
atomic.AddInt32(&headCount, 1)
time.Sleep(30 * time.Millisecond) // ensure goroutines overlap
w.WriteHeader(http.StatusOK)
} else {
w.Header().Set("Content-Type", "text/x-nix-narinfo")
fmt.Fprintln(w, "StorePath: /nix/store/abc123-test")
}
}))
defer ts.Close()
db, err := cache.Open(":memory:", 1000)
if err != nil {
t.Fatal(err)
}
defer db.Close()
p := prober.New(0.3)
p.InitUpstreams([]config.UpstreamConfig{{URL: ts.URL}})
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
const N = 10
var wg sync.WaitGroup
for range N {
wg.Add(1)
go func() {
defer wg.Done()
r.Resolve("abc123dedup", []string{ts.URL})
}()
}
wg.Wait()
if hc := atomic.LoadInt32(&headCount); hc > 1 {
t.Errorf("upstream HEAD hit %d times for %d concurrent callers; want 1", hc, N)
}
}

View file

@ -1,100 +0,0 @@
package server_test
import (
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/config"
"notashelf.dev/ncro/internal/prober"
"notashelf.dev/ncro/internal/router"
"notashelf.dev/ncro/internal/server"
)
// Verifies that the second identical narinfo request uses the cached route.
func TestRouteReuseOnSecondRequest(t *testing.T) {
requestCount := 0
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, ".narinfo") {
requestCount++
w.Header().Set("Content-Type", "text/x-nix-narinfo")
io.WriteString(w, "StorePath: /nix/store/test-pkg\nURL: nar/test.nar\n")
return
}
w.WriteHeader(404)
}))
defer upstream.Close()
f, _ := os.CreateTemp("", "ncro-int-*.db")
f.Close()
defer os.Remove(f.Name())
db, _ := cache.Open(f.Name(), 1000)
defer db.Close()
p := prober.New(0.3)
p.AddUpstream(upstream.URL, 0)
p.RecordLatency(upstream.URL, 10)
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
ts := httptest.NewServer(server.New(r, p, db, []config.UpstreamConfig{{URL: upstream.URL}}, 30))
defer ts.Close()
resp1, _ := http.Get(ts.URL + "/deadbeef00000000.narinfo")
io.Copy(io.Discard, resp1.Body)
resp1.Body.Close()
resp2, _ := http.Get(ts.URL + "/deadbeef00000000.narinfo")
io.Copy(io.Discard, resp2.Body)
resp2.Body.Close()
if resp1.StatusCode != 200 || resp2.StatusCode != 200 {
t.Errorf("expected 200/200, got %d/%d", resp1.StatusCode, resp2.StatusCode)
}
}
// Verifies that when the best-seeded upstream returns 404, the fallback upstream is used.
func TestUpstreamFailoverFallback(t *testing.T) {
good := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/x-nix-narinfo")
io.WriteString(w, "StorePath: /nix/store/fallback-pkg\n")
}))
defer good.Close()
bad := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
}))
defer bad.Close()
f, _ := os.CreateTemp("", "ncro-fb-*.db")
f.Close()
defer os.Remove(f.Name())
db, _ := cache.Open(f.Name(), 1000)
defer db.Close()
p := prober.New(0.3)
p.AddUpstream(bad.URL, 0)
p.AddUpstream(good.URL, 0)
p.RecordLatency(bad.URL, 1) // bad appears fastest
p.RecordLatency(good.URL, 50)
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
ts := httptest.NewServer(server.New(r, p, db, []config.UpstreamConfig{
{URL: bad.URL},
{URL: good.URL},
}, 30))
defer ts.Close()
resp, err := http.Get(ts.URL + "/cafebabe00000000.narinfo")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200 via fallback, got %d", resp.StatusCode)
}
}

View file

@ -1,252 +0,0 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/config"
"notashelf.dev/ncro/internal/metrics"
"notashelf.dev/ncro/internal/prober"
"notashelf.dev/ncro/internal/router"
)
// HTTP handler implementing the Nix binary cache protocol.
type Server struct {
router *router.Router
prober *prober.Prober
db *cache.DB
upstreams []config.UpstreamConfig
client *http.Client
cachePriority int
metricsHandler http.Handler
}
// Creates a Server.
func New(r *router.Router, p *prober.Prober, db *cache.DB, upstreams []config.UpstreamConfig, cachePriority int) *Server {
return &Server{
router: r,
prober: p,
db: db,
upstreams: upstreams,
client: &http.Client{Timeout: 60 * time.Second},
cachePriority: cachePriority,
metricsHandler: promhttp.Handler(),
}
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
switch {
case path == "/nix-cache-info":
s.handleCacheInfo(w, r)
case path == "/health":
s.handleHealth(w, r)
case path == "/metrics":
s.metricsHandler.ServeHTTP(w, r)
case strings.HasSuffix(path, ".narinfo"):
s.handleNarinfo(w, r)
case strings.HasPrefix(path, "/nar/"):
s.handleNAR(w, r)
default:
http.NotFound(w, r)
}
}
func (s *Server) handleCacheInfo(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, "StoreDir: /nix/store")
fmt.Fprintln(w, "WantMassQuery: 1")
fmt.Fprintf(w, "Priority: %d\n", s.cachePriority)
}
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
type upstreamStatus struct {
URL string `json:"url"`
Status string `json:"status"`
LatencyMs float64 `json:"latency_ms"`
ConsecutiveFails uint32 `json:"consecutive_fails"`
}
type response struct {
Status string `json:"status"`
Upstreams []upstreamStatus `json:"upstreams"`
}
sorted := s.prober.SortedByLatency()
upstreams := make([]upstreamStatus, len(sorted))
var downCount int
var anyDegraded bool
for i, h := range sorted {
upstreams[i] = upstreamStatus{
URL: h.URL,
Status: h.Status.String(),
LatencyMs: h.EMALatency,
ConsecutiveFails: h.ConsecutiveFails,
}
if h.Status == prober.StatusDown {
downCount++
} else if h.Status == prober.StatusDegraded {
anyDegraded = true
}
}
overall := "ok"
switch {
case len(sorted) > 0 && downCount == len(sorted):
overall = "down"
case downCount > 0 || anyDegraded:
overall = "degraded"
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response{Status: overall, Upstreams: upstreams})
}
func (s *Server) handleNarinfo(w http.ResponseWriter, r *http.Request) {
hash := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/"), ".narinfo")
result, err := s.router.Resolve(hash, s.upstreamURLs())
if err != nil {
slog.Warn("narinfo resolve failed", "hash", hash, "error", err)
metrics.NarinfoRequests.WithLabelValues("error").Inc()
switch {
case errors.Is(err, router.ErrNotFound):
http.NotFound(w, r)
default:
http.Error(w, "upstream unavailable", http.StatusBadGateway)
}
return
}
slog.Info("narinfo routed", "hash", hash, "upstream", result.URL, "cache_hit", result.CacheHit)
metrics.NarinfoRequests.WithLabelValues("200").Inc()
if len(result.NarInfoBytes) > 0 {
w.Header().Set("Content-Type", "text/x-nix-narinfo")
w.WriteHeader(http.StatusOK)
w.Write(result.NarInfoBytes)
return
}
s.proxyRequest(w, r, result.URL+r.URL.Path)
}
func (s *Server) handleNAR(w http.ResponseWriter, r *http.Request) {
metrics.NARRequests.Inc()
// Consult route cache: the narURL is the path without the leading slash.
narURL := strings.TrimPrefix(r.URL.Path, "/")
var tried string
if entry, err := s.db.GetRouteByNarURL(narURL); err == nil && entry != nil && entry.IsValid() {
tried = entry.UpstreamURL
if s.tryNARUpstream(w, r, entry.UpstreamURL) {
return
}
}
// Fall back through all upstreams sorted by latency.
for _, h := range s.prober.SortedByLatency() {
if h.Status == prober.StatusDown || h.URL == tried {
continue
}
if s.tryNARUpstream(w, r, h.URL) {
return
}
}
http.NotFound(w, r)
}
// Attempts to serve a NAR from upstreamBase. Returns true if the upstream
// responded with a non-404 status.
func (s *Server) tryNARUpstream(w http.ResponseWriter, r *http.Request, upstreamBase string) bool {
targetURL := upstreamBase + r.URL.Path
req, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
if err != nil {
return false
}
for _, hdr := range []string{"Accept", "Accept-Encoding", "Range"} {
if v := r.Header.Get(hdr); v != "" {
req.Header.Set(hdr, v)
}
}
resp, err := s.client.Do(req)
if err != nil {
slog.Warn("NAR upstream failed", "upstream", upstreamBase, "error", err)
return false
}
if resp.StatusCode == http.StatusNotFound {
resp.Body.Close()
return false
}
defer resp.Body.Close()
slog.Debug("proxying NAR", "path", r.URL.Path, "upstream", upstreamBase)
s.copyResponse(w, resp)
return true
}
// Forwards r to targetURL and streams the response zero-copy.
func (s *Server) proxyRequest(w http.ResponseWriter, r *http.Request, targetURL string) {
req, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
for _, h := range []string{"Accept", "Accept-Encoding", "Range"} {
if v := r.Header.Get(h); v != "" {
req.Header.Set(h, v)
}
}
resp, err := s.client.Do(req)
if err != nil {
slog.Error("upstream request failed", "url", targetURL, "error", err)
http.Error(w, "upstream error", http.StatusBadGateway)
return
}
defer resp.Body.Close()
s.copyResponse(w, resp)
}
// Copies response headers and body from resp to w.
func (s *Server) copyResponse(w http.ResponseWriter, resp *http.Response) {
for _, h := range []string{
"Content-Type", "Content-Length", "Content-Encoding",
"X-Nix-Signature", "Cache-Control", "Last-Modified",
} {
if v := resp.Header.Get(h); v != "" {
w.Header().Set(h, v)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
slog.Warn("stream interrupted", "error", err)
}
}
func (s *Server) upstreamURLs() []string {
// Include all upstreams the prober knows about: this covers both the
// statically-configured upstreams and any peers discovered at runtime
// via mDNS. Using the prober as the source of truth avoids a split
// between "what was configured" and "what was discovered".
sorted := s.prober.SortedByLatency()
urls := make([]string, 0, len(sorted))
for _, h := range sorted {
if h.Status != prober.StatusDown {
urls = append(urls, h.URL)
}
}
// Fall back to the static list if the prober has no entries yet (i.e.,
// before the first probe interval completes).
if len(urls) == 0 {
urls = make([]string, len(s.upstreams))
for i, u := range s.upstreams {
urls[i] = u.URL
}
}
return urls
}

View file

@ -1,487 +0,0 @@
package server_test
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"notashelf.dev/ncro/internal/cache"
"notashelf.dev/ncro/internal/config"
"notashelf.dev/ncro/internal/prober"
"notashelf.dev/ncro/internal/router"
"notashelf.dev/ncro/internal/server"
)
func makeTestServer(t *testing.T, upstreams ...string) *httptest.Server {
t.Helper()
f, _ := os.CreateTemp("", "ncro-srv-*.db")
f.Close()
t.Cleanup(func() { os.Remove(f.Name()) })
db, err := cache.Open(f.Name(), 1000)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })
p := prober.New(0.3)
for _, u := range upstreams {
p.AddUpstream(u, 0)
p.RecordLatency(u, 10)
}
upsCfg := make([]config.UpstreamConfig, len(upstreams))
for i, u := range upstreams {
upsCfg[i] = config.UpstreamConfig{URL: u}
}
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
return httptest.NewServer(server.New(r, p, db, upsCfg, 30))
}
func TestNixCacheInfo(t *testing.T) {
ts := makeTestServer(t)
defer ts.Close()
resp, err := http.Get(ts.URL + "/nix-cache-info")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "StoreDir:") {
t.Errorf("body missing StoreDir: %q", body)
}
}
func TestCacheInfoFields(t *testing.T) {
ts := makeTestServer(t)
defer ts.Close()
resp, err := http.Get(ts.URL + "/nix-cache-info")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
s := string(body)
for _, want := range []string{"StoreDir:", "WantMassQuery:", "Priority:"} {
if !strings.Contains(s, want) {
t.Errorf("nix-cache-info missing %q", want)
}
}
}
func TestHealthEndpoint(t *testing.T) {
ts := makeTestServer(t)
defer ts.Close()
resp, err := http.Get(ts.URL + "/health")
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
}
func TestMetricsEndpoint(t *testing.T) {
ts := makeTestServer(t)
defer ts.Close()
resp, err := http.Get(ts.URL + "/metrics")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
ct := resp.Header.Get("Content-Type")
if !strings.HasPrefix(ct, "text/plain") {
t.Errorf("Content-Type = %q, want text/plain", ct)
}
}
func TestNarinfoProxy(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, ".narinfo") {
w.Header().Set("Content-Type", "text/x-nix-narinfo")
fmt.Fprint(w, "StorePath: /nix/store/abc123-hello-2.12\nURL: nar/abc123.nar\nCompression: none\n")
return
}
w.WriteHeader(404)
}))
defer upstream.Close()
ts := makeTestServer(t, upstream.URL)
defer ts.Close()
resp, err := http.Get(ts.URL + "/abc123def456.narinfo")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("narinfo status = %d, want 200", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "StorePath:") {
t.Errorf("expected narinfo body, got: %q", body)
}
}
func TestNarinfoHEADRequest(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, ".narinfo") {
w.Header().Set("Content-Type", "text/x-nix-narinfo")
fmt.Fprint(w, "StorePath: /nix/store/abc-head-test\nURL: nar/abc.nar\n")
return
}
w.WriteHeader(404)
}))
defer upstream.Close()
ts := makeTestServer(t, upstream.URL)
defer ts.Close()
req, _ := http.NewRequest(http.MethodHead, ts.URL+"/abc123.narinfo", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("HEAD narinfo status = %d, want 200", resp.StatusCode)
}
}
func TestNarinfoNotFound(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
}))
defer upstream.Close()
ts := makeTestServer(t, upstream.URL)
defer ts.Close()
resp, _ := http.Get(ts.URL + "/notfound000000.narinfo")
if resp.StatusCode != 404 {
t.Errorf("status = %d, want 404", resp.StatusCode)
}
}
func TestNarinfoUpstreamError(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
defer upstream.Close()
ts := makeTestServer(t, upstream.URL)
defer ts.Close()
resp, _ := http.Get(ts.URL + "/abc123.narinfo")
// 404 (not found) or 502 (upstream error) are both acceptable
if resp.StatusCode == 200 {
t.Errorf("expected non-200 for upstream error, got %d", resp.StatusCode)
}
}
func TestNarinfoNoUpstreams(t *testing.T) {
ts := makeTestServer(t) // no upstreams
defer ts.Close()
resp, _ := http.Get(ts.URL + "/abc123.narinfo")
if resp.StatusCode == 200 {
t.Error("expected non-200 with no upstreams")
}
}
func TestUnknownPath(t *testing.T) {
ts := makeTestServer(t)
defer ts.Close()
resp, err := http.Get(ts.URL + "/unknown/path")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if resp.StatusCode != 404 {
t.Errorf("status = %d, want 404", resp.StatusCode)
}
}
func TestNARStreamingPassthrough(t *testing.T) {
narContent := []byte("fake-nar-content-bytes")
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/nar/") {
w.Header().Set("Content-Type", "application/x-nix-archive")
w.Write(narContent)
return
}
if strings.HasSuffix(r.URL.Path, ".narinfo") {
w.WriteHeader(200)
return
}
w.WriteHeader(404)
}))
defer upstream.Close()
ts := makeTestServer(t, upstream.URL)
defer ts.Close()
resp, err := http.Get(ts.URL + "/nar/abc123.nar")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("NAR status = %d, want 200", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
if string(body) != string(narContent) {
t.Errorf("NAR body mismatch: got %q, want %q", body, narContent)
}
}
func TestNARRangeHeaderForwarded(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/nar/") {
if r.Header.Get("Range") == "" {
http.Error(w, "Range header missing", 400)
return
}
w.WriteHeader(206)
w.Write([]byte("partial"))
return
}
if strings.HasSuffix(r.URL.Path, ".narinfo") {
w.WriteHeader(200)
return
}
w.WriteHeader(404)
}))
defer upstream.Close()
ts := makeTestServer(t, upstream.URL)
defer ts.Close()
req, _ := http.NewRequest(http.MethodGet, ts.URL+"/nar/abc.nar", nil)
req.Header.Set("Range", "bytes=0-1023")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if resp.StatusCode != 206 {
t.Errorf("Range request status = %d, want 206", resp.StatusCode)
}
}
func TestNARRoutingUsesCache(t *testing.T) {
// Upstream A: has the NAR.
upstreamA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, ".narinfo") {
w.Header().Set("Content-Type", "text/x-nix-narinfo")
fmt.Fprintln(w, "StorePath: /nix/store/abc123-test")
fmt.Fprintln(w, "URL: nar/abc123.nar.xz")
} else {
fmt.Fprintln(w, "NAR data from A")
}
}))
defer upstreamA.Close()
// Upstream B: does NOT have the NAR.
var bHit int32
upstreamB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&bHit, 1)
http.NotFound(w, r)
}))
defer upstreamB.Close()
db, err := cache.Open(":memory:", 100)
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Pre-seed the route cache: abc123 -> upstreamA, NarURL = "nar/abc123.nar.xz"
if err := db.SetRoute(&cache.RouteEntry{
StorePath: "abc123",
UpstreamURL: upstreamA.URL,
NarURL: "nar/abc123.nar.xz",
TTL: time.Now().Add(time.Hour),
}); err != nil {
t.Fatalf("SetRoute: %v", err)
}
p := prober.New(0.3)
p.InitUpstreams([]config.UpstreamConfig{{URL: upstreamA.URL}, {URL: upstreamB.URL}})
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
srv := server.New(r, p, db, []config.UpstreamConfig{{URL: upstreamA.URL}, {URL: upstreamB.URL}}, 30)
req := httptest.NewRequest(http.MethodGet, "/nar/abc123.nar.xz", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != 200 {
t.Fatalf("status = %d, want 200", w.Code)
}
if atomic.LoadInt32(&bHit) > 0 {
t.Error("upstream B should not have been contacted when route cache has the answer")
}
}
func TestNARFallbackWhenFirstUpstreamMissing(t *testing.T) {
missing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
}))
defer missing.Close()
hasIt := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/x-nix-archive")
w.Write([]byte("nar-bytes"))
}))
defer hasIt.Close()
f, _ := os.CreateTemp("", "ncro-nar-fallback-*.db")
f.Close()
t.Cleanup(func() { os.Remove(f.Name()) })
db, _ := cache.Open(f.Name(), 1000)
t.Cleanup(func() { db.Close() })
p := prober.New(0.3)
// missing appears faster
p.AddUpstream(missing.URL, 0)
p.AddUpstream(hasIt.URL, 0)
p.RecordLatency(missing.URL, 1)
p.RecordLatency(hasIt.URL, 50)
upsCfg := []config.UpstreamConfig{{URL: missing.URL}, {URL: hasIt.URL}}
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
ts := httptest.NewServer(server.New(r, p, db, upsCfg, 30))
defer ts.Close()
resp, err := http.Get(ts.URL + "/nar/abc123.nar")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected fallback NAR response 200, got %d", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
if string(body) != "nar-bytes" {
t.Errorf("NAR body = %q, want nar-bytes", body)
}
}
func TestHealthEndpointDegraded(t *testing.T) {
p := prober.New(0.3)
p.InitUpstreams([]config.UpstreamConfig{
{URL: "https://up1.example.com"},
{URL: "https://up2.example.com"},
})
p.RecordLatency("https://up1.example.com", 100)
for range 5 {
p.RecordFailure("https://up2.example.com")
}
db, err := cache.Open(":memory:", 100)
if err != nil {
t.Fatal(err)
}
defer db.Close()
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
srv := server.New(r, p, db, []config.UpstreamConfig{
{URL: "https://up1.example.com"},
{URL: "https://up2.example.com"},
}, 30)
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != 200 {
t.Fatalf("status = %d", w.Code)
}
var resp struct {
Status string `json:"status"`
Upstreams []struct {
URL string `json:"url"`
Status string `json:"status"`
} `json:"upstreams"`
}
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode: %v", err)
}
if resp.Status != "degraded" {
t.Errorf("status = %q, want degraded", resp.Status)
}
if len(resp.Upstreams) != 2 {
t.Errorf("upstreams = %d, want 2", len(resp.Upstreams))
}
var foundDegraded bool
for _, u := range resp.Upstreams {
if u.URL == "https://up2.example.com" && u.Status == "DEGRADED" {
foundDegraded = true
}
}
if !foundDegraded {
t.Error("expected up2 to have status DEGRADED")
}
var foundActive bool
for _, u := range resp.Upstreams {
if u.URL == "https://up1.example.com" && u.Status == "ACTIVE" {
foundActive = true
}
}
if !foundActive {
t.Error("expected up1 to have status ACTIVE")
}
}
func TestHealthEndpointAllDown(t *testing.T) {
p := prober.New(0.3)
p.InitUpstreams([]config.UpstreamConfig{{URL: "https://down.example.com"}})
for range 10 {
p.RecordFailure("https://down.example.com")
}
db, err := cache.Open(":memory:", 100)
if err != nil {
t.Fatal(err)
}
defer db.Close()
r := router.New(db, p, time.Hour, 5*time.Second, 10*time.Minute)
srv := server.New(r, p, db, []config.UpstreamConfig{{URL: "https://down.example.com"}}, 30)
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
var resp struct {
Status string `json:"status"`
}
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode: %v", err)
}
if resp.Status != "down" {
t.Errorf("status = %q, want down", resp.Status)
}
}

View file

@ -1,8 +1,9 @@
{
lib,
buildGoModule,
rustPlatform,
pkg-config,
}:
buildGoModule (finalAttrs: {
rustPlatform.buildRustPackage (finalAttrs: {
pname = "ncro";
version = "1.0.0";
@ -13,15 +14,14 @@ buildGoModule (finalAttrs: {
fs.toSource {
root = s;
fileset = fs.unions [
(s + /cmd)
(s + /internal)
(s + /go.mod)
(s + /go.sum)
(s + /src)
(s + /Cargo.toml)
(s + /Cargo.lock)
];
};
vendorHash = "sha256-9OkQIj2g5mZ+IpjIKvy8Il7J4xL4PJimEsXJP10FhmU=";
ldflags = ["-s" "-w" "-X main.version=${finalAttrs.version}"];
cargoLock.lockFile = "${finalAttrs.src}/Cargo.lock";
nativeBuildInputs = [pkg-config];
meta = {
mainProgram = "ncro";

View file

@ -1,18 +1,23 @@
{
mkShell,
go,
gopls,
delve,
gofumpt,
golines,
cargo,
clippy,
pkg-config,
rust-analyzer,
rustc,
rustfmt,
}:
mkShell {
name = "go";
packages = [
delve
go
gopls
gofumpt
golines
name = "rust";
strictDeps = true;
nativeBuildInputs = [
cargo
rustc
pkg-config
rust-analyzer
clippy
(rustfmt.override {asNightly = true;})
];
}

191
src/cli.rs Normal file
View file

@ -0,0 +1,191 @@
use clap::Parser;
use tokio::net::TcpListener;
use tracing_subscriber::{EnvFilter, fmt};
use crate::{
config::Config,
db::Db,
discovery::Discovery,
health::Prober,
mesh,
metrics,
router::Router,
server,
};
#[derive(Debug, Parser)]
#[command(name = "ncro", version, about = "Nix Cache Route Optimizer")]
pub struct Args {
#[arg(short, long, env = "NCRO_CONFIG")]
pub config: Option<String>,
}
pub async fn run() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = Config::load(args.config.as_deref())?;
cfg.validate()?;
init_logging(&cfg.logging.level, &cfg.logging.format);
let _ = metrics::get();
let db = Db::open(&cfg.cache.db_path, cfg.cache.max_entries).await?;
let prober = Prober::new(cfg.cache.latency_alpha);
prober.init_upstreams(&cfg.upstreams).await;
for row in db.load_all_health().await.unwrap_or_default() {
prober
.seed(
&row.url,
row.ema_latency,
row.consecutive_fails,
row.total_queries,
)
.await;
}
let db_for_health = db.clone();
prober
.set_health_persistence(move |url, ema, fails, queries| {
let db = db_for_health.clone();
tokio::spawn(async move {
let _ = db
.save_health(
&url,
ema,
i64::from(fails),
i64::try_from(queries).unwrap_or(i64::MAX),
)
.await;
});
})
.await;
for upstream in &cfg.upstreams {
let prober = prober.clone();
let url = upstream.url.clone();
tokio::spawn(async move {
prober.probe_upstream(url).await;
});
}
let router = Router::new(
db.clone(),
prober.clone(),
cfg.cache.ttl.0,
std::time::Duration::from_secs(5),
cfg.cache.negative_ttl.0,
);
for upstream in &cfg.upstreams {
if !upstream.public_key.is_empty() {
router
.set_upstream_key(upstream.url.clone(), upstream.public_key.clone())
.await?;
}
}
let (stop_tx, stop_rx) = tokio::sync::watch::channel(false);
let probe_prober = prober.clone();
let probe_stop = stop_rx.clone();
tokio::spawn(async move {
probe_prober
.run_probe_loop(std::time::Duration::from_secs(30), probe_stop)
.await;
});
let db_for_expiry = db.clone();
let mut expiry_stop = stop_rx.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
tokio::select! {
_ = expiry_stop.changed() => return,
_ = ticker.tick() => {
let _ = db_for_expiry.expire_old_routes().await;
let _ = db_for_expiry.expire_negatives().await;
if let Ok(count) = db_for_expiry.route_count().await { metrics::get().route_entries.set(count); }
}
}
}
});
if cfg.discovery.enabled {
let discovery = Discovery::new(cfg.discovery.clone(), prober.clone())?;
let discovery_stop = stop_rx.clone();
tokio::spawn(async move {
let _ = discovery.run(discovery_stop).await;
});
}
if cfg.mesh.enabled {
let node = mesh::Node::new(&cfg.mesh.private_key_path).await?;
tracing::info!(
node_id = node.id(),
public_key = hex::encode(node.public_key()),
"mesh node identity"
);
let allowed = cfg
.mesh
.peers
.iter()
.filter_map(|p| hex::decode(&p.public_key).ok()?.try_into().ok())
.collect::<Vec<[u8; 32]>>();
mesh::listen_and_serve(
&cfg.mesh.bind_addr,
db.clone(),
allowed,
stop_rx.clone(),
)
.await?;
let peers = cfg
.mesh
.peers
.iter()
.map(|p| p.addr.clone())
.collect::<Vec<_>>();
tokio::spawn(mesh::run_gossip_loop(
node,
db.clone(),
peers,
cfg.mesh.gossip_interval.0,
stop_rx.clone(),
));
}
let app = server::app(
router,
prober,
db,
cfg.upstreams.clone(),
cfg.server.cache_priority,
);
let listener =
TcpListener::bind(normalize_listen(&cfg.server.listen)).await?;
tracing::info!(
addr = cfg.server.listen,
upstreams = cfg.upstreams.len(),
version = env!("CARGO_PKG_VERSION"),
"ncro listening"
);
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
let _ = tokio::signal::ctrl_c().await;
});
let result = server.await;
let _ = stop_tx.send(true);
result?;
Ok(())
}
fn init_logging(level: &str, format_name: &str) {
let filter =
EnvFilter::try_new(level).unwrap_or_else(|_| EnvFilter::new("info"));
if format_name == "text" {
fmt().with_env_filter(filter).init();
} else {
fmt().json().with_env_filter(filter).init();
}
}
fn normalize_listen(listen: &str) -> String {
if listen.starts_with(':') {
format!("0.0.0.0{listen}")
} else {
listen.to_string()
}
}

336
src/config.rs Normal file
View file

@ -0,0 +1,336 @@
use std::{env, fs, time::Duration};
use serde::{Deserialize, Deserializer};
use thiserror::Error;
use url::Url;
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("read config: {0}")]
Read(#[from] std::io::Error),
#[error("parse config: {0}")]
Parse(#[from] serde_yaml::Error),
#[error("{0}")]
Validation(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn loads_defaults() -> Result<(), ConfigError> {
let cfg = Config::load(None)?;
assert_eq!(cfg.server.listen, ":8080");
assert_eq!(cfg.cache.max_entries, 100_000);
assert_eq!(cfg.upstreams.len(), 1);
cfg.validate()?;
Ok(())
}
#[test]
fn parses_duration_yaml() -> Result<(), serde_yaml::Error> {
let cfg: Config = serde_yaml::from_str(
"server:\n read_timeout: 30s\ncache:\n ttl: 2h\n",
)?;
assert_eq!(cfg.server.read_timeout.0, Duration::from_secs(30));
assert_eq!(cfg.cache.ttl.0, Duration::from_secs(7200));
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HumanDuration(pub Duration);
impl Default for HumanDuration {
fn default() -> Self {
Self(Duration::ZERO)
}
}
impl<'de> Deserialize<'de> for HumanDuration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
humantime_serde::deserialize(deserializer).map(Self)
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub struct UpstreamConfig {
pub url: String,
pub priority: i32,
pub public_key: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
pub listen: String,
pub read_timeout: HumanDuration,
pub write_timeout: HumanDuration,
pub cache_priority: i32,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
listen: ":8080".to_string(),
read_timeout: HumanDuration(Duration::from_secs(30)),
write_timeout: HumanDuration(Duration::from_secs(30)),
cache_priority: 30,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct CacheConfig {
pub db_path: String,
pub max_entries: i64,
pub ttl: HumanDuration,
pub negative_ttl: HumanDuration,
pub latency_alpha: f64,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
db_path: "/var/lib/ncro/routes.db".to_string(),
max_entries: 100_000,
ttl: HumanDuration(Duration::from_secs(60 * 60)),
negative_ttl: HumanDuration(Duration::from_secs(10 * 60)),
latency_alpha: 0.3,
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub struct PeerConfig {
pub addr: String,
pub public_key: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct MeshConfig {
pub enabled: bool,
pub bind_addr: String,
pub peers: Vec<PeerConfig>,
#[serde(rename = "private_key")]
pub private_key_path: String,
pub gossip_interval: HumanDuration,
}
impl Default for MeshConfig {
fn default() -> Self {
Self {
enabled: false,
bind_addr: "0.0.0.0:7946".to_string(),
peers: Vec::new(),
private_key_path: String::new(),
gossip_interval: HumanDuration(Duration::from_secs(30)),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct DiscoveryConfig {
pub enabled: bool,
pub service_name: String,
pub domain: String,
pub discovery_time: HumanDuration,
pub priority: i32,
}
impl Default for DiscoveryConfig {
fn default() -> Self {
Self {
enabled: false,
service_name: "_nix-serve._tcp".to_string(),
domain: "local".to_string(),
discovery_time: HumanDuration(Duration::from_secs(5)),
priority: 20,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct LoggingConfig {
pub level: String,
pub format: String,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: "info".to_string(),
format: "json".to_string(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
pub server: ServerConfig,
pub upstreams: Vec<UpstreamConfig>,
pub cache: CacheConfig,
pub mesh: MeshConfig,
pub discovery: DiscoveryConfig,
pub logging: LoggingConfig,
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig::default(),
upstreams: vec![UpstreamConfig {
url: "https://cache.nixos.org".to_string(),
priority: 10,
public_key: String::new(),
}],
cache: CacheConfig::default(),
mesh: MeshConfig::default(),
discovery: DiscoveryConfig::default(),
logging: LoggingConfig::default(),
}
}
}
impl Config {
pub fn load(path: Option<&str>) -> Result<Self, ConfigError> {
let mut cfg = if let Some(path) = path.filter(|p| !p.is_empty()) {
let data = fs::read_to_string(path)?;
serde_yaml::from_str::<Self>(&data)?
} else {
Self::default()
};
if let Ok(v) = env::var("NCRO_LISTEN")
&& !v.is_empty()
{
cfg.server.listen = v;
}
if let Ok(v) = env::var("NCRO_DB_PATH")
&& !v.is_empty()
{
cfg.cache.db_path = v;
}
if let Ok(v) = env::var("NCRO_LOG_LEVEL")
&& !v.is_empty()
{
cfg.logging.level = v;
}
Ok(cfg)
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.upstreams.is_empty() {
return Err(ConfigError::Validation(
"at least one upstream is required".to_string(),
));
}
for (i, upstream) in self.upstreams.iter().enumerate() {
if upstream.url.is_empty() {
return Err(ConfigError::Validation(format!(
"upstream[{i}]: URL is empty"
)));
}
Url::parse(&upstream.url).map_err(|err| {
ConfigError::Validation(format!(
"upstream[{i}]: invalid URL {:?}: {err}",
upstream.url
))
})?;
if !upstream.public_key.is_empty() && !upstream.public_key.contains(':') {
return Err(ConfigError::Validation(format!(
"upstream[{i}]: public_key must be in 'name:base64(key)' Nix format"
)));
}
}
if self.server.listen.is_empty() {
return Err(ConfigError::Validation(
"server.listen is empty".to_string(),
));
}
if self.server.cache_priority < 1 {
return Err(ConfigError::Validation(format!(
"server.cache_priority must be >= 1, got {}",
self.server.cache_priority
)));
}
if self.cache.latency_alpha <= 0.0 || self.cache.latency_alpha >= 1.0 {
return Err(ConfigError::Validation(format!(
"cache.latency_alpha must be between 0 and 1 exclusive, got {}",
self.cache.latency_alpha
)));
}
if self.cache.ttl.0.is_zero() {
return Err(ConfigError::Validation(
"cache.ttl must be positive".to_string(),
));
}
if self.cache.negative_ttl.0.is_zero() {
return Err(ConfigError::Validation(
"cache.negative_ttl must be positive".to_string(),
));
}
if self.cache.max_entries <= 0 {
return Err(ConfigError::Validation(
"cache.max_entries must be positive".to_string(),
));
}
if self.mesh.enabled && self.mesh.peers.is_empty() {
return Err(ConfigError::Validation(
"mesh.enabled is true but no peers configured".to_string(),
));
}
for (i, peer) in self.mesh.peers.iter().enumerate() {
if peer.addr.is_empty() {
return Err(ConfigError::Validation(format!(
"mesh.peers[{i}]: addr is empty"
)));
}
if !peer.public_key.is_empty() {
let bytes = hex::decode(&peer.public_key).map_err(|_| {
ConfigError::Validation(format!(
"mesh.peers[{i}]: public_key must be a hex-encoded 32-byte \
ed25519 key"
))
})?;
if bytes.len() != 32 {
return Err(ConfigError::Validation(format!(
"mesh.peers[{i}]: public_key must be a hex-encoded 32-byte \
ed25519 key"
)));
}
}
}
if self.discovery.enabled {
if self.discovery.service_name.is_empty() {
return Err(ConfigError::Validation(
"discovery.service_name is required when discovery is enabled"
.to_string(),
));
}
if self.discovery.domain.is_empty() {
return Err(ConfigError::Validation(
"discovery.domain is required when discovery is enabled".to_string(),
));
}
if self.discovery.discovery_time.0.is_zero() {
return Err(ConfigError::Validation(
"discovery.discovery_time must be positive".to_string(),
));
}
}
Ok(())
}
}

394
src/db.rs Normal file
View file

@ -0,0 +1,394 @@
use std::{path::Path, time::Duration};
use chrono::{DateTime, TimeZone, Utc};
use sqlx::{
Row,
SqlitePool,
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DbError {
#[error("sqlite: {0}")]
Sqlx(#[from] sqlx::Error),
#[error("create database directory: {0}")]
CreateDir(#[from] std::io::Error),
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct RouteEntry {
pub store_path: String,
pub upstream_url: String,
pub latency_ms: f64,
pub latency_ema: f64,
pub last_verified: DateTime<Utc>,
pub query_count: u32,
pub failure_count: u32,
pub ttl: DateTime<Utc>,
pub nar_hash: String,
pub nar_size: u64,
pub nar_url: String,
}
impl RouteEntry {
pub fn is_valid(&self) -> bool {
Utc::now() < self.ttl
}
}
#[derive(Debug, Clone)]
pub struct HealthRow {
pub url: String,
pub ema_latency: f64,
pub consecutive_fails: i64,
pub total_queries: i64,
}
#[derive(Clone)]
pub struct Db {
pool: SqlitePool,
max_entries: i64,
}
impl Db {
pub async fn open(path: &str, max_entries: i64) -> Result<Self, DbError> {
if path != ":memory:"
&& let Some(parent) = Path::new(path).parent()
{
tokio::fs::create_dir_all(parent).await?;
}
let options = if path == ":memory:" {
SqliteConnectOptions::new().filename(path)
} else {
SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true)
}
.journal_mode(SqliteJournalMode::Wal)
.busy_timeout(Duration::from_secs(5));
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect_with(options)
.await?;
migrate(&pool).await?;
Ok(Self { pool, max_entries })
}
pub async fn get_route(
&self,
store_path: &str,
) -> Result<Option<RouteEntry>, DbError> {
let row = sqlx::query(
r"SELECT store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE store_path = ?",
)
.bind(store_path)
.fetch_optional(&self.pool)
.await?;
Ok(row.as_ref().map(row_to_route))
}
pub async fn get_route_by_nar_url(
&self,
nar_url: &str,
) -> Result<Option<RouteEntry>, DbError> {
let row = sqlx::query(
r"SELECT store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE nar_url = ? AND ttl > ?",
)
.bind(nar_url)
.bind(Utc::now().timestamp())
.fetch_optional(&self.pool)
.await?;
Ok(row.as_ref().map(row_to_route))
}
pub async fn set_route(&self, entry: &RouteEntry) -> Result<(), DbError> {
sqlx::query(
r"INSERT INTO routes
(store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(store_path) DO UPDATE SET
upstream_url = excluded.upstream_url,
latency_ms = excluded.latency_ms,
latency_ema = excluded.latency_ema,
query_count = excluded.query_count,
failure_count = excluded.failure_count,
last_verified = excluded.last_verified,
ttl = excluded.ttl,
nar_hash = excluded.nar_hash,
nar_size = excluded.nar_size,
nar_url = excluded.nar_url",
)
.bind(&entry.store_path)
.bind(&entry.upstream_url)
.bind(entry.latency_ms)
.bind(entry.latency_ema)
.bind(i64::from(entry.query_count))
.bind(i64::from(entry.failure_count))
.bind(entry.last_verified.timestamp())
.bind(entry.ttl.timestamp())
.bind(&entry.nar_hash)
.bind(i64::try_from(entry.nar_size).unwrap_or(i64::MAX))
.bind(&entry.nar_url)
.execute(&self.pool)
.await?;
self.evict_if_needed().await
}
pub async fn expire_old_routes(&self) -> Result<(), DbError> {
sqlx::query("DELETE FROM routes WHERE ttl < ?")
.bind(Utc::now().timestamp())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn list_recent_routes(
&self,
n: i64,
) -> Result<Vec<RouteEntry>, DbError> {
let rows = sqlx::query(
r"SELECT store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE ttl > ? ORDER BY last_verified DESC LIMIT ?",
)
.bind(Utc::now().timestamp())
.bind(n)
.fetch_all(&self.pool)
.await?;
Ok(rows.iter().map(row_to_route).collect())
}
pub async fn route_count(&self) -> Result<i64, DbError> {
Ok(
sqlx::query("SELECT COUNT(*) FROM routes")
.fetch_one(&self.pool)
.await?
.get::<i64, _>(0),
)
}
pub async fn set_negative(
&self,
store_path: &str,
ttl: Duration,
) -> Result<(), DbError> {
sqlx::query(
r"INSERT INTO negative_cache (store_path, expires_at) VALUES (?, ?)
ON CONFLICT(store_path) DO UPDATE SET expires_at = excluded.expires_at",
)
.bind(store_path)
.bind((Utc::now() + chrono::Duration::from_std(ttl).unwrap_or_default()).timestamp())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn is_negative(&self, store_path: &str) -> Result<bool, DbError> {
Ok(
sqlx::query(
"SELECT EXISTS(SELECT 1 FROM negative_cache WHERE store_path = ? AND \
expires_at > ?)",
)
.bind(store_path)
.bind(Utc::now().timestamp())
.fetch_one(&self.pool)
.await?
.get::<i64, _>(0)
!= 0,
)
}
pub async fn expire_negatives(&self) -> Result<(), DbError> {
sqlx::query("DELETE FROM negative_cache WHERE expires_at < ?")
.bind(Utc::now().timestamp())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn save_health(
&self,
url: &str,
ema: f64,
consecutive_fails: i64,
total_queries: i64,
) -> Result<(), DbError> {
sqlx::query(
r"INSERT INTO upstream_health (url, ema_latency, consecutive_fails, total_queries)
VALUES (?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
ema_latency = excluded.ema_latency,
consecutive_fails = excluded.consecutive_fails,
total_queries = excluded.total_queries",
)
.bind(url)
.bind(ema)
.bind(consecutive_fails)
.bind(total_queries)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn load_all_health(&self) -> Result<Vec<HealthRow>, DbError> {
let rows = sqlx::query(
"SELECT url, ema_latency, consecutive_fails, total_queries FROM \
upstream_health",
)
.fetch_all(&self.pool)
.await?;
Ok(
rows
.into_iter()
.map(|row| {
HealthRow {
url: row.get("url"),
ema_latency: row.get("ema_latency"),
consecutive_fails: row.get("consecutive_fails"),
total_queries: row.get("total_queries"),
}
})
.collect(),
)
}
async fn evict_if_needed(&self) -> Result<(), DbError> {
sqlx::query(
r"DELETE FROM routes WHERE store_path IN (
SELECT store_path FROM routes ORDER BY last_verified ASC
LIMIT MAX(0, (SELECT COUNT(*) FROM routes) - ?)
)",
)
.bind(self.max_entries)
.execute(&self.pool)
.await?;
Ok(())
}
}
async fn migrate(pool: &SqlitePool) -> Result<(), DbError> {
sqlx::query(
r"CREATE TABLE IF NOT EXISTS routes (
store_path TEXT PRIMARY KEY,
upstream_url TEXT NOT NULL,
latency_ms REAL NOT NULL DEFAULT 0,
latency_ema REAL NOT NULL DEFAULT 0,
query_count INTEGER NOT NULL DEFAULT 1,
failure_count INTEGER NOT NULL DEFAULT 0,
last_verified INTEGER NOT NULL DEFAULT 0,
ttl INTEGER NOT NULL,
nar_hash TEXT NOT NULL DEFAULT '',
nar_size INTEGER NOT NULL DEFAULT 0,
nar_url TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
)",
)
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_routes_ttl ON routes(ttl)")
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_routes_last_verified ON \
routes(last_verified)",
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_routes_nar_url ON routes(nar_url)",
)
.execute(pool)
.await?;
sqlx::query(
r"CREATE TABLE IF NOT EXISTS upstream_health (
url TEXT PRIMARY KEY,
ema_latency REAL NOT NULL DEFAULT 0,
consecutive_fails INTEGER NOT NULL DEFAULT 0,
total_queries INTEGER NOT NULL DEFAULT 0
)",
)
.execute(pool)
.await?;
sqlx::query(
r"CREATE TABLE IF NOT EXISTS negative_cache (
store_path TEXT PRIMARY KEY,
expires_at INTEGER NOT NULL
)",
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_negative_expires ON \
negative_cache(expires_at)",
)
.execute(pool)
.await?;
Ok(())
}
fn row_to_route(row: &sqlx::sqlite::SqliteRow) -> RouteEntry {
let query_count = row.get::<i64, _>("query_count");
let failure_count = row.get::<i64, _>("failure_count");
let nar_size = row.get::<i64, _>("nar_size");
RouteEntry {
store_path: row.get("store_path"),
upstream_url: row.get("upstream_url"),
latency_ms: row.get("latency_ms"),
latency_ema: row.get("latency_ema"),
query_count: u32::try_from(query_count).unwrap_or_default(),
failure_count: u32::try_from(failure_count).unwrap_or_default(),
last_verified: Utc
.timestamp_opt(row.get("last_verified"), 0)
.single()
.unwrap_or_else(Utc::now),
ttl: Utc
.timestamp_opt(row.get("ttl"), 0)
.single()
.unwrap_or_else(Utc::now),
nar_hash: row.get("nar_hash"),
nar_size: u64::try_from(nar_size).unwrap_or_default(),
nar_url: row.get("nar_url"),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn route_roundtrip_and_negative_cache() -> Result<(), DbError> {
let db = Db::open(":memory:", 100).await?;
let now = Utc::now();
let entry = RouteEntry {
store_path: "abc123".into(),
upstream_url: "https://cache.nixos.org".into(),
latency_ms: 10.0,
latency_ema: 10.0,
last_verified: now,
query_count: 1,
failure_count: 0,
ttl: now + chrono::Duration::hours(1),
nar_hash: "sha256:abc".into(),
nar_size: 42,
nar_url: "nar/abc.nar.xz".into(),
};
db.set_route(&entry).await?;
let got = db
.get_route("abc123")
.await?
.ok_or(sqlx::Error::RowNotFound)?;
assert_eq!(got.upstream_url, entry.upstream_url);
assert!(db.get_route_by_nar_url("nar/abc.nar.xz").await?.is_some());
db.set_negative("missing", Duration::from_secs(60)).await?;
assert!(db.is_negative("missing").await?);
Ok(())
}
}

74
src/discovery.rs Normal file
View file

@ -0,0 +1,74 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use mdns_sd::{ServiceDaemon, ServiceEvent};
use tokio::sync::{Mutex, watch};
use crate::{config::DiscoveryConfig, health::Prober};
pub struct Discovery {
cfg: DiscoveryConfig,
prober: Prober,
daemon: ServiceDaemon,
peers: Arc<Mutex<HashMap<String, (String, Instant)>>>,
}
impl Discovery {
pub fn new(cfg: DiscoveryConfig, prober: Prober) -> anyhow::Result<Self> {
Ok(Self {
cfg,
prober,
daemon: ServiceDaemon::new()?,
peers: Arc::new(Mutex::new(HashMap::new())),
})
}
pub async fn run(
self,
mut stop: watch::Receiver<bool>,
) -> anyhow::Result<()> {
let service = format!(
"{}.{}.",
self.cfg.service_name.trim_end_matches('.'),
self.cfg.domain.trim_end_matches('.')
);
let receiver = self.daemon.browse(&service)?;
let peers = Arc::clone(&self.peers);
let prober = self.prober.clone();
let priority = self.cfg.priority;
let mut cleanup = tokio::time::interval(Duration::from_secs(10));
let expiration = if self.cfg.discovery_time.0.is_zero() {
Duration::from_secs(30)
} else {
self.cfg.discovery_time.0 * 3
};
loop {
tokio::select! {
_ = stop.changed() => { let _ = self.daemon.shutdown(); return Ok(()); }
_ = cleanup.tick() => {
let stale = {
let mut guard = peers.lock().await;
let now = Instant::now();
let stale = guard.iter().filter(|(_, (_, seen))| now.duration_since(*seen) > expiration).map(|(k, (u, _))| (k.clone(), u.clone())).collect::<Vec<_>>();
for (key, _) in &stale { guard.remove(key); }
stale
};
for (_, url) in stale { tracing::info!(url, "removing stale peer"); prober.remove_upstream(&url).await; }
}
event = tokio::task::spawn_blocking({ let receiver = receiver.clone(); move || receiver.recv_timeout(Duration::from_millis(500)).ok() }) => {
if let Ok(Some(ServiceEvent::ServiceResolved(info))) = event {
let Some(addr) = info.get_addresses().iter().next().map(mdns_sd::ScopedIp::to_ip_addr) else { continue; };
let url = format!("http://{}", std::net::SocketAddr::new(addr, info.get_port()));
let key = info.get_fullname().to_string();
let is_new = peers.lock().await.insert(key, (url.clone(), Instant::now())).is_none();
if is_new { tracing::info!(url, "discovered nix-serve instance"); prober.add_upstream(url, priority).await; }
}
}
}
}
}
}

310
src/health.rs Normal file
View file

@ -0,0 +1,310 @@
use std::{
cmp::Ordering,
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::RwLock;
use crate::config::UpstreamConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Status {
Active,
Degraded,
Down,
}
impl Status {
pub const fn as_str(self) -> &'static str {
match self {
Self::Active => "ACTIVE",
Self::Degraded => "DEGRADED",
Self::Down => "DOWN",
}
}
}
#[derive(Debug, Clone)]
pub struct UpstreamHealth {
pub url: String,
pub priority: i32,
pub ema_latency: f64,
pub last_probe: Option<Instant>,
pub consecutive_fails: u32,
pub total_queries: u64,
pub status: Status,
}
impl UpstreamHealth {
const fn new(url: String, priority: i32) -> Self {
Self {
url,
priority,
ema_latency: 0.0,
last_probe: None,
consecutive_fails: 0,
total_queries: 0,
status: Status::Active,
}
}
}
type PersistHealth = Arc<dyn Fn(String, f64, u32, u64) + Send + Sync>;
#[derive(Clone)]
pub struct Prober {
inner: Arc<ProberInner>,
}
struct ProberInner {
alpha: f64,
table: RwLock<HashMap<String, UpstreamHealth>>,
client: reqwest::Client,
persist_health: RwLock<Option<PersistHealth>>,
}
impl Prober {
pub fn new(alpha: f64) -> Self {
Self {
inner: Arc::new(ProberInner {
alpha,
table: RwLock::new(HashMap::new()),
client: reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
persist_health: RwLock::new(None),
}),
}
}
pub async fn init_upstreams(&self, upstreams: &[UpstreamConfig]) {
let mut table = self.inner.table.write().await;
for upstream in upstreams {
table.entry(upstream.url.clone()).or_insert_with(|| {
UpstreamHealth::new(upstream.url.clone(), upstream.priority)
});
}
}
#[allow(clippy::significant_drop_tightening)]
pub async fn seed(
&self,
url: &str,
ema_latency: f64,
consecutive_fails: i64,
total_queries: i64,
) {
{
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
health.ema_latency = ema_latency;
health.total_queries =
u64::try_from(total_queries.max(0)).unwrap_or_default();
health.consecutive_fails =
u32::try_from(consecutive_fails.max(0)).unwrap_or(u32::MAX);
health.status = compute_status(health.consecutive_fails);
}
}
pub async fn set_health_persistence<F>(&self, f: F)
where
F: Fn(String, f64, u32, u64) + Send + Sync + 'static,
{
*self.inner.persist_health.write().await = Some(Arc::new(f));
}
#[allow(clippy::significant_drop_tightening)]
pub async fn record_latency(&self, url: &str, ms: f64) {
let snapshot = {
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
if health.total_queries == 0 {
health.ema_latency = ms;
} else {
health.ema_latency = self
.inner
.alpha
.mul_add(ms, (1.0 - self.inner.alpha) * health.ema_latency);
}
health.consecutive_fails = 0;
health.total_queries += 1;
health.status = Status::Active;
health.last_probe = Some(Instant::now());
(
health.url.clone(),
health.ema_latency,
health.consecutive_fails,
health.total_queries,
)
};
let callback = self.inner.persist_health.read().await.clone();
if let Some(callback) = callback {
tokio::spawn(async move {
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
});
}
}
#[allow(clippy::significant_drop_tightening)]
pub async fn record_failure(&self, url: &str) {
let snapshot = {
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
health.consecutive_fails += 1;
health.status = compute_status(health.consecutive_fails);
(
health.url.clone(),
health.ema_latency,
health.consecutive_fails,
health.total_queries,
)
};
let callback = self.inner.persist_health.read().await.clone();
if let Some(callback) = callback {
tokio::spawn(async move {
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
});
}
}
pub async fn get_health(&self, url: &str) -> Option<UpstreamHealth> {
self.inner.table.read().await.get(url).cloned()
}
pub async fn sorted_by_latency(&self) -> Vec<UpstreamHealth> {
let mut result = self
.inner
.table
.read()
.await
.values()
.cloned()
.collect::<Vec<_>>();
result.sort_by(|a, b| {
match (a.status == Status::Down, b.status == Status::Down) {
(true, false) => return Ordering::Greater,
(false, true) => return Ordering::Less,
_ => {},
}
if b.ema_latency > 0.0
&& ((a.ema_latency - b.ema_latency).abs() / b.ema_latency) < 0.10
&& a.priority != b.priority
{
return a.priority.cmp(&b.priority);
}
a.ema_latency
.partial_cmp(&b.ema_latency)
.unwrap_or(Ordering::Equal)
});
result
}
pub async fn probe_upstream(&self, url: String) {
if !self.inner.table.read().await.contains_key(&url) {
return;
}
let start = Instant::now();
let ok = self
.inner
.client
.head(format!("{url}/nix-cache-info"))
.send()
.await
.map(|resp| resp.status().as_u16() == 200)
.unwrap_or(false);
if ok {
self
.record_latency(&url, start.elapsed().as_secs_f64() * 1000.0)
.await;
} else {
self.record_failure(&url).await;
}
}
pub async fn run_probe_loop(
&self,
interval: Duration,
mut stop: tokio::sync::watch::Receiver<bool>,
) {
let mut ticker = tokio::time::interval(interval);
loop {
tokio::select! {
_ = stop.changed() => return,
_ = ticker.tick() => {
let urls = self.inner.table.read().await.keys().cloned().collect::<Vec<_>>();
for url in urls {
let prober = self.clone();
tokio::spawn(async move { prober.probe_upstream(url).await; });
}
}
}
}
}
pub async fn add_upstream(&self, url: String, priority: i32) {
let inserted = self
.inner
.table
.write()
.await
.insert(url.clone(), UpstreamHealth::new(url.clone(), priority))
.is_none();
if inserted {
let prober = self.clone();
tokio::spawn(async move {
prober.probe_upstream(url).await;
});
}
}
pub async fn remove_upstream(&self, url: &str) {
self.inner.table.write().await.remove(url);
}
}
const fn compute_status(consecutive_fails: u32) -> Status {
match consecutive_fails {
10.. => Status::Down,
3.. => Status::Degraded,
_ => Status::Active,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn ema_and_status_progression() -> Result<(), Box<dyn std::error::Error>>
{
let p = Prober::new(0.3);
p.add_upstream("https://example.com".into(), 1).await;
p.record_latency("https://example.com", 100.0).await;
p.record_latency("https://example.com", 50.0).await;
let h = p
.get_health("https://example.com")
.await
.ok_or("missing health")?;
assert!((84.0..=86.0).contains(&h.ema_latency));
for _ in 0..10 {
p.record_failure("https://example.com").await;
}
assert_eq!(
p.get_health("https://example.com")
.await
.ok_or("missing health")?
.status,
Status::Down
);
Ok(())
}
}

15
src/main.rs Normal file
View file

@ -0,0 +1,15 @@
mod cli;
mod config;
mod db;
mod discovery;
mod health;
mod mesh;
mod metrics;
mod narinfo;
mod router;
mod server;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
cli::run().await
}

223
src/mesh.rs Normal file
View file

@ -0,0 +1,223 @@
use std::{path::Path, sync::Arc};
use chrono::Utc;
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::{net::UdpSocket, time::Duration};
use crate::db::{Db, RouteEntry};
const MAX_PACKET_SIZE: usize = 65_536;
const HEADER_SIZE: usize = 96;
type DecodedPacket<'a> = (&'a [u8], &'a [u8], &'a [u8], Message);
#[derive(Debug, Error)]
pub enum MeshError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("msgpack: {0}")]
Encode(#[from] rmp_serde::encode::Error),
#[error("decode msgpack: {0}")]
Decode(#[from] rmp_serde::decode::Error),
#[error("packet too short: {0} bytes")]
PacketTooShort(usize),
#[error("invalid signature")]
InvalidSignature,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum MsgType {
Announce = 1,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub r#type: MsgType,
pub node_id: String,
pub timestamp: i64,
pub routes: Vec<RouteEntry>,
}
#[derive(Clone)]
pub struct Node {
signing_key: Arc<SigningKey>,
}
impl Node {
pub async fn new(key_path: &str) -> Result<Self, MeshError> {
if key_path.is_empty() {
return Ok(Self {
signing_key: Arc::new(SigningKey::generate(&mut OsRng)),
});
}
if let Ok(data) = tokio::fs::read(key_path).await
&& (data.len() == 32 || data.len() == 64)
{
let Ok(bytes) = <[u8; 32]>::try_from(&data[..32]) else {
return Err(MeshError::InvalidSignature);
};
return Ok(Self {
signing_key: Arc::new(SigningKey::from_bytes(&bytes)),
});
}
if let Some(parent) = Path::new(key_path).parent() {
tokio::fs::create_dir_all(parent).await?;
}
let key = SigningKey::generate(&mut OsRng);
tokio::fs::write(key_path, key.to_bytes()).await?;
Ok(Self {
signing_key: Arc::new(key),
})
}
pub fn id(&self) -> String {
hex::encode(&self.public_key()[..8])
}
pub fn public_key(&self) -> [u8; 32] {
self.signing_key.verifying_key().to_bytes()
}
pub fn sign(&self, msg: &Message) -> Result<(Vec<u8>, Vec<u8>), MeshError> {
let body = rmp_serde::to_vec(msg)?;
Ok((
body.clone(),
self.signing_key.sign(&body).to_bytes().to_vec(),
))
}
}
pub fn verify(pubkey: &[u8], body: &[u8], sig: &[u8]) -> Result<(), MeshError> {
let pubkey: [u8; 32] =
pubkey.try_into().map_err(|_| MeshError::InvalidSignature)?;
let sig: [u8; 64] =
sig.try_into().map_err(|_| MeshError::InvalidSignature)?;
VerifyingKey::from_bytes(&pubkey)
.map_err(|_| MeshError::InvalidSignature)?
.verify(body, &Signature::from_bytes(&sig))
.map_err(|_| MeshError::InvalidSignature)
}
pub async fn listen_and_serve(
addr: &str,
db: Db,
allowed_keys: Vec<[u8; 32]>,
stop: tokio::sync::watch::Receiver<bool>,
) -> Result<(), MeshError> {
let socket = UdpSocket::bind(addr).await?;
tokio::spawn(async move {
let mut stop = stop;
let mut buf = vec![0; MAX_PACKET_SIZE];
loop {
tokio::select! {
_ = stop.changed() => return,
recv = socket.recv_from(&mut buf) => {
let Ok((n, src)) = recv else { return; };
match decode_packet(&buf[..n]) {
Ok((pubkey, sig, body, msg)) => {
if !allowed_keys.is_empty() && !allowed_keys.iter().any(|k| k.as_slice() == pubkey) {
tracing::warn!(?src, "mesh: rejecting packet from unknown sender");
continue;
}
if let Err(err) = verify(pubkey, body, sig) {
tracing::warn!(?src, error = %err, "mesh: signature verification failed");
continue;
}
if msg.r#type == MsgType::Announce && !msg.routes.is_empty() {
merge_routes(&db, msg.routes).await;
}
}
Err(err) => tracing::warn!(?src, error = %err, "mesh: malformed packet"),
}
}
}
}
});
Ok(())
}
async fn merge_routes(db: &Db, incoming: Vec<RouteEntry>) {
let now = Utc::now();
for route in incoming.into_iter().filter(|route| route.ttl > now) {
let should_set = match db.get_route(&route.store_path).await {
Ok(Some(existing)) if route.latency_ema > existing.latency_ema => false,
Ok(Some(existing))
if route.latency_ema.total_cmp(&existing.latency_ema).is_eq()
&& route.last_verified <= existing.last_verified =>
{
false
},
Ok(_) => true,
Err(err) => {
tracing::warn!(error = %err, store = route.store_path, "mesh: route lookup failed");
false
},
};
if should_set && let Err(err) = db.set_route(&route).await {
tracing::warn!(error = %err, store = route.store_path, "mesh: route merge failed");
}
}
}
pub async fn announce(
peer_addr: &str,
node: &Node,
routes: Vec<RouteEntry>,
) -> Result<(), MeshError> {
let msg = Message {
r#type: MsgType::Announce,
node_id: node.id(),
timestamp: Utc::now().timestamp_nanos_opt().unwrap_or_default(),
routes,
};
let packet = encode_packet(node, &msg)?;
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.send_to(&packet, peer_addr).await?;
Ok(())
}
pub async fn run_gossip_loop(
node: Node,
db: Db,
peers: Vec<String>,
interval: Duration,
mut stop: tokio::sync::watch::Receiver<bool>,
) {
let mut ticker = tokio::time::interval(interval);
loop {
tokio::select! {
_ = stop.changed() => return,
_ = ticker.tick() => {
let Ok(routes) = db.list_recent_routes(100).await else { continue; };
if routes.is_empty() { continue; }
for peer in &peers {
let peer = peer.clone();
let node = node.clone();
let routes = routes.clone();
tokio::spawn(async move { let _ = announce(&peer, &node, routes).await; });
}
}
}
}
}
fn encode_packet(node: &Node, msg: &Message) -> Result<Vec<u8>, MeshError> {
let (body, sig) = node.sign(msg)?;
let mut packet = Vec::with_capacity(HEADER_SIZE + body.len());
packet.extend_from_slice(&node.public_key());
packet.extend_from_slice(&sig);
packet.extend_from_slice(&body);
Ok(packet)
}
fn decode_packet(packet: &[u8]) -> Result<DecodedPacket<'_>, MeshError> {
if packet.len() < HEADER_SIZE {
return Err(MeshError::PacketTooShort(packet.len()));
}
let pubkey = &packet[..32];
let sig = &packet[32..HEADER_SIZE];
let body = &packet[HEADER_SIZE..];
let msg = rmp_serde::from_slice(body)?;
Ok((pubkey, sig, body, msg))
}

109
src/metrics.rs Normal file
View file

@ -0,0 +1,109 @@
use std::sync::OnceLock;
use prometheus::{
Encoder,
HistogramOpts,
HistogramVec,
IntCounter,
IntCounterVec,
IntGauge,
Opts,
Registry,
TextEncoder,
};
pub struct Metrics {
registry: Registry,
pub narinfo_cache_hits: IntCounter,
pub narinfo_cache_misses: IntCounter,
pub narinfo_requests: IntCounterVec,
pub nar_requests: IntCounter,
pub upstream_race_wins: IntCounterVec,
pub route_entries: IntGauge,
pub upstream_latency: HistogramVec,
}
static METRICS: OnceLock<Metrics> = OnceLock::new();
#[expect(
clippy::expect_used,
reason = "metric names and labels are static constants validated during \
startup"
)]
pub fn get() -> &'static Metrics {
METRICS.get_or_init(|| {
let registry = Registry::new();
let narinfo_cache_hits = IntCounter::new(
"ncro_narinfo_cache_hits_total",
"Narinfo requests served from route cache.",
)
.expect("valid metric");
let narinfo_cache_misses = IntCounter::new(
"ncro_narinfo_cache_misses_total",
"Narinfo requests requiring upstream race.",
)
.expect("valid metric");
let narinfo_requests = IntCounterVec::new(
Opts::new("ncro_narinfo_requests_total", "Narinfo requests by status."),
&["status"],
)
.expect("valid metric");
let nar_requests =
IntCounter::new("ncro_nar_requests_total", "NAR streaming requests.")
.expect("valid metric");
let upstream_race_wins = IntCounterVec::new(
Opts::new(
"ncro_upstream_race_wins_total",
"Times each upstream won the narinfo race.",
),
&["upstream"],
)
.expect("valid metric");
let route_entries = IntGauge::new(
"ncro_route_entries",
"Current number of route entries in SQLite.",
)
.expect("valid metric");
let upstream_latency = HistogramVec::new(
HistogramOpts::new(
"ncro_upstream_latency_seconds",
"Upstream narinfo race latency.",
),
&["upstream"],
)
.expect("valid metric");
for collector in [
Box::new(narinfo_cache_hits.clone())
as Box<dyn prometheus::core::Collector>,
Box::new(narinfo_cache_misses.clone()),
Box::new(narinfo_requests.clone()),
Box::new(nar_requests.clone()),
Box::new(upstream_race_wins.clone()),
Box::new(route_entries.clone()),
Box::new(upstream_latency.clone()),
] {
registry.register(collector).expect("register metric");
}
Metrics {
registry,
narinfo_cache_hits,
narinfo_cache_misses,
narinfo_requests,
nar_requests,
upstream_race_wins,
route_entries,
upstream_latency,
}
})
}
pub fn gather() -> String {
let mut buf = Vec::new();
let encoder = TextEncoder::new();
if encoder.encode(&get().registry.gather(), &mut buf).is_err() {
return String::new();
}
String::from_utf8_lossy(&buf).into_owned()
}

207
src/narinfo.rs Normal file
View file

@ -0,0 +1,207 @@
use std::io::{BufRead, BufReader, Read};
use base64::{Engine, engine::general_purpose::STANDARD};
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum NarInfoError {
#[error("read narinfo: {0}")]
Io(#[from] std::io::Error),
#[error("malformed line: {0:?}")]
MalformedLine(String),
#[error("missing StorePath")]
MissingStorePath,
#[error("{field}: {source}")]
ParseInt {
field: &'static str,
source: std::num::ParseIntError,
},
#[error("invalid public key {input:?}: missing ':'")]
MissingPublicKeySeparator { input: String },
#[error("invalid public key {input:?}: {source}")]
InvalidPublicKeyBase64 {
input: String,
source: base64::DecodeError,
},
#[error("invalid public key size {got}, want 32")]
InvalidPublicKeySize { got: usize },
}
#[cfg(test)]
mod tests {
use ed25519_dalek::{Signer, SigningKey};
use rand::rngs::OsRng;
use super::*;
#[test]
fn parses_realistic_narinfo() -> Result<(), NarInfoError> {
let input = "StorePath: /nix/store/abc-hello\nURL: \
nar/abc.nar.xz\nCompression: xz\nFileSize: 42\nNarHash: \
sha256:abc\nNarSize: 123\nReferences: abc-hello dep\nSig: \
key:sig=\n";
let ni = NarInfo::parse(input.as_bytes())?;
assert_eq!(ni.store_path, "/nix/store/abc-hello");
assert_eq!(ni.url, "nar/abc.nar.xz");
assert_eq!(ni.references.len(), 2);
Ok(())
}
#[test]
fn verifies_roundtrip_signature() -> Result<(), NarInfoError> {
let signing = SigningKey::generate(&mut OsRng);
let mut ni = NarInfo {
store_path: "/nix/store/abc-test".into(),
nar_hash: "sha256:abc".into(),
nar_size: 12,
references: vec!["abc-test".into()],
..Default::default()
};
let sig = signing.sign(ni.fingerprint().as_bytes());
let pubkey = format!(
"test:{}",
STANDARD.encode(signing.verifying_key().to_bytes())
);
ni.sig = vec![format!("test:{}", STANDARD.encode(sig.to_bytes()))];
assert!(ni.verify(&pubkey)?);
Ok(())
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct NarInfo {
pub store_path: String,
pub url: String,
pub compression: String,
pub file_hash: String,
pub file_size: u64,
pub nar_hash: String,
pub nar_size: u64,
pub references: Vec<String>,
pub deriver: String,
pub sig: Vec<String>,
pub ca: String,
}
pub fn parse_public_key(
input: &str,
) -> Result<(String, VerifyingKey), NarInfoError> {
let (name, b64) = input.split_once(':').ok_or_else(|| {
NarInfoError::MissingPublicKeySeparator {
input: input.to_string(),
}
})?;
if name.is_empty() {
return Err(NarInfoError::MissingPublicKeySeparator {
input: input.to_string(),
});
}
let raw = STANDARD.decode(b64).map_err(|source| {
NarInfoError::InvalidPublicKeyBase64 {
input: input.to_string(),
source,
}
})?;
let bytes: [u8; 32] = raw.try_into().map_err(|raw: Vec<u8>| {
NarInfoError::InvalidPublicKeySize { got: raw.len() }
})?;
let key = VerifyingKey::from_bytes(&bytes)
.map_err(|_| NarInfoError::InvalidPublicKeySize { got: bytes.len() })?;
Ok((name.to_string(), key))
}
impl NarInfo {
pub fn parse(reader: impl Read) -> Result<Self, NarInfoError> {
let mut narinfo = Self::default();
for line in BufReader::new(reader).lines() {
let line = line?;
if line.is_empty() {
continue;
}
let (key, value) = line
.split_once(": ")
.ok_or_else(|| NarInfoError::MalformedLine(line.clone()))?;
match key {
"StorePath" => narinfo.store_path = value.to_string(),
"URL" => narinfo.url = value.to_string(),
"Compression" => narinfo.compression = value.to_string(),
"FileHash" => narinfo.file_hash = value.to_string(),
"FileSize" => {
narinfo.file_size = value.parse().map_err(|source| {
NarInfoError::ParseInt {
field: "FileSize",
source,
}
})?;
},
"NarHash" => narinfo.nar_hash = value.to_string(),
"NarSize" => {
narinfo.nar_size = value.parse().map_err(|source| {
NarInfoError::ParseInt {
field: "NarSize",
source,
}
})?;
},
"References" => {
if !value.is_empty() {
narinfo.references =
value.split_whitespace().map(str::to_string).collect();
}
},
"Deriver" => narinfo.deriver = value.to_string(),
"Sig" => narinfo.sig.push(value.to_string()),
"CA" => narinfo.ca = value.to_string(),
_ => {},
}
}
if narinfo.store_path.is_empty() {
return Err(NarInfoError::MissingStorePath);
}
Ok(narinfo)
}
pub fn fingerprint(&self) -> String {
let refs = self
.references
.iter()
.map(|reference| {
if reference.starts_with("/nix/store/") {
reference.clone()
} else {
format!("/nix/store/{reference}")
}
})
.collect::<Vec<_>>()
.join(",");
format!(
"1;{};{};{};{}",
self.store_path, self.nar_hash, self.nar_size, refs
)
}
pub fn verify(&self, public_key: &str) -> Result<bool, NarInfoError> {
let (key_name, key) = parse_public_key(public_key)?;
let fingerprint = self.fingerprint();
for sig_line in &self.sig {
let Some((name, b64)) = sig_line.split_once(':') else {
continue;
};
if name != key_name {
continue;
}
let Ok(raw) = STANDARD.decode(b64) else {
continue;
};
let Ok(bytes) = <[u8; 64]>::try_from(raw.as_slice()) else {
continue;
};
let signature = Signature::from_bytes(&bytes);
if key.verify(fingerprint.as_bytes(), &signature).is_ok() {
return Ok(true);
}
}
Ok(false)
}
}

309
src/router.rs Normal file
View file

@ -0,0 +1,309 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use chrono::Utc;
use futures_util::{StreamExt, stream::FuturesUnordered};
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use crate::{
db::{Db, RouteEntry},
health::{Prober, Status},
metrics,
narinfo::NarInfo,
};
#[derive(Debug, Error)]
pub enum RouterError {
#[error("not found in any upstream")]
NotFound,
#[error("all upstreams unavailable")]
UpstreamUnavailable,
#[error("no candidates for {0:?}")]
NoCandidates(String),
#[error("narinfo signature verification failed")]
SignatureVerificationFailed,
#[error(transparent)]
Db(#[from] crate::db::DbError),
}
#[derive(Debug, Clone)]
pub struct ResolveResult {
pub url: String,
pub latency_ms: f64,
pub cache_hit: bool,
pub narinfo_bytes: Option<Vec<u8>>,
}
#[derive(Clone)]
pub struct Router {
inner: Arc<RouterInner>,
}
struct RouterInner {
db: Db,
prober: Prober,
route_ttl: Duration,
race_timeout: Duration,
negative_ttl: Duration,
client: reqwest::Client,
upstream_keys: RwLock<HashMap<String, String>>,
inflight: Mutex<HashMap<String, Arc<Mutex<()>>>>,
}
#[derive(Debug)]
struct RaceResult {
url: String,
latency_ms: f64,
}
impl Router {
pub fn new(
db: Db,
prober: Prober,
route_ttl: Duration,
race_timeout: Duration,
negative_ttl: Duration,
) -> Self {
Self {
inner: Arc::new(RouterInner {
db,
prober,
route_ttl,
race_timeout,
negative_ttl,
client: reqwest::Client::builder()
.timeout(race_timeout)
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
upstream_keys: RwLock::new(HashMap::new()),
inflight: Mutex::new(HashMap::new()),
}),
}
}
pub async fn set_upstream_key(
&self,
url: String,
public_key: String,
) -> Result<(), crate::narinfo::NarInfoError> {
crate::narinfo::parse_public_key(&public_key)?;
self
.inner
.upstream_keys
.write()
.await
.insert(url, public_key);
Ok(())
}
pub async fn resolve(
&self,
store_hash: &str,
candidates: &[String],
) -> Result<ResolveResult, RouterError> {
if self.inner.db.is_negative(store_hash).await? {
return Err(RouterError::NotFound);
}
if let Some(result) = self.valid_cached_route(store_hash).await? {
return Ok(result);
}
metrics::get().narinfo_cache_misses.inc();
let lock = {
let mut inflight = self.inner.inflight.lock().await;
Arc::clone(
inflight
.entry(store_hash.to_string())
.or_insert_with(|| Arc::new(Mutex::new(()))),
)
};
let _guard = lock.lock().await;
if let Some(result) = self.valid_cached_route(store_hash).await? {
self.inner.inflight.lock().await.remove(store_hash);
return Ok(result);
}
let result = self.race(store_hash, candidates).await;
if matches!(result, Err(RouterError::NotFound)) {
let _ = self
.inner
.db
.set_negative(store_hash, self.inner.negative_ttl)
.await;
}
self.inner.inflight.lock().await.remove(store_hash);
result
}
async fn valid_cached_route(
&self,
store_hash: &str,
) -> Result<Option<ResolveResult>, RouterError> {
let Some(entry) = self.inner.db.get_route(store_hash).await? else {
return Ok(None);
};
if !entry.is_valid() {
return Ok(None);
}
let health = self.inner.prober.get_health(&entry.upstream_url).await;
if !health.as_ref().is_none_or(|h| h.status == Status::Active) {
return Ok(None);
}
metrics::get().narinfo_cache_hits.inc();
Ok(Some(ResolveResult {
url: entry.upstream_url,
latency_ms: entry.latency_ema,
cache_hit: true,
narinfo_bytes: None,
}))
}
async fn race(
&self,
store_hash: &str,
candidates: &[String],
) -> Result<ResolveResult, RouterError> {
if candidates.is_empty() {
return Err(RouterError::NoCandidates(store_hash.to_string()));
}
let mut handles = FuturesUnordered::new();
for upstream in candidates {
let upstream = upstream.clone();
let store_hash = store_hash.to_string();
let client = self.inner.client.clone();
handles.push(tokio::spawn(async move {
let start = Instant::now();
let res = client
.head(format!("{upstream}/{store_hash}.narinfo"))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
Ok(RaceResult {
url: upstream,
latency_ms: start.elapsed().as_secs_f64() * 1000.0,
})
},
Ok(_) => Err(false),
Err(_) => Err(true),
}
}));
}
let mut net_errs = 0;
let mut not_founds = 0;
let mut winner: Option<RaceResult> = None;
let deadline = tokio::time::sleep(self.inner.race_timeout);
tokio::pin!(deadline);
while !handles.is_empty() {
tokio::select! {
() = &mut deadline => break,
joined = handles.next() => {
match joined {
Some(Ok(Ok(res))) => if winner.as_ref().is_none_or(|w| res.latency_ms < w.latency_ms) { winner = Some(res); },
Some(Ok(Err(true)) | Err(_)) => net_errs += 1,
Some(Ok(Err(false))) => not_founds += 1,
None => break,
}
}
}
}
let Some(winner) = winner else {
return if net_errs > 0 && not_founds == 0 {
Err(RouterError::UpstreamUnavailable)
} else {
Err(RouterError::NotFound)
};
};
metrics::get()
.upstream_race_wins
.with_label_values(&[&winner.url])
.inc();
metrics::get()
.upstream_latency
.with_label_values(&[&winner.url])
.observe(winner.latency_ms / 1000.0);
let (body, nar_url, nar_hash, nar_size) =
self.fetch_narinfo(&winner.url, store_hash).await?;
let ema = self
.inner
.prober
.get_health(&winner.url)
.await
.map_or(winner.latency_ms, |h| {
0.3f64.mul_add(winner.latency_ms, 0.7 * h.ema_latency)
});
self
.inner
.prober
.record_latency(&winner.url, winner.latency_ms)
.await;
let now = Utc::now();
self
.inner
.db
.set_route(&RouteEntry {
store_path: store_hash.to_string(),
upstream_url: winner.url.clone(),
latency_ms: winner.latency_ms,
latency_ema: ema,
last_verified: now,
query_count: 1,
failure_count: 0,
ttl: now
+ chrono::Duration::from_std(self.inner.route_ttl)
.unwrap_or_default(),
nar_hash,
nar_size,
nar_url,
})
.await?;
Ok(ResolveResult {
url: winner.url,
latency_ms: winner.latency_ms,
cache_hit: false,
narinfo_bytes: body,
})
}
async fn fetch_narinfo(
&self,
upstream: &str,
store_hash: &str,
) -> Result<(Option<Vec<u8>>, String, String, u64), RouterError> {
let Ok(resp) = self
.inner
.client
.get(format!("{upstream}/{store_hash}.narinfo"))
.send()
.await
else {
return Ok((None, String::new(), String::new(), 0));
};
if !resp.status().is_success() {
return Ok((None, String::new(), String::new(), 0));
}
let Ok(bytes) = resp.bytes().await else {
return Ok((None, String::new(), String::new(), 0));
};
let body = bytes.to_vec();
let Ok(parsed) = NarInfo::parse(body.as_slice()) else {
return Ok((Some(body), String::new(), String::new(), 0));
};
if let Some(pubkey) = self.inner.upstream_keys.read().await.get(upstream)
&& !parsed.verify(pubkey).unwrap_or(false)
{
tracing::warn!(
upstream,
store = store_hash,
"narinfo signature verification failed"
);
return Err(RouterError::SignatureVerificationFailed);
}
Ok((Some(body), parsed.url, parsed.nar_hash, parsed.nar_size))
}
}

301
src/server.rs Normal file
View file

@ -0,0 +1,301 @@
use std::sync::Arc;
use axum::{
Router as AxumRouter,
body::Body,
extract::{Path, State},
http::{HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode},
response::{IntoResponse, Response},
routing::get,
};
use bytes::Bytes;
use futures_util::TryStreamExt;
use serde::Serialize;
use crate::{
config::UpstreamConfig,
db::Db,
health::{Prober, Status},
metrics,
router::{Router, RouterError},
};
#[derive(Clone)]
pub struct AppState {
router: Router,
prober: Prober,
db: Db,
upstreams: Vec<UpstreamConfig>,
client: reqwest::Client,
cache_priority: i32,
}
pub fn app(
router: Router,
prober: Prober,
db: Db,
upstreams: Vec<UpstreamConfig>,
cache_priority: i32,
) -> AxumRouter {
let state = AppState {
router,
prober,
db,
upstreams,
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
cache_priority,
};
AxumRouter::new()
.route("/nix-cache-info", get(cache_info).head(cache_info))
.route("/health", get(health))
.route("/metrics", get(metrics_endpoint))
.route("/{hash}.narinfo", get(narinfo).head(narinfo))
.route("/nar/{*path}", get(nar).head(nar))
.with_state(Arc::new(state))
}
async fn cache_info(State(state): State<Arc<AppState>>) -> Response {
(
[("content-type", "text/plain")],
format!(
"StoreDir: /nix/store\nWantMassQuery: 1\nPriority: {}\n",
state.cache_priority
),
)
.into_response()
}
#[derive(Serialize)]
struct HealthResponse {
status: String,
upstreams: Vec<UpstreamStatus>,
}
#[derive(Serialize)]
struct UpstreamStatus {
url: String,
status: String,
latency_ms: f64,
consecutive_fails: u32,
}
async fn health(State(state): State<Arc<AppState>>) -> Response {
let sorted = state.prober.sorted_by_latency().await;
let down_count = sorted.iter().filter(|h| h.status == Status::Down).count();
let any_degraded = sorted.iter().any(|h| h.status == Status::Degraded);
let status = if !sorted.is_empty() && down_count == sorted.len() {
"down"
} else if down_count > 0 || any_degraded {
"degraded"
} else {
"ok"
};
axum::Json(HealthResponse {
status: status.to_string(),
upstreams: sorted
.into_iter()
.map(|h| {
UpstreamStatus {
url: h.url,
status: h.status.as_str().to_string(),
latency_ms: h.ema_latency,
consecutive_fails: h.consecutive_fails,
}
})
.collect(),
})
.into_response()
}
async fn metrics_endpoint() -> Response {
(
[("content-type", "text/plain; version=0.0.4")],
metrics::gather(),
)
.into_response()
}
async fn narinfo(
State(state): State<Arc<AppState>>,
Path(hash): Path<String>,
req: Request<Body>,
) -> Response {
let candidates = upstream_urls(&state).await;
match state.router.resolve(&hash, &candidates).await {
Ok(result) => {
tracing::info!(
hash,
upstream = result.url,
cache_hit = result.cache_hit,
latency_ms = result.latency_ms,
"narinfo routed"
);
metrics::get()
.narinfo_requests
.with_label_values(&["200"])
.inc();
if let Some(bytes) = result.narinfo_bytes {
return (
StatusCode::OK,
[("content-type", "text/x-nix-narinfo")],
Bytes::from(bytes),
)
.into_response();
}
proxy(
&state.client,
req.method().clone(),
req.headers(),
format!("{}{}", result.url, req.uri().path()),
)
.await
},
Err(RouterError::NotFound) => {
metrics::get()
.narinfo_requests
.with_label_values(&["error"])
.inc();
StatusCode::NOT_FOUND.into_response()
},
Err(err) => {
tracing::warn!(hash, error = %err, "narinfo resolve failed");
metrics::get()
.narinfo_requests
.with_label_values(&["error"])
.inc();
(StatusCode::BAD_GATEWAY, "upstream unavailable").into_response()
},
}
}
async fn nar(
State(state): State<Arc<AppState>>,
req: Request<Body>,
) -> Response {
metrics::get().nar_requests.inc();
let nar_url = req.uri().path().trim_start_matches('/').to_string();
if let Ok(Some(entry)) = state.db.get_route_by_nar_url(&nar_url).await
&& entry.is_valid()
&& let Some(resp) = try_nar_upstream(
&state.client,
req.method().clone(),
req.headers(),
&entry.upstream_url,
req.uri().path(),
)
.await
{
return resp;
}
for h in state.prober.sorted_by_latency().await {
if h.status == Status::Down {
continue;
}
if let Some(resp) = try_nar_upstream(
&state.client,
req.method().clone(),
req.headers(),
&h.url,
req.uri().path(),
)
.await
{
return resp;
}
}
StatusCode::NOT_FOUND.into_response()
}
async fn upstream_urls(state: &AppState) -> Vec<String> {
let urls = state
.prober
.sorted_by_latency()
.await
.into_iter()
.filter(|h| h.status != Status::Down)
.map(|h| h.url)
.collect::<Vec<_>>();
if urls.is_empty() {
state.upstreams.iter().map(|u| u.url.clone()).collect()
} else {
urls
}
}
async fn try_nar_upstream(
client: &reqwest::Client,
method: Method,
headers: &HeaderMap,
upstream: &str,
path: &str,
) -> Option<Response> {
let resp =
upstream_request(client, method, headers, format!("{upstream}{path}"))
.await
.ok()?;
if resp.status() == reqwest::StatusCode::NOT_FOUND {
return None;
}
Some(response_from_reqwest(resp))
}
async fn proxy(
client: &reqwest::Client,
method: Method,
headers: &HeaderMap,
url: String,
) -> Response {
match upstream_request(client, method, headers, url).await {
Ok(resp) => response_from_reqwest(resp),
Err(err) => {
tracing::warn!(error = %err, "upstream request failed");
(StatusCode::BAD_GATEWAY, "upstream error").into_response()
},
}
}
async fn upstream_request(
client: &reqwest::Client,
method: Method,
headers: &HeaderMap,
url: String,
) -> reqwest::Result<reqwest::Response> {
let mut req = client.request(method, url);
for name in ["accept", "accept-encoding", "range"] {
if let Some(value) = headers.get(name) {
req = req.header(name, value);
}
}
req.send().await
}
fn response_from_reqwest(resp: reqwest::Response) -> Response {
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::BAD_GATEWAY);
let headers = resp.headers().clone();
let stream = resp.bytes_stream().map_err(std::io::Error::other);
let mut out = Response::builder().status(status);
for name in [
"content-type",
"content-length",
"content-encoding",
"x-nix-signature",
"cache-control",
"last-modified",
] {
if let Some(value) = headers.get(name)
&& let (Ok(header_name), Ok(header_value)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_bytes(value.as_bytes()),
)
{
out = out.header(header_name, header_value);
}
}
out
.body(Body::from_stream(stream))
.unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
}