eris: speed up responses; refactor connection handlere

This commit is contained in:
raf 2025-05-01 22:03:01 +03:00
commit a2fc2bf2bc
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF

View file

@ -476,7 +476,48 @@ impl ScriptManager {
} }
} }
// Main connection handler - decides whether to tarpit or proxy // Find end of HTTP headers
// XXX: I'm sure this could be made less fragile.
fn find_header_end(data: &[u8]) -> Option<usize> {
for i in 0..data.len().saturating_sub(3) {
if data[i] == b'\r' && data[i + 1] == b'\n' && data[i + 2] == b'\r' && data[i + 3] == b'\n'
{
return Some(i + 4);
}
}
None
}
// Extract path from raw request data
fn extract_path_from_request(data: &[u8]) -> Option<&str> {
let request_line = data
.split(|&b| b == b'\r' || b == b'\n')
.next()
.filter(|line| !line.is_empty())?;
let mut parts = request_line.split(|&b| b == b' ');
let _ = parts.next()?; // Skip HTTP method
let path = parts.next()?;
std::str::from_utf8(path).ok()
}
// Extract header value from raw request data
fn extract_header_value(data: &[u8], header_name: &str) -> Option<String> {
let data_str = std::str::from_utf8(data).ok()?;
let header_prefix = format!("{header_name}: ").to_lowercase();
for line in data_str.lines() {
let line_lower = line.to_lowercase();
if line_lower.starts_with(&header_prefix) {
return Some(line[header_prefix.len()..].trim().to_string());
}
}
None
}
// Main connection handler.
// Decides whether to tarpit or proxy
async fn handle_connection( async fn handle_connection(
mut stream: TcpStream, mut stream: TcpStream,
config: Arc<Config>, config: Arc<Config>,
@ -484,6 +525,7 @@ async fn handle_connection(
markov_generator: Arc<MarkovGenerator>, markov_generator: Arc<MarkovGenerator>,
script_manager: Arc<ScriptManager>, script_manager: Arc<ScriptManager>,
) { ) {
// Get peer information
let peer_addr = match stream.peer_addr() { let peer_addr = match stream.peer_addr() {
Ok(addr) => addr.ip(), Ok(addr) => addr.ip(),
Err(e) => { Err(e) => {
@ -492,29 +534,48 @@ async fn handle_connection(
} }
}; };
log::debug!("New connection from: {peer_addr}"); // Check for blocked IPs to avoid any processing
// Check if IP is already blocked
if state.read().await.blocked.contains(&peer_addr) { if state.read().await.blocked.contains(&peer_addr) {
log::debug!("Rejected connection from blocked IP: {peer_addr}"); log::debug!("Rejected connection from blocked IP: {peer_addr}");
let _ = stream.shutdown().await; let _ = stream.shutdown().await;
return; return;
} }
// Read the HTTP request // Pre-check for whitelisted IPs to bypass heavy processing
let mut buffer = [0; 8192]; let mut whitelisted = false;
let mut request_data = Vec::new(); for network_str in &config.whitelist_networks {
if let Ok(network) = network_str.parse::<IpNetwork>() {
if network.contains(peer_addr) {
whitelisted = true;
break;
}
}
}
// Read with timeout to prevent hanging // Read buffer
let mut buffer = vec![0; 8192];
let mut request_data = Vec::with_capacity(8192);
let mut header_end_pos = 0;
// Read with timeout to prevent hanging resource load ops.
let read_fut = async { let read_fut = async {
loop { loop {
match stream.read(&mut buffer).await { match stream.read(&mut buffer).await {
Ok(0) => break, Ok(0) => break,
Ok(n) => { Ok(n) => {
request_data.extend_from_slice(&buffer[..n]); let new_data = &buffer[..n];
// Stop reading at empty line, this is the end of HTTP headers request_data.extend_from_slice(new_data);
if request_data.len() > 2 && &request_data[request_data.len() - 2..] == b"\r\n"
{ // Look for end of headers
if header_end_pos == 0 {
if let Some(pos) = find_header_end(&request_data) {
header_end_pos = pos;
break;
}
}
// Avoid excessive buffering
if request_data.len() > 32768 {
break; break;
} }
} }
@ -526,7 +587,7 @@ async fn handle_connection(
} }
}; };
let timeout_fut = sleep(Duration::from_secs(5)); let timeout_fut = sleep(Duration::from_secs(3));
tokio::select! { tokio::select! {
() = read_fut => {}, () = read_fut => {},
@ -537,54 +598,32 @@ async fn handle_connection(
} }
} }
// Parse the request // Fast path for whitelisted IPs. Skip full parsing and speed up "approved"
let request_str = String::from_utf8_lossy(&request_data); // connections automatically.
let request_lines: Vec<&str> = request_str.lines().collect(); if whitelisted {
log::debug!("Whitelisted IP {peer_addr} - using fast proxy path");
if request_lines.is_empty() { proxy_fast_path(stream, request_data, &config.backend_addr).await;
log::debug!("Empty request from: {peer_addr}");
let _ = stream.shutdown().await;
return; return;
} }
// Parse request line // Parse minimally to extract the path
let request_parts: Vec<&str> = request_lines[0].split_whitespace().collect(); let path = if let Some(p) = extract_path_from_request(&request_data) {
if request_parts.len() < 3 { p
log::debug!("Malformed request from {}: {}", peer_addr, request_lines[0]); } else {
log::debug!("Invalid request from {peer_addr}");
let _ = stream.shutdown().await; let _ = stream.shutdown().await;
return; return;
} };
let method = request_parts[0];
let path = request_parts[1];
let protocol = request_parts[2];
log::debug!("Request: {method} {path} {protocol} from {peer_addr}");
// Parse headers
let mut headers = HashMap::new();
for line in &request_lines[1..] {
if line.is_empty() {
break;
}
if let Some(idx) = line.find(':') {
let key = line[..idx].trim();
let value = line[idx + 1..].trim();
headers.insert(key, value.to_string());
}
}
let user_agent = headers
.get("user-agent")
.cloned()
.unwrap_or_else(|| "unknown".to_string());
// Check if this request matches our tarpit patterns // Check if this request matches our tarpit patterns
let should_tarpit = should_tarpit(path, &peer_addr, &config).await; let should_tarpit = should_tarpit(path, &peer_addr, &config).await;
if should_tarpit { if should_tarpit {
log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})"); // Extract minimal info needed for tarpit
let user_agent = extract_header_value(&request_data, "user-agent")
.unwrap_or_else(|| "unknown".to_string());
log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})");
// Update metrics // Update metrics
HITS_COUNTER.inc(); HITS_COUNTER.inc();
@ -599,7 +638,6 @@ async fn handle_connection(
*state.hits.entry(peer_addr).or_insert(0) += 1; *state.hits.entry(peer_addr).or_insert(0) += 1;
let hit_count = state.hits[&peer_addr]; let hit_count = state.hits[&peer_addr];
log::debug!("Hit count for {peer_addr}: {hit_count}");
// Block IPs that hit tarpits too many times // Block IPs that hit tarpits too many times
if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) { if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) {
@ -608,7 +646,7 @@ async fn handle_connection(
BLOCKED_IPS.set(state.blocked.len() as f64); BLOCKED_IPS.set(state.blocked.len() as f64);
state.save_to_disk(); state.save_to_disk();
// Try to add to firewall // Do firewall blocking in background
let peer_addr_str = peer_addr.to_string(); let peer_addr_str = peer_addr.to_string();
tokio::spawn(async move { tokio::spawn(async move {
log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); log::debug!("Adding IP {peer_addr_str} to firewall blacklist");
@ -660,42 +698,117 @@ async fn handle_connection(
) )
.await; .await;
} else { } else {
log::debug!("Proxying request: {method} {path} from {peer_addr}"); log::debug!("Proxying request: {path} from {peer_addr}");
proxy_fast_path(stream, request_data, &config.backend_addr).await;
// Proxy non-matching requests to the actual backend
proxy_to_backend(
stream,
method,
path,
protocol,
&headers,
&config.backend_addr,
)
.await;
} }
} }
// Determine if a request should be tarpitted based on path and IP // Forward a legitimate request to the real backend server
async fn proxy_fast_path(mut client_stream: TcpStream, request_data: Vec<u8>, backend_addr: &str) {
// Connect to backend server
let server_stream = match TcpStream::connect(backend_addr).await {
Ok(stream) => stream,
Err(e) => {
log::warn!("Failed to connect to backend {backend_addr}: {e}");
let _ = client_stream.shutdown().await;
return;
}
};
// Set TCP_NODELAY for both streams before splitting them
if let Err(e) = client_stream.set_nodelay(true) {
log::debug!("Failed to set TCP_NODELAY on client stream: {e}");
}
let mut server_stream = server_stream;
if let Err(e) = server_stream.set_nodelay(true) {
log::debug!("Failed to set TCP_NODELAY on server stream: {e}");
}
// Forward the original request bytes directly without parsing
if server_stream.write_all(&request_data).await.is_err() {
log::debug!("Failed to write request to backend server");
let _ = client_stream.shutdown().await;
return;
}
// Now split the streams for concurrent reading/writing
let (mut client_read, mut client_write) = client_stream.split();
let (mut server_read, mut server_write) = server_stream.split();
// 32KB buffer
let buf_size = 32768;
// Client -> Server
let client_to_server = async {
let mut buf = vec![0; buf_size];
let mut bytes_forwarded = 0;
loop {
match client_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if server_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
// Ensure everything is sent
let _ = server_write.flush().await;
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
};
// Server -> Client
let server_to_client = async {
let mut buf = vec![0; buf_size];
let mut bytes_forwarded = 0;
loop {
match server_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
// Ensure everything is sent
let _ = client_write.flush().await;
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
};
// Run both directions concurrently
tokio::join!(client_to_server, server_to_client);
log::debug!("Fast proxy connection completed");
}
// Decide if a request should be tarpitted based on path and IP
async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool {
// Don't tarpit whitelisted IPs (internal networks, etc) // Check whitelist IPs first to avoid unnecessary pattern matching
for network_str in &config.whitelist_networks { for network_str in &config.whitelist_networks {
if let Ok(network) = network_str.parse::<IpNetwork>() { if let Ok(network) = network_str.parse::<IpNetwork>() {
if network.contains(*ip) { if network.contains(*ip) {
log::debug!("IP {ip} is in whitelist network {network_str}");
return false; return false;
} }
} }
} }
// Check if the request path matches any of our trap patterns // Use a more efficient pattern matching approach
let path_lower = path.to_lowercase();
for pattern in &config.trap_patterns { for pattern in &config.trap_patterns {
if path.contains(pattern) { if path_lower.contains(pattern) {
log::debug!("Path '{path}' matches trap pattern '{pattern}'");
return true; return true;
} }
} }
// No trap patterns matched
false false
} }
@ -733,12 +846,12 @@ async fn generate_deceptive_response(
let markov_text = markov.generate(response_type, 30); let markov_text = markov.generate(response_type, 30);
// Use Lua to enhance with honeytokens and other deceptive content // Use Lua to enhance with honeytokens and other deceptive content
let enhanced = let response_expanded =
script_manager.expand_response(&markov_text, response_type, path, &tracking_token); script_manager.expand_response(&markov_text, response_type, path, &tracking_token);
// Return full HTTP response with appropriate headers // Return full HTTP response with appropriate headers
format!( format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{enhanced}" "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{response_expanded}"
) )
} }
@ -831,96 +944,6 @@ async fn tarpit_connection(
let _ = stream.shutdown().await; let _ = stream.shutdown().await;
} }
// Forward a legitimate request to the real backend server
async fn proxy_to_backend(
mut client_stream: TcpStream,
method: &str,
path: &str,
protocol: &str,
headers: &HashMap<&str, String>,
backend_addr: &str,
) {
// Connect to backend server
let server_stream = match TcpStream::connect(backend_addr).await {
Ok(stream) => stream,
Err(e) => {
log::warn!("Failed to connect to backend {backend_addr}: {e}");
let _ = client_stream.shutdown().await;
return;
}
};
log::debug!("Connected to backend server at {backend_addr}");
// Forward the original request
let mut request = format!("{method} {path} {protocol}\r\n");
for (key, value) in headers {
request.push_str(&format!("{key}: {value}\r\n"));
}
request.push_str("\r\n");
let mut server_stream = server_stream;
if server_stream.write_all(request.as_bytes()).await.is_err() {
log::debug!("Failed to write request to backend server");
let _ = client_stream.shutdown().await;
return;
}
// Set up bidirectional forwarding between client and backend
let (mut client_read, mut client_write) = client_stream.split();
let (mut server_read, mut server_write) = server_stream.split();
// Client -> Server
let client_to_server = async {
let mut buf = [0; 8192];
let mut bytes_forwarded = 0;
loop {
match client_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if server_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
};
// Server -> Client
let server_to_client = async {
let mut buf = [0; 8192];
let mut bytes_forwarded = 0;
loop {
match server_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
};
// Run both directions concurrently
tokio::select! {
() = client_to_server => {},
() = server_to_client => {},
}
log::debug!("Proxy connection completed");
}
// Set up nftables firewall rules for IP blocking // Set up nftables firewall rules for IP blocking
async fn setup_firewall() -> Result<(), String> { async fn setup_firewall() -> Result<(), String> {
log::info!("Setting up firewall rules"); log::info!("Setting up firewall rules");
@ -1439,4 +1462,36 @@ mod tests {
// Verify tracking token is included // Verify tracking token is included
assert!(resp1.contains("BOT_TestBot")); assert!(resp1.contains("BOT_TestBot"));
} }
#[test]
fn test_find_header_end() {
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content";
assert_eq!(find_header_end(data), Some(53));
let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n";
assert_eq!(find_header_end(incomplete), None);
}
#[test]
fn test_extract_path_from_request() {
let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n";
assert_eq!(extract_path_from_request(data), Some("/index.html"));
let bad_data = b"INVALID DATA";
assert_eq!(extract_path_from_request(bad_data), None);
}
#[test]
fn test_extract_header_value() {
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n";
assert_eq!(
extract_header_value(data, "user-agent"),
Some("TestBot/1.0".to_string())
);
assert_eq!(
extract_header_value(data, "Host"),
Some("example.com".to_string())
);
assert_eq!(extract_header_value(data, "nonexistent"), None);
}
} }