diff --git a/src/config.rs b/src/config.rs index f4c40b2..74a1571 100644 --- a/src/config.rs +++ b/src/config.rs @@ -92,6 +92,32 @@ pub struct Args { help = "Log format: plain, pretty, json, pretty-json" )] pub log_format: String, + + #[clap(long, help = "Enable rate limiting for connections from the same IP")] + pub rate_limit_enabled: bool, + + #[clap(long, default_value = "60", help = "Rate limit window in seconds")] + pub rate_limit_window: u64, + + #[clap( + long, + default_value = "30", + help = "Maximum number of connections allowed per IP in the rate limit window" + )] + pub rate_limit_max: usize, + + #[clap( + long, + default_value = "100", + help = "Connection attempts threshold before considering for IP blocking" + )] + pub rate_limit_block_threshold: usize, + + #[clap( + long, + help = "Send a 429 response for rate limited connections instead of dropping connection" + )] + pub rate_limit_slow_response: bool, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] @@ -162,6 +188,11 @@ pub struct Config { pub config_dir: String, pub cache_dir: String, pub log_format: LogFormat, + pub rate_limit_enabled: bool, + pub rate_limit_window_seconds: u64, + pub rate_limit_max_connections: usize, + pub rate_limit_block_threshold: usize, + pub rate_limit_slow_response: bool, } impl Default for Config { @@ -224,6 +255,11 @@ impl Default for Config { config_dir: "./conf".to_string(), cache_dir: "./cache".to_string(), log_format: LogFormat::Pretty, + rate_limit_enabled: true, + rate_limit_window_seconds: 60, + rate_limit_max_connections: 30, + rate_limit_block_threshold: 100, + rate_limit_slow_response: true, } } } @@ -297,6 +333,12 @@ impl Config { data_dir, config_dir, cache_dir, + log_format: LogFormat::Pretty, + rate_limit_enabled: args.rate_limit_enabled, + rate_limit_window_seconds: args.rate_limit_window, + rate_limit_max_connections: args.rate_limit_max, + rate_limit_block_threshold: args.rate_limit_block_threshold, + rate_limit_slow_response: args.rate_limit_slow_response, ..Default::default() } } @@ -425,6 +467,11 @@ mod tests { config_file: None, log_level: "debug".to_string(), log_format: "pretty".to_string(), + rate_limit_enabled: true, + rate_limit_window: 30, + rate_limit_max: 20, + rate_limit_block_threshold: 50, + rate_limit_slow_response: true, }; let config = Config::from_args(&args); @@ -441,6 +488,11 @@ mod tests { assert_eq!(config.data_dir, "/tmp/eris/data"); assert_eq!(config.config_dir, "/tmp/eris/conf"); assert_eq!(config.cache_dir, "/tmp/eris/cache"); + assert!(config.rate_limit_enabled); + assert_eq!(config.rate_limit_window_seconds, 30); + assert_eq!(config.rate_limit_max_connections, 20); + assert_eq!(config.rate_limit_block_threshold, 50); + assert!(config.rate_limit_slow_response); } #[test] @@ -537,12 +589,22 @@ mod tests { let loaded_json = Config::load_from_file(&json_path).unwrap(); assert_eq!(loaded_json.listen_addr, config.listen_addr); assert_eq!(loaded_json.min_delay, config.min_delay); + assert_eq!(loaded_json.rate_limit_enabled, config.rate_limit_enabled); + assert_eq!( + loaded_json.rate_limit_max_connections, + config.rate_limit_max_connections + ); // Test TOML serialization and deserialization config.save_to_file(&toml_path).unwrap(); let loaded_toml = Config::load_from_file(&toml_path).unwrap(); assert_eq!(loaded_toml.listen_addr, config.listen_addr); assert_eq!(loaded_toml.min_delay, config.min_delay); + assert_eq!(loaded_toml.rate_limit_enabled, config.rate_limit_enabled); + assert_eq!( + loaded_toml.rate_limit_max_connections, + config.rate_limit_max_connections + ); // Clean up let _ = std::fs::remove_file(json_path); diff --git a/src/metrics.rs b/src/metrics.rs index 48bceb9..44c3fb4 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -24,6 +24,11 @@ lazy_static! { register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap(); pub static ref UA_HITS: CounterVec = register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap(); + pub static ref RATE_LIMITED_CONNECTIONS: Counter = register_counter!( + "eris_rate_limited_total", + "Number of connections rejected due to rate limiting" + ) + .unwrap(); } // Prometheus metrics endpoint @@ -76,6 +81,7 @@ mod tests { UA_HITS.with_label_values(&["TestBot/1.0"]).inc(); BLOCKED_IPS.set(5.0); ACTIVE_CONNECTIONS.set(3.0); + RATE_LIMITED_CONNECTIONS.inc(); // Create test app let app = @@ -96,6 +102,7 @@ mod tests { assert!(body_str.contains("eris_ua_hits_total")); assert!(body_str.contains("eris_blocked_ips")); assert!(body_str.contains("eris_active_connections")); + assert!(body_str.contains("eris_rate_limited_total")); } #[actix_web::test] diff --git a/src/network.rs b/src/network/mod.rs similarity index 85% rename from src/network.rs rename to src/network/mod.rs index 56d6c65..0fd1bfa 100644 --- a/src/network.rs +++ b/src/network/mod.rs @@ -12,13 +12,28 @@ use tokio::time::sleep; use crate::config::Config; use crate::lua::{EventContext, EventType, ScriptManager}; use crate::markov::MarkovGenerator; -use crate::metrics::{ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS}; +use crate::metrics::{ + ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, RATE_LIMITED_CONNECTIONS, UA_HITS, +}; use crate::state::BotState; use crate::utils::{ choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request, find_header_end, generate_session_id, get_timestamp, }; +mod rate_limiter; +use rate_limiter::RateLimiter; + +// Global rate limiter instance. +// Default is 30 connections per IP in a 60 second window +// XXX: This might add overhead of the proxy, e.g. NGINX already implements +// rate limiting. Though I don't think we have a way of knowing if the middleman +// we are handing the connections to (from the same middleman in some cases) has +// rate limiting. +lazy_static::lazy_static! { + static ref RATE_LIMITER: RateLimiter = RateLimiter::new(60, 30); +} + // Main connection handler. // Decides whether to tarpit or proxy pub async fn handle_connection( @@ -29,7 +44,7 @@ pub async fn handle_connection( script_manager: Arc, ) { // Get peer information - let peer_addr = match stream.peer_addr() { + let peer_addr: IpAddr = match stream.peer_addr() { Ok(addr) => addr.ip(), Err(e) => { log::debug!("Failed to get peer address: {e}"); @@ -44,6 +59,52 @@ pub async fn handle_connection( return; } + // Apply rate limiting before any further processing + if config.rate_limit_enabled && !RATE_LIMITER.check_rate_limit(peer_addr).await { + log::info!("Rate limited connection from {peer_addr}"); + RATE_LIMITED_CONNECTIONS.inc(); + + // Optionally, add the IP to a temporary block list + // if it's constantly hitting the rate limit + let connection_count = RATE_LIMITER.get_connection_count(&peer_addr); + if connection_count > config.rate_limit_block_threshold { + log::warn!( + "IP {peer_addr} exceeding rate limit with {connection_count} connection attempts, considering for blocking" + ); + + // Trigger a blocked event for Lua scripts + let rate_limit_ctx = EventContext { + event_type: EventType::BlockIP, + ip: Some(peer_addr.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + script_manager.trigger_event(&rate_limit_ctx); + } + + // Either send a slow response or just close connection + if config.rate_limit_slow_response { + // Send a simple 429 Too Many Requests respons. If the bots actually respected + // HTTP error codes, the internet would be a mildly better place. + let response = "HTTP/1.1 429 Too Many Requests\r\n\ + Content-Type: text/plain\r\n\ + Retry-After: 60\r\n\ + Connection: close\r\n\ + \r\n\ + Rate limit exceeded. Please try again later."; + + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.flush().await; + } + + let _ = stream.shutdown().await; + return; + } + // Check if Lua scripts allow this connection if !script_manager.on_connection(&peer_addr.to_string()) { log::debug!("Connection rejected by Lua script: {peer_addr}"); diff --git a/src/network/rate_limiter.rs b/src/network/rate_limiter.rs new file mode 100644 index 0000000..6150793 --- /dev/null +++ b/src/network/rate_limiter.rs @@ -0,0 +1,85 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; + +pub struct RateLimiter { + connections: Arc>>>, + window_seconds: u64, + max_connections: usize, + cleanup_interval: Duration, + last_cleanup: Instant, +} + +impl RateLimiter { + pub fn new(window_seconds: u64, max_connections: usize) -> Self { + Self { + connections: Arc::new(Mutex::new(HashMap::new())), + window_seconds, + max_connections, + cleanup_interval: Duration::from_secs(60), + last_cleanup: Instant::now(), + } + } + + pub async fn check_rate_limit(&self, ip: IpAddr) -> bool { + let now = Instant::now(); + let window = Duration::from_secs(self.window_seconds); + + let mut connections = self.connections.lock().await; + + // Periodically clean up old entries across all IPs + if now.duration_since(self.last_cleanup) > self.cleanup_interval { + self.cleanup_old_entries(&mut connections, now, window); + } + + // Clean up old entries for this specific IP + if let Some(times) = connections.get_mut(&ip) { + times.retain(|time| now.duration_since(*time) < window); + + // Check if rate limit exceeded + if times.len() >= self.max_connections { + log::debug!("Rate limit exceeded for IP: {}", ip); + return false; + } + + // Add new connection time + times.push(now); + } else { + connections.insert(ip, vec![now]); + } + + true + } + + fn cleanup_old_entries( + &self, + connections: &mut HashMap>, + now: Instant, + window: Duration, + ) { + let mut empty_keys = Vec::new(); + + for (ip, times) in connections.iter_mut() { + times.retain(|time| now.duration_since(*time) < window); + if times.is_empty() { + empty_keys.push(*ip); + } + } + + // Remove empty entries + for ip in empty_keys { + connections.remove(&ip); + } + } + + pub fn get_connection_count(&self, ip: &IpAddr) -> usize { + if let Ok(connections) = self.connections.try_lock() { + if let Some(times) = connections.get(ip) { + return times.len(); + } + } + 0 + } +}