Signed-off-by: NotAShelf <raf@notashelf.dev> Change-Id: I786da853078e1013bb8f463ed9e9869c6a6a6964
310 lines
7.7 KiB
Rust
310 lines
7.7 KiB
Rust
use std::{
|
|
cmp::Ordering,
|
|
collections::HashMap,
|
|
sync::Arc,
|
|
time::{Duration, Instant},
|
|
};
|
|
|
|
use tokio::sync::RwLock;
|
|
|
|
use crate::config::UpstreamConfig;
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum Status {
|
|
Active,
|
|
Degraded,
|
|
Down,
|
|
}
|
|
|
|
impl Status {
|
|
pub const fn as_str(self) -> &'static str {
|
|
match self {
|
|
Self::Active => "ACTIVE",
|
|
Self::Degraded => "DEGRADED",
|
|
Self::Down => "DOWN",
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct UpstreamHealth {
|
|
pub url: String,
|
|
pub priority: i32,
|
|
pub ema_latency: f64,
|
|
pub last_probe: Option<Instant>,
|
|
pub consecutive_fails: u32,
|
|
pub total_queries: u64,
|
|
pub status: Status,
|
|
}
|
|
|
|
impl UpstreamHealth {
|
|
const fn new(url: String, priority: i32) -> Self {
|
|
Self {
|
|
url,
|
|
priority,
|
|
ema_latency: 0.0,
|
|
last_probe: None,
|
|
consecutive_fails: 0,
|
|
total_queries: 0,
|
|
status: Status::Active,
|
|
}
|
|
}
|
|
}
|
|
|
|
type PersistHealth = Arc<dyn Fn(String, f64, u32, u64) + Send + Sync>;
|
|
|
|
#[derive(Clone)]
|
|
pub struct Prober {
|
|
inner: Arc<ProberInner>,
|
|
}
|
|
|
|
struct ProberInner {
|
|
alpha: f64,
|
|
table: RwLock<HashMap<String, UpstreamHealth>>,
|
|
client: reqwest::Client,
|
|
persist_health: RwLock<Option<PersistHealth>>,
|
|
}
|
|
|
|
impl Prober {
|
|
pub fn new(alpha: f64) -> Self {
|
|
Self {
|
|
inner: Arc::new(ProberInner {
|
|
alpha,
|
|
table: RwLock::new(HashMap::new()),
|
|
client: reqwest::Client::builder()
|
|
.timeout(Duration::from_secs(10))
|
|
.build()
|
|
.unwrap_or_else(|_| reqwest::Client::new()),
|
|
persist_health: RwLock::new(None),
|
|
}),
|
|
}
|
|
}
|
|
|
|
pub async fn init_upstreams(&self, upstreams: &[UpstreamConfig]) {
|
|
let mut table = self.inner.table.write().await;
|
|
for upstream in upstreams {
|
|
table.entry(upstream.url.clone()).or_insert_with(|| {
|
|
UpstreamHealth::new(upstream.url.clone(), upstream.priority)
|
|
});
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::significant_drop_tightening)]
|
|
pub async fn seed(
|
|
&self,
|
|
url: &str,
|
|
ema_latency: f64,
|
|
consecutive_fails: i64,
|
|
total_queries: i64,
|
|
) {
|
|
{
|
|
let mut table = self.inner.table.write().await;
|
|
let Some(health) = table.get_mut(url) else {
|
|
return;
|
|
};
|
|
health.ema_latency = ema_latency;
|
|
health.total_queries =
|
|
u64::try_from(total_queries.max(0)).unwrap_or_default();
|
|
health.consecutive_fails =
|
|
u32::try_from(consecutive_fails.max(0)).unwrap_or(u32::MAX);
|
|
health.status = compute_status(health.consecutive_fails);
|
|
}
|
|
}
|
|
|
|
pub async fn set_health_persistence<F>(&self, f: F)
|
|
where
|
|
F: Fn(String, f64, u32, u64) + Send + Sync + 'static,
|
|
{
|
|
*self.inner.persist_health.write().await = Some(Arc::new(f));
|
|
}
|
|
|
|
#[allow(clippy::significant_drop_tightening)]
|
|
pub async fn record_latency(&self, url: &str, ms: f64) {
|
|
let snapshot = {
|
|
let mut table = self.inner.table.write().await;
|
|
let Some(health) = table.get_mut(url) else {
|
|
return;
|
|
};
|
|
if health.total_queries == 0 {
|
|
health.ema_latency = ms;
|
|
} else {
|
|
health.ema_latency = self
|
|
.inner
|
|
.alpha
|
|
.mul_add(ms, (1.0 - self.inner.alpha) * health.ema_latency);
|
|
}
|
|
health.consecutive_fails = 0;
|
|
health.total_queries += 1;
|
|
health.status = Status::Active;
|
|
health.last_probe = Some(Instant::now());
|
|
(
|
|
health.url.clone(),
|
|
health.ema_latency,
|
|
health.consecutive_fails,
|
|
health.total_queries,
|
|
)
|
|
};
|
|
let callback = self.inner.persist_health.read().await.clone();
|
|
if let Some(callback) = callback {
|
|
tokio::spawn(async move {
|
|
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
|
|
});
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::significant_drop_tightening)]
|
|
pub async fn record_failure(&self, url: &str) {
|
|
let snapshot = {
|
|
let mut table = self.inner.table.write().await;
|
|
let Some(health) = table.get_mut(url) else {
|
|
return;
|
|
};
|
|
health.consecutive_fails += 1;
|
|
health.status = compute_status(health.consecutive_fails);
|
|
(
|
|
health.url.clone(),
|
|
health.ema_latency,
|
|
health.consecutive_fails,
|
|
health.total_queries,
|
|
)
|
|
};
|
|
let callback = self.inner.persist_health.read().await.clone();
|
|
if let Some(callback) = callback {
|
|
tokio::spawn(async move {
|
|
callback(snapshot.0, snapshot.1, snapshot.2, snapshot.3);
|
|
});
|
|
}
|
|
}
|
|
|
|
pub async fn get_health(&self, url: &str) -> Option<UpstreamHealth> {
|
|
self.inner.table.read().await.get(url).cloned()
|
|
}
|
|
|
|
pub async fn sorted_by_latency(&self) -> Vec<UpstreamHealth> {
|
|
let mut result = self
|
|
.inner
|
|
.table
|
|
.read()
|
|
.await
|
|
.values()
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
result.sort_by(|a, b| {
|
|
match (a.status == Status::Down, b.status == Status::Down) {
|
|
(true, false) => return Ordering::Greater,
|
|
(false, true) => return Ordering::Less,
|
|
_ => {},
|
|
}
|
|
if b.ema_latency > 0.0
|
|
&& ((a.ema_latency - b.ema_latency).abs() / b.ema_latency) < 0.10
|
|
&& a.priority != b.priority
|
|
{
|
|
return a.priority.cmp(&b.priority);
|
|
}
|
|
a.ema_latency
|
|
.partial_cmp(&b.ema_latency)
|
|
.unwrap_or(Ordering::Equal)
|
|
});
|
|
result
|
|
}
|
|
|
|
pub async fn probe_upstream(&self, url: String) {
|
|
if !self.inner.table.read().await.contains_key(&url) {
|
|
return;
|
|
}
|
|
let start = Instant::now();
|
|
let ok = self
|
|
.inner
|
|
.client
|
|
.head(format!("{url}/nix-cache-info"))
|
|
.send()
|
|
.await
|
|
.map(|resp| resp.status().as_u16() == 200)
|
|
.unwrap_or(false);
|
|
if ok {
|
|
self
|
|
.record_latency(&url, start.elapsed().as_secs_f64() * 1000.0)
|
|
.await;
|
|
} else {
|
|
self.record_failure(&url).await;
|
|
}
|
|
}
|
|
|
|
pub async fn run_probe_loop(
|
|
&self,
|
|
interval: Duration,
|
|
mut stop: tokio::sync::watch::Receiver<bool>,
|
|
) {
|
|
let mut ticker = tokio::time::interval(interval);
|
|
loop {
|
|
tokio::select! {
|
|
_ = stop.changed() => return,
|
|
_ = ticker.tick() => {
|
|
let urls = self.inner.table.read().await.keys().cloned().collect::<Vec<_>>();
|
|
for url in urls {
|
|
let prober = self.clone();
|
|
tokio::spawn(async move { prober.probe_upstream(url).await; });
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn add_upstream(&self, url: String, priority: i32) {
|
|
let inserted = self
|
|
.inner
|
|
.table
|
|
.write()
|
|
.await
|
|
.insert(url.clone(), UpstreamHealth::new(url.clone(), priority))
|
|
.is_none();
|
|
if inserted {
|
|
let prober = self.clone();
|
|
tokio::spawn(async move {
|
|
prober.probe_upstream(url).await;
|
|
});
|
|
}
|
|
}
|
|
|
|
pub async fn remove_upstream(&self, url: &str) {
|
|
self.inner.table.write().await.remove(url);
|
|
}
|
|
}
|
|
|
|
const fn compute_status(consecutive_fails: u32) -> Status {
|
|
match consecutive_fails {
|
|
10.. => Status::Down,
|
|
3.. => Status::Degraded,
|
|
_ => Status::Active,
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn ema_and_status_progression() -> Result<(), Box<dyn std::error::Error>>
|
|
{
|
|
let p = Prober::new(0.3);
|
|
p.add_upstream("https://example.com".into(), 1).await;
|
|
p.record_latency("https://example.com", 100.0).await;
|
|
p.record_latency("https://example.com", 50.0).await;
|
|
let h = p
|
|
.get_health("https://example.com")
|
|
.await
|
|
.ok_or("missing health")?;
|
|
assert!((84.0..=86.0).contains(&h.ema_latency));
|
|
for _ in 0..10 {
|
|
p.record_failure("https://example.com").await;
|
|
}
|
|
assert_eq!(
|
|
p.get_health("https://example.com")
|
|
.await
|
|
.ok_or("missing health")?
|
|
.status,
|
|
Status::Down
|
|
);
|
|
Ok(())
|
|
}
|
|
}
|