treewide: rewrite everything in Rust

Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: I786da853078e1013bb8f463ed9e9869c6a6a6964
This commit is contained in:
raf 2026-05-11 12:08:49 +03:00
commit ea96477830
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
43 changed files with 5993 additions and 4594 deletions

191
src/cli.rs Normal file
View file

@ -0,0 +1,191 @@
use clap::Parser;
use tokio::net::TcpListener;
use tracing_subscriber::{EnvFilter, fmt};
use crate::{
config::Config,
db::Db,
discovery::Discovery,
health::Prober,
mesh,
metrics,
router::Router,
server,
};
#[derive(Debug, Parser)]
#[command(name = "ncro", version, about = "Nix Cache Route Optimizer")]
pub struct Args {
#[arg(short, long, env = "NCRO_CONFIG")]
pub config: Option<String>,
}
pub async fn run() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = Config::load(args.config.as_deref())?;
cfg.validate()?;
init_logging(&cfg.logging.level, &cfg.logging.format);
let _ = metrics::get();
let db = Db::open(&cfg.cache.db_path, cfg.cache.max_entries).await?;
let prober = Prober::new(cfg.cache.latency_alpha);
prober.init_upstreams(&cfg.upstreams).await;
for row in db.load_all_health().await.unwrap_or_default() {
prober
.seed(
&row.url,
row.ema_latency,
row.consecutive_fails,
row.total_queries,
)
.await;
}
let db_for_health = db.clone();
prober
.set_health_persistence(move |url, ema, fails, queries| {
let db = db_for_health.clone();
tokio::spawn(async move {
let _ = db
.save_health(
&url,
ema,
i64::from(fails),
i64::try_from(queries).unwrap_or(i64::MAX),
)
.await;
});
})
.await;
for upstream in &cfg.upstreams {
let prober = prober.clone();
let url = upstream.url.clone();
tokio::spawn(async move {
prober.probe_upstream(url).await;
});
}
let router = Router::new(
db.clone(),
prober.clone(),
cfg.cache.ttl.0,
std::time::Duration::from_secs(5),
cfg.cache.negative_ttl.0,
);
for upstream in &cfg.upstreams {
if !upstream.public_key.is_empty() {
router
.set_upstream_key(upstream.url.clone(), upstream.public_key.clone())
.await?;
}
}
let (stop_tx, stop_rx) = tokio::sync::watch::channel(false);
let probe_prober = prober.clone();
let probe_stop = stop_rx.clone();
tokio::spawn(async move {
probe_prober
.run_probe_loop(std::time::Duration::from_secs(30), probe_stop)
.await;
});
let db_for_expiry = db.clone();
let mut expiry_stop = stop_rx.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
tokio::select! {
_ = expiry_stop.changed() => return,
_ = ticker.tick() => {
let _ = db_for_expiry.expire_old_routes().await;
let _ = db_for_expiry.expire_negatives().await;
if let Ok(count) = db_for_expiry.route_count().await { metrics::get().route_entries.set(count); }
}
}
}
});
if cfg.discovery.enabled {
let discovery = Discovery::new(cfg.discovery.clone(), prober.clone())?;
let discovery_stop = stop_rx.clone();
tokio::spawn(async move {
let _ = discovery.run(discovery_stop).await;
});
}
if cfg.mesh.enabled {
let node = mesh::Node::new(&cfg.mesh.private_key_path).await?;
tracing::info!(
node_id = node.id(),
public_key = hex::encode(node.public_key()),
"mesh node identity"
);
let allowed = cfg
.mesh
.peers
.iter()
.filter_map(|p| hex::decode(&p.public_key).ok()?.try_into().ok())
.collect::<Vec<[u8; 32]>>();
mesh::listen_and_serve(
&cfg.mesh.bind_addr,
db.clone(),
allowed,
stop_rx.clone(),
)
.await?;
let peers = cfg
.mesh
.peers
.iter()
.map(|p| p.addr.clone())
.collect::<Vec<_>>();
tokio::spawn(mesh::run_gossip_loop(
node,
db.clone(),
peers,
cfg.mesh.gossip_interval.0,
stop_rx.clone(),
));
}
let app = server::app(
router,
prober,
db,
cfg.upstreams.clone(),
cfg.server.cache_priority,
);
let listener =
TcpListener::bind(normalize_listen(&cfg.server.listen)).await?;
tracing::info!(
addr = cfg.server.listen,
upstreams = cfg.upstreams.len(),
version = env!("CARGO_PKG_VERSION"),
"ncro listening"
);
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
let _ = tokio::signal::ctrl_c().await;
});
let result = server.await;
let _ = stop_tx.send(true);
result?;
Ok(())
}
fn init_logging(level: &str, format_name: &str) {
let filter =
EnvFilter::try_new(level).unwrap_or_else(|_| EnvFilter::new("info"));
if format_name == "text" {
fmt().with_env_filter(filter).init();
} else {
fmt().json().with_env_filter(filter).init();
}
}
fn normalize_listen(listen: &str) -> String {
if listen.starts_with(':') {
format!("0.0.0.0{listen}")
} else {
listen.to_string()
}
}

336
src/config.rs Normal file
View file

@ -0,0 +1,336 @@
use std::{env, fs, time::Duration};
use serde::{Deserialize, Deserializer};
use thiserror::Error;
use url::Url;
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("read config: {0}")]
Read(#[from] std::io::Error),
#[error("parse config: {0}")]
Parse(#[from] serde_yaml::Error),
#[error("{0}")]
Validation(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn loads_defaults() -> Result<(), ConfigError> {
let cfg = Config::load(None)?;
assert_eq!(cfg.server.listen, ":8080");
assert_eq!(cfg.cache.max_entries, 100_000);
assert_eq!(cfg.upstreams.len(), 1);
cfg.validate()?;
Ok(())
}
#[test]
fn parses_duration_yaml() -> Result<(), serde_yaml::Error> {
let cfg: Config = serde_yaml::from_str(
"server:\n read_timeout: 30s\ncache:\n ttl: 2h\n",
)?;
assert_eq!(cfg.server.read_timeout.0, Duration::from_secs(30));
assert_eq!(cfg.cache.ttl.0, Duration::from_secs(7200));
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HumanDuration(pub Duration);
impl Default for HumanDuration {
fn default() -> Self {
Self(Duration::ZERO)
}
}
impl<'de> Deserialize<'de> for HumanDuration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
humantime_serde::deserialize(deserializer).map(Self)
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub struct UpstreamConfig {
pub url: String,
pub priority: i32,
pub public_key: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
pub listen: String,
pub read_timeout: HumanDuration,
pub write_timeout: HumanDuration,
pub cache_priority: i32,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
listen: ":8080".to_string(),
read_timeout: HumanDuration(Duration::from_secs(30)),
write_timeout: HumanDuration(Duration::from_secs(30)),
cache_priority: 30,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct CacheConfig {
pub db_path: String,
pub max_entries: i64,
pub ttl: HumanDuration,
pub negative_ttl: HumanDuration,
pub latency_alpha: f64,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
db_path: "/var/lib/ncro/routes.db".to_string(),
max_entries: 100_000,
ttl: HumanDuration(Duration::from_secs(60 * 60)),
negative_ttl: HumanDuration(Duration::from_secs(10 * 60)),
latency_alpha: 0.3,
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub struct PeerConfig {
pub addr: String,
pub public_key: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct MeshConfig {
pub enabled: bool,
pub bind_addr: String,
pub peers: Vec<PeerConfig>,
#[serde(rename = "private_key")]
pub private_key_path: String,
pub gossip_interval: HumanDuration,
}
impl Default for MeshConfig {
fn default() -> Self {
Self {
enabled: false,
bind_addr: "0.0.0.0:7946".to_string(),
peers: Vec::new(),
private_key_path: String::new(),
gossip_interval: HumanDuration(Duration::from_secs(30)),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct DiscoveryConfig {
pub enabled: bool,
pub service_name: String,
pub domain: String,
pub discovery_time: HumanDuration,
pub priority: i32,
}
impl Default for DiscoveryConfig {
fn default() -> Self {
Self {
enabled: false,
service_name: "_nix-serve._tcp".to_string(),
domain: "local".to_string(),
discovery_time: HumanDuration(Duration::from_secs(5)),
priority: 20,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct LoggingConfig {
pub level: String,
pub format: String,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: "info".to_string(),
format: "json".to_string(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
pub server: ServerConfig,
pub upstreams: Vec<UpstreamConfig>,
pub cache: CacheConfig,
pub mesh: MeshConfig,
pub discovery: DiscoveryConfig,
pub logging: LoggingConfig,
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig::default(),
upstreams: vec![UpstreamConfig {
url: "https://cache.nixos.org".to_string(),
priority: 10,
public_key: String::new(),
}],
cache: CacheConfig::default(),
mesh: MeshConfig::default(),
discovery: DiscoveryConfig::default(),
logging: LoggingConfig::default(),
}
}
}
impl Config {
pub fn load(path: Option<&str>) -> Result<Self, ConfigError> {
let mut cfg = if let Some(path) = path.filter(|p| !p.is_empty()) {
let data = fs::read_to_string(path)?;
serde_yaml::from_str::<Self>(&data)?
} else {
Self::default()
};
if let Ok(v) = env::var("NCRO_LISTEN")
&& !v.is_empty()
{
cfg.server.listen = v;
}
if let Ok(v) = env::var("NCRO_DB_PATH")
&& !v.is_empty()
{
cfg.cache.db_path = v;
}
if let Ok(v) = env::var("NCRO_LOG_LEVEL")
&& !v.is_empty()
{
cfg.logging.level = v;
}
Ok(cfg)
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.upstreams.is_empty() {
return Err(ConfigError::Validation(
"at least one upstream is required".to_string(),
));
}
for (i, upstream) in self.upstreams.iter().enumerate() {
if upstream.url.is_empty() {
return Err(ConfigError::Validation(format!(
"upstream[{i}]: URL is empty"
)));
}
Url::parse(&upstream.url).map_err(|err| {
ConfigError::Validation(format!(
"upstream[{i}]: invalid URL {:?}: {err}",
upstream.url
))
})?;
if !upstream.public_key.is_empty() && !upstream.public_key.contains(':') {
return Err(ConfigError::Validation(format!(
"upstream[{i}]: public_key must be in 'name:base64(key)' Nix format"
)));
}
}
if self.server.listen.is_empty() {
return Err(ConfigError::Validation(
"server.listen is empty".to_string(),
));
}
if self.server.cache_priority < 1 {
return Err(ConfigError::Validation(format!(
"server.cache_priority must be >= 1, got {}",
self.server.cache_priority
)));
}
if self.cache.latency_alpha <= 0.0 || self.cache.latency_alpha >= 1.0 {
return Err(ConfigError::Validation(format!(
"cache.latency_alpha must be between 0 and 1 exclusive, got {}",
self.cache.latency_alpha
)));
}
if self.cache.ttl.0.is_zero() {
return Err(ConfigError::Validation(
"cache.ttl must be positive".to_string(),
));
}
if self.cache.negative_ttl.0.is_zero() {
return Err(ConfigError::Validation(
"cache.negative_ttl must be positive".to_string(),
));
}
if self.cache.max_entries <= 0 {
return Err(ConfigError::Validation(
"cache.max_entries must be positive".to_string(),
));
}
if self.mesh.enabled && self.mesh.peers.is_empty() {
return Err(ConfigError::Validation(
"mesh.enabled is true but no peers configured".to_string(),
));
}
for (i, peer) in self.mesh.peers.iter().enumerate() {
if peer.addr.is_empty() {
return Err(ConfigError::Validation(format!(
"mesh.peers[{i}]: addr is empty"
)));
}
if !peer.public_key.is_empty() {
let bytes = hex::decode(&peer.public_key).map_err(|_| {
ConfigError::Validation(format!(
"mesh.peers[{i}]: public_key must be a hex-encoded 32-byte \
ed25519 key"
))
})?;
if bytes.len() != 32 {
return Err(ConfigError::Validation(format!(
"mesh.peers[{i}]: public_key must be a hex-encoded 32-byte \
ed25519 key"
)));
}
}
}
if self.discovery.enabled {
if self.discovery.service_name.is_empty() {
return Err(ConfigError::Validation(
"discovery.service_name is required when discovery is enabled"
.to_string(),
));
}
if self.discovery.domain.is_empty() {
return Err(ConfigError::Validation(
"discovery.domain is required when discovery is enabled".to_string(),
));
}
if self.discovery.discovery_time.0.is_zero() {
return Err(ConfigError::Validation(
"discovery.discovery_time must be positive".to_string(),
));
}
}
Ok(())
}
}

394
src/db.rs Normal file
View file

@ -0,0 +1,394 @@
use std::{path::Path, time::Duration};
use chrono::{DateTime, TimeZone, Utc};
use sqlx::{
Row,
SqlitePool,
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DbError {
#[error("sqlite: {0}")]
Sqlx(#[from] sqlx::Error),
#[error("create database directory: {0}")]
CreateDir(#[from] std::io::Error),
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct RouteEntry {
pub store_path: String,
pub upstream_url: String,
pub latency_ms: f64,
pub latency_ema: f64,
pub last_verified: DateTime<Utc>,
pub query_count: u32,
pub failure_count: u32,
pub ttl: DateTime<Utc>,
pub nar_hash: String,
pub nar_size: u64,
pub nar_url: String,
}
impl RouteEntry {
pub fn is_valid(&self) -> bool {
Utc::now() < self.ttl
}
}
#[derive(Debug, Clone)]
pub struct HealthRow {
pub url: String,
pub ema_latency: f64,
pub consecutive_fails: i64,
pub total_queries: i64,
}
#[derive(Clone)]
pub struct Db {
pool: SqlitePool,
max_entries: i64,
}
impl Db {
pub async fn open(path: &str, max_entries: i64) -> Result<Self, DbError> {
if path != ":memory:"
&& let Some(parent) = Path::new(path).parent()
{
tokio::fs::create_dir_all(parent).await?;
}
let options = if path == ":memory:" {
SqliteConnectOptions::new().filename(path)
} else {
SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true)
}
.journal_mode(SqliteJournalMode::Wal)
.busy_timeout(Duration::from_secs(5));
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect_with(options)
.await?;
migrate(&pool).await?;
Ok(Self { pool, max_entries })
}
pub async fn get_route(
&self,
store_path: &str,
) -> Result<Option<RouteEntry>, DbError> {
let row = sqlx::query(
r"SELECT store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE store_path = ?",
)
.bind(store_path)
.fetch_optional(&self.pool)
.await?;
Ok(row.as_ref().map(row_to_route))
}
pub async fn get_route_by_nar_url(
&self,
nar_url: &str,
) -> Result<Option<RouteEntry>, DbError> {
let row = sqlx::query(
r"SELECT store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE nar_url = ? AND ttl > ?",
)
.bind(nar_url)
.bind(Utc::now().timestamp())
.fetch_optional(&self.pool)
.await?;
Ok(row.as_ref().map(row_to_route))
}
pub async fn set_route(&self, entry: &RouteEntry) -> Result<(), DbError> {
sqlx::query(
r"INSERT INTO routes
(store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(store_path) DO UPDATE SET
upstream_url = excluded.upstream_url,
latency_ms = excluded.latency_ms,
latency_ema = excluded.latency_ema,
query_count = excluded.query_count,
failure_count = excluded.failure_count,
last_verified = excluded.last_verified,
ttl = excluded.ttl,
nar_hash = excluded.nar_hash,
nar_size = excluded.nar_size,
nar_url = excluded.nar_url",
)
.bind(&entry.store_path)
.bind(&entry.upstream_url)
.bind(entry.latency_ms)
.bind(entry.latency_ema)
.bind(i64::from(entry.query_count))
.bind(i64::from(entry.failure_count))
.bind(entry.last_verified.timestamp())
.bind(entry.ttl.timestamp())
.bind(&entry.nar_hash)
.bind(i64::try_from(entry.nar_size).unwrap_or(i64::MAX))
.bind(&entry.nar_url)
.execute(&self.pool)
.await?;
self.evict_if_needed().await
}
pub async fn expire_old_routes(&self) -> Result<(), DbError> {
sqlx::query("DELETE FROM routes WHERE ttl < ?")
.bind(Utc::now().timestamp())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn list_recent_routes(
&self,
n: i64,
) -> Result<Vec<RouteEntry>, DbError> {
let rows = sqlx::query(
r"SELECT store_path, upstream_url, latency_ms, latency_ema, query_count, failure_count,
last_verified, ttl, nar_hash, nar_size, nar_url
FROM routes WHERE ttl > ? ORDER BY last_verified DESC LIMIT ?",
)
.bind(Utc::now().timestamp())
.bind(n)
.fetch_all(&self.pool)
.await?;
Ok(rows.iter().map(row_to_route).collect())
}
pub async fn route_count(&self) -> Result<i64, DbError> {
Ok(
sqlx::query("SELECT COUNT(*) FROM routes")
.fetch_one(&self.pool)
.await?
.get::<i64, _>(0),
)
}
pub async fn set_negative(
&self,
store_path: &str,
ttl: Duration,
) -> Result<(), DbError> {
sqlx::query(
r"INSERT INTO negative_cache (store_path, expires_at) VALUES (?, ?)
ON CONFLICT(store_path) DO UPDATE SET expires_at = excluded.expires_at",
)
.bind(store_path)
.bind((Utc::now() + chrono::Duration::from_std(ttl).unwrap_or_default()).timestamp())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn is_negative(&self, store_path: &str) -> Result<bool, DbError> {
Ok(
sqlx::query(
"SELECT EXISTS(SELECT 1 FROM negative_cache WHERE store_path = ? AND \
expires_at > ?)",
)
.bind(store_path)
.bind(Utc::now().timestamp())
.fetch_one(&self.pool)
.await?
.get::<i64, _>(0)
!= 0,
)
}
pub async fn expire_negatives(&self) -> Result<(), DbError> {
sqlx::query("DELETE FROM negative_cache WHERE expires_at < ?")
.bind(Utc::now().timestamp())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn save_health(
&self,
url: &str,
ema: f64,
consecutive_fails: i64,
total_queries: i64,
) -> Result<(), DbError> {
sqlx::query(
r"INSERT INTO upstream_health (url, ema_latency, consecutive_fails, total_queries)
VALUES (?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
ema_latency = excluded.ema_latency,
consecutive_fails = excluded.consecutive_fails,
total_queries = excluded.total_queries",
)
.bind(url)
.bind(ema)
.bind(consecutive_fails)
.bind(total_queries)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn load_all_health(&self) -> Result<Vec<HealthRow>, DbError> {
let rows = sqlx::query(
"SELECT url, ema_latency, consecutive_fails, total_queries FROM \
upstream_health",
)
.fetch_all(&self.pool)
.await?;
Ok(
rows
.into_iter()
.map(|row| {
HealthRow {
url: row.get("url"),
ema_latency: row.get("ema_latency"),
consecutive_fails: row.get("consecutive_fails"),
total_queries: row.get("total_queries"),
}
})
.collect(),
)
}
async fn evict_if_needed(&self) -> Result<(), DbError> {
sqlx::query(
r"DELETE FROM routes WHERE store_path IN (
SELECT store_path FROM routes ORDER BY last_verified ASC
LIMIT MAX(0, (SELECT COUNT(*) FROM routes) - ?)
)",
)
.bind(self.max_entries)
.execute(&self.pool)
.await?;
Ok(())
}
}
async fn migrate(pool: &SqlitePool) -> Result<(), DbError> {
sqlx::query(
r"CREATE TABLE IF NOT EXISTS routes (
store_path TEXT PRIMARY KEY,
upstream_url TEXT NOT NULL,
latency_ms REAL NOT NULL DEFAULT 0,
latency_ema REAL NOT NULL DEFAULT 0,
query_count INTEGER NOT NULL DEFAULT 1,
failure_count INTEGER NOT NULL DEFAULT 0,
last_verified INTEGER NOT NULL DEFAULT 0,
ttl INTEGER NOT NULL,
nar_hash TEXT NOT NULL DEFAULT '',
nar_size INTEGER NOT NULL DEFAULT 0,
nar_url TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
)",
)
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_routes_ttl ON routes(ttl)")
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_routes_last_verified ON \
routes(last_verified)",
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_routes_nar_url ON routes(nar_url)",
)
.execute(pool)
.await?;
sqlx::query(
r"CREATE TABLE IF NOT EXISTS upstream_health (
url TEXT PRIMARY KEY,
ema_latency REAL NOT NULL DEFAULT 0,
consecutive_fails INTEGER NOT NULL DEFAULT 0,
total_queries INTEGER NOT NULL DEFAULT 0
)",
)
.execute(pool)
.await?;
sqlx::query(
r"CREATE TABLE IF NOT EXISTS negative_cache (
store_path TEXT PRIMARY KEY,
expires_at INTEGER NOT NULL
)",
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_negative_expires ON \
negative_cache(expires_at)",
)
.execute(pool)
.await?;
Ok(())
}
fn row_to_route(row: &sqlx::sqlite::SqliteRow) -> RouteEntry {
let query_count = row.get::<i64, _>("query_count");
let failure_count = row.get::<i64, _>("failure_count");
let nar_size = row.get::<i64, _>("nar_size");
RouteEntry {
store_path: row.get("store_path"),
upstream_url: row.get("upstream_url"),
latency_ms: row.get("latency_ms"),
latency_ema: row.get("latency_ema"),
query_count: u32::try_from(query_count).unwrap_or_default(),
failure_count: u32::try_from(failure_count).unwrap_or_default(),
last_verified: Utc
.timestamp_opt(row.get("last_verified"), 0)
.single()
.unwrap_or_else(Utc::now),
ttl: Utc
.timestamp_opt(row.get("ttl"), 0)
.single()
.unwrap_or_else(Utc::now),
nar_hash: row.get("nar_hash"),
nar_size: u64::try_from(nar_size).unwrap_or_default(),
nar_url: row.get("nar_url"),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn route_roundtrip_and_negative_cache() -> Result<(), DbError> {
let db = Db::open(":memory:", 100).await?;
let now = Utc::now();
let entry = RouteEntry {
store_path: "abc123".into(),
upstream_url: "https://cache.nixos.org".into(),
latency_ms: 10.0,
latency_ema: 10.0,
last_verified: now,
query_count: 1,
failure_count: 0,
ttl: now + chrono::Duration::hours(1),
nar_hash: "sha256:abc".into(),
nar_size: 42,
nar_url: "nar/abc.nar.xz".into(),
};
db.set_route(&entry).await?;
let got = db
.get_route("abc123")
.await?
.ok_or(sqlx::Error::RowNotFound)?;
assert_eq!(got.upstream_url, entry.upstream_url);
assert!(db.get_route_by_nar_url("nar/abc.nar.xz").await?.is_some());
db.set_negative("missing", Duration::from_secs(60)).await?;
assert!(db.is_negative("missing").await?);
Ok(())
}
}

74
src/discovery.rs Normal file
View file

@ -0,0 +1,74 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use mdns_sd::{ServiceDaemon, ServiceEvent};
use tokio::sync::{Mutex, watch};
use crate::{config::DiscoveryConfig, health::Prober};
pub struct Discovery {
cfg: DiscoveryConfig,
prober: Prober,
daemon: ServiceDaemon,
peers: Arc<Mutex<HashMap<String, (String, Instant)>>>,
}
impl Discovery {
pub fn new(cfg: DiscoveryConfig, prober: Prober) -> anyhow::Result<Self> {
Ok(Self {
cfg,
prober,
daemon: ServiceDaemon::new()?,
peers: Arc::new(Mutex::new(HashMap::new())),
})
}
pub async fn run(
self,
mut stop: watch::Receiver<bool>,
) -> anyhow::Result<()> {
let service = format!(
"{}.{}.",
self.cfg.service_name.trim_end_matches('.'),
self.cfg.domain.trim_end_matches('.')
);
let receiver = self.daemon.browse(&service)?;
let peers = Arc::clone(&self.peers);
let prober = self.prober.clone();
let priority = self.cfg.priority;
let mut cleanup = tokio::time::interval(Duration::from_secs(10));
let expiration = if self.cfg.discovery_time.0.is_zero() {
Duration::from_secs(30)
} else {
self.cfg.discovery_time.0 * 3
};
loop {
tokio::select! {
_ = stop.changed() => { let _ = self.daemon.shutdown(); return Ok(()); }
_ = cleanup.tick() => {
let stale = {
let mut guard = peers.lock().await;
let now = Instant::now();
let stale = guard.iter().filter(|(_, (_, seen))| now.duration_since(*seen) > expiration).map(|(k, (u, _))| (k.clone(), u.clone())).collect::<Vec<_>>();
for (key, _) in &stale { guard.remove(key); }
stale
};
for (_, url) in stale { tracing::info!(url, "removing stale peer"); prober.remove_upstream(&url).await; }
}
event = tokio::task::spawn_blocking({ let receiver = receiver.clone(); move || receiver.recv_timeout(Duration::from_millis(500)).ok() }) => {
if let Ok(Some(ServiceEvent::ServiceResolved(info))) = event {
let Some(addr) = info.get_addresses().iter().next().map(mdns_sd::ScopedIp::to_ip_addr) else { continue; };
let url = format!("http://{}", std::net::SocketAddr::new(addr, info.get_port()));
let key = info.get_fullname().to_string();
let is_new = peers.lock().await.insert(key, (url.clone(), Instant::now())).is_none();
if is_new { tracing::info!(url, "discovered nix-serve instance"); prober.add_upstream(url, priority).await; }
}
}
}
}
}
}

310
src/health.rs Normal file
View file

@ -0,0 +1,310 @@
use std::{
cmp::Ordering,
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::RwLock;
use crate::config::UpstreamConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Status {
Active,
Degraded,
Down,
}
impl Status {
pub const fn as_str(self) -> &'static str {
match self {
Self::Active => "ACTIVE",
Self::Degraded => "DEGRADED",
Self::Down => "DOWN",
}
}
}
#[derive(Debug, Clone)]
pub struct UpstreamHealth {
pub url: String,
pub priority: i32,
pub ema_latency: f64,
pub last_probe: Option<Instant>,
pub consecutive_fails: u32,
pub total_queries: u64,
pub status: Status,
}
impl UpstreamHealth {
const fn new(url: String, priority: i32) -> Self {
Self {
url,
priority,
ema_latency: 0.0,
last_probe: None,
consecutive_fails: 0,
total_queries: 0,
status: Status::Active,
}
}
}
type PersistHealth = Arc<dyn Fn(String, f64, u32, u64) + Send + Sync>;
#[derive(Clone)]
pub struct Prober {
inner: Arc<ProberInner>,
}
struct ProberInner {
alpha: f64,
table: RwLock<HashMap<String, UpstreamHealth>>,
client: reqwest::Client,
persist_health: RwLock<Option<PersistHealth>>,
}
impl Prober {
pub fn new(alpha: f64) -> Self {
Self {
inner: Arc::new(ProberInner {
alpha,
table: RwLock::new(HashMap::new()),
client: reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
persist_health: RwLock::new(None),
}),
}
}
pub async fn init_upstreams(&self, upstreams: &[UpstreamConfig]) {
let mut table = self.inner.table.write().await;
for upstream in upstreams {
table.entry(upstream.url.clone()).or_insert_with(|| {
UpstreamHealth::new(upstream.url.clone(), upstream.priority)
});
}
}
#[allow(clippy::significant_drop_tightening)]
pub async fn seed(
&self,
url: &str,
ema_latency: f64,
consecutive_fails: i64,
total_queries: i64,
) {
{
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
health.ema_latency = ema_latency;
health.total_queries =
u64::try_from(total_queries.max(0)).unwrap_or_default();
health.consecutive_fails =
u32::try_from(consecutive_fails.max(0)).unwrap_or(u32::MAX);
health.status = compute_status(health.consecutive_fails);
}
}
pub async fn set_health_persistence<F>(&self, f: F)
where
F: Fn(String, f64, u32, u64) + Send + Sync + 'static,
{
*self.inner.persist_health.write().await = Some(Arc::new(f));
}
#[allow(clippy::significant_drop_tightening)]
pub async fn record_latency(&self, url: &str, ms: f64) {
let snapshot = {
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
if health.total_queries == 0 {
health.ema_latency = ms;
} else {
health.ema_latency = self
.inner
.alpha
.mul_add(ms, (1.0 - self.inner.alpha) * health.ema_latency);
}
health.consecutive_fails = 0;
health.total_queries += 1;
health.status = Status::Active;
health.last_probe = Some(Instant::now());
(
health.url.clone(),
health.ema_latency,
health.consecutive_fails,
health.total_queries,
)
};
let callback = self.inner.persist_health.read().await.clone();
if let Some(callback) = callback {
tokio::spawn(async move {
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
});
}
}
#[allow(clippy::significant_drop_tightening)]
pub async fn record_failure(&self, url: &str) {
let snapshot = {
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
health.consecutive_fails += 1;
health.status = compute_status(health.consecutive_fails);
(
health.url.clone(),
health.ema_latency,
health.consecutive_fails,
health.total_queries,
)
};
let callback = self.inner.persist_health.read().await.clone();
if let Some(callback) = callback {
tokio::spawn(async move {
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
});
}
}
pub async fn get_health(&self, url: &str) -> Option<UpstreamHealth> {
self.inner.table.read().await.get(url).cloned()
}
pub async fn sorted_by_latency(&self) -> Vec<UpstreamHealth> {
let mut result = self
.inner
.table
.read()
.await
.values()
.cloned()
.collect::<Vec<_>>();
result.sort_by(|a, b| {
match (a.status == Status::Down, b.status == Status::Down) {
(true, false) => return Ordering::Greater,
(false, true) => return Ordering::Less,
_ => {},
}
if b.ema_latency > 0.0
&& ((a.ema_latency - b.ema_latency).abs() / b.ema_latency) < 0.10
&& a.priority != b.priority
{
return a.priority.cmp(&b.priority);
}
a.ema_latency
.partial_cmp(&b.ema_latency)
.unwrap_or(Ordering::Equal)
});
result
}
pub async fn probe_upstream(&self, url: String) {
if !self.inner.table.read().await.contains_key(&url) {
return;
}
let start = Instant::now();
let ok = self
.inner
.client
.head(format!("{url}/nix-cache-info"))
.send()
.await
.map(|resp| resp.status().as_u16() == 200)
.unwrap_or(false);
if ok {
self
.record_latency(&url, start.elapsed().as_secs_f64() * 1000.0)
.await;
} else {
self.record_failure(&url).await;
}
}
pub async fn run_probe_loop(
&self,
interval: Duration,
mut stop: tokio::sync::watch::Receiver<bool>,
) {
let mut ticker = tokio::time::interval(interval);
loop {
tokio::select! {
_ = stop.changed() => return,
_ = ticker.tick() => {
let urls = self.inner.table.read().await.keys().cloned().collect::<Vec<_>>();
for url in urls {
let prober = self.clone();
tokio::spawn(async move { prober.probe_upstream(url).await; });
}
}
}
}
}
pub async fn add_upstream(&self, url: String, priority: i32) {
let inserted = self
.inner
.table
.write()
.await
.insert(url.clone(), UpstreamHealth::new(url.clone(), priority))
.is_none();
if inserted {
let prober = self.clone();
tokio::spawn(async move {
prober.probe_upstream(url).await;
});
}
}
pub async fn remove_upstream(&self, url: &str) {
self.inner.table.write().await.remove(url);
}
}
const fn compute_status(consecutive_fails: u32) -> Status {
match consecutive_fails {
10.. => Status::Down,
3.. => Status::Degraded,
_ => Status::Active,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn ema_and_status_progression() -> Result<(), Box<dyn std::error::Error>>
{
let p = Prober::new(0.3);
p.add_upstream("https://example.com".into(), 1).await;
p.record_latency("https://example.com", 100.0).await;
p.record_latency("https://example.com", 50.0).await;
let h = p
.get_health("https://example.com")
.await
.ok_or("missing health")?;
assert!((84.0..=86.0).contains(&h.ema_latency));
for _ in 0..10 {
p.record_failure("https://example.com").await;
}
assert_eq!(
p.get_health("https://example.com")
.await
.ok_or("missing health")?
.status,
Status::Down
);
Ok(())
}
}

15
src/main.rs Normal file
View file

@ -0,0 +1,15 @@
mod cli;
mod config;
mod db;
mod discovery;
mod health;
mod mesh;
mod metrics;
mod narinfo;
mod router;
mod server;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
cli::run().await
}

223
src/mesh.rs Normal file
View file

@ -0,0 +1,223 @@
use std::{path::Path, sync::Arc};
use chrono::Utc;
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::{net::UdpSocket, time::Duration};
use crate::db::{Db, RouteEntry};
const MAX_PACKET_SIZE: usize = 65_536;
const HEADER_SIZE: usize = 96;
type DecodedPacket<'a> = (&'a [u8], &'a [u8], &'a [u8], Message);
#[derive(Debug, Error)]
pub enum MeshError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("msgpack: {0}")]
Encode(#[from] rmp_serde::encode::Error),
#[error("decode msgpack: {0}")]
Decode(#[from] rmp_serde::decode::Error),
#[error("packet too short: {0} bytes")]
PacketTooShort(usize),
#[error("invalid signature")]
InvalidSignature,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum MsgType {
Announce = 1,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub r#type: MsgType,
pub node_id: String,
pub timestamp: i64,
pub routes: Vec<RouteEntry>,
}
#[derive(Clone)]
pub struct Node {
signing_key: Arc<SigningKey>,
}
impl Node {
pub async fn new(key_path: &str) -> Result<Self, MeshError> {
if key_path.is_empty() {
return Ok(Self {
signing_key: Arc::new(SigningKey::generate(&mut OsRng)),
});
}
if let Ok(data) = tokio::fs::read(key_path).await
&& (data.len() == 32 || data.len() == 64)
{
let Ok(bytes) = <[u8; 32]>::try_from(&data[..32]) else {
return Err(MeshError::InvalidSignature);
};
return Ok(Self {
signing_key: Arc::new(SigningKey::from_bytes(&bytes)),
});
}
if let Some(parent) = Path::new(key_path).parent() {
tokio::fs::create_dir_all(parent).await?;
}
let key = SigningKey::generate(&mut OsRng);
tokio::fs::write(key_path, key.to_bytes()).await?;
Ok(Self {
signing_key: Arc::new(key),
})
}
pub fn id(&self) -> String {
hex::encode(&self.public_key()[..8])
}
pub fn public_key(&self) -> [u8; 32] {
self.signing_key.verifying_key().to_bytes()
}
pub fn sign(&self, msg: &Message) -> Result<(Vec<u8>, Vec<u8>), MeshError> {
let body = rmp_serde::to_vec(msg)?;
Ok((
body.clone(),
self.signing_key.sign(&body).to_bytes().to_vec(),
))
}
}
pub fn verify(pubkey: &[u8], body: &[u8], sig: &[u8]) -> Result<(), MeshError> {
let pubkey: [u8; 32] =
pubkey.try_into().map_err(|_| MeshError::InvalidSignature)?;
let sig: [u8; 64] =
sig.try_into().map_err(|_| MeshError::InvalidSignature)?;
VerifyingKey::from_bytes(&pubkey)
.map_err(|_| MeshError::InvalidSignature)?
.verify(body, &Signature::from_bytes(&sig))
.map_err(|_| MeshError::InvalidSignature)
}
pub async fn listen_and_serve(
addr: &str,
db: Db,
allowed_keys: Vec<[u8; 32]>,
stop: tokio::sync::watch::Receiver<bool>,
) -> Result<(), MeshError> {
let socket = UdpSocket::bind(addr).await?;
tokio::spawn(async move {
let mut stop = stop;
let mut buf = vec![0; MAX_PACKET_SIZE];
loop {
tokio::select! {
_ = stop.changed() => return,
recv = socket.recv_from(&mut buf) => {
let Ok((n, src)) = recv else { return; };
match decode_packet(&buf[..n]) {
Ok((pubkey, sig, body, msg)) => {
if !allowed_keys.is_empty() && !allowed_keys.iter().any(|k| k.as_slice() == pubkey) {
tracing::warn!(?src, "mesh: rejecting packet from unknown sender");
continue;
}
if let Err(err) = verify(pubkey, body, sig) {
tracing::warn!(?src, error = %err, "mesh: signature verification failed");
continue;
}
if msg.r#type == MsgType::Announce && !msg.routes.is_empty() {
merge_routes(&db, msg.routes).await;
}
}
Err(err) => tracing::warn!(?src, error = %err, "mesh: malformed packet"),
}
}
}
}
});
Ok(())
}
async fn merge_routes(db: &Db, incoming: Vec<RouteEntry>) {
let now = Utc::now();
for route in incoming.into_iter().filter(|route| route.ttl > now) {
let should_set = match db.get_route(&route.store_path).await {
Ok(Some(existing)) if route.latency_ema > existing.latency_ema => false,
Ok(Some(existing))
if route.latency_ema.total_cmp(&existing.latency_ema).is_eq()
&& route.last_verified <= existing.last_verified =>
{
false
},
Ok(_) => true,
Err(err) => {
tracing::warn!(error = %err, store = route.store_path, "mesh: route lookup failed");
false
},
};
if should_set && let Err(err) = db.set_route(&route).await {
tracing::warn!(error = %err, store = route.store_path, "mesh: route merge failed");
}
}
}
pub async fn announce(
peer_addr: &str,
node: &Node,
routes: Vec<RouteEntry>,
) -> Result<(), MeshError> {
let msg = Message {
r#type: MsgType::Announce,
node_id: node.id(),
timestamp: Utc::now().timestamp_nanos_opt().unwrap_or_default(),
routes,
};
let packet = encode_packet(node, &msg)?;
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.send_to(&packet, peer_addr).await?;
Ok(())
}
pub async fn run_gossip_loop(
node: Node,
db: Db,
peers: Vec<String>,
interval: Duration,
mut stop: tokio::sync::watch::Receiver<bool>,
) {
let mut ticker = tokio::time::interval(interval);
loop {
tokio::select! {
_ = stop.changed() => return,
_ = ticker.tick() => {
let Ok(routes) = db.list_recent_routes(100).await else { continue; };
if routes.is_empty() { continue; }
for peer in &peers {
let peer = peer.clone();
let node = node.clone();
let routes = routes.clone();
tokio::spawn(async move { let _ = announce(&peer, &node, routes).await; });
}
}
}
}
}
fn encode_packet(node: &Node, msg: &Message) -> Result<Vec<u8>, MeshError> {
let (body, sig) = node.sign(msg)?;
let mut packet = Vec::with_capacity(HEADER_SIZE + body.len());
packet.extend_from_slice(&node.public_key());
packet.extend_from_slice(&sig);
packet.extend_from_slice(&body);
Ok(packet)
}
fn decode_packet(packet: &[u8]) -> Result<DecodedPacket<'_>, MeshError> {
if packet.len() < HEADER_SIZE {
return Err(MeshError::PacketTooShort(packet.len()));
}
let pubkey = &packet[..32];
let sig = &packet[32..HEADER_SIZE];
let body = &packet[HEADER_SIZE..];
let msg = rmp_serde::from_slice(body)?;
Ok((pubkey, sig, body, msg))
}

109
src/metrics.rs Normal file
View file

@ -0,0 +1,109 @@
use std::sync::OnceLock;
use prometheus::{
Encoder,
HistogramOpts,
HistogramVec,
IntCounter,
IntCounterVec,
IntGauge,
Opts,
Registry,
TextEncoder,
};
pub struct Metrics {
registry: Registry,
pub narinfo_cache_hits: IntCounter,
pub narinfo_cache_misses: IntCounter,
pub narinfo_requests: IntCounterVec,
pub nar_requests: IntCounter,
pub upstream_race_wins: IntCounterVec,
pub route_entries: IntGauge,
pub upstream_latency: HistogramVec,
}
static METRICS: OnceLock<Metrics> = OnceLock::new();
#[expect(
clippy::expect_used,
reason = "metric names and labels are static constants validated during \
startup"
)]
pub fn get() -> &'static Metrics {
METRICS.get_or_init(|| {
let registry = Registry::new();
let narinfo_cache_hits = IntCounter::new(
"ncro_narinfo_cache_hits_total",
"Narinfo requests served from route cache.",
)
.expect("valid metric");
let narinfo_cache_misses = IntCounter::new(
"ncro_narinfo_cache_misses_total",
"Narinfo requests requiring upstream race.",
)
.expect("valid metric");
let narinfo_requests = IntCounterVec::new(
Opts::new("ncro_narinfo_requests_total", "Narinfo requests by status."),
&["status"],
)
.expect("valid metric");
let nar_requests =
IntCounter::new("ncro_nar_requests_total", "NAR streaming requests.")
.expect("valid metric");
let upstream_race_wins = IntCounterVec::new(
Opts::new(
"ncro_upstream_race_wins_total",
"Times each upstream won the narinfo race.",
),
&["upstream"],
)
.expect("valid metric");
let route_entries = IntGauge::new(
"ncro_route_entries",
"Current number of route entries in SQLite.",
)
.expect("valid metric");
let upstream_latency = HistogramVec::new(
HistogramOpts::new(
"ncro_upstream_latency_seconds",
"Upstream narinfo race latency.",
),
&["upstream"],
)
.expect("valid metric");
for collector in [
Box::new(narinfo_cache_hits.clone())
as Box<dyn prometheus::core::Collector>,
Box::new(narinfo_cache_misses.clone()),
Box::new(narinfo_requests.clone()),
Box::new(nar_requests.clone()),
Box::new(upstream_race_wins.clone()),
Box::new(route_entries.clone()),
Box::new(upstream_latency.clone()),
] {
registry.register(collector).expect("register metric");
}
Metrics {
registry,
narinfo_cache_hits,
narinfo_cache_misses,
narinfo_requests,
nar_requests,
upstream_race_wins,
route_entries,
upstream_latency,
}
})
}
pub fn gather() -> String {
let mut buf = Vec::new();
let encoder = TextEncoder::new();
if encoder.encode(&get().registry.gather(), &mut buf).is_err() {
return String::new();
}
String::from_utf8_lossy(&buf).into_owned()
}

207
src/narinfo.rs Normal file
View file

@ -0,0 +1,207 @@
use std::io::{BufRead, BufReader, Read};
use base64::{Engine, engine::general_purpose::STANDARD};
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum NarInfoError {
#[error("read narinfo: {0}")]
Io(#[from] std::io::Error),
#[error("malformed line: {0:?}")]
MalformedLine(String),
#[error("missing StorePath")]
MissingStorePath,
#[error("{field}: {source}")]
ParseInt {
field: &'static str,
source: std::num::ParseIntError,
},
#[error("invalid public key {input:?}: missing ':'")]
MissingPublicKeySeparator { input: String },
#[error("invalid public key {input:?}: {source}")]
InvalidPublicKeyBase64 {
input: String,
source: base64::DecodeError,
},
#[error("invalid public key size {got}, want 32")]
InvalidPublicKeySize { got: usize },
}
#[cfg(test)]
mod tests {
use ed25519_dalek::{Signer, SigningKey};
use rand::rngs::OsRng;
use super::*;
#[test]
fn parses_realistic_narinfo() -> Result<(), NarInfoError> {
let input = "StorePath: /nix/store/abc-hello\nURL: \
nar/abc.nar.xz\nCompression: xz\nFileSize: 42\nNarHash: \
sha256:abc\nNarSize: 123\nReferences: abc-hello dep\nSig: \
key:sig=\n";
let ni = NarInfo::parse(input.as_bytes())?;
assert_eq!(ni.store_path, "/nix/store/abc-hello");
assert_eq!(ni.url, "nar/abc.nar.xz");
assert_eq!(ni.references.len(), 2);
Ok(())
}
#[test]
fn verifies_roundtrip_signature() -> Result<(), NarInfoError> {
let signing = SigningKey::generate(&mut OsRng);
let mut ni = NarInfo {
store_path: "/nix/store/abc-test".into(),
nar_hash: "sha256:abc".into(),
nar_size: 12,
references: vec!["abc-test".into()],
..Default::default()
};
let sig = signing.sign(ni.fingerprint().as_bytes());
let pubkey = format!(
"test:{}",
STANDARD.encode(signing.verifying_key().to_bytes())
);
ni.sig = vec![format!("test:{}", STANDARD.encode(sig.to_bytes()))];
assert!(ni.verify(&pubkey)?);
Ok(())
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct NarInfo {
pub store_path: String,
pub url: String,
pub compression: String,
pub file_hash: String,
pub file_size: u64,
pub nar_hash: String,
pub nar_size: u64,
pub references: Vec<String>,
pub deriver: String,
pub sig: Vec<String>,
pub ca: String,
}
pub fn parse_public_key(
input: &str,
) -> Result<(String, VerifyingKey), NarInfoError> {
let (name, b64) = input.split_once(':').ok_or_else(|| {
NarInfoError::MissingPublicKeySeparator {
input: input.to_string(),
}
})?;
if name.is_empty() {
return Err(NarInfoError::MissingPublicKeySeparator {
input: input.to_string(),
});
}
let raw = STANDARD.decode(b64).map_err(|source| {
NarInfoError::InvalidPublicKeyBase64 {
input: input.to_string(),
source,
}
})?;
let bytes: [u8; 32] = raw.try_into().map_err(|raw: Vec<u8>| {
NarInfoError::InvalidPublicKeySize { got: raw.len() }
})?;
let key = VerifyingKey::from_bytes(&bytes)
.map_err(|_| NarInfoError::InvalidPublicKeySize { got: bytes.len() })?;
Ok((name.to_string(), key))
}
impl NarInfo {
pub fn parse(reader: impl Read) -> Result<Self, NarInfoError> {
let mut narinfo = Self::default();
for line in BufReader::new(reader).lines() {
let line = line?;
if line.is_empty() {
continue;
}
let (key, value) = line
.split_once(": ")
.ok_or_else(|| NarInfoError::MalformedLine(line.clone()))?;
match key {
"StorePath" => narinfo.store_path = value.to_string(),
"URL" => narinfo.url = value.to_string(),
"Compression" => narinfo.compression = value.to_string(),
"FileHash" => narinfo.file_hash = value.to_string(),
"FileSize" => {
narinfo.file_size = value.parse().map_err(|source| {
NarInfoError::ParseInt {
field: "FileSize",
source,
}
})?;
},
"NarHash" => narinfo.nar_hash = value.to_string(),
"NarSize" => {
narinfo.nar_size = value.parse().map_err(|source| {
NarInfoError::ParseInt {
field: "NarSize",
source,
}
})?;
},
"References" => {
if !value.is_empty() {
narinfo.references =
value.split_whitespace().map(str::to_string).collect();
}
},
"Deriver" => narinfo.deriver = value.to_string(),
"Sig" => narinfo.sig.push(value.to_string()),
"CA" => narinfo.ca = value.to_string(),
_ => {},
}
}
if narinfo.store_path.is_empty() {
return Err(NarInfoError::MissingStorePath);
}
Ok(narinfo)
}
pub fn fingerprint(&self) -> String {
let refs = self
.references
.iter()
.map(|reference| {
if reference.starts_with("/nix/store/") {
reference.clone()
} else {
format!("/nix/store/{reference}")
}
})
.collect::<Vec<_>>()
.join(",");
format!(
"1;{};{};{};{}",
self.store_path, self.nar_hash, self.nar_size, refs
)
}
pub fn verify(&self, public_key: &str) -> Result<bool, NarInfoError> {
let (key_name, key) = parse_public_key(public_key)?;
let fingerprint = self.fingerprint();
for sig_line in &self.sig {
let Some((name, b64)) = sig_line.split_once(':') else {
continue;
};
if name != key_name {
continue;
}
let Ok(raw) = STANDARD.decode(b64) else {
continue;
};
let Ok(bytes) = <[u8; 64]>::try_from(raw.as_slice()) else {
continue;
};
let signature = Signature::from_bytes(&bytes);
if key.verify(fingerprint.as_bytes(), &signature).is_ok() {
return Ok(true);
}
}
Ok(false)
}
}

309
src/router.rs Normal file
View file

@ -0,0 +1,309 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use chrono::Utc;
use futures_util::{StreamExt, stream::FuturesUnordered};
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use crate::{
db::{Db, RouteEntry},
health::{Prober, Status},
metrics,
narinfo::NarInfo,
};
#[derive(Debug, Error)]
pub enum RouterError {
#[error("not found in any upstream")]
NotFound,
#[error("all upstreams unavailable")]
UpstreamUnavailable,
#[error("no candidates for {0:?}")]
NoCandidates(String),
#[error("narinfo signature verification failed")]
SignatureVerificationFailed,
#[error(transparent)]
Db(#[from] crate::db::DbError),
}
#[derive(Debug, Clone)]
pub struct ResolveResult {
pub url: String,
pub latency_ms: f64,
pub cache_hit: bool,
pub narinfo_bytes: Option<Vec<u8>>,
}
#[derive(Clone)]
pub struct Router {
inner: Arc<RouterInner>,
}
struct RouterInner {
db: Db,
prober: Prober,
route_ttl: Duration,
race_timeout: Duration,
negative_ttl: Duration,
client: reqwest::Client,
upstream_keys: RwLock<HashMap<String, String>>,
inflight: Mutex<HashMap<String, Arc<Mutex<()>>>>,
}
#[derive(Debug)]
struct RaceResult {
url: String,
latency_ms: f64,
}
impl Router {
pub fn new(
db: Db,
prober: Prober,
route_ttl: Duration,
race_timeout: Duration,
negative_ttl: Duration,
) -> Self {
Self {
inner: Arc::new(RouterInner {
db,
prober,
route_ttl,
race_timeout,
negative_ttl,
client: reqwest::Client::builder()
.timeout(race_timeout)
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
upstream_keys: RwLock::new(HashMap::new()),
inflight: Mutex::new(HashMap::new()),
}),
}
}
pub async fn set_upstream_key(
&self,
url: String,
public_key: String,
) -> Result<(), crate::narinfo::NarInfoError> {
crate::narinfo::parse_public_key(&public_key)?;
self
.inner
.upstream_keys
.write()
.await
.insert(url, public_key);
Ok(())
}
pub async fn resolve(
&self,
store_hash: &str,
candidates: &[String],
) -> Result<ResolveResult, RouterError> {
if self.inner.db.is_negative(store_hash).await? {
return Err(RouterError::NotFound);
}
if let Some(result) = self.valid_cached_route(store_hash).await? {
return Ok(result);
}
metrics::get().narinfo_cache_misses.inc();
let lock = {
let mut inflight = self.inner.inflight.lock().await;
Arc::clone(
inflight
.entry(store_hash.to_string())
.or_insert_with(|| Arc::new(Mutex::new(()))),
)
};
let _guard = lock.lock().await;
if let Some(result) = self.valid_cached_route(store_hash).await? {
self.inner.inflight.lock().await.remove(store_hash);
return Ok(result);
}
let result = self.race(store_hash, candidates).await;
if matches!(result, Err(RouterError::NotFound)) {
let _ = self
.inner
.db
.set_negative(store_hash, self.inner.negative_ttl)
.await;
}
self.inner.inflight.lock().await.remove(store_hash);
result
}
async fn valid_cached_route(
&self,
store_hash: &str,
) -> Result<Option<ResolveResult>, RouterError> {
let Some(entry) = self.inner.db.get_route(store_hash).await? else {
return Ok(None);
};
if !entry.is_valid() {
return Ok(None);
}
let health = self.inner.prober.get_health(&entry.upstream_url).await;
if !health.as_ref().is_none_or(|h| h.status == Status::Active) {
return Ok(None);
}
metrics::get().narinfo_cache_hits.inc();
Ok(Some(ResolveResult {
url: entry.upstream_url,
latency_ms: entry.latency_ema,
cache_hit: true,
narinfo_bytes: None,
}))
}
async fn race(
&self,
store_hash: &str,
candidates: &[String],
) -> Result<ResolveResult, RouterError> {
if candidates.is_empty() {
return Err(RouterError::NoCandidates(store_hash.to_string()));
}
let mut handles = FuturesUnordered::new();
for upstream in candidates {
let upstream = upstream.clone();
let store_hash = store_hash.to_string();
let client = self.inner.client.clone();
handles.push(tokio::spawn(async move {
let start = Instant::now();
let res = client
.head(format!("{upstream}/{store_hash}.narinfo"))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
Ok(RaceResult {
url: upstream,
latency_ms: start.elapsed().as_secs_f64() * 1000.0,
})
},
Ok(_) => Err(false),
Err(_) => Err(true),
}
}));
}
let mut net_errs = 0;
let mut not_founds = 0;
let mut winner: Option<RaceResult> = None;
let deadline = tokio::time::sleep(self.inner.race_timeout);
tokio::pin!(deadline);
while !handles.is_empty() {
tokio::select! {
() = &mut deadline => break,
joined = handles.next() => {
match joined {
Some(Ok(Ok(res))) => if winner.as_ref().is_none_or(|w| res.latency_ms < w.latency_ms) { winner = Some(res); },
Some(Ok(Err(true)) | Err(_)) => net_errs += 1,
Some(Ok(Err(false))) => not_founds += 1,
None => break,
}
}
}
}
let Some(winner) = winner else {
return if net_errs > 0 && not_founds == 0 {
Err(RouterError::UpstreamUnavailable)
} else {
Err(RouterError::NotFound)
};
};
metrics::get()
.upstream_race_wins
.with_label_values(&[&winner.url])
.inc();
metrics::get()
.upstream_latency
.with_label_values(&[&winner.url])
.observe(winner.latency_ms / 1000.0);
let (body, nar_url, nar_hash, nar_size) =
self.fetch_narinfo(&winner.url, store_hash).await?;
let ema = self
.inner
.prober
.get_health(&winner.url)
.await
.map_or(winner.latency_ms, |h| {
0.3f64.mul_add(winner.latency_ms, 0.7 * h.ema_latency)
});
self
.inner
.prober
.record_latency(&winner.url, winner.latency_ms)
.await;
let now = Utc::now();
self
.inner
.db
.set_route(&RouteEntry {
store_path: store_hash.to_string(),
upstream_url: winner.url.clone(),
latency_ms: winner.latency_ms,
latency_ema: ema,
last_verified: now,
query_count: 1,
failure_count: 0,
ttl: now
+ chrono::Duration::from_std(self.inner.route_ttl)
.unwrap_or_default(),
nar_hash,
nar_size,
nar_url,
})
.await?;
Ok(ResolveResult {
url: winner.url,
latency_ms: winner.latency_ms,
cache_hit: false,
narinfo_bytes: body,
})
}
async fn fetch_narinfo(
&self,
upstream: &str,
store_hash: &str,
) -> Result<(Option<Vec<u8>>, String, String, u64), RouterError> {
let Ok(resp) = self
.inner
.client
.get(format!("{upstream}/{store_hash}.narinfo"))
.send()
.await
else {
return Ok((None, String::new(), String::new(), 0));
};
if !resp.status().is_success() {
return Ok((None, String::new(), String::new(), 0));
}
let Ok(bytes) = resp.bytes().await else {
return Ok((None, String::new(), String::new(), 0));
};
let body = bytes.to_vec();
let Ok(parsed) = NarInfo::parse(body.as_slice()) else {
return Ok((Some(body), String::new(), String::new(), 0));
};
if let Some(pubkey) = self.inner.upstream_keys.read().await.get(upstream)
&& !parsed.verify(pubkey).unwrap_or(false)
{
tracing::warn!(
upstream,
store = store_hash,
"narinfo signature verification failed"
);
return Err(RouterError::SignatureVerificationFailed);
}
Ok((Some(body), parsed.url, parsed.nar_hash, parsed.nar_size))
}
}

301
src/server.rs Normal file
View file

@ -0,0 +1,301 @@
use std::sync::Arc;
use axum::{
Router as AxumRouter,
body::Body,
extract::{Path, State},
http::{HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode},
response::{IntoResponse, Response},
routing::get,
};
use bytes::Bytes;
use futures_util::TryStreamExt;
use serde::Serialize;
use crate::{
config::UpstreamConfig,
db::Db,
health::{Prober, Status},
metrics,
router::{Router, RouterError},
};
#[derive(Clone)]
pub struct AppState {
router: Router,
prober: Prober,
db: Db,
upstreams: Vec<UpstreamConfig>,
client: reqwest::Client,
cache_priority: i32,
}
pub fn app(
router: Router,
prober: Prober,
db: Db,
upstreams: Vec<UpstreamConfig>,
cache_priority: i32,
) -> AxumRouter {
let state = AppState {
router,
prober,
db,
upstreams,
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
cache_priority,
};
AxumRouter::new()
.route("/nix-cache-info", get(cache_info).head(cache_info))
.route("/health", get(health))
.route("/metrics", get(metrics_endpoint))
.route("/{hash}.narinfo", get(narinfo).head(narinfo))
.route("/nar/{*path}", get(nar).head(nar))
.with_state(Arc::new(state))
}
async fn cache_info(State(state): State<Arc<AppState>>) -> Response {
(
[("content-type", "text/plain")],
format!(
"StoreDir: /nix/store\nWantMassQuery: 1\nPriority: {}\n",
state.cache_priority
),
)
.into_response()
}
#[derive(Serialize)]
struct HealthResponse {
status: String,
upstreams: Vec<UpstreamStatus>,
}
#[derive(Serialize)]
struct UpstreamStatus {
url: String,
status: String,
latency_ms: f64,
consecutive_fails: u32,
}
async fn health(State(state): State<Arc<AppState>>) -> Response {
let sorted = state.prober.sorted_by_latency().await;
let down_count = sorted.iter().filter(|h| h.status == Status::Down).count();
let any_degraded = sorted.iter().any(|h| h.status == Status::Degraded);
let status = if !sorted.is_empty() && down_count == sorted.len() {
"down"
} else if down_count > 0 || any_degraded {
"degraded"
} else {
"ok"
};
axum::Json(HealthResponse {
status: status.to_string(),
upstreams: sorted
.into_iter()
.map(|h| {
UpstreamStatus {
url: h.url,
status: h.status.as_str().to_string(),
latency_ms: h.ema_latency,
consecutive_fails: h.consecutive_fails,
}
})
.collect(),
})
.into_response()
}
async fn metrics_endpoint() -> Response {
(
[("content-type", "text/plain; version=0.0.4")],
metrics::gather(),
)
.into_response()
}
async fn narinfo(
State(state): State<Arc<AppState>>,
Path(hash): Path<String>,
req: Request<Body>,
) -> Response {
let candidates = upstream_urls(&state).await;
match state.router.resolve(&hash, &candidates).await {
Ok(result) => {
tracing::info!(
hash,
upstream = result.url,
cache_hit = result.cache_hit,
latency_ms = result.latency_ms,
"narinfo routed"
);
metrics::get()
.narinfo_requests
.with_label_values(&["200"])
.inc();
if let Some(bytes) = result.narinfo_bytes {
return (
StatusCode::OK,
[("content-type", "text/x-nix-narinfo")],
Bytes::from(bytes),
)
.into_response();
}
proxy(
&state.client,
req.method().clone(),
req.headers(),
format!("{}{}", result.url, req.uri().path()),
)
.await
},
Err(RouterError::NotFound) => {
metrics::get()
.narinfo_requests
.with_label_values(&["error"])
.inc();
StatusCode::NOT_FOUND.into_response()
},
Err(err) => {
tracing::warn!(hash, error = %err, "narinfo resolve failed");
metrics::get()
.narinfo_requests
.with_label_values(&["error"])
.inc();
(StatusCode::BAD_GATEWAY, "upstream unavailable").into_response()
},
}
}
async fn nar(
State(state): State<Arc<AppState>>,
req: Request<Body>,
) -> Response {
metrics::get().nar_requests.inc();
let nar_url = req.uri().path().trim_start_matches('/').to_string();
if let Ok(Some(entry)) = state.db.get_route_by_nar_url(&nar_url).await
&& entry.is_valid()
&& let Some(resp) = try_nar_upstream(
&state.client,
req.method().clone(),
req.headers(),
&entry.upstream_url,
req.uri().path(),
)
.await
{
return resp;
}
for h in state.prober.sorted_by_latency().await {
if h.status == Status::Down {
continue;
}
if let Some(resp) = try_nar_upstream(
&state.client,
req.method().clone(),
req.headers(),
&h.url,
req.uri().path(),
)
.await
{
return resp;
}
}
StatusCode::NOT_FOUND.into_response()
}
async fn upstream_urls(state: &AppState) -> Vec<String> {
let urls = state
.prober
.sorted_by_latency()
.await
.into_iter()
.filter(|h| h.status != Status::Down)
.map(|h| h.url)
.collect::<Vec<_>>();
if urls.is_empty() {
state.upstreams.iter().map(|u| u.url.clone()).collect()
} else {
urls
}
}
async fn try_nar_upstream(
client: &reqwest::Client,
method: Method,
headers: &HeaderMap,
upstream: &str,
path: &str,
) -> Option<Response> {
let resp =
upstream_request(client, method, headers, format!("{upstream}{path}"))
.await
.ok()?;
if resp.status() == reqwest::StatusCode::NOT_FOUND {
return None;
}
Some(response_from_reqwest(resp))
}
async fn proxy(
client: &reqwest::Client,
method: Method,
headers: &HeaderMap,
url: String,
) -> Response {
match upstream_request(client, method, headers, url).await {
Ok(resp) => response_from_reqwest(resp),
Err(err) => {
tracing::warn!(error = %err, "upstream request failed");
(StatusCode::BAD_GATEWAY, "upstream error").into_response()
},
}
}
async fn upstream_request(
client: &reqwest::Client,
method: Method,
headers: &HeaderMap,
url: String,
) -> reqwest::Result<reqwest::Response> {
let mut req = client.request(method, url);
for name in ["accept", "accept-encoding", "range"] {
if let Some(value) = headers.get(name) {
req = req.header(name, value);
}
}
req.send().await
}
fn response_from_reqwest(resp: reqwest::Response) -> Response {
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::BAD_GATEWAY);
let headers = resp.headers().clone();
let stream = resp.bytes_stream().map_err(std::io::Error::other);
let mut out = Response::builder().status(status);
for name in [
"content-type",
"content-length",
"content-encoding",
"x-nix-signature",
"cache-control",
"last-modified",
] {
if let Some(value) = headers.get(name)
&& let (Ok(header_name), Ok(header_value)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_bytes(value.as_bytes()),
)
{
out = out.header(header_name, header_value);
}
}
out
.body(Body::from_stream(stream))
.unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
}