eris: speed up responses; refactor connection handlere
This commit is contained in:
parent
6c18427dc3
commit
a2fc2bf2bc
1 changed files with 219 additions and 164 deletions
383
src/main.rs
383
src/main.rs
|
@ -476,7 +476,48 @@ impl ScriptManager {
|
|||
}
|
||||
}
|
||||
|
||||
// Main connection handler - decides whether to tarpit or proxy
|
||||
// Find end of HTTP headers
|
||||
// XXX: I'm sure this could be made less fragile.
|
||||
fn find_header_end(data: &[u8]) -> Option<usize> {
|
||||
for i in 0..data.len().saturating_sub(3) {
|
||||
if data[i] == b'\r' && data[i + 1] == b'\n' && data[i + 2] == b'\r' && data[i + 3] == b'\n'
|
||||
{
|
||||
return Some(i + 4);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Extract path from raw request data
|
||||
fn extract_path_from_request(data: &[u8]) -> Option<&str> {
|
||||
let request_line = data
|
||||
.split(|&b| b == b'\r' || b == b'\n')
|
||||
.next()
|
||||
.filter(|line| !line.is_empty())?;
|
||||
|
||||
let mut parts = request_line.split(|&b| b == b' ');
|
||||
let _ = parts.next()?; // Skip HTTP method
|
||||
let path = parts.next()?;
|
||||
|
||||
std::str::from_utf8(path).ok()
|
||||
}
|
||||
|
||||
// Extract header value from raw request data
|
||||
fn extract_header_value(data: &[u8], header_name: &str) -> Option<String> {
|
||||
let data_str = std::str::from_utf8(data).ok()?;
|
||||
let header_prefix = format!("{header_name}: ").to_lowercase();
|
||||
|
||||
for line in data_str.lines() {
|
||||
let line_lower = line.to_lowercase();
|
||||
if line_lower.starts_with(&header_prefix) {
|
||||
return Some(line[header_prefix.len()..].trim().to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Main connection handler.
|
||||
// Decides whether to tarpit or proxy
|
||||
async fn handle_connection(
|
||||
mut stream: TcpStream,
|
||||
config: Arc<Config>,
|
||||
|
@ -484,6 +525,7 @@ async fn handle_connection(
|
|||
markov_generator: Arc<MarkovGenerator>,
|
||||
script_manager: Arc<ScriptManager>,
|
||||
) {
|
||||
// Get peer information
|
||||
let peer_addr = match stream.peer_addr() {
|
||||
Ok(addr) => addr.ip(),
|
||||
Err(e) => {
|
||||
|
@ -492,29 +534,48 @@ async fn handle_connection(
|
|||
}
|
||||
};
|
||||
|
||||
log::debug!("New connection from: {peer_addr}");
|
||||
|
||||
// Check if IP is already blocked
|
||||
// Check for blocked IPs to avoid any processing
|
||||
if state.read().await.blocked.contains(&peer_addr) {
|
||||
log::debug!("Rejected connection from blocked IP: {peer_addr}");
|
||||
let _ = stream.shutdown().await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the HTTP request
|
||||
let mut buffer = [0; 8192];
|
||||
let mut request_data = Vec::new();
|
||||
// Pre-check for whitelisted IPs to bypass heavy processing
|
||||
let mut whitelisted = false;
|
||||
for network_str in &config.whitelist_networks {
|
||||
if let Ok(network) = network_str.parse::<IpNetwork>() {
|
||||
if network.contains(peer_addr) {
|
||||
whitelisted = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read with timeout to prevent hanging
|
||||
// Read buffer
|
||||
let mut buffer = vec![0; 8192];
|
||||
let mut request_data = Vec::with_capacity(8192);
|
||||
let mut header_end_pos = 0;
|
||||
|
||||
// Read with timeout to prevent hanging resource load ops.
|
||||
let read_fut = async {
|
||||
loop {
|
||||
match stream.read(&mut buffer).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
request_data.extend_from_slice(&buffer[..n]);
|
||||
// Stop reading at empty line, this is the end of HTTP headers
|
||||
if request_data.len() > 2 && &request_data[request_data.len() - 2..] == b"\r\n"
|
||||
{
|
||||
let new_data = &buffer[..n];
|
||||
request_data.extend_from_slice(new_data);
|
||||
|
||||
// Look for end of headers
|
||||
if header_end_pos == 0 {
|
||||
if let Some(pos) = find_header_end(&request_data) {
|
||||
header_end_pos = pos;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid excessive buffering
|
||||
if request_data.len() > 32768 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -526,7 +587,7 @@ async fn handle_connection(
|
|||
}
|
||||
};
|
||||
|
||||
let timeout_fut = sleep(Duration::from_secs(5));
|
||||
let timeout_fut = sleep(Duration::from_secs(3));
|
||||
|
||||
tokio::select! {
|
||||
() = read_fut => {},
|
||||
|
@ -537,54 +598,32 @@ async fn handle_connection(
|
|||
}
|
||||
}
|
||||
|
||||
// Parse the request
|
||||
let request_str = String::from_utf8_lossy(&request_data);
|
||||
let request_lines: Vec<&str> = request_str.lines().collect();
|
||||
|
||||
if request_lines.is_empty() {
|
||||
log::debug!("Empty request from: {peer_addr}");
|
||||
let _ = stream.shutdown().await;
|
||||
// Fast path for whitelisted IPs. Skip full parsing and speed up "approved"
|
||||
// connections automatically.
|
||||
if whitelisted {
|
||||
log::debug!("Whitelisted IP {peer_addr} - using fast proxy path");
|
||||
proxy_fast_path(stream, request_data, &config.backend_addr).await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse request line
|
||||
let request_parts: Vec<&str> = request_lines[0].split_whitespace().collect();
|
||||
if request_parts.len() < 3 {
|
||||
log::debug!("Malformed request from {}: {}", peer_addr, request_lines[0]);
|
||||
// Parse minimally to extract the path
|
||||
let path = if let Some(p) = extract_path_from_request(&request_data) {
|
||||
p
|
||||
} else {
|
||||
log::debug!("Invalid request from {peer_addr}");
|
||||
let _ = stream.shutdown().await;
|
||||
return;
|
||||
}
|
||||
|
||||
let method = request_parts[0];
|
||||
let path = request_parts[1];
|
||||
let protocol = request_parts[2];
|
||||
|
||||
log::debug!("Request: {method} {path} {protocol} from {peer_addr}");
|
||||
|
||||
// Parse headers
|
||||
let mut headers = HashMap::new();
|
||||
for line in &request_lines[1..] {
|
||||
if line.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(idx) = line.find(':') {
|
||||
let key = line[..idx].trim();
|
||||
let value = line[idx + 1..].trim();
|
||||
headers.insert(key, value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let user_agent = headers
|
||||
.get("user-agent")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
};
|
||||
|
||||
// Check if this request matches our tarpit patterns
|
||||
let should_tarpit = should_tarpit(path, &peer_addr, &config).await;
|
||||
|
||||
if should_tarpit {
|
||||
log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})");
|
||||
// Extract minimal info needed for tarpit
|
||||
let user_agent = extract_header_value(&request_data, "user-agent")
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})");
|
||||
|
||||
// Update metrics
|
||||
HITS_COUNTER.inc();
|
||||
|
@ -599,7 +638,6 @@ async fn handle_connection(
|
|||
|
||||
*state.hits.entry(peer_addr).or_insert(0) += 1;
|
||||
let hit_count = state.hits[&peer_addr];
|
||||
log::debug!("Hit count for {peer_addr}: {hit_count}");
|
||||
|
||||
// Block IPs that hit tarpits too many times
|
||||
if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) {
|
||||
|
@ -608,7 +646,7 @@ async fn handle_connection(
|
|||
BLOCKED_IPS.set(state.blocked.len() as f64);
|
||||
state.save_to_disk();
|
||||
|
||||
// Try to add to firewall
|
||||
// Do firewall blocking in background
|
||||
let peer_addr_str = peer_addr.to_string();
|
||||
tokio::spawn(async move {
|
||||
log::debug!("Adding IP {peer_addr_str} to firewall blacklist");
|
||||
|
@ -660,42 +698,117 @@ async fn handle_connection(
|
|||
)
|
||||
.await;
|
||||
} else {
|
||||
log::debug!("Proxying request: {method} {path} from {peer_addr}");
|
||||
|
||||
// Proxy non-matching requests to the actual backend
|
||||
proxy_to_backend(
|
||||
stream,
|
||||
method,
|
||||
path,
|
||||
protocol,
|
||||
&headers,
|
||||
&config.backend_addr,
|
||||
)
|
||||
.await;
|
||||
log::debug!("Proxying request: {path} from {peer_addr}");
|
||||
proxy_fast_path(stream, request_data, &config.backend_addr).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if a request should be tarpitted based on path and IP
|
||||
// Forward a legitimate request to the real backend server
|
||||
async fn proxy_fast_path(mut client_stream: TcpStream, request_data: Vec<u8>, backend_addr: &str) {
|
||||
// Connect to backend server
|
||||
let server_stream = match TcpStream::connect(backend_addr).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
log::warn!("Failed to connect to backend {backend_addr}: {e}");
|
||||
let _ = client_stream.shutdown().await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Set TCP_NODELAY for both streams before splitting them
|
||||
if let Err(e) = client_stream.set_nodelay(true) {
|
||||
log::debug!("Failed to set TCP_NODELAY on client stream: {e}");
|
||||
}
|
||||
|
||||
let mut server_stream = server_stream;
|
||||
if let Err(e) = server_stream.set_nodelay(true) {
|
||||
log::debug!("Failed to set TCP_NODELAY on server stream: {e}");
|
||||
}
|
||||
|
||||
// Forward the original request bytes directly without parsing
|
||||
if server_stream.write_all(&request_data).await.is_err() {
|
||||
log::debug!("Failed to write request to backend server");
|
||||
let _ = client_stream.shutdown().await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Now split the streams for concurrent reading/writing
|
||||
let (mut client_read, mut client_write) = client_stream.split();
|
||||
let (mut server_read, mut server_write) = server_stream.split();
|
||||
|
||||
// 32KB buffer
|
||||
let buf_size = 32768;
|
||||
|
||||
// Client -> Server
|
||||
let client_to_server = async {
|
||||
let mut buf = vec![0; buf_size];
|
||||
let mut bytes_forwarded = 0;
|
||||
|
||||
loop {
|
||||
match client_read.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
bytes_forwarded += n;
|
||||
if server_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure everything is sent
|
||||
let _ = server_write.flush().await;
|
||||
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
|
||||
};
|
||||
|
||||
// Server -> Client
|
||||
let server_to_client = async {
|
||||
let mut buf = vec![0; buf_size];
|
||||
let mut bytes_forwarded = 0;
|
||||
|
||||
loop {
|
||||
match server_read.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
bytes_forwarded += n;
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure everything is sent
|
||||
let _ = client_write.flush().await;
|
||||
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
|
||||
};
|
||||
|
||||
// Run both directions concurrently
|
||||
tokio::join!(client_to_server, server_to_client);
|
||||
log::debug!("Fast proxy connection completed");
|
||||
}
|
||||
|
||||
// Decide if a request should be tarpitted based on path and IP
|
||||
async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool {
|
||||
// Don't tarpit whitelisted IPs (internal networks, etc)
|
||||
// Check whitelist IPs first to avoid unnecessary pattern matching
|
||||
for network_str in &config.whitelist_networks {
|
||||
if let Ok(network) = network_str.parse::<IpNetwork>() {
|
||||
if network.contains(*ip) {
|
||||
log::debug!("IP {ip} is in whitelist network {network_str}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the request path matches any of our trap patterns
|
||||
// Use a more efficient pattern matching approach
|
||||
let path_lower = path.to_lowercase();
|
||||
for pattern in &config.trap_patterns {
|
||||
if path.contains(pattern) {
|
||||
log::debug!("Path '{path}' matches trap pattern '{pattern}'");
|
||||
if path_lower.contains(pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// No trap patterns matched
|
||||
false
|
||||
}
|
||||
|
||||
|
@ -733,12 +846,12 @@ async fn generate_deceptive_response(
|
|||
let markov_text = markov.generate(response_type, 30);
|
||||
|
||||
// Use Lua to enhance with honeytokens and other deceptive content
|
||||
let enhanced =
|
||||
let response_expanded =
|
||||
script_manager.expand_response(&markov_text, response_type, path, &tracking_token);
|
||||
|
||||
// Return full HTTP response with appropriate headers
|
||||
format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{enhanced}"
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{response_expanded}"
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -831,96 +944,6 @@ async fn tarpit_connection(
|
|||
let _ = stream.shutdown().await;
|
||||
}
|
||||
|
||||
// Forward a legitimate request to the real backend server
|
||||
async fn proxy_to_backend(
|
||||
mut client_stream: TcpStream,
|
||||
method: &str,
|
||||
path: &str,
|
||||
protocol: &str,
|
||||
headers: &HashMap<&str, String>,
|
||||
backend_addr: &str,
|
||||
) {
|
||||
// Connect to backend server
|
||||
let server_stream = match TcpStream::connect(backend_addr).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
log::warn!("Failed to connect to backend {backend_addr}: {e}");
|
||||
let _ = client_stream.shutdown().await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::debug!("Connected to backend server at {backend_addr}");
|
||||
|
||||
// Forward the original request
|
||||
let mut request = format!("{method} {path} {protocol}\r\n");
|
||||
for (key, value) in headers {
|
||||
request.push_str(&format!("{key}: {value}\r\n"));
|
||||
}
|
||||
request.push_str("\r\n");
|
||||
|
||||
let mut server_stream = server_stream;
|
||||
if server_stream.write_all(request.as_bytes()).await.is_err() {
|
||||
log::debug!("Failed to write request to backend server");
|
||||
let _ = client_stream.shutdown().await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Set up bidirectional forwarding between client and backend
|
||||
let (mut client_read, mut client_write) = client_stream.split();
|
||||
let (mut server_read, mut server_write) = server_stream.split();
|
||||
|
||||
// Client -> Server
|
||||
let client_to_server = async {
|
||||
let mut buf = [0; 8192];
|
||||
let mut bytes_forwarded = 0;
|
||||
|
||||
loop {
|
||||
match client_read.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
bytes_forwarded += n;
|
||||
if server_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
|
||||
};
|
||||
|
||||
// Server -> Client
|
||||
let server_to_client = async {
|
||||
let mut buf = [0; 8192];
|
||||
let mut bytes_forwarded = 0;
|
||||
|
||||
loop {
|
||||
match server_read.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
bytes_forwarded += n;
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
|
||||
};
|
||||
|
||||
// Run both directions concurrently
|
||||
tokio::select! {
|
||||
() = client_to_server => {},
|
||||
() = server_to_client => {},
|
||||
}
|
||||
|
||||
log::debug!("Proxy connection completed");
|
||||
}
|
||||
|
||||
// Set up nftables firewall rules for IP blocking
|
||||
async fn setup_firewall() -> Result<(), String> {
|
||||
log::info!("Setting up firewall rules");
|
||||
|
@ -1439,4 +1462,36 @@ mod tests {
|
|||
// Verify tracking token is included
|
||||
assert!(resp1.contains("BOT_TestBot"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_header_end() {
|
||||
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content";
|
||||
assert_eq!(find_header_end(data), Some(53));
|
||||
|
||||
let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n";
|
||||
assert_eq!(find_header_end(incomplete), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_path_from_request() {
|
||||
let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n";
|
||||
assert_eq!(extract_path_from_request(data), Some("/index.html"));
|
||||
|
||||
let bad_data = b"INVALID DATA";
|
||||
assert_eq!(extract_path_from_request(bad_data), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_header_value() {
|
||||
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n";
|
||||
assert_eq!(
|
||||
extract_header_value(data, "user-agent"),
|
||||
Some("TestBot/1.0".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
extract_header_value(data, "Host"),
|
||||
Some("example.com".to_string())
|
||||
);
|
||||
assert_eq!(extract_header_value(data, "nonexistent"), None);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue