diff --git a/src/main.rs b/src/main.rs index 8bc3e98..ff96b67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,10 @@ -use actix_web::{App, HttpResponse, HttpServer, web}; +use actix_web::{App, HttpRequest, HttpResponse, HttpServer, web}; use clap::Parser; use ipnetwork::IpNetwork; +use lazy_static::lazy_static; +use prometheus::{ + Counter, CounterVec, Gauge, register_counter, register_counter_vec, register_gauge, +}; use rlua::{Function, Lua}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -18,91 +22,40 @@ use tokio::sync::RwLock; use tokio::time::sleep; mod markov; -mod metrics; - use markov::MarkovGenerator; -use metrics::{ - ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, - status_handler, -}; // Command-line arguments using clap #[derive(Parser, Debug, Clone)] -#[clap( - author, - version, - about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" -)] +#[clap(author, version, about)] struct Args { - #[clap( - long, - default_value = "0.0.0.0:8888", - help = "Address and port to listen for incoming HTTP requests (format: ip:port)" - )] + #[clap(long, default_value = "0.0.0.0:8888")] listen_addr: String, - #[clap( - long, - default_value = "9100", - help = "Port to expose Prometheus metrics and status endpoint" - )] + #[clap(long, default_value = "9100")] metrics_port: u16, - #[clap(long, help = "Disable Prometheus metrics server completely")] - disable_metrics: bool, - - #[clap( - long, - default_value = "127.0.0.1:80", - help = "Backend server address to proxy legitimate requests to (format: ip:port)" - )] + #[clap(long, default_value = "127.0.0.1:80")] backend_addr: String, - #[clap( - long, - default_value = "1000", - help = "Minimum delay in milliseconds between chunks sent to attacker" - )] + #[clap(long, default_value = "1000")] min_delay: u64, - #[clap( - long, - default_value = "15000", - help = "Maximum delay in milliseconds between chunks sent to attacker" - )] + #[clap(long, default_value = "15000")] max_delay: u64, - #[clap( - long, - default_value = "600", - help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" - )] + #[clap(long, default_value = "600")] max_tarpit_time: u64, - #[clap( - long, - default_value = "3", - help = "Number of hits to honeypot patterns before permanently blocking an IP" - )] + #[clap(long, default_value = "3")] block_threshold: u32, - #[clap( - long, - help = "Base directory for all application data (overrides XDG directory structure)" - )] + #[clap(long)] base_dir: Option, - #[clap( - long, - help = "Path to JSON configuration file (overrides command line options)" - )] + #[clap(long)] config_file: Option, - #[clap( - long, - default_value = "info", - help = "Log level: trace, debug, info, warn, error" - )] + #[clap(long, default_value = "info")] log_level: String, } @@ -111,7 +64,6 @@ struct Args { struct Config { listen_addr: String, metrics_port: u16, - disable_metrics: bool, backend_addr: String, min_delay: u64, max_delay: u64, @@ -131,7 +83,6 @@ impl Default for Config { Self { listen_addr: "0.0.0.0:8888".to_string(), metrics_port: 9100, - disable_metrics: false, backend_addr: "127.0.0.1:80".to_string(), min_delay: 1000, max_delay: 15000, @@ -217,7 +168,6 @@ impl Config { Self { listen_addr: args.listen_addr.clone(), metrics_port: args.metrics_port, - disable_metrics: args.disable_metrics, backend_addr: args.backend_addr.clone(), min_delay: args.min_delay, max_delay: args.max_delay, @@ -268,6 +218,23 @@ impl Config { } } +// Prometheus metrics. I'll expand this with more metrics as I need to. +lazy_static! { + static ref HITS_COUNTER: Counter = + register_counter!("eris_hits_total", "Total number of hits to honeypot paths").unwrap(); + static ref BLOCKED_IPS: Gauge = + register_gauge!("eris_blocked_ips", "Number of IPs permanently blocked").unwrap(); + static ref ACTIVE_CONNECTIONS: Gauge = register_gauge!( + "eris_active_connections", + "Number of currently active connections in tarpit" + ) + .unwrap(); + static ref PATH_HITS: CounterVec = + register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap(); + static ref UA_HITS: CounterVec = + register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap(); +} + // State of bots/IPs hitting the honeypot #[derive(Clone, Debug)] struct BotState { @@ -559,7 +526,9 @@ async fn handle_connection( let path = request_parts[1]; let protocol = request_parts[2]; - log::debug!("Request: {method} {path} {protocol} from {peer_addr}"); + log::debug!( + "Request: {method} {path} {protocol} from {peer_addr}" + ); // Parse headers let mut headers = HashMap::new(); @@ -584,7 +553,9 @@ async fn handle_connection( let should_tarpit = should_tarpit(path, &peer_addr, &config).await; if should_tarpit { - log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})"); + log::info!( + "Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})" + ); // Update metrics HITS_COUNTER.inc(); @@ -779,7 +750,9 @@ async fn tarpit_connection( // Check if we've exceeded maximum tarpit time let elapsed_secs = start_time.elapsed().as_secs(); if elapsed_secs > max_tarpit_time { - log::info!("Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}"); + log::info!( + "Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}" + ); break; } @@ -921,6 +894,41 @@ async fn proxy_to_backend( log::debug!("Proxy connection completed"); } +// Prometheus metrics endpoint +async fn metrics_handler(_req: HttpRequest) -> HttpResponse { + use prometheus::Encoder; + let encoder = prometheus::TextEncoder::new(); + let mut buffer = Vec::new(); + + match encoder.encode(&prometheus::gather(), &mut buffer) { + Ok(()) => { + log::debug!("Metrics requested, returned {} bytes", buffer.len()); + } + Err(e) => { + log::error!("Error encoding metrics: {e}"); + } + } + + HttpResponse::Ok().content_type("text/plain").body(buffer) +} + +// Status JSON endpoint +async fn status_handler(state: web::Data>>) -> HttpResponse { + let state = state.read().await; + + let info = serde_json::json!({ + "status": "running", + "version": env!("CARGO_PKG_VERSION"), + "blocked_ips": state.blocked.len(), + "active_connections": state.active_connections.len(), + "hit_count": state.hits.len(), + }); + + HttpResponse::Ok() + .content_type("application/json") + .body(serde_json::to_string_pretty(&info).unwrap()) +} + // Set up nftables firewall rules for IP blocking async fn setup_firewall() -> Result<(), String> { log::info!("Setting up firewall rules"); @@ -1187,71 +1195,52 @@ async fn main() -> std::io::Result<()> { Ok::<(), String>(()) }); - // Start the metrics server with actix_web only if metrics are not disabled - let metrics_server = if metrics_config.disable_metrics { - log::info!("Metrics server disabled via configuration"); - None - } else { - let metrics_addr = format!("0.0.0.0:{}", metrics_config.metrics_port); - log::info!("Starting metrics server on {metrics_addr}"); + // Start the metrics server with actix_web + let metrics_addr = format!("0.0.0.0:{}", metrics_config.metrics_port); + log::info!("Starting metrics server on {metrics_addr}"); - let server = HttpServer::new(move || { - App::new() - .app_data(web::Data::new(metrics_state.clone())) - .route("/metrics", web::get().to(metrics_handler)) - .route("/status", web::get().to(status_handler)) - .route("/", web::get().to(|| async { - HttpResponse::Ok().body("Botpot Server is running. Visit /metrics for metrics or /status for status.") - })) - }) - .bind(&metrics_addr); + let metrics_server = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(metrics_state.clone())) + .route("/metrics", web::get().to(metrics_handler)) + .route("/status", web::get().to(|data: web::Data>>| async move { + status_handler(data).await + })) + .route("/", web::get().to(|| async { + HttpResponse::Ok().body("Botpot Server is running. Visit /metrics for metrics or /status for status.") + })) + }) + .bind(&metrics_addr); - match server { - Ok(server) => Some(server.run()), - Err(e) => { - log::error!("Failed to bind metrics server to {metrics_addr}: {e}"); - None - } + let metrics_server = match metrics_server { + Ok(server) => server.run(), + Err(e) => { + log::error!("Failed to bind metrics server to {metrics_addr}: {e}"); + return Err(e); } }; - // Run both servers concurrently if metrics server is enabled - if let Some(metrics_server) = metrics_server { - tokio::select! { - result = tarpit_server => match result { - Ok(Ok(())) => Ok(()), - Ok(Err(e)) => { - log::error!("Tarpit server error: {e}"); - Err(std::io::Error::new(std::io::ErrorKind::Other, e)) - }, - Err(e) => { - log::error!("Tarpit server task error: {e}"); - Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())) - }, - }, - result = metrics_server => { - if let Err(ref e) = result { - log::error!("Metrics server error: {e}"); - } - result - }, - } - } else { - // Just run the tarpit server if metrics are disabled - match tarpit_server.await { + log::info!("Metrics server listening on {metrics_addr}"); + + // Run both servers concurrently + tokio::select! { + result = tarpit_server => match result { Ok(Ok(())) => Ok(()), Ok(Err(e)) => { log::error!("Tarpit server error: {e}"); Err(std::io::Error::new(std::io::ErrorKind::Other, e)) - } + }, Err(e) => { log::error!("Tarpit server task error: {e}"); - Err(std::io::Error::new( - std::io::ErrorKind::Other, - e.to_string(), - )) + Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())) + }, + }, + result = metrics_server => { + if let Err(ref e) = result { + log::error!("Metrics server error: {e}"); } - } + result + }, } } @@ -1266,7 +1255,6 @@ mod tests { let args = Args { listen_addr: "127.0.0.1:8080".to_string(), metrics_port: 9000, - disable_metrics: true, backend_addr: "127.0.0.1:8081".to_string(), min_delay: 500, max_delay: 10000, @@ -1280,7 +1268,6 @@ mod tests { let config = Config::from_args(&args); assert_eq!(config.listen_addr, "127.0.0.1:8080"); assert_eq!(config.metrics_port, 9000); - assert!(config.disable_metrics); assert_eq!(config.backend_addr, "127.0.0.1:8081"); assert_eq!(config.min_delay, 500); assert_eq!(config.max_delay, 10000); diff --git a/src/metrics.rs b/src/metrics.rs deleted file mode 100644 index c733752..0000000 --- a/src/metrics.rs +++ /dev/null @@ -1,61 +0,0 @@ -use actix_web::{HttpRequest, HttpResponse, web}; -use lazy_static::lazy_static; -use prometheus::{ - Counter, CounterVec, Encoder, Gauge, register_counter, register_counter_vec, register_gauge, -}; -use serde_json::json; -use std::sync::Arc; -use tokio::sync::RwLock; - -use crate::BotState; - -// Prometheus metrics. I'll expand this with more metrics as I need to. -lazy_static! { - pub static ref HITS_COUNTER: Counter = - register_counter!("eris_hits_total", "Total number of hits to honeypot paths").unwrap(); - pub static ref BLOCKED_IPS: Gauge = - register_gauge!("eris_blocked_ips", "Number of IPs permanently blocked").unwrap(); - pub static ref ACTIVE_CONNECTIONS: Gauge = register_gauge!( - "eris_active_connections", - "Number of currently active connections in tarpit" - ) - .unwrap(); - pub static ref PATH_HITS: CounterVec = - register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap(); - pub static ref UA_HITS: CounterVec = - register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap(); -} - -// Prometheus metrics endpoint -pub async fn metrics_handler(_req: HttpRequest) -> HttpResponse { - let encoder = prometheus::TextEncoder::new(); - let mut buffer = Vec::new(); - - match encoder.encode(&prometheus::gather(), &mut buffer) { - Ok(()) => { - log::debug!("Metrics requested, returned {} bytes", buffer.len()); - } - Err(e) => { - log::error!("Error encoding metrics: {e}"); - } - } - - HttpResponse::Ok().content_type("text/plain").body(buffer) -} - -// Status JSON endpoint -pub async fn status_handler(state: web::Data>>) -> HttpResponse { - let state = state.read().await; - - let info = json!({ - "status": "running", - "version": env!("CARGO_PKG_VERSION"), - "blocked_ips": state.blocked.len(), - "active_connections": state.active_connections.len(), - "hit_count": state.hits.len(), - }); - - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string_pretty(&info).unwrap()) -}