diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..c055435 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,281 @@ +use std::io; + +/// Result type alias for the application +pub type Result = std::result::Result; + +/// Comprehensive error types for the Eris application +#[derive(thiserror::Error, Debug)] +pub enum ErisError { + /// Configuration-related errors + #[error("Configuration error: {message}")] + Config { message: String }, + + /// Network-related errors + #[error("Network error: {0}")] + Network(#[from] io::Error), + + /// HTTP parsing errors + #[error("HTTP parsing error: {message}")] + HttpParse { message: String }, + + /// Firewall operation errors + #[error("Firewall operation failed: {message}")] + Firewall { message: String }, + + /// Lua script execution errors + #[error("Lua script error: {message}")] + Lua { message: String }, + + /// Markov chain generation errors + #[error("Markov generation error: {message}")] + Markov { message: String }, + + /// Metrics collection errors + #[error("Metrics error: {message}")] + Metrics { message: String }, + + /// File system errors + #[error("File system error: {message}")] + FileSystem { message: String }, + + /// Validation errors + #[error("Validation error: {message}")] + Validation { message: String }, + + /// IP address parsing errors + #[error("Invalid IP address: {address}")] + InvalidIp { address: String }, + + /// Connection limit exceeded + #[error("Connection limit exceeded: {current}/{max}")] + ConnectionLimit { current: usize, max: usize }, + + /// Rate limiting errors + #[error("Rate limit exceeded for IP: {ip}")] + RateLimit { ip: String }, + + /// Timeout errors + #[error("Operation timed out: {operation}")] + Timeout { operation: String }, + + /// Permission errors + #[error("Permission denied: {operation}")] + Permission { operation: String }, + + /// Resource not found errors + #[error("Resource not found: {resource}")] + NotFound { resource: String }, + + /// Generic application errors + #[error("Application error: {message}")] + Application { message: String }, +} + +impl ErisError { + /// Create a new configuration error + pub fn config>(message: T) -> Self { + Self::Config { + message: message.into(), + } + } + + /// Create a new HTTP parsing error + pub fn http_parse>(message: T) -> Self { + Self::HttpParse { + message: message.into(), + } + } + + /// Create a new firewall error + pub fn firewall>(message: T) -> Self { + Self::Firewall { + message: message.into(), + } + } + + /// Create a new Lua script error + pub fn lua>(message: T) -> Self { + Self::Lua { + message: message.into(), + } + } + + /// Create a new Markov generation error + pub fn markov>(message: T) -> Self { + Self::Markov { + message: message.into(), + } + } + + /// Create a new metrics error + pub fn metrics>(message: T) -> Self { + Self::Metrics { + message: message.into(), + } + } + + /// Create a new file system error + pub fn filesystem>(message: T) -> Self { + Self::FileSystem { + message: message.into(), + } + } + + /// Create a new validation error + pub fn validation>(message: T) -> Self { + Self::Validation { + message: message.into(), + } + } + + /// Create a new invalid IP error + pub fn invalid_ip>(address: T) -> Self { + Self::InvalidIp { + address: address.into(), + } + } + + /// Create a new connection limit error + #[must_use] + pub const fn connection_limit(current: usize, max: usize) -> Self { + Self::ConnectionLimit { current, max } + } + + /// Create a new rate limit error + pub fn rate_limit>(ip: T) -> Self { + Self::RateLimit { ip: ip.into() } + } + + /// Create a new timeout error + pub fn timeout>(operation: T) -> Self { + Self::Timeout { + operation: operation.into(), + } + } + + /// Create a new permission error + pub fn permission>(operation: T) -> Self { + Self::Permission { + operation: operation.into(), + } + } + + /// Create a new not found error + pub fn not_found>(resource: T) -> Self { + Self::NotFound { + resource: resource.into(), + } + } + + /// Create a new application error + pub fn application>(message: T) -> Self { + Self::Application { + message: message.into(), + } + } + + /// Check if this is a retryable error + #[must_use] + pub const fn is_retryable(&self) -> bool { + matches!( + self, + Self::Network(_) + | Self::Timeout { .. } + | Self::ConnectionLimit { .. } + | Self::RateLimit { .. } + ) + } + + /// Check if this error should be logged at debug level + #[must_use] + pub const fn is_debug_level(&self) -> bool { + matches!( + self, + Self::Network(_) | Self::HttpParse { .. } | Self::RateLimit { .. } + ) + } + + /// Get error category for metrics + #[must_use] + pub const fn category(&self) -> &'static str { + match self { + Self::Config { .. } => "config", + Self::Network { .. } => "network", + Self::HttpParse { .. } => "http", + Self::Firewall { .. } => "firewall", + Self::Lua { .. } => "lua", + Self::Markov { .. } => "markov", + Self::Metrics { .. } => "metrics", + Self::FileSystem { .. } => "filesystem", + Self::Validation { .. } => "validation", + Self::InvalidIp { .. } => "network", + Self::ConnectionLimit { .. } => "connection", + Self::RateLimit { .. } => "rate_limit", + Self::Timeout { .. } => "timeout", + Self::Permission { .. } => "permission", + Self::NotFound { .. } => "not_found", + Self::Application { .. } => "application", + } + } +} + +/// Convert from `serde_json::Error` +impl From for ErisError { + fn from(err: serde_json::Error) -> Self { + Self::config(format!("JSON parsing error: {err}")) + } +} + +/// Convert from `rlua::Error` +impl From for ErisError { + fn from(err: rlua::Error) -> Self { + Self::lua(format!("Lua execution error: {err}")) + } +} + +/// Convert from `ipnetwork::IpNetworkError` +impl From for ErisError { + fn from(err: ipnetwork::IpNetworkError) -> Self { + Self::validation(format!("IP network error: {err}")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_creation() { + let err = ErisError::config("Invalid port"); + assert!(matches!(err, ErisError::Config { .. })); + assert_eq!(err.category(), "config"); + } + + #[test] + fn test_error_retryable() { + assert!( + ErisError::Network(io::Error::new(io::ErrorKind::TimedOut, "timeout")).is_retryable() + ); + assert!(!ErisError::config("test").is_retryable()); + } + + #[test] + fn test_error_debug_level() { + assert!( + ErisError::Network(io::Error::new(io::ErrorKind::ConnectionRefused, "refused")) + .is_debug_level() + ); + assert!(!ErisError::config("test").is_debug_level()); + } + + #[test] + fn test_error_conversions() { + let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found"); + let eris_err: ErisError = io_err.into(); + assert!(matches!(eris_err, ErisError::Network(_))); + + let json_err = serde_json::from_str::("invalid json").unwrap_err(); + let eris_err: ErisError = json_err.into(); + assert!(matches!(eris_err, ErisError::Config { .. })); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..7bbc595 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,141 @@ +//! Eris - Sophisticated HTTP tarpit and honeypot +use std::collections::{HashMap, HashSet}; +use std::net::IpAddr; + +pub mod error; +pub mod markov; +pub mod metrics; + +// Re-export commonly used types +pub use error::{ErisError, Result}; +pub use markov::MarkovGenerator; +pub use metrics::{ + ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, + status_handler, +}; + +/// State of bots/IPs hitting the honeypot +#[derive(Clone, Debug)] +pub struct BotState { + pub hits: HashMap, + pub blocked: HashSet, + pub active_connections: HashSet, + pub data_dir: String, + pub cache_dir: String, +} + +impl BotState { + #[must_use] + pub fn new(data_dir: &str, cache_dir: &str) -> Self { + Self { + hits: HashMap::new(), + blocked: HashSet::new(), + active_connections: HashSet::new(), + data_dir: data_dir.to_string(), + cache_dir: cache_dir.to_string(), + } + } + + /// Load previous state from disk + #[must_use] + pub fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { + let mut state = Self::new(data_dir, cache_dir); + let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); + + if let Ok(content) = std::fs::read_to_string(&blocked_ips_file) { + let mut loaded = 0; + for line in content.lines() { + if let Ok(ip) = line.parse::() { + state.blocked.insert(ip); + loaded += 1; + } + } + log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); + } else { + log::info!("No blocked IPs file found at {blocked_ips_file}"); + } + + // Check for temporary hit counter cache + let hit_cache_file = format!("{cache_dir}/hit_counters.json"); + if let Ok(content) = std::fs::read_to_string(&hit_cache_file) + && let Ok(hit_map) = + serde_json::from_str::>(&content) + { + for (ip_str, count) in hit_map { + if let Ok(ip) = ip_str.parse::() { + state.hits.insert(ip, count); + } + } + log::info!("Loaded hit counters for {} IPs", state.hits.len()); + } + + BLOCKED_IPS.set(state.blocked.len() as f64); + state + } + + /// Persist state to disk for later reloading + pub fn save_to_disk(&self) { + // Save blocked IPs + if let Err(e) = std::fs::create_dir_all(&self.data_dir) { + log::error!("Failed to create data directory: {e}"); + return; + } + + let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); + + match std::fs::File::create(&blocked_ips_file) { + Ok(mut file) => { + let mut count = 0; + for ip in &self.blocked { + if std::io::Write::write_fmt(&mut file, format_args!("{ip}\n")).is_ok() { + count += 1; + } + } + log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); + } + Err(e) => { + log::error!("Failed to create blocked IPs file: {e}"); + } + } + + // Save hit counters to cache + if let Err(e) = std::fs::create_dir_all(&self.cache_dir) { + log::error!("Failed to create cache directory: {e}"); + return; + } + + let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); + let mut hit_map = std::collections::HashMap::new(); + for (ip, count) in &self.hits { + hit_map.insert(ip.to_string(), *count); + } + + match std::fs::File::create(&hit_cache_file) { + Ok(file) => { + if let Err(e) = serde_json::to_writer(file, &hit_map) { + log::error!("Failed to write hit counters to cache: {e}"); + } else { + log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); + } + } + Err(e) => { + log::error!("Failed to create hit counter cache file: {e}"); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_library_imports() { + // Test that we can import and use the main types + let _err = ErisError::config("test"); + let _result: Result<()> = Ok(()); + + // Test markov generator creation + let _markov = MarkovGenerator::new("./test_corpora"); + } +} diff --git a/src/main.rs b/src/main.rs index 2dc17d2..956ff69 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,13 @@ use actix_web::{App, HttpResponse, HttpServer, web}; use clap::Parser; +use eris::{BotState, ErisError, MarkovGenerator, Result}; use ipnetwork::IpNetwork; use rlua::{Function, Lua}; use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::env; use std::fs; -use std::io::Write; + use std::net::IpAddr; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -17,11 +18,8 @@ use tokio::process::Command; use tokio::sync::RwLock; use tokio::time::sleep; -mod markov; -mod metrics; - -use markov::MarkovGenerator; -use metrics::{ +// Import metrics from the metrics module +use eris::{ ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, status_handler, }; @@ -270,114 +268,6 @@ impl Config { } } -// State of bots/IPs hitting the honeypot -#[derive(Clone, Debug)] -struct BotState { - hits: HashMap, - blocked: HashSet, - active_connections: HashSet, - data_dir: String, - cache_dir: String, -} - -impl BotState { - fn new(data_dir: &str, cache_dir: &str) -> Self { - Self { - hits: HashMap::new(), - blocked: HashSet::new(), - active_connections: HashSet::new(), - data_dir: data_dir.to_string(), - cache_dir: cache_dir.to_string(), - } - } - - // Load previous state from disk - fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { - let mut state = Self::new(data_dir, cache_dir); - let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); - - if let Ok(content) = fs::read_to_string(&blocked_ips_file) { - let mut loaded = 0; - for line in content.lines() { - if let Ok(ip) = line.parse::() { - state.blocked.insert(ip); - loaded += 1; - } - } - log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); - } else { - log::info!("No blocked IPs file found at {blocked_ips_file}"); - } - - // Check for temporary hit counter cache - let hit_cache_file = format!("{cache_dir}/hit_counters.json"); - if let Ok(content) = fs::read_to_string(&hit_cache_file) { - if let Ok(hit_map) = serde_json::from_str::>(&content) { - for (ip_str, count) in hit_map { - if let Ok(ip) = ip_str.parse::() { - state.hits.insert(ip, count); - } - } - log::info!("Loaded hit counters for {} IPs", state.hits.len()); - } - } - - BLOCKED_IPS.set(state.blocked.len() as f64); - state - } - - // Persist state to disk for later reloading - fn save_to_disk(&self) { - // Save blocked IPs - if let Err(e) = fs::create_dir_all(&self.data_dir) { - log::error!("Failed to create data directory: {e}"); - return; - } - - let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); - - match fs::File::create(&blocked_ips_file) { - Ok(mut file) => { - let mut count = 0; - for ip in &self.blocked { - if writeln!(file, "{ip}").is_ok() { - count += 1; - } - } - log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); - } - Err(e) => { - log::error!("Failed to create blocked IPs file: {e}"); - } - } - - // Save hit counters to cache - if let Err(e) = fs::create_dir_all(&self.cache_dir) { - log::error!("Failed to create cache directory: {e}"); - return; - } - - let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); - let mut hit_map = HashMap::new(); - for (ip, count) in &self.hits { - hit_map.insert(ip.to_string(), *count); - } - - match fs::File::create(&hit_cache_file) { - Ok(file) => { - if let Err(e) = serde_json::to_writer(file, &hit_map) { - log::error!("Failed to write hit counters to cache: {e}"); - } else { - log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); - } - } - Err(e) => { - log::error!("Failed to create hit counter cache file: {e}"); - } - } - } -} - // Lua scripts for response generation and customization struct ScriptManager { script_content: String, @@ -394,18 +284,16 @@ impl ScriptManager { if script_dir.exists() { log::debug!("Loading Lua scripts from directory: {scripts_dir}"); if let Ok(entries) = fs::read_dir(script_dir) { - for entry in entries { - if let Ok(entry) = entry { - let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { - if let Ok(content) = fs::read_to_string(&path) { - log::debug!("Loaded Lua script: {}", path.display()); - script_content.push_str(&content); - script_content.push('\n'); - scripts_loaded = true; - } else { - log::warn!("Failed to read Lua script: {}", path.display()); - } + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { + if let Ok(content) = fs::read_to_string(&path) { + log::debug!("Loaded Lua script: {}", path.display()); + script_content.push_str(&content); + script_content.push('\n'); + scripts_loaded = true; + } else { + log::warn!("Failed to read Lua script: {}", path.display()); } } } @@ -681,12 +569,11 @@ async fn handle_connection( async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { // Don't tarpit whitelisted IPs (internal networks, etc) for network_str in &config.whitelist_networks { - if let Ok(network) = network_str.parse::() { - if network.contains(*ip) { + if let Ok(network) = network_str.parse::() + && network.contains(*ip) { log::debug!("IP {ip} is in whitelist network {network_str}"); return false; } - } } // Check if the request path matches any of our trap patterns @@ -924,7 +811,7 @@ async fn proxy_to_backend( } // Set up nftables firewall rules for IP blocking -async fn setup_firewall() -> Result<(), String> { +async fn setup_firewall() -> Result<()> { log::info!("Setting up firewall rules"); // Check if nft command exists @@ -956,7 +843,8 @@ async fn setup_firewall() -> Result<(), String> { .await; if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); + log::error!("Failed to create nftables table: {e}"); + return Err(ErisError::firewall("Failed to create nftables table")); } } } @@ -969,7 +857,8 @@ async fn setup_firewall() -> Result<(), String> { .await; if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); + log::error!("Failed to create nftables table: {e}"); + return Err(ErisError::firewall("Failed to create nftables table")); } } } @@ -997,13 +886,16 @@ async fn setup_firewall() -> Result<(), String> { .await; if let Err(e) = result { - return Err(format!("Failed to create blacklist set: {e}")); + log::error!("Failed to create blacklist set: {e}"); + return Err(ErisError::firewall("Failed to create blacklist set")); } } } Err(e) => { - log::warn!("Failed to check if blacklist set exists: {e}"); - return Err(format!("Failed to check if blacklist set exists: {e}")); + log::error!("Failed to check if blacklist set exists: {e}"); + return Err(ErisError::firewall( + "Failed to check if blacklist set exists", + )); } } @@ -1036,13 +928,16 @@ async fn setup_firewall() -> Result<(), String> { .await; if let Err(e) = result { - return Err(format!("Failed to add firewall rule: {e}")); + log::error!("Failed to add firewall rule: {e}"); + return Err(ErisError::firewall("Failed to add firewall rule")); } } } Err(e) => { - log::warn!("Failed to check if firewall rule exists: {e}"); - return Err(format!("Failed to check if firewall rule exists: {e}")); + log::error!("Failed to check if firewall rule exists: {e}"); + return Err(ErisError::firewall( + "Failed to check if firewall rule exists", + )); } } @@ -1152,7 +1047,8 @@ async fn main() -> std::io::Result<()> { let listener = match TcpListener::bind(&config.listen_addr).await { Ok(l) => l, Err(e) => { - return Err(format!("Failed to bind to {}: {}", config.listen_addr, e)); + log::error!("Failed to bind to {}: {}", config.listen_addr, e); + return Err(ErisError::config("Failed to bind to listen address")); } }; @@ -1186,7 +1082,7 @@ async fn main() -> std::io::Result<()> { } #[allow(unreachable_code)] - Ok::<(), String>(()) + Ok(()) }); // Start the metrics server with actix_web only if metrics are not disabled @@ -1227,11 +1123,11 @@ async fn main() -> std::io::Result<()> { Ok(Ok(())) => Ok(()), Ok(Err(e)) => { log::error!("Tarpit server error: {e}"); - Err(std::io::Error::new(std::io::ErrorKind::Other, e)) + Err(std::io::Error::other(e)) }, Err(e) => { log::error!("Tarpit server task error: {e}"); - Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())) + Err(std::io::Error::other(e.to_string())) }, }, result = metrics_server => { @@ -1247,12 +1143,11 @@ async fn main() -> std::io::Result<()> { Ok(Ok(())) => Ok(()), Ok(Err(e)) => { log::error!("Tarpit server error: {e}"); - Err(std::io::Error::new(std::io::ErrorKind::Other, e)) + Err(std::io::Error::other(e)) } Err(e) => { log::error!("Tarpit server task error: {e}"); - Err(std::io::Error::new( - std::io::ErrorKind::Other, + Err(std::io::Error::other( e.to_string(), )) } diff --git a/src/markov.rs b/src/markov.rs index b640dac..2da848a 100644 --- a/src/markov.rs +++ b/src/markov.rs @@ -90,6 +90,7 @@ pub struct MarkovGenerator { } impl MarkovGenerator { + #[must_use] pub fn new(corpus_dir: &str) -> Self { let mut chains = HashMap::new(); @@ -101,28 +102,22 @@ impl MarkovGenerator { // Load corpus files if they exist let path = Path::new(corpus_dir); - if path.exists() && path.is_dir() { - if let Ok(entries) = fs::read_dir(path) { - for entry in entries { - if let Ok(entry) = entry { - let file_path = entry.path(); - if let Some(file_name) = file_path.file_stem() { - if let Some(file_name_str) = file_name.to_str() { - if types.contains(&file_name_str) { - if let Ok(content) = fs::read_to_string(&file_path) { - let mut chain = Chain::new(DEFAULT_ORDER); - for line in content.lines() { - chain.add(line); - } - chains.insert(file_name_str.to_string(), chain); + if path.exists() && path.is_dir() + && let Ok(entries) = fs::read_dir(path) { + for entry in entries.flatten() { + let file_path = entry.path(); + if let Some(file_name) = file_path.file_stem() + && let Some(file_name_str) = file_name.to_str() + && types.contains(&file_name_str) + && let Ok(content) = fs::read_to_string(&file_path) { + let mut chain = Chain::new(DEFAULT_ORDER); + for line in content.lines() { + chain.add(line); } + chains.insert(file_name_str.to_string(), chain); } - } - } - } } } - } // If corpus files didn't exist, initialize with some default content if chains["php_exploit"].start_states.is_empty() {