ncro/src/health.rs
NotAShelf ea96477830
treewide: rewrite everything in Rust
Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: I786da853078e1013bb8f463ed9e9869c6a6a6964
2026-05-11 13:28:29 +03:00

310 lines
7.7 KiB
Rust

use std::{
cmp::Ordering,
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::RwLock;
use crate::config::UpstreamConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Status {
Active,
Degraded,
Down,
}
impl Status {
pub const fn as_str(self) -> &'static str {
match self {
Self::Active => "ACTIVE",
Self::Degraded => "DEGRADED",
Self::Down => "DOWN",
}
}
}
#[derive(Debug, Clone)]
pub struct UpstreamHealth {
pub url: String,
pub priority: i32,
pub ema_latency: f64,
pub last_probe: Option<Instant>,
pub consecutive_fails: u32,
pub total_queries: u64,
pub status: Status,
}
impl UpstreamHealth {
const fn new(url: String, priority: i32) -> Self {
Self {
url,
priority,
ema_latency: 0.0,
last_probe: None,
consecutive_fails: 0,
total_queries: 0,
status: Status::Active,
}
}
}
type PersistHealth = Arc<dyn Fn(String, f64, u32, u64) + Send + Sync>;
#[derive(Clone)]
pub struct Prober {
inner: Arc<ProberInner>,
}
struct ProberInner {
alpha: f64,
table: RwLock<HashMap<String, UpstreamHealth>>,
client: reqwest::Client,
persist_health: RwLock<Option<PersistHealth>>,
}
impl Prober {
pub fn new(alpha: f64) -> Self {
Self {
inner: Arc::new(ProberInner {
alpha,
table: RwLock::new(HashMap::new()),
client: reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
persist_health: RwLock::new(None),
}),
}
}
pub async fn init_upstreams(&self, upstreams: &[UpstreamConfig]) {
let mut table = self.inner.table.write().await;
for upstream in upstreams {
table.entry(upstream.url.clone()).or_insert_with(|| {
UpstreamHealth::new(upstream.url.clone(), upstream.priority)
});
}
}
#[allow(clippy::significant_drop_tightening)]
pub async fn seed(
&self,
url: &str,
ema_latency: f64,
consecutive_fails: i64,
total_queries: i64,
) {
{
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
health.ema_latency = ema_latency;
health.total_queries =
u64::try_from(total_queries.max(0)).unwrap_or_default();
health.consecutive_fails =
u32::try_from(consecutive_fails.max(0)).unwrap_or(u32::MAX);
health.status = compute_status(health.consecutive_fails);
}
}
pub async fn set_health_persistence<F>(&self, f: F)
where
F: Fn(String, f64, u32, u64) + Send + Sync + 'static,
{
*self.inner.persist_health.write().await = Some(Arc::new(f));
}
#[allow(clippy::significant_drop_tightening)]
pub async fn record_latency(&self, url: &str, ms: f64) {
let snapshot = {
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
if health.total_queries == 0 {
health.ema_latency = ms;
} else {
health.ema_latency = self
.inner
.alpha
.mul_add(ms, (1.0 - self.inner.alpha) * health.ema_latency);
}
health.consecutive_fails = 0;
health.total_queries += 1;
health.status = Status::Active;
health.last_probe = Some(Instant::now());
(
health.url.clone(),
health.ema_latency,
health.consecutive_fails,
health.total_queries,
)
};
let callback = self.inner.persist_health.read().await.clone();
if let Some(callback) = callback {
tokio::spawn(async move {
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
});
}
}
#[allow(clippy::significant_drop_tightening)]
pub async fn record_failure(&self, url: &str) {
let snapshot = {
let mut table = self.inner.table.write().await;
let Some(health) = table.get_mut(url) else {
return;
};
health.consecutive_fails += 1;
health.status = compute_status(health.consecutive_fails);
(
health.url.clone(),
health.ema_latency,
health.consecutive_fails,
health.total_queries,
)
};
let callback = self.inner.persist_health.read().await.clone();
if let Some(callback) = callback {
tokio::spawn(async move {
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
});
}
}
pub async fn get_health(&self, url: &str) -> Option<UpstreamHealth> {
self.inner.table.read().await.get(url).cloned()
}
pub async fn sorted_by_latency(&self) -> Vec<UpstreamHealth> {
let mut result = self
.inner
.table
.read()
.await
.values()
.cloned()
.collect::<Vec<_>>();
result.sort_by(|a, b| {
match (a.status == Status::Down, b.status == Status::Down) {
(true, false) => return Ordering::Greater,
(false, true) => return Ordering::Less,
_ => {},
}
if b.ema_latency > 0.0
&& ((a.ema_latency - b.ema_latency).abs() / b.ema_latency) < 0.10
&& a.priority != b.priority
{
return a.priority.cmp(&b.priority);
}
a.ema_latency
.partial_cmp(&b.ema_latency)
.unwrap_or(Ordering::Equal)
});
result
}
pub async fn probe_upstream(&self, url: String) {
if !self.inner.table.read().await.contains_key(&url) {
return;
}
let start = Instant::now();
let ok = self
.inner
.client
.head(format!("{url}/nix-cache-info"))
.send()
.await
.map(|resp| resp.status().as_u16() == 200)
.unwrap_or(false);
if ok {
self
.record_latency(&url, start.elapsed().as_secs_f64() * 1000.0)
.await;
} else {
self.record_failure(&url).await;
}
}
pub async fn run_probe_loop(
&self,
interval: Duration,
mut stop: tokio::sync::watch::Receiver<bool>,
) {
let mut ticker = tokio::time::interval(interval);
loop {
tokio::select! {
_ = stop.changed() => return,
_ = ticker.tick() => {
let urls = self.inner.table.read().await.keys().cloned().collect::<Vec<_>>();
for url in urls {
let prober = self.clone();
tokio::spawn(async move { prober.probe_upstream(url).await; });
}
}
}
}
}
pub async fn add_upstream(&self, url: String, priority: i32) {
let inserted = self
.inner
.table
.write()
.await
.insert(url.clone(), UpstreamHealth::new(url.clone(), priority))
.is_none();
if inserted {
let prober = self.clone();
tokio::spawn(async move {
prober.probe_upstream(url).await;
});
}
}
pub async fn remove_upstream(&self, url: &str) {
self.inner.table.write().await.remove(url);
}
}
const fn compute_status(consecutive_fails: u32) -> Status {
match consecutive_fails {
10.. => Status::Down,
3.. => Status::Degraded,
_ => Status::Active,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn ema_and_status_progression() -> Result<(), Box<dyn std::error::Error>>
{
let p = Prober::new(0.3);
p.add_upstream("https://example.com".into(), 1).await;
p.record_latency("https://example.com", 100.0).await;
p.record_latency("https://example.com", 50.0).await;
let h = p
.get_health("https://example.com")
.await
.ok_or("missing health")?;
assert!((84.0..=86.0).contains(&h.ema_latency));
for _ in 0..10 {
p.record_failure("https://example.com").await;
}
assert_eq!(
p.get_health("https://example.com")
.await
.ok_or("missing health")?
.status,
Status::Down
);
Ok(())
}
}