network: implement basic ratelimiting
Still pretty barebones, all things considered. I want to revisit this once multi-site support is in.
This commit is contained in:
parent
6f6a60f667
commit
ff3f80adda
4 changed files with 217 additions and 2 deletions
|
@ -92,6 +92,32 @@ pub struct Args {
|
|||
help = "Log format: plain, pretty, json, pretty-json"
|
||||
)]
|
||||
pub log_format: String,
|
||||
|
||||
#[clap(long, help = "Enable rate limiting for connections from the same IP")]
|
||||
pub rate_limit_enabled: bool,
|
||||
|
||||
#[clap(long, default_value = "60", help = "Rate limit window in seconds")]
|
||||
pub rate_limit_window: u64,
|
||||
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "30",
|
||||
help = "Maximum number of connections allowed per IP in the rate limit window"
|
||||
)]
|
||||
pub rate_limit_max: usize,
|
||||
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "100",
|
||||
help = "Connection attempts threshold before considering for IP blocking"
|
||||
)]
|
||||
pub rate_limit_block_threshold: usize,
|
||||
|
||||
#[clap(
|
||||
long,
|
||||
help = "Send a 429 response for rate limited connections instead of dropping connection"
|
||||
)]
|
||||
pub rate_limit_slow_response: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
|
@ -162,6 +188,11 @@ pub struct Config {
|
|||
pub config_dir: String,
|
||||
pub cache_dir: String,
|
||||
pub log_format: LogFormat,
|
||||
pub rate_limit_enabled: bool,
|
||||
pub rate_limit_window_seconds: u64,
|
||||
pub rate_limit_max_connections: usize,
|
||||
pub rate_limit_block_threshold: usize,
|
||||
pub rate_limit_slow_response: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
|
@ -224,6 +255,11 @@ impl Default for Config {
|
|||
config_dir: "./conf".to_string(),
|
||||
cache_dir: "./cache".to_string(),
|
||||
log_format: LogFormat::Pretty,
|
||||
rate_limit_enabled: true,
|
||||
rate_limit_window_seconds: 60,
|
||||
rate_limit_max_connections: 30,
|
||||
rate_limit_block_threshold: 100,
|
||||
rate_limit_slow_response: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -297,6 +333,12 @@ impl Config {
|
|||
data_dir,
|
||||
config_dir,
|
||||
cache_dir,
|
||||
log_format: LogFormat::Pretty,
|
||||
rate_limit_enabled: args.rate_limit_enabled,
|
||||
rate_limit_window_seconds: args.rate_limit_window,
|
||||
rate_limit_max_connections: args.rate_limit_max,
|
||||
rate_limit_block_threshold: args.rate_limit_block_threshold,
|
||||
rate_limit_slow_response: args.rate_limit_slow_response,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
@ -425,6 +467,11 @@ mod tests {
|
|||
config_file: None,
|
||||
log_level: "debug".to_string(),
|
||||
log_format: "pretty".to_string(),
|
||||
rate_limit_enabled: true,
|
||||
rate_limit_window: 30,
|
||||
rate_limit_max: 20,
|
||||
rate_limit_block_threshold: 50,
|
||||
rate_limit_slow_response: true,
|
||||
};
|
||||
|
||||
let config = Config::from_args(&args);
|
||||
|
@ -441,6 +488,11 @@ mod tests {
|
|||
assert_eq!(config.data_dir, "/tmp/eris/data");
|
||||
assert_eq!(config.config_dir, "/tmp/eris/conf");
|
||||
assert_eq!(config.cache_dir, "/tmp/eris/cache");
|
||||
assert!(config.rate_limit_enabled);
|
||||
assert_eq!(config.rate_limit_window_seconds, 30);
|
||||
assert_eq!(config.rate_limit_max_connections, 20);
|
||||
assert_eq!(config.rate_limit_block_threshold, 50);
|
||||
assert!(config.rate_limit_slow_response);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -537,12 +589,22 @@ mod tests {
|
|||
let loaded_json = Config::load_from_file(&json_path).unwrap();
|
||||
assert_eq!(loaded_json.listen_addr, config.listen_addr);
|
||||
assert_eq!(loaded_json.min_delay, config.min_delay);
|
||||
assert_eq!(loaded_json.rate_limit_enabled, config.rate_limit_enabled);
|
||||
assert_eq!(
|
||||
loaded_json.rate_limit_max_connections,
|
||||
config.rate_limit_max_connections
|
||||
);
|
||||
|
||||
// Test TOML serialization and deserialization
|
||||
config.save_to_file(&toml_path).unwrap();
|
||||
let loaded_toml = Config::load_from_file(&toml_path).unwrap();
|
||||
assert_eq!(loaded_toml.listen_addr, config.listen_addr);
|
||||
assert_eq!(loaded_toml.min_delay, config.min_delay);
|
||||
assert_eq!(loaded_toml.rate_limit_enabled, config.rate_limit_enabled);
|
||||
assert_eq!(
|
||||
loaded_toml.rate_limit_max_connections,
|
||||
config.rate_limit_max_connections
|
||||
);
|
||||
|
||||
// Clean up
|
||||
let _ = std::fs::remove_file(json_path);
|
||||
|
|
|
@ -24,6 +24,11 @@ lazy_static! {
|
|||
register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap();
|
||||
pub static ref UA_HITS: CounterVec =
|
||||
register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap();
|
||||
pub static ref RATE_LIMITED_CONNECTIONS: Counter = register_counter!(
|
||||
"eris_rate_limited_total",
|
||||
"Number of connections rejected due to rate limiting"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Prometheus metrics endpoint
|
||||
|
@ -76,6 +81,7 @@ mod tests {
|
|||
UA_HITS.with_label_values(&["TestBot/1.0"]).inc();
|
||||
BLOCKED_IPS.set(5.0);
|
||||
ACTIVE_CONNECTIONS.set(3.0);
|
||||
RATE_LIMITED_CONNECTIONS.inc();
|
||||
|
||||
// Create test app
|
||||
let app =
|
||||
|
@ -96,6 +102,7 @@ mod tests {
|
|||
assert!(body_str.contains("eris_ua_hits_total"));
|
||||
assert!(body_str.contains("eris_blocked_ips"));
|
||||
assert!(body_str.contains("eris_active_connections"));
|
||||
assert!(body_str.contains("eris_rate_limited_total"));
|
||||
}
|
||||
|
||||
#[actix_web::test]
|
||||
|
|
|
@ -12,13 +12,28 @@ use tokio::time::sleep;
|
|||
use crate::config::Config;
|
||||
use crate::lua::{EventContext, EventType, ScriptManager};
|
||||
use crate::markov::MarkovGenerator;
|
||||
use crate::metrics::{ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS};
|
||||
use crate::metrics::{
|
||||
ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, RATE_LIMITED_CONNECTIONS, UA_HITS,
|
||||
};
|
||||
use crate::state::BotState;
|
||||
use crate::utils::{
|
||||
choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request,
|
||||
find_header_end, generate_session_id, get_timestamp,
|
||||
};
|
||||
|
||||
mod rate_limiter;
|
||||
use rate_limiter::RateLimiter;
|
||||
|
||||
// Global rate limiter instance.
|
||||
// Default is 30 connections per IP in a 60 second window
|
||||
// XXX: This might add overhead of the proxy, e.g. NGINX already implements
|
||||
// rate limiting. Though I don't think we have a way of knowing if the middleman
|
||||
// we are handing the connections to (from the same middleman in some cases) has
|
||||
// rate limiting.
|
||||
lazy_static::lazy_static! {
|
||||
static ref RATE_LIMITER: RateLimiter = RateLimiter::new(60, 30);
|
||||
}
|
||||
|
||||
// Main connection handler.
|
||||
// Decides whether to tarpit or proxy
|
||||
pub async fn handle_connection(
|
||||
|
@ -29,7 +44,7 @@ pub async fn handle_connection(
|
|||
script_manager: Arc<ScriptManager>,
|
||||
) {
|
||||
// Get peer information
|
||||
let peer_addr = match stream.peer_addr() {
|
||||
let peer_addr: IpAddr = match stream.peer_addr() {
|
||||
Ok(addr) => addr.ip(),
|
||||
Err(e) => {
|
||||
log::debug!("Failed to get peer address: {e}");
|
||||
|
@ -44,6 +59,52 @@ pub async fn handle_connection(
|
|||
return;
|
||||
}
|
||||
|
||||
// Apply rate limiting before any further processing
|
||||
if config.rate_limit_enabled && !RATE_LIMITER.check_rate_limit(peer_addr).await {
|
||||
log::info!("Rate limited connection from {peer_addr}");
|
||||
RATE_LIMITED_CONNECTIONS.inc();
|
||||
|
||||
// Optionally, add the IP to a temporary block list
|
||||
// if it's constantly hitting the rate limit
|
||||
let connection_count = RATE_LIMITER.get_connection_count(&peer_addr);
|
||||
if connection_count > config.rate_limit_block_threshold {
|
||||
log::warn!(
|
||||
"IP {peer_addr} exceeding rate limit with {connection_count} connection attempts, considering for blocking"
|
||||
);
|
||||
|
||||
// Trigger a blocked event for Lua scripts
|
||||
let rate_limit_ctx = EventContext {
|
||||
event_type: EventType::BlockIP,
|
||||
ip: Some(peer_addr.to_string()),
|
||||
path: None,
|
||||
user_agent: None,
|
||||
request_headers: None,
|
||||
content: None,
|
||||
timestamp: get_timestamp(),
|
||||
session_id: None,
|
||||
};
|
||||
script_manager.trigger_event(&rate_limit_ctx);
|
||||
}
|
||||
|
||||
// Either send a slow response or just close connection
|
||||
if config.rate_limit_slow_response {
|
||||
// Send a simple 429 Too Many Requests respons. If the bots actually respected
|
||||
// HTTP error codes, the internet would be a mildly better place.
|
||||
let response = "HTTP/1.1 429 Too Many Requests\r\n\
|
||||
Content-Type: text/plain\r\n\
|
||||
Retry-After: 60\r\n\
|
||||
Connection: close\r\n\
|
||||
\r\n\
|
||||
Rate limit exceeded. Please try again later.";
|
||||
|
||||
let _ = stream.write_all(response.as_bytes()).await;
|
||||
let _ = stream.flush().await;
|
||||
}
|
||||
|
||||
let _ = stream.shutdown().await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if Lua scripts allow this connection
|
||||
if !script_manager.on_connection(&peer_addr.to_string()) {
|
||||
log::debug!("Connection rejected by Lua script: {peer_addr}");
|
85
src/network/rate_limiter.rs
Normal file
85
src/network/rate_limiter.rs
Normal file
|
@ -0,0 +1,85 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct RateLimiter {
|
||||
connections: Arc<Mutex<HashMap<IpAddr, Vec<Instant>>>>,
|
||||
window_seconds: u64,
|
||||
max_connections: usize,
|
||||
cleanup_interval: Duration,
|
||||
last_cleanup: Instant,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(window_seconds: u64, max_connections: usize) -> Self {
|
||||
Self {
|
||||
connections: Arc::new(Mutex::new(HashMap::new())),
|
||||
window_seconds,
|
||||
max_connections,
|
||||
cleanup_interval: Duration::from_secs(60),
|
||||
last_cleanup: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_rate_limit(&self, ip: IpAddr) -> bool {
|
||||
let now = Instant::now();
|
||||
let window = Duration::from_secs(self.window_seconds);
|
||||
|
||||
let mut connections = self.connections.lock().await;
|
||||
|
||||
// Periodically clean up old entries across all IPs
|
||||
if now.duration_since(self.last_cleanup) > self.cleanup_interval {
|
||||
self.cleanup_old_entries(&mut connections, now, window);
|
||||
}
|
||||
|
||||
// Clean up old entries for this specific IP
|
||||
if let Some(times) = connections.get_mut(&ip) {
|
||||
times.retain(|time| now.duration_since(*time) < window);
|
||||
|
||||
// Check if rate limit exceeded
|
||||
if times.len() >= self.max_connections {
|
||||
log::debug!("Rate limit exceeded for IP: {}", ip);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add new connection time
|
||||
times.push(now);
|
||||
} else {
|
||||
connections.insert(ip, vec![now]);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn cleanup_old_entries(
|
||||
&self,
|
||||
connections: &mut HashMap<IpAddr, Vec<Instant>>,
|
||||
now: Instant,
|
||||
window: Duration,
|
||||
) {
|
||||
let mut empty_keys = Vec::new();
|
||||
|
||||
for (ip, times) in connections.iter_mut() {
|
||||
times.retain(|time| now.duration_since(*time) < window);
|
||||
if times.is_empty() {
|
||||
empty_keys.push(*ip);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove empty entries
|
||||
for ip in empty_keys {
|
||||
connections.remove(&ip);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_connection_count(&self, ip: &IpAddr) -> usize {
|
||||
if let Ok(connections) = self.connections.try_lock() {
|
||||
if let Some(times) = connections.get(ip) {
|
||||
return times.len();
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue