network: implement basic ratelimiting

Still pretty barebones, all things considered. I want to revisit this once multi-site support is in.
This commit is contained in:
raf 2025-05-02 12:52:24 +03:00
commit ff3f80adda
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
4 changed files with 217 additions and 2 deletions

View file

@ -92,6 +92,32 @@ pub struct Args {
help = "Log format: plain, pretty, json, pretty-json"
)]
pub log_format: String,
#[clap(long, help = "Enable rate limiting for connections from the same IP")]
pub rate_limit_enabled: bool,
#[clap(long, default_value = "60", help = "Rate limit window in seconds")]
pub rate_limit_window: u64,
#[clap(
long,
default_value = "30",
help = "Maximum number of connections allowed per IP in the rate limit window"
)]
pub rate_limit_max: usize,
#[clap(
long,
default_value = "100",
help = "Connection attempts threshold before considering for IP blocking"
)]
pub rate_limit_block_threshold: usize,
#[clap(
long,
help = "Send a 429 response for rate limited connections instead of dropping connection"
)]
pub rate_limit_slow_response: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
@ -162,6 +188,11 @@ pub struct Config {
pub config_dir: String,
pub cache_dir: String,
pub log_format: LogFormat,
pub rate_limit_enabled: bool,
pub rate_limit_window_seconds: u64,
pub rate_limit_max_connections: usize,
pub rate_limit_block_threshold: usize,
pub rate_limit_slow_response: bool,
}
impl Default for Config {
@ -224,6 +255,11 @@ impl Default for Config {
config_dir: "./conf".to_string(),
cache_dir: "./cache".to_string(),
log_format: LogFormat::Pretty,
rate_limit_enabled: true,
rate_limit_window_seconds: 60,
rate_limit_max_connections: 30,
rate_limit_block_threshold: 100,
rate_limit_slow_response: true,
}
}
}
@ -297,6 +333,12 @@ impl Config {
data_dir,
config_dir,
cache_dir,
log_format: LogFormat::Pretty,
rate_limit_enabled: args.rate_limit_enabled,
rate_limit_window_seconds: args.rate_limit_window,
rate_limit_max_connections: args.rate_limit_max,
rate_limit_block_threshold: args.rate_limit_block_threshold,
rate_limit_slow_response: args.rate_limit_slow_response,
..Default::default()
}
}
@ -425,6 +467,11 @@ mod tests {
config_file: None,
log_level: "debug".to_string(),
log_format: "pretty".to_string(),
rate_limit_enabled: true,
rate_limit_window: 30,
rate_limit_max: 20,
rate_limit_block_threshold: 50,
rate_limit_slow_response: true,
};
let config = Config::from_args(&args);
@ -441,6 +488,11 @@ mod tests {
assert_eq!(config.data_dir, "/tmp/eris/data");
assert_eq!(config.config_dir, "/tmp/eris/conf");
assert_eq!(config.cache_dir, "/tmp/eris/cache");
assert!(config.rate_limit_enabled);
assert_eq!(config.rate_limit_window_seconds, 30);
assert_eq!(config.rate_limit_max_connections, 20);
assert_eq!(config.rate_limit_block_threshold, 50);
assert!(config.rate_limit_slow_response);
}
#[test]
@ -537,12 +589,22 @@ mod tests {
let loaded_json = Config::load_from_file(&json_path).unwrap();
assert_eq!(loaded_json.listen_addr, config.listen_addr);
assert_eq!(loaded_json.min_delay, config.min_delay);
assert_eq!(loaded_json.rate_limit_enabled, config.rate_limit_enabled);
assert_eq!(
loaded_json.rate_limit_max_connections,
config.rate_limit_max_connections
);
// Test TOML serialization and deserialization
config.save_to_file(&toml_path).unwrap();
let loaded_toml = Config::load_from_file(&toml_path).unwrap();
assert_eq!(loaded_toml.listen_addr, config.listen_addr);
assert_eq!(loaded_toml.min_delay, config.min_delay);
assert_eq!(loaded_toml.rate_limit_enabled, config.rate_limit_enabled);
assert_eq!(
loaded_toml.rate_limit_max_connections,
config.rate_limit_max_connections
);
// Clean up
let _ = std::fs::remove_file(json_path);

View file

@ -24,6 +24,11 @@ lazy_static! {
register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap();
pub static ref UA_HITS: CounterVec =
register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap();
pub static ref RATE_LIMITED_CONNECTIONS: Counter = register_counter!(
"eris_rate_limited_total",
"Number of connections rejected due to rate limiting"
)
.unwrap();
}
// Prometheus metrics endpoint
@ -76,6 +81,7 @@ mod tests {
UA_HITS.with_label_values(&["TestBot/1.0"]).inc();
BLOCKED_IPS.set(5.0);
ACTIVE_CONNECTIONS.set(3.0);
RATE_LIMITED_CONNECTIONS.inc();
// Create test app
let app =
@ -96,6 +102,7 @@ mod tests {
assert!(body_str.contains("eris_ua_hits_total"));
assert!(body_str.contains("eris_blocked_ips"));
assert!(body_str.contains("eris_active_connections"));
assert!(body_str.contains("eris_rate_limited_total"));
}
#[actix_web::test]

View file

@ -12,13 +12,28 @@ use tokio::time::sleep;
use crate::config::Config;
use crate::lua::{EventContext, EventType, ScriptManager};
use crate::markov::MarkovGenerator;
use crate::metrics::{ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS};
use crate::metrics::{
ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, RATE_LIMITED_CONNECTIONS, UA_HITS,
};
use crate::state::BotState;
use crate::utils::{
choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request,
find_header_end, generate_session_id, get_timestamp,
};
mod rate_limiter;
use rate_limiter::RateLimiter;
// Global rate limiter instance.
// Default is 30 connections per IP in a 60 second window
// XXX: This might add overhead of the proxy, e.g. NGINX already implements
// rate limiting. Though I don't think we have a way of knowing if the middleman
// we are handing the connections to (from the same middleman in some cases) has
// rate limiting.
lazy_static::lazy_static! {
static ref RATE_LIMITER: RateLimiter = RateLimiter::new(60, 30);
}
// Main connection handler.
// Decides whether to tarpit or proxy
pub async fn handle_connection(
@ -29,7 +44,7 @@ pub async fn handle_connection(
script_manager: Arc<ScriptManager>,
) {
// Get peer information
let peer_addr = match stream.peer_addr() {
let peer_addr: IpAddr = match stream.peer_addr() {
Ok(addr) => addr.ip(),
Err(e) => {
log::debug!("Failed to get peer address: {e}");
@ -44,6 +59,52 @@ pub async fn handle_connection(
return;
}
// Apply rate limiting before any further processing
if config.rate_limit_enabled && !RATE_LIMITER.check_rate_limit(peer_addr).await {
log::info!("Rate limited connection from {peer_addr}");
RATE_LIMITED_CONNECTIONS.inc();
// Optionally, add the IP to a temporary block list
// if it's constantly hitting the rate limit
let connection_count = RATE_LIMITER.get_connection_count(&peer_addr);
if connection_count > config.rate_limit_block_threshold {
log::warn!(
"IP {peer_addr} exceeding rate limit with {connection_count} connection attempts, considering for blocking"
);
// Trigger a blocked event for Lua scripts
let rate_limit_ctx = EventContext {
event_type: EventType::BlockIP,
ip: Some(peer_addr.to_string()),
path: None,
user_agent: None,
request_headers: None,
content: None,
timestamp: get_timestamp(),
session_id: None,
};
script_manager.trigger_event(&rate_limit_ctx);
}
// Either send a slow response or just close connection
if config.rate_limit_slow_response {
// Send a simple 429 Too Many Requests respons. If the bots actually respected
// HTTP error codes, the internet would be a mildly better place.
let response = "HTTP/1.1 429 Too Many Requests\r\n\
Content-Type: text/plain\r\n\
Retry-After: 60\r\n\
Connection: close\r\n\
\r\n\
Rate limit exceeded. Please try again later.";
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
}
let _ = stream.shutdown().await;
return;
}
// Check if Lua scripts allow this connection
if !script_manager.on_connection(&peer_addr.to_string()) {
log::debug!("Connection rejected by Lua script: {peer_addr}");

View file

@ -0,0 +1,85 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
pub struct RateLimiter {
connections: Arc<Mutex<HashMap<IpAddr, Vec<Instant>>>>,
window_seconds: u64,
max_connections: usize,
cleanup_interval: Duration,
last_cleanup: Instant,
}
impl RateLimiter {
pub fn new(window_seconds: u64, max_connections: usize) -> Self {
Self {
connections: Arc::new(Mutex::new(HashMap::new())),
window_seconds,
max_connections,
cleanup_interval: Duration::from_secs(60),
last_cleanup: Instant::now(),
}
}
pub async fn check_rate_limit(&self, ip: IpAddr) -> bool {
let now = Instant::now();
let window = Duration::from_secs(self.window_seconds);
let mut connections = self.connections.lock().await;
// Periodically clean up old entries across all IPs
if now.duration_since(self.last_cleanup) > self.cleanup_interval {
self.cleanup_old_entries(&mut connections, now, window);
}
// Clean up old entries for this specific IP
if let Some(times) = connections.get_mut(&ip) {
times.retain(|time| now.duration_since(*time) < window);
// Check if rate limit exceeded
if times.len() >= self.max_connections {
log::debug!("Rate limit exceeded for IP: {}", ip);
return false;
}
// Add new connection time
times.push(now);
} else {
connections.insert(ip, vec![now]);
}
true
}
fn cleanup_old_entries(
&self,
connections: &mut HashMap<IpAddr, Vec<Instant>>,
now: Instant,
window: Duration,
) {
let mut empty_keys = Vec::new();
for (ip, times) in connections.iter_mut() {
times.retain(|time| now.duration_since(*time) < window);
if times.is_empty() {
empty_keys.push(*ip);
}
}
// Remove empty entries
for ip in empty_keys {
connections.remove(&ip);
}
}
pub fn get_connection_count(&self, ip: &IpAddr) -> usize {
if let Ok(connections) = self.connections.try_lock() {
if let Some(times) = connections.get(ip) {
return times.len();
}
}
0
}
}