diff --git a/src/main.rs b/src/main.rs index ff96b67..8bc3e98 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,6 @@ -use actix_web::{App, HttpRequest, HttpResponse, HttpServer, web}; +use actix_web::{App, HttpResponse, HttpServer, web}; use clap::Parser; use ipnetwork::IpNetwork; -use lazy_static::lazy_static; -use prometheus::{ - Counter, CounterVec, Gauge, register_counter, register_counter_vec, register_gauge, -}; use rlua::{Function, Lua}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -22,40 +18,91 @@ use tokio::sync::RwLock; use tokio::time::sleep; mod markov; +mod metrics; + use markov::MarkovGenerator; +use metrics::{ + ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, + status_handler, +}; // Command-line arguments using clap #[derive(Parser, Debug, Clone)] -#[clap(author, version, about)] +#[clap( + author, + version, + about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" +)] struct Args { - #[clap(long, default_value = "0.0.0.0:8888")] + #[clap( + long, + default_value = "0.0.0.0:8888", + help = "Address and port to listen for incoming HTTP requests (format: ip:port)" + )] listen_addr: String, - #[clap(long, default_value = "9100")] + #[clap( + long, + default_value = "9100", + help = "Port to expose Prometheus metrics and status endpoint" + )] metrics_port: u16, - #[clap(long, default_value = "127.0.0.1:80")] + #[clap(long, help = "Disable Prometheus metrics server completely")] + disable_metrics: bool, + + #[clap( + long, + default_value = "127.0.0.1:80", + help = "Backend server address to proxy legitimate requests to (format: ip:port)" + )] backend_addr: String, - #[clap(long, default_value = "1000")] + #[clap( + long, + default_value = "1000", + help = "Minimum delay in milliseconds between chunks sent to attacker" + )] min_delay: u64, - #[clap(long, default_value = "15000")] + #[clap( + long, + default_value = "15000", + help = "Maximum delay in milliseconds between chunks sent to attacker" + )] max_delay: u64, - #[clap(long, default_value = "600")] + #[clap( + long, + default_value = "600", + help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" + )] max_tarpit_time: u64, - #[clap(long, default_value = "3")] + #[clap( + long, + default_value = "3", + help = "Number of hits to honeypot patterns before permanently blocking an IP" + )] block_threshold: u32, - #[clap(long)] + #[clap( + long, + help = "Base directory for all application data (overrides XDG directory structure)" + )] base_dir: Option, - #[clap(long)] + #[clap( + long, + help = "Path to JSON configuration file (overrides command line options)" + )] config_file: Option, - #[clap(long, default_value = "info")] + #[clap( + long, + default_value = "info", + help = "Log level: trace, debug, info, warn, error" + )] log_level: String, } @@ -64,6 +111,7 @@ struct Args { struct Config { listen_addr: String, metrics_port: u16, + disable_metrics: bool, backend_addr: String, min_delay: u64, max_delay: u64, @@ -83,6 +131,7 @@ impl Default for Config { Self { listen_addr: "0.0.0.0:8888".to_string(), metrics_port: 9100, + disable_metrics: false, backend_addr: "127.0.0.1:80".to_string(), min_delay: 1000, max_delay: 15000, @@ -168,6 +217,7 @@ impl Config { Self { listen_addr: args.listen_addr.clone(), metrics_port: args.metrics_port, + disable_metrics: args.disable_metrics, backend_addr: args.backend_addr.clone(), min_delay: args.min_delay, max_delay: args.max_delay, @@ -218,23 +268,6 @@ impl Config { } } -// Prometheus metrics. I'll expand this with more metrics as I need to. -lazy_static! { - static ref HITS_COUNTER: Counter = - register_counter!("eris_hits_total", "Total number of hits to honeypot paths").unwrap(); - static ref BLOCKED_IPS: Gauge = - register_gauge!("eris_blocked_ips", "Number of IPs permanently blocked").unwrap(); - static ref ACTIVE_CONNECTIONS: Gauge = register_gauge!( - "eris_active_connections", - "Number of currently active connections in tarpit" - ) - .unwrap(); - static ref PATH_HITS: CounterVec = - register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap(); - static ref UA_HITS: CounterVec = - register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap(); -} - // State of bots/IPs hitting the honeypot #[derive(Clone, Debug)] struct BotState { @@ -526,9 +559,7 @@ async fn handle_connection( let path = request_parts[1]; let protocol = request_parts[2]; - log::debug!( - "Request: {method} {path} {protocol} from {peer_addr}" - ); + log::debug!("Request: {method} {path} {protocol} from {peer_addr}"); // Parse headers let mut headers = HashMap::new(); @@ -553,9 +584,7 @@ async fn handle_connection( let should_tarpit = should_tarpit(path, &peer_addr, &config).await; if should_tarpit { - log::info!( - "Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})" - ); + log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})"); // Update metrics HITS_COUNTER.inc(); @@ -750,9 +779,7 @@ async fn tarpit_connection( // Check if we've exceeded maximum tarpit time let elapsed_secs = start_time.elapsed().as_secs(); if elapsed_secs > max_tarpit_time { - log::info!( - "Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}" - ); + log::info!("Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}"); break; } @@ -894,41 +921,6 @@ async fn proxy_to_backend( log::debug!("Proxy connection completed"); } -// Prometheus metrics endpoint -async fn metrics_handler(_req: HttpRequest) -> HttpResponse { - use prometheus::Encoder; - let encoder = prometheus::TextEncoder::new(); - let mut buffer = Vec::new(); - - match encoder.encode(&prometheus::gather(), &mut buffer) { - Ok(()) => { - log::debug!("Metrics requested, returned {} bytes", buffer.len()); - } - Err(e) => { - log::error!("Error encoding metrics: {e}"); - } - } - - HttpResponse::Ok().content_type("text/plain").body(buffer) -} - -// Status JSON endpoint -async fn status_handler(state: web::Data>>) -> HttpResponse { - let state = state.read().await; - - let info = serde_json::json!({ - "status": "running", - "version": env!("CARGO_PKG_VERSION"), - "blocked_ips": state.blocked.len(), - "active_connections": state.active_connections.len(), - "hit_count": state.hits.len(), - }); - - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string_pretty(&info).unwrap()) -} - // Set up nftables firewall rules for IP blocking async fn setup_firewall() -> Result<(), String> { log::info!("Setting up firewall rules"); @@ -1195,52 +1187,71 @@ async fn main() -> std::io::Result<()> { Ok::<(), String>(()) }); - // Start the metrics server with actix_web - let metrics_addr = format!("0.0.0.0:{}", metrics_config.metrics_port); - log::info!("Starting metrics server on {metrics_addr}"); + // Start the metrics server with actix_web only if metrics are not disabled + let metrics_server = if metrics_config.disable_metrics { + log::info!("Metrics server disabled via configuration"); + None + } else { + let metrics_addr = format!("0.0.0.0:{}", metrics_config.metrics_port); + log::info!("Starting metrics server on {metrics_addr}"); - let metrics_server = HttpServer::new(move || { - App::new() - .app_data(web::Data::new(metrics_state.clone())) - .route("/metrics", web::get().to(metrics_handler)) - .route("/status", web::get().to(|data: web::Data>>| async move { - status_handler(data).await - })) - .route("/", web::get().to(|| async { - HttpResponse::Ok().body("Botpot Server is running. Visit /metrics for metrics or /status for status.") - })) - }) - .bind(&metrics_addr); + let server = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(metrics_state.clone())) + .route("/metrics", web::get().to(metrics_handler)) + .route("/status", web::get().to(status_handler)) + .route("/", web::get().to(|| async { + HttpResponse::Ok().body("Botpot Server is running. Visit /metrics for metrics or /status for status.") + })) + }) + .bind(&metrics_addr); - let metrics_server = match metrics_server { - Ok(server) => server.run(), - Err(e) => { - log::error!("Failed to bind metrics server to {metrics_addr}: {e}"); - return Err(e); + match server { + Ok(server) => Some(server.run()), + Err(e) => { + log::error!("Failed to bind metrics server to {metrics_addr}: {e}"); + None + } } }; - log::info!("Metrics server listening on {metrics_addr}"); - - // Run both servers concurrently - tokio::select! { - result = tarpit_server => match result { + // Run both servers concurrently if metrics server is enabled + if let Some(metrics_server) = metrics_server { + tokio::select! { + result = tarpit_server => match result { + Ok(Ok(())) => Ok(()), + Ok(Err(e)) => { + log::error!("Tarpit server error: {e}"); + Err(std::io::Error::new(std::io::ErrorKind::Other, e)) + }, + Err(e) => { + log::error!("Tarpit server task error: {e}"); + Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())) + }, + }, + result = metrics_server => { + if let Err(ref e) = result { + log::error!("Metrics server error: {e}"); + } + result + }, + } + } else { + // Just run the tarpit server if metrics are disabled + match tarpit_server.await { Ok(Ok(())) => Ok(()), Ok(Err(e)) => { log::error!("Tarpit server error: {e}"); Err(std::io::Error::new(std::io::ErrorKind::Other, e)) - }, + } Err(e) => { log::error!("Tarpit server task error: {e}"); - Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())) - }, - }, - result = metrics_server => { - if let Err(ref e) = result { - log::error!("Metrics server error: {e}"); + Err(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + )) } - result - }, + } } } @@ -1255,6 +1266,7 @@ mod tests { let args = Args { listen_addr: "127.0.0.1:8080".to_string(), metrics_port: 9000, + disable_metrics: true, backend_addr: "127.0.0.1:8081".to_string(), min_delay: 500, max_delay: 10000, @@ -1268,6 +1280,7 @@ mod tests { let config = Config::from_args(&args); assert_eq!(config.listen_addr, "127.0.0.1:8080"); assert_eq!(config.metrics_port, 9000); + assert!(config.disable_metrics); assert_eq!(config.backend_addr, "127.0.0.1:8081"); assert_eq!(config.min_delay, 500); assert_eq!(config.max_delay, 10000); diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..c733752 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,61 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use lazy_static::lazy_static; +use prometheus::{ + Counter, CounterVec, Encoder, Gauge, register_counter, register_counter_vec, register_gauge, +}; +use serde_json::json; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::BotState; + +// Prometheus metrics. I'll expand this with more metrics as I need to. +lazy_static! { + pub static ref HITS_COUNTER: Counter = + register_counter!("eris_hits_total", "Total number of hits to honeypot paths").unwrap(); + pub static ref BLOCKED_IPS: Gauge = + register_gauge!("eris_blocked_ips", "Number of IPs permanently blocked").unwrap(); + pub static ref ACTIVE_CONNECTIONS: Gauge = register_gauge!( + "eris_active_connections", + "Number of currently active connections in tarpit" + ) + .unwrap(); + pub static ref PATH_HITS: CounterVec = + register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap(); + pub static ref UA_HITS: CounterVec = + register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap(); +} + +// Prometheus metrics endpoint +pub async fn metrics_handler(_req: HttpRequest) -> HttpResponse { + let encoder = prometheus::TextEncoder::new(); + let mut buffer = Vec::new(); + + match encoder.encode(&prometheus::gather(), &mut buffer) { + Ok(()) => { + log::debug!("Metrics requested, returned {} bytes", buffer.len()); + } + Err(e) => { + log::error!("Error encoding metrics: {e}"); + } + } + + HttpResponse::Ok().content_type("text/plain").body(buffer) +} + +// Status JSON endpoint +pub async fn status_handler(state: web::Data>>) -> HttpResponse { + let state = state.read().await; + + let info = json!({ + "status": "running", + "version": env!("CARGO_PKG_VERSION"), + "blocked_ips": state.blocked.len(), + "active_connections": state.active_connections.len(), + "hit_count": state.hits.len(), + }); + + HttpResponse::Ok() + .content_type("application/json") + .body(serde_json::to_string_pretty(&info).unwrap()) +}