From a2fc2bf2bcbb4584bf46be21ba514f402e249ad9 Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Thu, 1 May 2025 22:03:01 +0300 Subject: [PATCH] eris: speed up responses; refactor connection handlere --- src/main.rs | 383 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 219 insertions(+), 164 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2471ced..b384a9b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -476,7 +476,48 @@ impl ScriptManager { } } -// Main connection handler - decides whether to tarpit or proxy +// Find end of HTTP headers +// XXX: I'm sure this could be made less fragile. +fn find_header_end(data: &[u8]) -> Option { + for i in 0..data.len().saturating_sub(3) { + if data[i] == b'\r' && data[i + 1] == b'\n' && data[i + 2] == b'\r' && data[i + 3] == b'\n' + { + return Some(i + 4); + } + } + None +} + +// Extract path from raw request data +fn extract_path_from_request(data: &[u8]) -> Option<&str> { + let request_line = data + .split(|&b| b == b'\r' || b == b'\n') + .next() + .filter(|line| !line.is_empty())?; + + let mut parts = request_line.split(|&b| b == b' '); + let _ = parts.next()?; // Skip HTTP method + let path = parts.next()?; + + std::str::from_utf8(path).ok() +} + +// Extract header value from raw request data +fn extract_header_value(data: &[u8], header_name: &str) -> Option { + let data_str = std::str::from_utf8(data).ok()?; + let header_prefix = format!("{header_name}: ").to_lowercase(); + + for line in data_str.lines() { + let line_lower = line.to_lowercase(); + if line_lower.starts_with(&header_prefix) { + return Some(line[header_prefix.len()..].trim().to_string()); + } + } + None +} + +// Main connection handler. +// Decides whether to tarpit or proxy async fn handle_connection( mut stream: TcpStream, config: Arc, @@ -484,6 +525,7 @@ async fn handle_connection( markov_generator: Arc, script_manager: Arc, ) { + // Get peer information let peer_addr = match stream.peer_addr() { Ok(addr) => addr.ip(), Err(e) => { @@ -492,29 +534,48 @@ async fn handle_connection( } }; - log::debug!("New connection from: {peer_addr}"); - - // Check if IP is already blocked + // Check for blocked IPs to avoid any processing if state.read().await.blocked.contains(&peer_addr) { log::debug!("Rejected connection from blocked IP: {peer_addr}"); let _ = stream.shutdown().await; return; } - // Read the HTTP request - let mut buffer = [0; 8192]; - let mut request_data = Vec::new(); + // Pre-check for whitelisted IPs to bypass heavy processing + let mut whitelisted = false; + for network_str in &config.whitelist_networks { + if let Ok(network) = network_str.parse::() { + if network.contains(peer_addr) { + whitelisted = true; + break; + } + } + } - // Read with timeout to prevent hanging + // Read buffer + let mut buffer = vec![0; 8192]; + let mut request_data = Vec::with_capacity(8192); + let mut header_end_pos = 0; + + // Read with timeout to prevent hanging resource load ops. let read_fut = async { loop { match stream.read(&mut buffer).await { Ok(0) => break, Ok(n) => { - request_data.extend_from_slice(&buffer[..n]); - // Stop reading at empty line, this is the end of HTTP headers - if request_data.len() > 2 && &request_data[request_data.len() - 2..] == b"\r\n" - { + let new_data = &buffer[..n]; + request_data.extend_from_slice(new_data); + + // Look for end of headers + if header_end_pos == 0 { + if let Some(pos) = find_header_end(&request_data) { + header_end_pos = pos; + break; + } + } + + // Avoid excessive buffering + if request_data.len() > 32768 { break; } } @@ -526,7 +587,7 @@ async fn handle_connection( } }; - let timeout_fut = sleep(Duration::from_secs(5)); + let timeout_fut = sleep(Duration::from_secs(3)); tokio::select! { () = read_fut => {}, @@ -537,54 +598,32 @@ async fn handle_connection( } } - // Parse the request - let request_str = String::from_utf8_lossy(&request_data); - let request_lines: Vec<&str> = request_str.lines().collect(); - - if request_lines.is_empty() { - log::debug!("Empty request from: {peer_addr}"); - let _ = stream.shutdown().await; + // Fast path for whitelisted IPs. Skip full parsing and speed up "approved" + // connections automatically. + if whitelisted { + log::debug!("Whitelisted IP {peer_addr} - using fast proxy path"); + proxy_fast_path(stream, request_data, &config.backend_addr).await; return; } - // Parse request line - let request_parts: Vec<&str> = request_lines[0].split_whitespace().collect(); - if request_parts.len() < 3 { - log::debug!("Malformed request from {}: {}", peer_addr, request_lines[0]); + // Parse minimally to extract the path + let path = if let Some(p) = extract_path_from_request(&request_data) { + p + } else { + log::debug!("Invalid request from {peer_addr}"); let _ = stream.shutdown().await; return; - } - - let method = request_parts[0]; - let path = request_parts[1]; - let protocol = request_parts[2]; - - log::debug!("Request: {method} {path} {protocol} from {peer_addr}"); - - // Parse headers - let mut headers = HashMap::new(); - for line in &request_lines[1..] { - if line.is_empty() { - break; - } - - if let Some(idx) = line.find(':') { - let key = line[..idx].trim(); - let value = line[idx + 1..].trim(); - headers.insert(key, value.to_string()); - } - } - - let user_agent = headers - .get("user-agent") - .cloned() - .unwrap_or_else(|| "unknown".to_string()); + }; // Check if this request matches our tarpit patterns let should_tarpit = should_tarpit(path, &peer_addr, &config).await; if should_tarpit { - log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})"); + // Extract minimal info needed for tarpit + let user_agent = extract_header_value(&request_data, "user-agent") + .unwrap_or_else(|| "unknown".to_string()); + + log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})"); // Update metrics HITS_COUNTER.inc(); @@ -599,7 +638,6 @@ async fn handle_connection( *state.hits.entry(peer_addr).or_insert(0) += 1; let hit_count = state.hits[&peer_addr]; - log::debug!("Hit count for {peer_addr}: {hit_count}"); // Block IPs that hit tarpits too many times if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) { @@ -608,7 +646,7 @@ async fn handle_connection( BLOCKED_IPS.set(state.blocked.len() as f64); state.save_to_disk(); - // Try to add to firewall + // Do firewall blocking in background let peer_addr_str = peer_addr.to_string(); tokio::spawn(async move { log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); @@ -660,42 +698,117 @@ async fn handle_connection( ) .await; } else { - log::debug!("Proxying request: {method} {path} from {peer_addr}"); - - // Proxy non-matching requests to the actual backend - proxy_to_backend( - stream, - method, - path, - protocol, - &headers, - &config.backend_addr, - ) - .await; + log::debug!("Proxying request: {path} from {peer_addr}"); + proxy_fast_path(stream, request_data, &config.backend_addr).await; } } -// Determine if a request should be tarpitted based on path and IP +// Forward a legitimate request to the real backend server +async fn proxy_fast_path(mut client_stream: TcpStream, request_data: Vec, backend_addr: &str) { + // Connect to backend server + let server_stream = match TcpStream::connect(backend_addr).await { + Ok(stream) => stream, + Err(e) => { + log::warn!("Failed to connect to backend {backend_addr}: {e}"); + let _ = client_stream.shutdown().await; + return; + } + }; + + // Set TCP_NODELAY for both streams before splitting them + if let Err(e) = client_stream.set_nodelay(true) { + log::debug!("Failed to set TCP_NODELAY on client stream: {e}"); + } + + let mut server_stream = server_stream; + if let Err(e) = server_stream.set_nodelay(true) { + log::debug!("Failed to set TCP_NODELAY on server stream: {e}"); + } + + // Forward the original request bytes directly without parsing + if server_stream.write_all(&request_data).await.is_err() { + log::debug!("Failed to write request to backend server"); + let _ = client_stream.shutdown().await; + return; + } + + // Now split the streams for concurrent reading/writing + let (mut client_read, mut client_write) = client_stream.split(); + let (mut server_read, mut server_write) = server_stream.split(); + + // 32KB buffer + let buf_size = 32768; + + // Client -> Server + let client_to_server = async { + let mut buf = vec![0; buf_size]; + let mut bytes_forwarded = 0; + + loop { + match client_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if server_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + // Ensure everything is sent + let _ = server_write.flush().await; + log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); + }; + + // Server -> Client + let server_to_client = async { + let mut buf = vec![0; buf_size]; + let mut bytes_forwarded = 0; + + loop { + match server_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + // Ensure everything is sent + let _ = client_write.flush().await; + log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); + }; + + // Run both directions concurrently + tokio::join!(client_to_server, server_to_client); + log::debug!("Fast proxy connection completed"); +} + +// Decide if a request should be tarpitted based on path and IP async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { - // Don't tarpit whitelisted IPs (internal networks, etc) + // Check whitelist IPs first to avoid unnecessary pattern matching for network_str in &config.whitelist_networks { if let Ok(network) = network_str.parse::() { if network.contains(*ip) { - log::debug!("IP {ip} is in whitelist network {network_str}"); return false; } } } - // Check if the request path matches any of our trap patterns + // Use a more efficient pattern matching approach + let path_lower = path.to_lowercase(); for pattern in &config.trap_patterns { - if path.contains(pattern) { - log::debug!("Path '{path}' matches trap pattern '{pattern}'"); + if path_lower.contains(pattern) { return true; } } - // No trap patterns matched false } @@ -733,12 +846,12 @@ async fn generate_deceptive_response( let markov_text = markov.generate(response_type, 30); // Use Lua to enhance with honeytokens and other deceptive content - let enhanced = + let response_expanded = script_manager.expand_response(&markov_text, response_type, path, &tracking_token); // Return full HTTP response with appropriate headers format!( - "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{enhanced}" + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{response_expanded}" ) } @@ -831,96 +944,6 @@ async fn tarpit_connection( let _ = stream.shutdown().await; } -// Forward a legitimate request to the real backend server -async fn proxy_to_backend( - mut client_stream: TcpStream, - method: &str, - path: &str, - protocol: &str, - headers: &HashMap<&str, String>, - backend_addr: &str, -) { - // Connect to backend server - let server_stream = match TcpStream::connect(backend_addr).await { - Ok(stream) => stream, - Err(e) => { - log::warn!("Failed to connect to backend {backend_addr}: {e}"); - let _ = client_stream.shutdown().await; - return; - } - }; - - log::debug!("Connected to backend server at {backend_addr}"); - - // Forward the original request - let mut request = format!("{method} {path} {protocol}\r\n"); - for (key, value) in headers { - request.push_str(&format!("{key}: {value}\r\n")); - } - request.push_str("\r\n"); - - let mut server_stream = server_stream; - if server_stream.write_all(request.as_bytes()).await.is_err() { - log::debug!("Failed to write request to backend server"); - let _ = client_stream.shutdown().await; - return; - } - - // Set up bidirectional forwarding between client and backend - let (mut client_read, mut client_write) = client_stream.split(); - let (mut server_read, mut server_write) = server_stream.split(); - - // Client -> Server - let client_to_server = async { - let mut buf = [0; 8192]; - let mut bytes_forwarded = 0; - - loop { - match client_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if server_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); - }; - - // Server -> Client - let server_to_client = async { - let mut buf = [0; 8192]; - let mut bytes_forwarded = 0; - - loop { - match server_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if client_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); - }; - - // Run both directions concurrently - tokio::select! { - () = client_to_server => {}, - () = server_to_client => {}, - } - - log::debug!("Proxy connection completed"); -} - // Set up nftables firewall rules for IP blocking async fn setup_firewall() -> Result<(), String> { log::info!("Setting up firewall rules"); @@ -1439,4 +1462,36 @@ mod tests { // Verify tracking token is included assert!(resp1.contains("BOT_TestBot")); } + + #[test] + fn test_find_header_end() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content"; + assert_eq!(find_header_end(data), Some(53)); + + let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n"; + assert_eq!(find_header_end(incomplete), None); + } + + #[test] + fn test_extract_path_from_request() { + let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert_eq!(extract_path_from_request(data), Some("/index.html")); + + let bad_data = b"INVALID DATA"; + assert_eq!(extract_path_from_request(bad_data), None); + } + + #[test] + fn test_extract_header_value() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n"; + assert_eq!( + extract_header_value(data, "user-agent"), + Some("TestBot/1.0".to_string()) + ); + assert_eq!( + extract_header_value(data, "Host"), + Some("example.com".to_string()) + ); + assert_eq!(extract_header_value(data, "nonexistent"), None); + } }