diff --git a/src/network.rs b/src/network.rs index 226ed00..53d7e3f 100644 --- a/src/network.rs +++ b/src/network.rs @@ -80,9 +80,7 @@ pub async fn handle_connection( if header_end_pos == 0 { if let Some(pos) = find_header_end(&request_data) { header_end_pos = pos; - // XXX: Breaking here appears to be malforming the request - // and causing 404 errors. - // So, continue reading the body if present but do not break. + break; } } @@ -341,6 +339,9 @@ pub async fn proxy_fast_path( Ok(stream) => stream, Err(e) => { log::warn!("Failed to connect to backend {backend_addr}: {e}"); + // Send a basic 502 Bad Gateway response instead of just closing the connection + let response = "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 21\r\nConnection: close\r\n\r\nBackend unreachable."; + let _ = client_stream.write_all(response.as_bytes()).await; let _ = client_stream.shutdown().await; return; } @@ -357,68 +358,96 @@ pub async fn proxy_fast_path( } // Forward the original request bytes directly without parsing - if server_stream.write_all(&request_data).await.is_err() { - log::debug!("Failed to write request to backend server"); + if let Err(e) = server_stream.write_all(&request_data).await { + log::debug!("Failed to write request to backend server: {e}"); + let response = "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 31\r\nConnection: close\r\n\r\nFailed to send request to backend."; + let _ = client_stream.write_all(response.as_bytes()).await; let _ = client_stream.shutdown().await; return; } - // Now split the streams for concurrent reading/writing - let (mut client_read, mut client_write) = client_stream.split(); - let (mut server_read, mut server_write) = server_stream.split(); - - // 32KB buffer - let buf_size = 32768; - - // Client -> Server - let client_to_server = async { - let mut buf = vec![0; buf_size]; - let mut bytes_forwarded = 0; + let _ = server_stream.flush().await; + { + // Buffer for initial response headers read + let mut buffer = vec![0; 32768]; + let mut total_bytes = 0; + let mut headers_complete = false; + // First read the response headers and validate them loop { - match client_read.read(&mut buf).await { - Ok(0) => break, + match server_stream.read(&mut buffer).await { + Ok(0) => { + if !headers_complete { + log::warn!("Backend closed connection before sending complete headers"); + let response = "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 38\r\nConnection: close\r\n\r\nBackend closed connection prematurely."; + let _ = client_stream.write_all(response.as_bytes()).await; + } + return; + } Ok(n) => { - bytes_forwarded += n; - if server_write.write_all(&buf[..n]).await.is_err() { + // Write data immediately to client + if let Err(e) = client_stream.write_all(&buffer[..n]).await { + log::debug!("Failed to write response to client: {e}"); + return; + } + + total_bytes += n; + + // Check if we've received the full headers + if !headers_complete { + let slice = &buffer[..n]; + if find_header_end(slice).is_some() + || (total_bytes > 4 && slice.windows(4).any(|w| w == b"\r\n\r\n")) + { + headers_complete = true; + } + } + + // If we've processed the headers, we can continue to the next phase + if headers_complete { break; } + + // Safety timeout; don't wait forever for headers + if total_bytes > 16384 && !headers_complete { + log::warn!("Headers too large or not properly terminated"); + return; + } + } + Err(e) => { + log::debug!("Error reading from backend: {e}"); + if !headers_complete { + let response = "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 29\r\nConnection: close\r\n\r\nError reading from backend."; + let _ = client_stream.write_all(response.as_bytes()).await; + } + return; } - Err(_) => break, } } - // Ensure everything is sent - let _ = server_write.flush().await; - log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); - }; - - // Server -> Client - let server_to_client = async { - let mut buf = vec![0; buf_size]; - let mut bytes_forwarded = 0; - - loop { - match server_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if client_write.write_all(&buf[..n]).await.is_err() { - break; + // Now continue with the rest of the response body + if headers_complete { + loop { + match server_stream.read(&mut buffer).await { + Ok(0) => break, + Ok(n) => { + if client_stream.write_all(&buffer[..n]).await.is_err() { + break; + } + total_bytes += n; } + Err(_) => break, } - Err(_) => break, } + + // Ensure everything is sent + let _ = client_stream.flush().await; + log::debug!("Proxy completed successfully, total bytes: {}", total_bytes); } + } - // Ensure everything is sent - let _ = client_write.flush().await; - log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); - }; - - // Run both directions concurrently - tokio::join!(client_to_server, server_to_client); - log::debug!("Fast proxy connection completed"); + // After everything is done, let connections close naturally + // No more malformed headers here, I hope. } // Generate a deceptive HTTP response that appears legitimate