eris: speed up responses; refactor connection handlere

This commit is contained in:
raf 2025-05-01 22:03:01 +03:00
commit a2fc2bf2bc
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF

View file

@ -476,7 +476,48 @@ impl ScriptManager {
}
}
// Main connection handler - decides whether to tarpit or proxy
// Find end of HTTP headers
// XXX: I'm sure this could be made less fragile.
fn find_header_end(data: &[u8]) -> Option<usize> {
for i in 0..data.len().saturating_sub(3) {
if data[i] == b'\r' && data[i + 1] == b'\n' && data[i + 2] == b'\r' && data[i + 3] == b'\n'
{
return Some(i + 4);
}
}
None
}
// Extract path from raw request data
fn extract_path_from_request(data: &[u8]) -> Option<&str> {
let request_line = data
.split(|&b| b == b'\r' || b == b'\n')
.next()
.filter(|line| !line.is_empty())?;
let mut parts = request_line.split(|&b| b == b' ');
let _ = parts.next()?; // Skip HTTP method
let path = parts.next()?;
std::str::from_utf8(path).ok()
}
// Extract header value from raw request data
fn extract_header_value(data: &[u8], header_name: &str) -> Option<String> {
let data_str = std::str::from_utf8(data).ok()?;
let header_prefix = format!("{header_name}: ").to_lowercase();
for line in data_str.lines() {
let line_lower = line.to_lowercase();
if line_lower.starts_with(&header_prefix) {
return Some(line[header_prefix.len()..].trim().to_string());
}
}
None
}
// Main connection handler.
// Decides whether to tarpit or proxy
async fn handle_connection(
mut stream: TcpStream,
config: Arc<Config>,
@ -484,6 +525,7 @@ async fn handle_connection(
markov_generator: Arc<MarkovGenerator>,
script_manager: Arc<ScriptManager>,
) {
// Get peer information
let peer_addr = match stream.peer_addr() {
Ok(addr) => addr.ip(),
Err(e) => {
@ -492,29 +534,48 @@ async fn handle_connection(
}
};
log::debug!("New connection from: {peer_addr}");
// Check if IP is already blocked
// Check for blocked IPs to avoid any processing
if state.read().await.blocked.contains(&peer_addr) {
log::debug!("Rejected connection from blocked IP: {peer_addr}");
let _ = stream.shutdown().await;
return;
}
// Read the HTTP request
let mut buffer = [0; 8192];
let mut request_data = Vec::new();
// Pre-check for whitelisted IPs to bypass heavy processing
let mut whitelisted = false;
for network_str in &config.whitelist_networks {
if let Ok(network) = network_str.parse::<IpNetwork>() {
if network.contains(peer_addr) {
whitelisted = true;
break;
}
}
}
// Read with timeout to prevent hanging
// Read buffer
let mut buffer = vec![0; 8192];
let mut request_data = Vec::with_capacity(8192);
let mut header_end_pos = 0;
// Read with timeout to prevent hanging resource load ops.
let read_fut = async {
loop {
match stream.read(&mut buffer).await {
Ok(0) => break,
Ok(n) => {
request_data.extend_from_slice(&buffer[..n]);
// Stop reading at empty line, this is the end of HTTP headers
if request_data.len() > 2 && &request_data[request_data.len() - 2..] == b"\r\n"
{
let new_data = &buffer[..n];
request_data.extend_from_slice(new_data);
// Look for end of headers
if header_end_pos == 0 {
if let Some(pos) = find_header_end(&request_data) {
header_end_pos = pos;
break;
}
}
// Avoid excessive buffering
if request_data.len() > 32768 {
break;
}
}
@ -526,7 +587,7 @@ async fn handle_connection(
}
};
let timeout_fut = sleep(Duration::from_secs(5));
let timeout_fut = sleep(Duration::from_secs(3));
tokio::select! {
() = read_fut => {},
@ -537,54 +598,32 @@ async fn handle_connection(
}
}
// Parse the request
let request_str = String::from_utf8_lossy(&request_data);
let request_lines: Vec<&str> = request_str.lines().collect();
if request_lines.is_empty() {
log::debug!("Empty request from: {peer_addr}");
let _ = stream.shutdown().await;
// Fast path for whitelisted IPs. Skip full parsing and speed up "approved"
// connections automatically.
if whitelisted {
log::debug!("Whitelisted IP {peer_addr} - using fast proxy path");
proxy_fast_path(stream, request_data, &config.backend_addr).await;
return;
}
// Parse request line
let request_parts: Vec<&str> = request_lines[0].split_whitespace().collect();
if request_parts.len() < 3 {
log::debug!("Malformed request from {}: {}", peer_addr, request_lines[0]);
// Parse minimally to extract the path
let path = if let Some(p) = extract_path_from_request(&request_data) {
p
} else {
log::debug!("Invalid request from {peer_addr}");
let _ = stream.shutdown().await;
return;
}
let method = request_parts[0];
let path = request_parts[1];
let protocol = request_parts[2];
log::debug!("Request: {method} {path} {protocol} from {peer_addr}");
// Parse headers
let mut headers = HashMap::new();
for line in &request_lines[1..] {
if line.is_empty() {
break;
}
if let Some(idx) = line.find(':') {
let key = line[..idx].trim();
let value = line[idx + 1..].trim();
headers.insert(key, value.to_string());
}
}
let user_agent = headers
.get("user-agent")
.cloned()
.unwrap_or_else(|| "unknown".to_string());
};
// Check if this request matches our tarpit patterns
let should_tarpit = should_tarpit(path, &peer_addr, &config).await;
if should_tarpit {
log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})");
// Extract minimal info needed for tarpit
let user_agent = extract_header_value(&request_data, "user-agent")
.unwrap_or_else(|| "unknown".to_string());
log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})");
// Update metrics
HITS_COUNTER.inc();
@ -599,7 +638,6 @@ async fn handle_connection(
*state.hits.entry(peer_addr).or_insert(0) += 1;
let hit_count = state.hits[&peer_addr];
log::debug!("Hit count for {peer_addr}: {hit_count}");
// Block IPs that hit tarpits too many times
if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) {
@ -608,7 +646,7 @@ async fn handle_connection(
BLOCKED_IPS.set(state.blocked.len() as f64);
state.save_to_disk();
// Try to add to firewall
// Do firewall blocking in background
let peer_addr_str = peer_addr.to_string();
tokio::spawn(async move {
log::debug!("Adding IP {peer_addr_str} to firewall blacklist");
@ -660,42 +698,117 @@ async fn handle_connection(
)
.await;
} else {
log::debug!("Proxying request: {method} {path} from {peer_addr}");
// Proxy non-matching requests to the actual backend
proxy_to_backend(
stream,
method,
path,
protocol,
&headers,
&config.backend_addr,
)
.await;
log::debug!("Proxying request: {path} from {peer_addr}");
proxy_fast_path(stream, request_data, &config.backend_addr).await;
}
}
// Determine if a request should be tarpitted based on path and IP
// Forward a legitimate request to the real backend server
async fn proxy_fast_path(mut client_stream: TcpStream, request_data: Vec<u8>, backend_addr: &str) {
// Connect to backend server
let server_stream = match TcpStream::connect(backend_addr).await {
Ok(stream) => stream,
Err(e) => {
log::warn!("Failed to connect to backend {backend_addr}: {e}");
let _ = client_stream.shutdown().await;
return;
}
};
// Set TCP_NODELAY for both streams before splitting them
if let Err(e) = client_stream.set_nodelay(true) {
log::debug!("Failed to set TCP_NODELAY on client stream: {e}");
}
let mut server_stream = server_stream;
if let Err(e) = server_stream.set_nodelay(true) {
log::debug!("Failed to set TCP_NODELAY on server stream: {e}");
}
// Forward the original request bytes directly without parsing
if server_stream.write_all(&request_data).await.is_err() {
log::debug!("Failed to write request to backend server");
let _ = client_stream.shutdown().await;
return;
}
// Now split the streams for concurrent reading/writing
let (mut client_read, mut client_write) = client_stream.split();
let (mut server_read, mut server_write) = server_stream.split();
// 32KB buffer
let buf_size = 32768;
// Client -> Server
let client_to_server = async {
let mut buf = vec![0; buf_size];
let mut bytes_forwarded = 0;
loop {
match client_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if server_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
// Ensure everything is sent
let _ = server_write.flush().await;
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
};
// Server -> Client
let server_to_client = async {
let mut buf = vec![0; buf_size];
let mut bytes_forwarded = 0;
loop {
match server_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
// Ensure everything is sent
let _ = client_write.flush().await;
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
};
// Run both directions concurrently
tokio::join!(client_to_server, server_to_client);
log::debug!("Fast proxy connection completed");
}
// Decide if a request should be tarpitted based on path and IP
async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool {
// Don't tarpit whitelisted IPs (internal networks, etc)
// Check whitelist IPs first to avoid unnecessary pattern matching
for network_str in &config.whitelist_networks {
if let Ok(network) = network_str.parse::<IpNetwork>() {
if network.contains(*ip) {
log::debug!("IP {ip} is in whitelist network {network_str}");
return false;
}
}
}
// Check if the request path matches any of our trap patterns
// Use a more efficient pattern matching approach
let path_lower = path.to_lowercase();
for pattern in &config.trap_patterns {
if path.contains(pattern) {
log::debug!("Path '{path}' matches trap pattern '{pattern}'");
if path_lower.contains(pattern) {
return true;
}
}
// No trap patterns matched
false
}
@ -733,12 +846,12 @@ async fn generate_deceptive_response(
let markov_text = markov.generate(response_type, 30);
// Use Lua to enhance with honeytokens and other deceptive content
let enhanced =
let response_expanded =
script_manager.expand_response(&markov_text, response_type, path, &tracking_token);
// Return full HTTP response with appropriate headers
format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{enhanced}"
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{response_expanded}"
)
}
@ -831,96 +944,6 @@ async fn tarpit_connection(
let _ = stream.shutdown().await;
}
// Forward a legitimate request to the real backend server
async fn proxy_to_backend(
mut client_stream: TcpStream,
method: &str,
path: &str,
protocol: &str,
headers: &HashMap<&str, String>,
backend_addr: &str,
) {
// Connect to backend server
let server_stream = match TcpStream::connect(backend_addr).await {
Ok(stream) => stream,
Err(e) => {
log::warn!("Failed to connect to backend {backend_addr}: {e}");
let _ = client_stream.shutdown().await;
return;
}
};
log::debug!("Connected to backend server at {backend_addr}");
// Forward the original request
let mut request = format!("{method} {path} {protocol}\r\n");
for (key, value) in headers {
request.push_str(&format!("{key}: {value}\r\n"));
}
request.push_str("\r\n");
let mut server_stream = server_stream;
if server_stream.write_all(request.as_bytes()).await.is_err() {
log::debug!("Failed to write request to backend server");
let _ = client_stream.shutdown().await;
return;
}
// Set up bidirectional forwarding between client and backend
let (mut client_read, mut client_write) = client_stream.split();
let (mut server_read, mut server_write) = server_stream.split();
// Client -> Server
let client_to_server = async {
let mut buf = [0; 8192];
let mut bytes_forwarded = 0;
loop {
match client_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if server_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
};
// Server -> Client
let server_to_client = async {
let mut buf = [0; 8192];
let mut bytes_forwarded = 0;
loop {
match server_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
};
// Run both directions concurrently
tokio::select! {
() = client_to_server => {},
() = server_to_client => {},
}
log::debug!("Proxy connection completed");
}
// Set up nftables firewall rules for IP blocking
async fn setup_firewall() -> Result<(), String> {
log::info!("Setting up firewall rules");
@ -1439,4 +1462,36 @@ mod tests {
// Verify tracking token is included
assert!(resp1.contains("BOT_TestBot"));
}
#[test]
fn test_find_header_end() {
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content";
assert_eq!(find_header_end(data), Some(53));
let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n";
assert_eq!(find_header_end(incomplete), None);
}
#[test]
fn test_extract_path_from_request() {
let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n";
assert_eq!(extract_path_from_request(data), Some("/index.html"));
let bad_data = b"INVALID DATA";
assert_eq!(extract_path_from_request(bad_data), None);
}
#[test]
fn test_extract_header_value() {
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n";
assert_eq!(
extract_header_value(data, "user-agent"),
Some("TestBot/1.0".to_string())
);
assert_eq!(
extract_header_value(data, "Host"),
Some("example.com".to_string())
);
assert_eq!(extract_header_value(data, "nonexistent"), None);
}
}