From 37e57fa01510a775a33a2393d296fe613a9be94d Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Fri, 2 May 2025 10:35:18 +0300 Subject: [PATCH] eris: modulerize; fix some lints --- src/config.rs | 532 ++++++++++++++++++ src/firewall.rs | 128 +++++ src/main.rs | 1422 +---------------------------------------------- src/network.rs | 443 +++++++++++++++ src/state.rs | 164 ++++++ src/utils.rs | 168 ++++++ 6 files changed, 1449 insertions(+), 1408 deletions(-) create mode 100644 src/config.rs create mode 100644 src/firewall.rs create mode 100644 src/network.rs create mode 100644 src/state.rs create mode 100644 src/utils.rs diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..ab85136 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,532 @@ +use clap::Parser; +use ipnetwork::IpNetwork; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::env; +use std::fs; +use std::net::IpAddr; +use std::path::{Path, PathBuf}; + +// Command-line arguments using clap +#[derive(Parser, Debug, Clone)] +#[clap( + author, + version, + about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" +)] +pub struct Args { + #[clap( + long, + default_value = "0.0.0.0:8888", + help = "Address and port to listen for incoming HTTP requests (format: ip:port)" + )] + pub listen_addr: String, + + #[clap( + long, + default_value = "0.0.0.0:9100", + help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)" + )] + pub metrics_addr: String, + + #[clap(long, help = "Disable Prometheus metrics server completely")] + pub disable_metrics: bool, + + #[clap( + long, + default_value = "127.0.0.1:80", + help = "Backend server address to proxy legitimate requests to (format: ip:port)" + )] + pub backend_addr: String, + + #[clap( + long, + default_value = "1000", + help = "Minimum delay in milliseconds between chunks sent to attacker" + )] + pub min_delay: u64, + + #[clap( + long, + default_value = "15000", + help = "Maximum delay in milliseconds between chunks sent to attacker" + )] + pub max_delay: u64, + + #[clap( + long, + default_value = "600", + help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" + )] + pub max_tarpit_time: u64, + + #[clap( + long, + default_value = "3", + help = "Number of hits to honeypot patterns before permanently blocking an IP" + )] + pub block_threshold: u32, + + #[clap( + long, + help = "Base directory for all application data (overrides XDG directory structure)" + )] + pub base_dir: Option, + + #[clap( + long, + help = "Path to configuration file (JSON or TOML, overrides command line options)" + )] + pub config_file: Option, + + #[clap( + long, + default_value = "info", + help = "Log level: trace, debug, info, warn, error" + )] + pub log_level: String, +} + +// Trap pattern structure. It can be either a plain string +// regex to catch more advanced patterns necessitated by +// more sophisticated crawlers. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum TrapPattern { + Plain(String), + Regex { pattern: String, regex: bool }, +} + +impl TrapPattern { + pub fn as_plain(value: &str) -> Self { + Self::Plain(value.to_string()) + } + + pub fn as_regex(value: &str) -> Self { + Self::Regex { + pattern: value.to_string(), + regex: true, + } + } + + pub fn matches(&self, path: &str) -> bool { + match self { + Self::Plain(pattern) => path.contains(pattern), + Self::Regex { + pattern, + regex: true, + } => { + if let Ok(re) = Regex::new(pattern) { + re.is_match(path) + } else { + false + } + } + _ => false, + } + } +} + +// Configuration structure +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Config { + pub listen_addr: String, + pub metrics_addr: String, + pub disable_metrics: bool, + pub backend_addr: String, + pub min_delay: u64, + pub max_delay: u64, + pub max_tarpit_time: u64, + pub block_threshold: u32, + pub trap_patterns: Vec, + pub whitelist_networks: Vec, + pub markov_corpora_dir: String, + pub lua_scripts_dir: String, + pub data_dir: String, + pub config_dir: String, + pub cache_dir: String, +} + +impl Default for Config { + fn default() -> Self { + Self { + listen_addr: "0.0.0.0:8888".to_string(), + metrics_addr: "0.0.0.0:9100".to_string(), + disable_metrics: false, + backend_addr: "127.0.0.1:80".to_string(), + min_delay: 1000, + max_delay: 15000, + max_tarpit_time: 600, + block_threshold: 3, + trap_patterns: vec![ + // Basic attack patterns as plain strings + TrapPattern::as_plain("/vendor/phpunit"), + TrapPattern::as_plain("eval-stdin.php"), + TrapPattern::as_plain("/wp-admin"), + TrapPattern::as_plain("/wp-login.php"), + TrapPattern::as_plain("/xmlrpc.php"), + TrapPattern::as_plain("/phpMyAdmin"), + TrapPattern::as_plain("/solr/"), + TrapPattern::as_plain("/.env"), + TrapPattern::as_plain("/config"), + TrapPattern::as_plain("/actuator/"), + // More aggressive patterns for various PHP exploits. + // XXX: I dedicate this entire section to that one single crawler + // that has been scanning my entire network, hitting 403s left and right + // but not giving up, and coming back the next day at the same time to + // scan the same paths over and over. Kudos to you, random crawler. + TrapPattern::as_regex(r"/.*phpunit.*eval-stdin\.php"), + TrapPattern::as_regex(r"/index\.php\?s=/index/\\think\\app/invokefunction"), + TrapPattern::as_regex(r".*%ADd\+auto_prepend_file%3dphp://input.*"), + TrapPattern::as_regex(r".*%ADd\+allow_url_include%3d1.*"), + TrapPattern::as_regex(r".*/wp-content/plugins/.*\.php"), + TrapPattern::as_regex(r".*/wp-content/themes/.*\.php"), + TrapPattern::as_regex(r".*eval\(.*\).*"), + TrapPattern::as_regex(r".*/adminer\.php.*"), + TrapPattern::as_regex(r".*/admin\.php.*"), + TrapPattern::as_regex(r".*/administrator/.*"), + TrapPattern::as_regex(r".*/wp-json/.*"), + TrapPattern::as_regex(r".*/api/.*\.php.*"), + TrapPattern::as_regex(r".*/cgi-bin/.*"), + TrapPattern::as_regex(r".*/owa/.*"), + TrapPattern::as_regex(r".*/ecp/.*"), + TrapPattern::as_regex(r".*/webshell\.php.*"), + TrapPattern::as_regex(r".*/shell\.php.*"), + TrapPattern::as_regex(r".*/cmd\.php.*"), + TrapPattern::as_regex(r".*/struts.*"), + ], + whitelist_networks: vec![ + "192.168.0.0/16".to_string(), + "10.0.0.0/8".to_string(), + "172.16.0.0/12".to_string(), + "127.0.0.0/8".to_string(), + ], + markov_corpora_dir: "./corpora".to_string(), + lua_scripts_dir: "./scripts".to_string(), + data_dir: "./data".to_string(), + config_dir: "./conf".to_string(), + cache_dir: "./cache".to_string(), + } + } +} + +// Gets standard XDG directory paths for config, data and cache. +// XXX: This could be "simplified" by using the Dirs crate, but I can't +// really justify pulling a library for something I can handle in less +// than 30 lines. Unless cross-platform becomes necessary, the below +// implementation is good enough. For alternative platforms, we can simply +// enhance the current implementation as needed. +pub fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) { + let config_home = env::var_os("XDG_CONFIG_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".config") + }); + + let data_home = env::var_os("XDG_DATA_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".local").join("share") + }); + + let cache_home = env::var_os("XDG_CACHE_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".cache") + }); + + let config_dir = config_home.join("eris"); + let data_dir = data_home.join("eris"); + let cache_dir = cache_home.join("eris"); + + (config_dir, data_dir, cache_dir) +} + +impl Config { + // Create configuration from command-line args. We'll be falling back to this + // when the configuration is invalid, so it must be validated more strictly. + pub fn from_args(args: &Args) -> Self { + let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir { + let base_str = base_dir.to_string_lossy().to_string(); + ( + format!("{base_str}/conf"), + format!("{base_str}/data"), + format!("{base_str}/cache"), + ) + } else { + let (c, d, cache) = get_xdg_dirs(); + ( + c.to_string_lossy().to_string(), + d.to_string_lossy().to_string(), + cache.to_string_lossy().to_string(), + ) + }; + + Self { + listen_addr: args.listen_addr.clone(), + metrics_addr: args.metrics_addr.clone(), + disable_metrics: args.disable_metrics, + backend_addr: args.backend_addr.clone(), + min_delay: args.min_delay, + max_delay: args.max_delay, + max_tarpit_time: args.max_tarpit_time, + block_threshold: args.block_threshold, + markov_corpora_dir: format!("{data_dir}/corpora"), + lua_scripts_dir: format!("{data_dir}/scripts"), + data_dir, + config_dir, + cache_dir, + ..Default::default() + } + } + + // Load configuration from a file (JSON or TOML) + pub fn load_from_file(path: &Path) -> std::io::Result { + let content = fs::read_to_string(path)?; + + let extension = path + .extension() + .map(|ext| ext.to_string_lossy().to_lowercase()) + .unwrap_or_default(); + + let config = match extension.as_str() { + "toml" => toml::from_str(&content).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to parse TOML: {e}"), + ) + })?, + _ => { + // Default to JSON for any other extension + serde_json::from_str(&content).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to parse JSON: {e}"), + ) + })? + } + }; + + Ok(config) + } + + // Save configuration to a file (JSON or TOML) + pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + + let extension = path + .extension() + .map(|ext| ext.to_string_lossy().to_lowercase()) + .unwrap_or_default(); + + let content = match extension.as_str() { + "toml" => toml::to_string_pretty(self).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to serialize to TOML: {e}"), + ) + })?, + _ => { + // Default to JSON for any other extension + serde_json::to_string_pretty(self).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to serialize to JSON: {e}"), + ) + })? + } + }; + + fs::write(path, content)?; + Ok(()) + } + + // Create required directories if they don't exist + pub fn ensure_dirs_exist(&self) -> std::io::Result<()> { + let dirs = [ + &self.markov_corpora_dir, + &self.lua_scripts_dir, + &self.data_dir, + &self.config_dir, + &self.cache_dir, + ]; + + for dir in dirs { + fs::create_dir_all(dir)?; + log::debug!("Created directory: {dir}"); + } + + Ok(()) + } +} + +// Decide if a request should be tarpitted based on path and IP +pub fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { + // Check whitelist IPs first to avoid unnecessary pattern matching + for network_str in &config.whitelist_networks { + if let Ok(network) = network_str.parse::() { + if network.contains(*ip) { + return false; + } + } + } + + // Use pattern matching based on the trap pattern type. It can be + // a plain string or regex. + for pattern in &config.trap_patterns { + if pattern.matches(path) { + return true; + } + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + + #[test] + fn test_config_from_args() { + let args = Args { + listen_addr: "127.0.0.1:8080".to_string(), + metrics_addr: "127.0.0.1:9000".to_string(), + disable_metrics: true, + backend_addr: "127.0.0.1:8081".to_string(), + min_delay: 500, + max_delay: 10000, + max_tarpit_time: 300, + block_threshold: 5, + base_dir: Some(PathBuf::from("/tmp/eris")), + config_file: None, + log_level: "debug".to_string(), + }; + + let config = Config::from_args(&args); + assert_eq!(config.listen_addr, "127.0.0.1:8080"); + assert_eq!(config.metrics_addr, "127.0.0.1:9000"); + assert!(config.disable_metrics); + assert_eq!(config.backend_addr, "127.0.0.1:8081"); + assert_eq!(config.min_delay, 500); + assert_eq!(config.max_delay, 10000); + assert_eq!(config.max_tarpit_time, 300); + assert_eq!(config.block_threshold, 5); + assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora"); + assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts"); + assert_eq!(config.data_dir, "/tmp/eris/data"); + assert_eq!(config.config_dir, "/tmp/eris/conf"); + assert_eq!(config.cache_dir, "/tmp/eris/cache"); + } + + #[test] + fn test_trap_pattern_matching() { + // Test plain string pattern + let plain = TrapPattern::as_plain("phpunit"); + assert!(plain.matches("path/to/phpunit/test")); + assert!(!plain.matches("path/to/something/else")); + + // Test regex pattern + let regex = TrapPattern::as_regex(r".*eval-stdin\.php.*"); + assert!(regex.matches("/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php")); + assert!(regex.matches("/tests/eval-stdin.php?param")); + assert!(!regex.matches("/normal/path")); + + // Test invalid regex pattern (should return false) + let invalid = TrapPattern::Regex { + pattern: "(invalid[regex".to_string(), + regex: true, + }; + assert!(!invalid.matches("anything")); + } + + #[tokio::test] + async fn test_should_tarpit() { + let config = Config::default(); + + // Test trap patterns + assert!(should_tarpit( + "/vendor/phpunit/whatever", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + assert!(should_tarpit( + "/wp-admin/login.php", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + assert!(should_tarpit( + "/.env", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + + // Test whitelist networks + assert!(!should_tarpit( + "/wp-admin/login.php", + &IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + &config + )); + assert!(!should_tarpit( + "/vendor/phpunit/whatever", + &IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + &config + )); + + // Test legitimate paths + assert!(!should_tarpit( + "/index.html", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + assert!(!should_tarpit( + "/images/logo.png", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + + // Test regex patterns + assert!(should_tarpit( + "/index.php?s=/index/\\think\\app/invokefunction&function=call_user_func_array&vars[0]=md5&vars[1][]=Hello", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + + assert!(should_tarpit( + "/hello.world?%ADd+allow_url_include%3d1+%ADd+auto_prepend_file%3dphp://input", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + } + + #[test] + fn test_config_file_formats() { + // Create temporary JSON config file + let temp_dir = std::env::temp_dir(); + let json_path = temp_dir.join("temp_config.json"); + let toml_path = temp_dir.join("temp_config.toml"); + + let config = Config::default(); + + // Test JSON serialization and deserialization + config.save_to_file(&json_path).unwrap(); + let loaded_json = Config::load_from_file(&json_path).unwrap(); + assert_eq!(loaded_json.listen_addr, config.listen_addr); + assert_eq!(loaded_json.min_delay, config.min_delay); + + // Test TOML serialization and deserialization + config.save_to_file(&toml_path).unwrap(); + let loaded_toml = Config::load_from_file(&toml_path).unwrap(); + assert_eq!(loaded_toml.listen_addr, config.listen_addr); + assert_eq!(loaded_toml.min_delay, config.min_delay); + + // Clean up + let _ = std::fs::remove_file(json_path); + let _ = std::fs::remove_file(toml_path); + } +} diff --git a/src/firewall.rs b/src/firewall.rs new file mode 100644 index 0000000..73f36e5 --- /dev/null +++ b/src/firewall.rs @@ -0,0 +1,128 @@ +use tokio::process::Command; + +// Set up nftables firewall rules for IP blocking +pub async fn setup_firewall() -> Result<(), String> { + log::info!("Setting up firewall rules"); + + // Check if nft command exists + let nft_exists = Command::new("which") + .arg("nft") + .output() + .await + .map(|output| output.status.success()) + .unwrap_or(false); + + if !nft_exists { + log::warn!("nft command not found. Firewall rules will not be set up."); + return Ok(()); + } + + // Create table if it doesn't exist + let output = Command::new("nft") + .args(["list", "table", "inet", "filter"]) + .output() + .await; + + match output { + Ok(output) => { + if !output.status.success() { + log::info!("Creating nftables table"); + let result = Command::new("nft") + .args(["create", "table", "inet", "filter"]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create nftables table: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if nftables table exists: {e}"); + log::info!("Will try to create it anyway"); + let result = Command::new("nft") + .args(["create", "table", "inet", "filter"]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create nftables table: {e}")); + } + } + } + + // Create blacklist set if it doesn't exist + let output = Command::new("nft") + .args(["list", "set", "inet", "filter", "eris_blacklist"]) + .output() + .await; + + match output { + Ok(output) => { + if !output.status.success() { + log::info!("Creating eris_blacklist set"); + let result = Command::new("nft") + .args([ + "create", + "set", + "inet", + "filter", + "eris_blacklist", + "{ type ipv4_addr; flags interval; }", + ]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create blacklist set: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if blacklist set exists: {e}"); + return Err(format!("Failed to check if blacklist set exists: {e}")); + } + } + + // Add rule to drop traffic from blacklisted IPs + let output = Command::new("nft") + .args(["list", "chain", "inet", "filter", "input"]) + .output() + .await; + + // Check if our rule already exists + match output { + Ok(output) => { + let rule_exists = String::from_utf8_lossy(&output.stdout) + .contains("ip saddr @eris_blacklist counter drop"); + + if !rule_exists { + log::info!("Adding drop rule for blacklisted IPs"); + let result = Command::new("nft") + .args([ + "add", + "rule", + "inet", + "filter", + "input", + "ip saddr @eris_blacklist", + "counter", + "drop", + ]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to add firewall rule: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if firewall rule exists: {e}"); + return Err(format!("Failed to check if firewall rule exists: {e}")); + } + } + + log::info!("Firewall setup complete"); + Ok(()) +} diff --git a/src/main.rs b/src/main.rs index 185c258..9c6c73f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,1155 +1,28 @@ use actix_web::{App, HttpResponse, HttpServer, web}; use clap::Parser; -use ipnetwork::IpNetwork; -use regex::Regex; -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; -use std::env; use std::fs; -use std::hash::Hasher; -use std::io::Write; -use std::net::IpAddr; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::process::Command; +use std::time::Duration; +use tokio::net::TcpListener; use tokio::sync::RwLock; -use tokio::time::sleep; +mod config; +mod firewall; mod lua; mod markov; mod metrics; +mod network; +mod state; +mod utils; +use config::{Args, Config}; use lua::{EventContext, EventType, ScriptManager}; use markov::MarkovGenerator; -use metrics::{ - ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, - status_handler, -}; - -// Command-line arguments using clap -#[derive(Parser, Debug, Clone)] -#[clap( - author, - version, - about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" -)] -struct Args { - #[clap( - long, - default_value = "0.0.0.0:8888", - help = "Address and port to listen for incoming HTTP requests (format: ip:port)" - )] - listen_addr: String, - - #[clap( - long, - default_value = "0.0.0.0:9100", - help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)" - )] - metrics_addr: String, - - #[clap(long, help = "Disable Prometheus metrics server completely")] - disable_metrics: bool, - - #[clap( - long, - default_value = "127.0.0.1:80", - help = "Backend server address to proxy legitimate requests to (format: ip:port)" - )] - backend_addr: String, - - #[clap( - long, - default_value = "1000", - help = "Minimum delay in milliseconds between chunks sent to attacker" - )] - min_delay: u64, - - #[clap( - long, - default_value = "15000", - help = "Maximum delay in milliseconds between chunks sent to attacker" - )] - max_delay: u64, - - #[clap( - long, - default_value = "600", - help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" - )] - max_tarpit_time: u64, - - #[clap( - long, - default_value = "3", - help = "Number of hits to honeypot patterns before permanently blocking an IP" - )] - block_threshold: u32, - - #[clap( - long, - help = "Base directory for all application data (overrides XDG directory structure)" - )] - base_dir: Option, - - #[clap( - long, - help = "Path to configuration file (JSON or TOML, overrides command line options)" - )] - config_file: Option, - - #[clap( - long, - default_value = "info", - help = "Log level: trace, debug, info, warn, error" - )] - log_level: String, -} - -// Trap pattern structure that can be either a plain string or regex -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(untagged)] -enum TrapPattern { - Plain(String), - Regex { pattern: String, regex: bool }, -} - -impl TrapPattern { - fn as_plain(value: &str) -> Self { - Self::Plain(value.to_string()) - } - - fn as_regex(value: &str) -> Self { - Self::Regex { - pattern: value.to_string(), - regex: true, - } - } - - fn matches(&self, path: &str) -> bool { - match self { - Self::Plain(pattern) => path.contains(pattern), - Self::Regex { - pattern, - regex: true, - } => { - if let Ok(re) = Regex::new(pattern) { - re.is_match(path) - } else { - false - } - } - _ => false, - } - } -} - -// Configuration structure -#[derive(Clone, Debug, Deserialize, Serialize)] -struct Config { - listen_addr: String, - metrics_addr: String, - disable_metrics: bool, - backend_addr: String, - min_delay: u64, - max_delay: u64, - max_tarpit_time: u64, - block_threshold: u32, - trap_patterns: Vec, - whitelist_networks: Vec, - markov_corpora_dir: String, - lua_scripts_dir: String, - data_dir: String, - config_dir: String, - cache_dir: String, -} - -impl Default for Config { - fn default() -> Self { - Self { - listen_addr: "0.0.0.0:8888".to_string(), - metrics_addr: "0.0.0.0:9100".to_string(), - disable_metrics: false, - backend_addr: "127.0.0.1:80".to_string(), - min_delay: 1000, - max_delay: 15000, - max_tarpit_time: 600, - block_threshold: 3, - trap_patterns: vec![ - // Basic attack patterns as plain strings - TrapPattern::as_plain("/vendor/phpunit"), - TrapPattern::as_plain("eval-stdin.php"), - TrapPattern::as_plain("/wp-admin"), - TrapPattern::as_plain("/wp-login.php"), - TrapPattern::as_plain("/xmlrpc.php"), - TrapPattern::as_plain("/phpMyAdmin"), - TrapPattern::as_plain("/solr/"), - TrapPattern::as_plain("/.env"), - TrapPattern::as_plain("/config"), - TrapPattern::as_plain("/actuator/"), - // More aggressive patterns for various PHP exploits - TrapPattern::as_regex(r"/.*phpunit.*eval-stdin\.php"), - TrapPattern::as_regex(r"/index\.php\?s=/index/\\think\\app/invokefunction"), - TrapPattern::as_regex(r".*%ADd\+auto_prepend_file%3dphp://input.*"), - TrapPattern::as_regex(r".*%ADd\+allow_url_include%3d1.*"), - TrapPattern::as_regex(r".*/wp-content/plugins/.*\.php"), - TrapPattern::as_regex(r".*/wp-content/themes/.*\.php"), - TrapPattern::as_regex(r".*eval\(.*\).*"), - TrapPattern::as_regex(r".*/adminer\.php.*"), - TrapPattern::as_regex(r".*/admin\.php.*"), - TrapPattern::as_regex(r".*/administrator/.*"), - TrapPattern::as_regex(r".*/wp-json/.*"), - TrapPattern::as_regex(r".*/api/.*\.php.*"), - TrapPattern::as_regex(r".*/cgi-bin/.*"), - TrapPattern::as_regex(r".*/owa/.*"), - TrapPattern::as_regex(r".*/ecp/.*"), - TrapPattern::as_regex(r".*/webshell\.php.*"), - TrapPattern::as_regex(r".*/shell\.php.*"), - TrapPattern::as_regex(r".*/cmd\.php.*"), - TrapPattern::as_regex(r".*/struts.*"), - ], - whitelist_networks: vec![ - "192.168.0.0/16".to_string(), - "10.0.0.0/8".to_string(), - "172.16.0.0/12".to_string(), - "127.0.0.0/8".to_string(), - ], - markov_corpora_dir: "./corpora".to_string(), - lua_scripts_dir: "./scripts".to_string(), - data_dir: "./data".to_string(), - config_dir: "./conf".to_string(), - cache_dir: "./cache".to_string(), - } - } -} - -// Gets standard XDG directory paths for config, data and cache -fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) { - let config_home = env::var_os("XDG_CONFIG_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".config") - }); - - let data_home = env::var_os("XDG_DATA_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".local").join("share") - }); - - let cache_home = env::var_os("XDG_CACHE_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".cache") - }); - - let config_dir = config_home.join("eris"); - let data_dir = data_home.join("eris"); - let cache_dir = cache_home.join("eris"); - - (config_dir, data_dir, cache_dir) -} - -impl Config { - // Create configuration from command-line args - fn from_args(args: &Args) -> Self { - let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir { - let base_str = base_dir.to_string_lossy().to_string(); - ( - format!("{base_str}/conf"), - format!("{base_str}/data"), - format!("{base_str}/cache"), - ) - } else { - let (c, d, cache) = get_xdg_dirs(); - ( - c.to_string_lossy().to_string(), - d.to_string_lossy().to_string(), - cache.to_string_lossy().to_string(), - ) - }; - - Self { - listen_addr: args.listen_addr.clone(), - metrics_addr: args.metrics_addr.clone(), - disable_metrics: args.disable_metrics, - backend_addr: args.backend_addr.clone(), - min_delay: args.min_delay, - max_delay: args.max_delay, - max_tarpit_time: args.max_tarpit_time, - block_threshold: args.block_threshold, - markov_corpora_dir: format!("{data_dir}/corpora"), - lua_scripts_dir: format!("{data_dir}/scripts"), - data_dir, - config_dir, - cache_dir, - ..Default::default() - } - } - - // Load configuration from a file (JSON or TOML) - fn load_from_file(path: &Path) -> std::io::Result { - let content = fs::read_to_string(path)?; - - let extension = path - .extension() - .map(|ext| ext.to_string_lossy().to_lowercase()) - .unwrap_or_default(); - - let config = match extension.as_str() { - "toml" => toml::from_str(&content).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to parse TOML: {e}"), - ) - })?, - _ => { - // Default to JSON for any other extension - serde_json::from_str(&content).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to parse JSON: {e}"), - ) - })? - } - }; - - Ok(config) - } - - // Save configuration to a file (JSON or TOML) - fn save_to_file(&self, path: &Path) -> std::io::Result<()> { - if let Some(parent) = path.parent() { - fs::create_dir_all(parent)?; - } - - let extension = path - .extension() - .map(|ext| ext.to_string_lossy().to_lowercase()) - .unwrap_or_default(); - - let content = match extension.as_str() { - "toml" => toml::to_string_pretty(self).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to serialize to TOML: {e}"), - ) - })?, - _ => { - // Default to JSON for any other extension - serde_json::to_string_pretty(self).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to serialize to JSON: {e}"), - ) - })? - } - }; - - fs::write(path, content)?; - Ok(()) - } - - // Create required directories if they don't exist - fn ensure_dirs_exist(&self) -> std::io::Result<()> { - let dirs = [ - &self.markov_corpora_dir, - &self.lua_scripts_dir, - &self.data_dir, - &self.config_dir, - &self.cache_dir, - ]; - - for dir in dirs { - fs::create_dir_all(dir)?; - log::debug!("Created directory: {dir}"); - } - - Ok(()) - } -} - -// State of bots/IPs hitting the honeypot -#[derive(Clone, Debug)] -struct BotState { - hits: HashMap, - blocked: HashSet, - active_connections: HashSet, - data_dir: String, - cache_dir: String, -} - -impl BotState { - fn new(data_dir: &str, cache_dir: &str) -> Self { - Self { - hits: HashMap::new(), - blocked: HashSet::new(), - active_connections: HashSet::new(), - data_dir: data_dir.to_string(), - cache_dir: cache_dir.to_string(), - } - } - - // Load previous state from disk - fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { - let mut state = Self::new(data_dir, cache_dir); - let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); - - if let Ok(content) = fs::read_to_string(&blocked_ips_file) { - let mut loaded = 0; - for line in content.lines() { - if let Ok(ip) = line.parse::() { - state.blocked.insert(ip); - loaded += 1; - } - } - log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); - } else { - log::info!("No blocked IPs file found at {blocked_ips_file}"); - } - - // Check for temporary hit counter cache - let hit_cache_file = format!("{cache_dir}/hit_counters.json"); - if let Ok(content) = fs::read_to_string(&hit_cache_file) { - if let Ok(hit_map) = serde_json::from_str::>(&content) { - for (ip_str, count) in hit_map { - if let Ok(ip) = ip_str.parse::() { - state.hits.insert(ip, count); - } - } - log::info!("Loaded hit counters for {} IPs", state.hits.len()); - } - } - - BLOCKED_IPS.set(state.blocked.len() as f64); - state - } - - // Persist state to disk for later reloading - fn save_to_disk(&self) { - // Save blocked IPs - if let Err(e) = fs::create_dir_all(&self.data_dir) { - log::error!("Failed to create data directory: {e}"); - return; - } - - let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); - - match fs::File::create(&blocked_ips_file) { - Ok(mut file) => { - let mut count = 0; - for ip in &self.blocked { - if writeln!(file, "{ip}").is_ok() { - count += 1; - } - } - log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); - } - Err(e) => { - log::error!("Failed to create blocked IPs file: {e}"); - } - } - - // Save hit counters to cache - if let Err(e) = fs::create_dir_all(&self.cache_dir) { - log::error!("Failed to create cache directory: {e}"); - return; - } - - let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); - let hit_map: HashMap = self - .hits - .iter() - .map(|(ip, count)| (ip.to_string(), *count)) - .collect(); - - match fs::File::create(&hit_cache_file) { - Ok(file) => { - if let Err(e) = serde_json::to_writer(file, &hit_map) { - log::error!("Failed to write hit counters to cache: {e}"); - } else { - log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); - } - } - Err(e) => { - log::error!("Failed to create hit counter cache file: {e}"); - } - } - } -} - -// Find end of HTTP headers -fn find_header_end(data: &[u8]) -> Option { - data.windows(4) - .position(|window| window == b"\r\n\r\n") - .map(|pos| pos + 4) -} - -// Extract path from raw request data -fn extract_path_from_request(data: &[u8]) -> Option<&str> { - // Get first line from request - let first_line = data - .split(|&b| b == b'\r' || b == b'\n') - .next() - .filter(|line| !line.is_empty())?; - - // Split by spaces and ensure we have at least 3 parts (METHOD PATH VERSION) - let parts: Vec<&[u8]> = first_line.split(|&b| b == b' ').collect(); - if parts.len() < 3 || !parts[2].starts_with(b"HTTP/") { - return None; - } - - // Return the path (second element) - std::str::from_utf8(parts[1]).ok() -} - -// Extract header value from raw request data -fn extract_header_value(data: &[u8], header_name: &str) -> Option { - let data_str = std::str::from_utf8(data).ok()?; - let header_prefix = format!("{header_name}: ").to_lowercase(); - - for line in data_str.lines() { - let line_lower = line.to_lowercase(); - if line_lower.starts_with(&header_prefix) { - return Some(line[header_prefix.len()..].trim().to_string()); - } - } - None -} - -// Extract all headers from request data -fn extract_all_headers(data: &[u8]) -> HashMap { - let mut headers = HashMap::new(); - - if let Ok(data_str) = std::str::from_utf8(data) { - let mut lines = data_str.lines(); - - // Skip the request line - let _ = lines.next(); - - // Parse headers until empty line - for line in lines { - if line.is_empty() { - break; - } - - if let Some(colon_pos) = line.find(':') { - let key = line[..colon_pos].trim().to_lowercase(); - let value = line[colon_pos + 1..].trim().to_string(); - headers.insert(key, value); - } - } - } - - headers -} - -// Determine response type based on request path -fn choose_response_type(path: &str) -> &'static str { - if path.contains("phpunit") || path.contains("eval") { - "php_exploit" - } else if path.contains("wp-") { - "wordpress" - } else if path.contains("api") { - "api" - } else { - "generic" - } -} - -// Helper function to get current timestamp in seconds -fn get_timestamp() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() -} - -// Create a unique session ID for tracking a connection -fn generate_session_id(ip: &str, user_agent: &str) -> String { - let timestamp = get_timestamp(); - let random = rand::random::(); - - // Use std::hash instead of xxhash_rust - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - std::hash::Hash::hash(&format!("{ip}_{user_agent}_{timestamp}"), &mut hasher); - let hash = hasher.finish(); - - format!("SID_{hash:x}_{random:x}") -} - -// Main connection handler. -// Decides whether to tarpit or proxy -async fn handle_connection( - mut stream: TcpStream, - config: Arc, - state: Arc>, - markov_generator: Arc, - script_manager: Arc, -) { - // Get peer information - let peer_addr = match stream.peer_addr() { - Ok(addr) => addr.ip(), - Err(e) => { - log::debug!("Failed to get peer address: {e}"); - return; - } - }; - - // Check for blocked IPs to avoid any processing - if state.read().await.blocked.contains(&peer_addr) { - log::debug!("Rejected connection from blocked IP: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - - // Check if Lua scripts allow this connection - if !script_manager.on_connection(&peer_addr.to_string()) { - log::debug!("Connection rejected by Lua script: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - - // Pre-check for whitelisted IPs to bypass heavy processing - let mut whitelisted = false; - for network_str in &config.whitelist_networks { - if let Ok(network) = network_str.parse::() { - if network.contains(peer_addr) { - whitelisted = true; - break; - } - } - } - - // Read buffer - let mut buffer = vec![0; 8192]; - let mut request_data = Vec::with_capacity(8192); - let mut header_end_pos = 0; - - // Read with timeout to prevent hanging resource load ops. - let read_fut = async { - loop { - match stream.read(&mut buffer).await { - Ok(0) => break, - Ok(n) => { - let new_data = &buffer[..n]; - request_data.extend_from_slice(new_data); - - // Look for end of headers - if header_end_pos == 0 { - if let Some(pos) = find_header_end(&request_data) { - header_end_pos = pos; - break; - } - } - - // Avoid excessive buffering - if request_data.len() > 32768 { - break; - } - } - Err(e) => { - log::debug!("Error reading from stream: {e}"); - break; - } - } - } - }; - - let timeout_fut = sleep(Duration::from_secs(3)); - - tokio::select! { - () = read_fut => {}, - () = timeout_fut => { - log::debug!("Connection timeout from: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - } - - // Fast path for whitelisted IPs. Skip full parsing and speed up "approved" - // connections automatically. - if whitelisted { - log::debug!("Whitelisted IP {peer_addr} - using fast proxy path"); - proxy_fast_path(stream, request_data, &config.backend_addr).await; - return; - } - - // Parse minimally to extract the path - let path = if let Some(p) = extract_path_from_request(&request_data) { - p - } else { - log::debug!("Invalid request from {peer_addr}"); - let _ = stream.shutdown().await; - return; - }; - - // Extract request headers for Lua scripts - let headers = extract_all_headers(&request_data); - - // Extract user agent for logging and decision making - let user_agent = - extract_header_value(&request_data, "user-agent").unwrap_or_else(|| "unknown".to_string()); - - // Trigger request event for Lua scripts - let request_ctx = EventContext { - event_type: EventType::Request, - ip: Some(peer_addr.to_string()), - path: Some(path.to_string()), - user_agent: Some(user_agent.clone()), - request_headers: Some(headers.clone()), - content: None, - timestamp: get_timestamp(), - session_id: Some(generate_session_id(&peer_addr.to_string(), &user_agent)), - }; - script_manager.trigger_event(&request_ctx); - - // Check if this request matches our tarpit patterns - let should_tarpit = should_tarpit(path, &peer_addr, &config).await; - - if should_tarpit { - log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})"); - - // Update metrics - HITS_COUNTER.inc(); - PATH_HITS.with_label_values(&[path]).inc(); - UA_HITS.with_label_values(&[&user_agent]).inc(); - - // Update state and check for blocking threshold - { - let mut state = state.write().await; - state.active_connections.insert(peer_addr); - ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); - - *state.hits.entry(peer_addr).or_insert(0) += 1; - let hit_count = state.hits[&peer_addr]; - - // Use Lua to decide whether to block this IP - let should_block = script_manager.should_block_ip(&peer_addr.to_string(), hit_count); - - // Block IPs that hit tarpits too many times - if should_block && !state.blocked.contains(&peer_addr) { - log::info!("Blocking IP {peer_addr} after {hit_count} hits"); - state.blocked.insert(peer_addr); - BLOCKED_IPS.set(state.blocked.len() as f64); - state.save_to_disk(); - - // Do firewall blocking in background - let peer_addr_str = peer_addr.to_string(); - tokio::spawn(async move { - log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); - match Command::new("nft") - .args([ - "add", - "element", - "inet", - "filter", - "eris_blacklist", - "{", - &peer_addr_str, - "}", - ]) - .output() - .await - { - Ok(output) => { - if !output.status.success() { - log::warn!( - "Failed to add IP {} to firewall: {}", - peer_addr_str, - String::from_utf8_lossy(&output.stderr) - ); - } - } - Err(e) => { - log::warn!("Failed to execute nft command: {e}"); - } - } - }); - } - } - - // Generate a deceptive response using Markov chains and Lua - let response = generate_deceptive_response( - path, - &user_agent, - &peer_addr, - &headers, - &markov_generator, - &script_manager, - ) - .await; - - // Generate a session ID for tracking this tarpit session - let session_id = generate_session_id(&peer_addr.to_string(), &user_agent); - - // Send the response with the tarpit delay strategy - { - let mut stream = stream; - let peer_addr = peer_addr; - let state = state.clone(); - let min_delay = config.min_delay; - let max_delay = config.max_delay; - let max_tarpit_time = config.max_tarpit_time; - let script_manager = script_manager.clone(); - async move { - let start_time = Instant::now(); - let mut chars = response.chars().collect::>(); - for i in (0..chars.len()).rev() { - if i > 0 && rand::random::() < 0.1 { - chars.swap(i, i - 1); - } - } - log::debug!( - "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", - peer_addr, - chars.len(), - min_delay, - max_delay - ); - let mut position = 0; - let mut chunks_sent = 0; - let mut total_delay = 0; - while position < chars.len() { - // Check if we've exceeded maximum tarpit time - let elapsed_secs = start_time.elapsed().as_secs(); - if elapsed_secs > max_tarpit_time { - log::info!( - "Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}" - ); - break; - } - - // Decide how many chars to send in this chunk (usually 1, sometimes more) - let chunk_size = if rand::random::() < 0.9 { - 1 - } else { - (rand::random::() * 3.0).floor() as usize + 1 - }; - - let end = (position + chunk_size).min(chars.len()); - let chunk: String = chars[position..end].iter().collect(); - - // Process chunk through Lua before sending - let processed_chunk = - script_manager.process_chunk(&chunk, &peer_addr.to_string(), &session_id); - - // Try to write processed chunk - if stream.write_all(processed_chunk.as_bytes()).await.is_err() { - log::debug!("Connection closed by client during tarpit: {peer_addr}"); - break; - } - - if stream.flush().await.is_err() { - log::debug!("Failed to flush stream during tarpit: {peer_addr}"); - break; - } - - position = end; - chunks_sent += 1; - - // Apply random delay between min and max configured values - let delay_ms = - (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; - total_delay += delay_ms; - sleep(Duration::from_millis(delay_ms)).await; - } - log::debug!( - "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", - peer_addr, - chunks_sent, - position * 100 / chars.len(), - total_delay, - start_time.elapsed().as_secs() - ); - let disconnection_ctx = EventContext { - event_type: EventType::Disconnection, - ip: Some(peer_addr.to_string()), - path: None, - user_agent: None, - request_headers: None, - content: None, - timestamp: get_timestamp(), - session_id: Some(session_id), - }; - script_manager.trigger_event(&disconnection_ctx); - if let Ok(mut state) = state.try_write() { - state.active_connections.remove(&peer_addr); - ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); - } - let _ = stream.shutdown().await; - } - } - .await; - } else { - log::debug!("Proxying request: {path} from {peer_addr}"); - proxy_fast_path(stream, request_data, &config.backend_addr).await; - } -} - -// Forward a legitimate request to the real backend server -async fn proxy_fast_path(mut client_stream: TcpStream, request_data: Vec, backend_addr: &str) { - // Connect to backend server - let server_stream = match TcpStream::connect(backend_addr).await { - Ok(stream) => stream, - Err(e) => { - log::warn!("Failed to connect to backend {backend_addr}: {e}"); - let _ = client_stream.shutdown().await; - return; - } - }; - - // Set TCP_NODELAY for both streams before splitting them - if let Err(e) = client_stream.set_nodelay(true) { - log::debug!("Failed to set TCP_NODELAY on client stream: {e}"); - } - - let mut server_stream = server_stream; - if let Err(e) = server_stream.set_nodelay(true) { - log::debug!("Failed to set TCP_NODELAY on server stream: {e}"); - } - - // Forward the original request bytes directly without parsing - if server_stream.write_all(&request_data).await.is_err() { - log::debug!("Failed to write request to backend server"); - let _ = client_stream.shutdown().await; - return; - } - - // Now split the streams for concurrent reading/writing - let (mut client_read, mut client_write) = client_stream.split(); - let (mut server_read, mut server_write) = server_stream.split(); - - // 32KB buffer - let buf_size = 32768; - - // Client -> Server - let client_to_server = async { - let mut buf = vec![0; buf_size]; - let mut bytes_forwarded = 0; - - loop { - match client_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if server_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - // Ensure everything is sent - let _ = server_write.flush().await; - log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); - }; - - // Server -> Client - let server_to_client = async { - let mut buf = vec![0; buf_size]; - let mut bytes_forwarded = 0; - - loop { - match server_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if client_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - // Ensure everything is sent - let _ = client_write.flush().await; - log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); - }; - - // Run both directions concurrently - tokio::join!(client_to_server, server_to_client); - log::debug!("Fast proxy connection completed"); -} - -// Decide if a request should be tarpitted based on path and IP -async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { - // Check whitelist IPs first to avoid unnecessary pattern matching - for network_str in &config.whitelist_networks { - if let Ok(network) = network_str.parse::() { - if network.contains(*ip) { - return false; - } - } - } - - // Use pattern matching based on the trap pattern type (plain string or regex) - for pattern in &config.trap_patterns { - if pattern.matches(path) { - return true; - } - } - - false -} - -// Generate a deceptive HTTP response that appears legitimate -async fn generate_deceptive_response( - path: &str, - user_agent: &str, - peer_addr: &IpAddr, - headers: &HashMap, - markov: &MarkovGenerator, - script_manager: &ScriptManager, -) -> String { - // Generate base response using Markov chain text generator - let response_type = choose_response_type(path); - let markov_text = markov.generate(response_type, 30); - - // Use Lua scripts to enhance with honeytokens and other deceptive content - script_manager.generate_response( - path, - user_agent, - &peer_addr.to_string(), - headers, - &markov_text, - ) -} - -// Set up nftables firewall rules for IP blocking -async fn setup_firewall() -> Result<(), String> { - log::info!("Setting up firewall rules"); - - // Check if nft command exists - let nft_exists = Command::new("which") - .arg("nft") - .output() - .await - .map(|output| output.status.success()) - .unwrap_or(false); - - if !nft_exists { - log::warn!("nft command not found. Firewall rules will not be set up."); - return Ok(()); - } - - // Create table if it doesn't exist - let output = Command::new("nft") - .args(["list", "table", "inet", "filter"]) - .output() - .await; - - match output { - Ok(output) => { - if !output.status.success() { - log::info!("Creating nftables table"); - let result = Command::new("nft") - .args(["create", "table", "inet", "filter"]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if nftables table exists: {e}"); - log::info!("Will try to create it anyway"); - let result = Command::new("nft") - .args(["create", "table", "inet", "filter"]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); - } - } - } - - // Create blacklist set if it doesn't exist - let output = Command::new("nft") - .args(["list", "set", "inet", "filter", "eris_blacklist"]) - .output() - .await; - - match output { - Ok(output) => { - if !output.status.success() { - log::info!("Creating eris_blacklist set"); - let result = Command::new("nft") - .args([ - "create", - "set", - "inet", - "filter", - "eris_blacklist", - "{ type ipv4_addr; flags interval; }", - ]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create blacklist set: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if blacklist set exists: {e}"); - return Err(format!("Failed to check if blacklist set exists: {e}")); - } - } - - // Add rule to drop traffic from blacklisted IPs - let output = Command::new("nft") - .args(["list", "chain", "inet", "filter", "input"]) - .output() - .await; - - // Check if our rule already exists - match output { - Ok(output) => { - let rule_exists = String::from_utf8_lossy(&output.stdout) - .contains("ip saddr @eris_blacklist counter drop"); - - if !rule_exists { - log::info!("Adding drop rule for blacklisted IPs"); - let result = Command::new("nft") - .args([ - "add", - "rule", - "inet", - "filter", - "input", - "ip saddr @eris_blacklist", - "counter", - "drop", - ]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to add firewall rule: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if firewall rule exists: {e}"); - return Err(format!("Failed to check if firewall rule exists: {e}")); - } - } - - log::info!("Firewall setup complete"); - Ok(()) -} +use metrics::{metrics_handler, status_handler}; +use network::handle_connection; +use state::BotState; +use utils::get_timestamp; #[actix_web::main] async fn main() -> std::io::Result<()> { @@ -1228,7 +101,7 @@ async fn main() -> std::io::Result<()> { let config = Arc::new(config); // Setup firewall rules for IP blocking - match setup_firewall().await { + match firewall::setup_firewall().await { Ok(()) => {} Err(e) => { log::warn!("Failed to set up firewall rules: {e}"); @@ -1394,270 +267,3 @@ async fn main() -> std::io::Result<()> { } } } - -#[cfg(test)] -mod tests { - use super::*; - use std::net::{IpAddr, Ipv4Addr}; - use tokio::sync::RwLock; - - #[test] - fn test_config_from_args() { - let args = Args { - listen_addr: "127.0.0.1:8080".to_string(), - metrics_addr: "127.0.0.1:9000".to_string(), - disable_metrics: true, - backend_addr: "127.0.0.1:8081".to_string(), - min_delay: 500, - max_delay: 10000, - max_tarpit_time: 300, - block_threshold: 5, - base_dir: Some(PathBuf::from("/tmp/eris")), - config_file: None, - log_level: "debug".to_string(), - }; - - let config = Config::from_args(&args); - assert_eq!(config.listen_addr, "127.0.0.1:8080"); - assert_eq!(config.metrics_addr, "127.0.0.1:9000"); - assert!(config.disable_metrics); - assert_eq!(config.backend_addr, "127.0.0.1:8081"); - assert_eq!(config.min_delay, 500); - assert_eq!(config.max_delay, 10000); - assert_eq!(config.max_tarpit_time, 300); - assert_eq!(config.block_threshold, 5); - assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora"); - assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts"); - assert_eq!(config.data_dir, "/tmp/eris/data"); - assert_eq!(config.config_dir, "/tmp/eris/conf"); - assert_eq!(config.cache_dir, "/tmp/eris/cache"); - } - - #[test] - fn test_trap_pattern_matching() { - // Test plain string pattern - let plain = TrapPattern::as_plain("phpunit"); - assert!(plain.matches("path/to/phpunit/test")); - assert!(!plain.matches("path/to/something/else")); - - // Test regex pattern - let regex = TrapPattern::as_regex(r".*eval-stdin\.php.*"); - assert!(regex.matches("/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php")); - assert!(regex.matches("/tests/eval-stdin.php?param")); - assert!(!regex.matches("/normal/path")); - - // Test invalid regex pattern (should return false) - let invalid = TrapPattern::Regex { - pattern: "(invalid[regex".to_string(), - regex: true, - }; - assert!(!invalid.matches("anything")); - } - - #[tokio::test] - async fn test_should_tarpit() { - let config = Config::default(); - - // Test trap patterns - assert!( - should_tarpit( - "/vendor/phpunit/whatever", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - assert!( - should_tarpit( - "/wp-admin/login.php", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - assert!(should_tarpit("/.env", &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), &config).await); - - // Test whitelist networks - assert!( - !should_tarpit( - "/wp-admin/login.php", - &IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - &config - ) - .await - ); - assert!( - !should_tarpit( - "/vendor/phpunit/whatever", - &IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), - &config - ) - .await - ); - - // Test legitimate paths - assert!( - !should_tarpit( - "/index.html", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - assert!( - !should_tarpit( - "/images/logo.png", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - - // Test regex patterns - assert!( - should_tarpit( - "/index.php?s=/index/\\think\\app/invokefunction&function=call_user_func_array&vars[0]=md5&vars[1][]=Hello", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - - assert!( - should_tarpit( - "/hello.world?%ADd+allow_url_include%3d1+%ADd+auto_prepend_file%3dphp://input", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - } - - #[tokio::test] - async fn test_bot_state() { - let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache"); - let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); - let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)); - - let state = Arc::new(RwLock::new(state)); - - // Test hit counter - { - let mut state = state.write().await; - *state.hits.entry(ip1).or_insert(0) += 1; - *state.hits.entry(ip1).or_insert(0) += 1; - *state.hits.entry(ip2).or_insert(0) += 1; - - assert_eq!(*state.hits.get(&ip1).unwrap(), 2); - assert_eq!(*state.hits.get(&ip2).unwrap(), 1); - } - - // Test blocking - { - let mut state = state.write().await; - state.blocked.insert(ip1); - assert!(state.blocked.contains(&ip1)); - assert!(!state.blocked.contains(&ip2)); - drop(state); - } - - // Test active connections - { - let mut state = state.write().await; - state.active_connections.insert(ip1); - state.active_connections.insert(ip2); - assert_eq!(state.active_connections.len(), 2); - - state.active_connections.remove(&ip1); - assert_eq!(state.active_connections.len(), 1); - assert!(!state.active_connections.contains(&ip1)); - assert!(state.active_connections.contains(&ip2)); - drop(state); - } - } - - #[test] - fn test_find_header_end() { - let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content"; - assert_eq!(find_header_end(data), Some(55)); - - let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n"; - assert_eq!(find_header_end(incomplete), None); - } - - #[test] - fn test_extract_path_from_request() { - let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n"; - assert_eq!(extract_path_from_request(data), Some("/index.html")); - - let bad_data = b"INVALID DATA"; - assert_eq!(extract_path_from_request(bad_data), None); - } - - #[test] - fn test_extract_header_value() { - let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n"; - assert_eq!( - extract_header_value(data, "user-agent"), - Some("TestBot/1.0".to_string()) - ); - assert_eq!( - extract_header_value(data, "Host"), - Some("example.com".to_string()) - ); - assert_eq!(extract_header_value(data, "nonexistent"), None); - } - - #[test] - fn test_extract_all_headers() { - let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\nAccept: */*\r\n\r\n"; - let headers = extract_all_headers(data); - - assert_eq!(headers.len(), 3); - assert_eq!(headers.get("host").unwrap(), "example.com"); - assert_eq!(headers.get("user-agent").unwrap(), "TestBot/1.0"); - assert_eq!(headers.get("accept").unwrap(), "*/*"); - } - - #[test] - fn test_choose_response_type() { - assert_eq!( - choose_response_type("/vendor/phpunit/whatever"), - "php_exploit" - ); - assert_eq!( - choose_response_type("/path/to/eval-stdin.php"), - "php_exploit" - ); - assert_eq!(choose_response_type("/wp-admin/login.php"), "wordpress"); - assert_eq!(choose_response_type("/wp-login.php"), "wordpress"); - assert_eq!(choose_response_type("/api/v1/users"), "api"); - assert_eq!(choose_response_type("/index.html"), "generic"); - } - - #[test] - fn test_config_file_formats() { - // Create temporary JSON config file - let temp_dir = std::env::temp_dir(); - let json_path = temp_dir.join("temp_config.json"); - let toml_path = temp_dir.join("temp_config.toml"); - - let config = Config::default(); - - // Test JSON serialization and deserialization - config.save_to_file(&json_path).unwrap(); - let loaded_json = Config::load_from_file(&json_path).unwrap(); - assert_eq!(loaded_json.listen_addr, config.listen_addr); - assert_eq!(loaded_json.min_delay, config.min_delay); - - // Test TOML serialization and deserialization - config.save_to_file(&toml_path).unwrap(); - let loaded_toml = Config::load_from_file(&toml_path).unwrap(); - assert_eq!(loaded_toml.listen_addr, config.listen_addr); - assert_eq!(loaded_toml.min_delay, config.min_delay); - - // Clean up - let _ = std::fs::remove_file(json_path); - let _ = std::fs::remove_file(toml_path); - } -} diff --git a/src/network.rs b/src/network.rs new file mode 100644 index 0000000..56d6c65 --- /dev/null +++ b/src/network.rs @@ -0,0 +1,443 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::process::Command; +use tokio::sync::RwLock; +use tokio::time::sleep; + +use crate::config::Config; +use crate::lua::{EventContext, EventType, ScriptManager}; +use crate::markov::MarkovGenerator; +use crate::metrics::{ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS}; +use crate::state::BotState; +use crate::utils::{ + choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request, + find_header_end, generate_session_id, get_timestamp, +}; + +// Main connection handler. +// Decides whether to tarpit or proxy +pub async fn handle_connection( + mut stream: TcpStream, + config: Arc, + state: Arc>, + markov_generator: Arc, + script_manager: Arc, +) { + // Get peer information + let peer_addr = match stream.peer_addr() { + Ok(addr) => addr.ip(), + Err(e) => { + log::debug!("Failed to get peer address: {e}"); + return; + } + }; + + // Check for blocked IPs to avoid any processing + if state.read().await.blocked.contains(&peer_addr) { + log::debug!("Rejected connection from blocked IP: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + + // Check if Lua scripts allow this connection + if !script_manager.on_connection(&peer_addr.to_string()) { + log::debug!("Connection rejected by Lua script: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + + // Pre-check for whitelisted IPs to bypass heavy processing + let mut whitelisted = false; + for network_str in &config.whitelist_networks { + if let Ok(network) = network_str.parse::() { + if network.contains(peer_addr) { + whitelisted = true; + break; + } + } + } + + // Read buffer + let mut buffer = vec![0; 8192]; + let mut request_data = Vec::with_capacity(8192); + let mut header_end_pos = 0; + + // Read with timeout to prevent hanging resource load ops. + let read_fut = async { + loop { + match stream.read(&mut buffer).await { + Ok(0) => break, + Ok(n) => { + let new_data = &buffer[..n]; + request_data.extend_from_slice(new_data); + + // Look for end of headers + if header_end_pos == 0 { + if let Some(pos) = find_header_end(&request_data) { + header_end_pos = pos; + break; + } + } + + // Avoid excessive buffering + if request_data.len() > 32768 { + break; + } + } + Err(e) => { + log::debug!("Error reading from stream: {e}"); + break; + } + } + } + }; + + let timeout_fut = sleep(Duration::from_secs(3)); + + tokio::select! { + () = read_fut => {}, + () = timeout_fut => { + log::debug!("Connection timeout from: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + } + + // Fast path for whitelisted IPs. Skip full parsing and speed up "approved" + // connections automatically. + if whitelisted { + log::debug!("Whitelisted IP {peer_addr} - using fast proxy path"); + proxy_fast_path(stream, request_data, &config.backend_addr).await; + return; + } + + // Parse minimally to extract the path + let path = if let Some(p) = extract_path_from_request(&request_data) { + p + } else { + log::debug!("Invalid request from {peer_addr}"); + let _ = stream.shutdown().await; + return; + }; + + // Extract request headers for Lua scripts + let headers = extract_all_headers(&request_data); + + // Extract user agent for logging and decision making + let user_agent = + extract_header_value(&request_data, "user-agent").unwrap_or_else(|| "unknown".to_string()); + + // Trigger request event for Lua scripts + let request_ctx = EventContext { + event_type: EventType::Request, + ip: Some(peer_addr.to_string()), + path: Some(path.to_string()), + user_agent: Some(user_agent.clone()), + request_headers: Some(headers.clone()), + content: None, + timestamp: get_timestamp(), + session_id: Some(generate_session_id(&peer_addr.to_string(), &user_agent)), + }; + script_manager.trigger_event(&request_ctx); + + // Check if this request matches our tarpit patterns + let should_tarpit = crate::config::should_tarpit(path, &peer_addr, &config); + + if should_tarpit { + log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})"); + + // Update metrics + HITS_COUNTER.inc(); + PATH_HITS.with_label_values(&[path]).inc(); + UA_HITS.with_label_values(&[&user_agent]).inc(); + + // Update state and check for blocking threshold + { + let mut state = state.write().await; + state.active_connections.insert(peer_addr); + ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); + + *state.hits.entry(peer_addr).or_insert(0) += 1; + let hit_count = state.hits[&peer_addr]; + + // Use Lua to decide whether to block this IP + let should_block = script_manager.should_block_ip(&peer_addr.to_string(), hit_count); + + // Block IPs that hit tarpits too many times + if should_block && !state.blocked.contains(&peer_addr) { + log::info!("Blocking IP {peer_addr} after {hit_count} hits"); + state.blocked.insert(peer_addr); + BLOCKED_IPS.set(state.blocked.len() as f64); + state.save_to_disk(); + + // Do firewall blocking in background + let peer_addr_str = peer_addr.to_string(); + tokio::spawn(async move { + log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); + match Command::new("nft") + .args([ + "add", + "element", + "inet", + "filter", + "eris_blacklist", + "{", + &peer_addr_str, + "}", + ]) + .output() + .await + { + Ok(output) => { + if !output.status.success() { + log::warn!( + "Failed to add IP {} to firewall: {}", + peer_addr_str, + String::from_utf8_lossy(&output.stderr) + ); + } + } + Err(e) => { + log::warn!("Failed to execute nft command: {e}"); + } + } + }); + } + } + + // Generate a deceptive response using Markov chains and Lua + let response = generate_deceptive_response( + path, + &user_agent, + &peer_addr, + &headers, + &markov_generator, + &script_manager, + ) + .await; + + // Generate a session ID for tracking this tarpit session + let session_id = generate_session_id(&peer_addr.to_string(), &user_agent); + + // Send the response with the tarpit delay strategy + { + let mut stream = stream; + let peer_addr = peer_addr; + let state = state.clone(); + let min_delay = config.min_delay; + let max_delay = config.max_delay; + let max_tarpit_time = config.max_tarpit_time; + let script_manager = script_manager.clone(); + async move { + let start_time = Instant::now(); + let mut chars = response.chars().collect::>(); + for i in (0..chars.len()).rev() { + if i > 0 && rand::random::() < 0.1 { + chars.swap(i, i - 1); + } + } + log::debug!( + "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", + peer_addr, + chars.len(), + min_delay, + max_delay + ); + let mut position = 0; + let mut chunks_sent = 0; + let mut total_delay = 0; + while position < chars.len() { + // Check if we've exceeded maximum tarpit time + let elapsed_secs = start_time.elapsed().as_secs(); + if elapsed_secs > max_tarpit_time { + log::info!( + "Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}" + ); + break; + } + + // Decide how many chars to send in this chunk (usually 1, sometimes more) + let chunk_size = if rand::random::() < 0.9 { + 1 + } else { + (rand::random::() * 3.0).floor() as usize + 1 + }; + + let end = (position + chunk_size).min(chars.len()); + let chunk: String = chars[position..end].iter().collect(); + + // Process chunk through Lua before sending + let processed_chunk = + script_manager.process_chunk(&chunk, &peer_addr.to_string(), &session_id); + + // Try to write processed chunk + if stream.write_all(processed_chunk.as_bytes()).await.is_err() { + log::debug!("Connection closed by client during tarpit: {peer_addr}"); + break; + } + + if stream.flush().await.is_err() { + log::debug!("Failed to flush stream during tarpit: {peer_addr}"); + break; + } + + position = end; + chunks_sent += 1; + + // Apply random delay between min and max configured values + let delay_ms = + (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; + total_delay += delay_ms; + sleep(Duration::from_millis(delay_ms)).await; + } + log::debug!( + "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", + peer_addr, + chunks_sent, + position * 100 / chars.len(), + total_delay, + start_time.elapsed().as_secs() + ); + let disconnection_ctx = EventContext { + event_type: EventType::Disconnection, + ip: Some(peer_addr.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: Some(session_id), + }; + script_manager.trigger_event(&disconnection_ctx); + if let Ok(mut state) = state.try_write() { + state.active_connections.remove(&peer_addr); + ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); + } + let _ = stream.shutdown().await; + } + } + .await; + } else { + log::debug!("Proxying request: {path} from {peer_addr}"); + proxy_fast_path(stream, request_data, &config.backend_addr).await; + } +} + +// Forward a legitimate request to the real backend server +pub async fn proxy_fast_path( + mut client_stream: TcpStream, + request_data: Vec, + backend_addr: &str, +) { + // Connect to backend server + let server_stream = match TcpStream::connect(backend_addr).await { + Ok(stream) => stream, + Err(e) => { + log::warn!("Failed to connect to backend {backend_addr}: {e}"); + let _ = client_stream.shutdown().await; + return; + } + }; + + // Set TCP_NODELAY for both streams before splitting them + if let Err(e) = client_stream.set_nodelay(true) { + log::debug!("Failed to set TCP_NODELAY on client stream: {e}"); + } + + let mut server_stream = server_stream; + if let Err(e) = server_stream.set_nodelay(true) { + log::debug!("Failed to set TCP_NODELAY on server stream: {e}"); + } + + // Forward the original request bytes directly without parsing + if server_stream.write_all(&request_data).await.is_err() { + log::debug!("Failed to write request to backend server"); + let _ = client_stream.shutdown().await; + return; + } + + // Now split the streams for concurrent reading/writing + let (mut client_read, mut client_write) = client_stream.split(); + let (mut server_read, mut server_write) = server_stream.split(); + + // 32KB buffer + let buf_size = 32768; + + // Client -> Server + let client_to_server = async { + let mut buf = vec![0; buf_size]; + let mut bytes_forwarded = 0; + + loop { + match client_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if server_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + // Ensure everything is sent + let _ = server_write.flush().await; + log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); + }; + + // Server -> Client + let server_to_client = async { + let mut buf = vec![0; buf_size]; + let mut bytes_forwarded = 0; + + loop { + match server_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + // Ensure everything is sent + let _ = client_write.flush().await; + log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); + }; + + // Run both directions concurrently + tokio::join!(client_to_server, server_to_client); + log::debug!("Fast proxy connection completed"); +} + +// Generate a deceptive HTTP response that appears legitimate +pub async fn generate_deceptive_response( + path: &str, + user_agent: &str, + peer_addr: &IpAddr, + headers: &HashMap, + markov: &MarkovGenerator, + script_manager: &ScriptManager, +) -> String { + // Generate base response using Markov chain text generator + let response_type = choose_response_type(path); + let markov_text = markov.generate(response_type, 30); + + // Use Lua scripts to enhance with honeytokens and other deceptive content + script_manager.generate_response( + path, + user_agent, + &peer_addr.to_string(), + headers, + &markov_text, + ) +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..a391401 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,164 @@ +use std::collections::{HashMap, HashSet}; +use std::fs; +use std::io::Write; +use std::net::IpAddr; + +use crate::metrics::BLOCKED_IPS; + +// State of bots/IPs hitting the honeypot +#[derive(Clone, Debug)] +pub struct BotState { + pub hits: HashMap, + pub blocked: HashSet, + pub active_connections: HashSet, + pub data_dir: String, + pub cache_dir: String, +} + +impl BotState { + pub fn new(data_dir: &str, cache_dir: &str) -> Self { + Self { + hits: HashMap::new(), + blocked: HashSet::new(), + active_connections: HashSet::new(), + data_dir: data_dir.to_string(), + cache_dir: cache_dir.to_string(), + } + } + + // Load previous state from disk + pub fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { + let mut state = Self::new(data_dir, cache_dir); + let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); + + if let Ok(content) = fs::read_to_string(&blocked_ips_file) { + let mut loaded = 0; + for line in content.lines() { + if let Ok(ip) = line.parse::() { + state.blocked.insert(ip); + loaded += 1; + } + } + log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); + } else { + log::info!("No blocked IPs file found at {blocked_ips_file}"); + } + + // Check for temporary hit counter cache + let hit_cache_file = format!("{cache_dir}/hit_counters.json"); + if let Ok(content) = fs::read_to_string(&hit_cache_file) { + if let Ok(hit_map) = serde_json::from_str::>(&content) { + for (ip_str, count) in hit_map { + if let Ok(ip) = ip_str.parse::() { + state.hits.insert(ip, count); + } + } + log::info!("Loaded hit counters for {} IPs", state.hits.len()); + } + } + + BLOCKED_IPS.set(state.blocked.len() as f64); + state + } + + // Persist state to disk for later reloading + pub fn save_to_disk(&self) { + // Save blocked IPs + if let Err(e) = fs::create_dir_all(&self.data_dir) { + log::error!("Failed to create data directory: {e}"); + return; + } + + let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); + + match fs::File::create(&blocked_ips_file) { + Ok(mut file) => { + let mut count = 0; + for ip in &self.blocked { + if writeln!(file, "{ip}").is_ok() { + count += 1; + } + } + log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); + } + Err(e) => { + log::error!("Failed to create blocked IPs file: {e}"); + } + } + + // Save hit counters to cache + if let Err(e) = fs::create_dir_all(&self.cache_dir) { + log::error!("Failed to create cache directory: {e}"); + return; + } + + let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); + let hit_map: HashMap = self + .hits + .iter() + .map(|(ip, count)| (ip.to_string(), *count)) + .collect(); + + match fs::File::create(&hit_cache_file) { + Ok(file) => { + if let Err(e) = serde_json::to_writer(file, &hit_map) { + log::error!("Failed to write hit counters to cache: {e}"); + } else { + log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); + } + } + Err(e) => { + log::error!("Failed to create hit counter cache file: {e}"); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + use std::sync::Arc; + use tokio::sync::RwLock; + + #[tokio::test] + async fn test_bot_state() { + let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache"); + let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)); + + let state = Arc::new(RwLock::new(state)); + + // Test hit counter + { + let mut state = state.write().await; + *state.hits.entry(ip1).or_insert(0) += 1; + *state.hits.entry(ip1).or_insert(0) += 1; + *state.hits.entry(ip2).or_insert(0) += 1; + + assert_eq!(*state.hits.get(&ip1).unwrap(), 2); + assert_eq!(*state.hits.get(&ip2).unwrap(), 1); + } + + // Test blocking + { + let mut state = state.write().await; + state.blocked.insert(ip1); + assert!(state.blocked.contains(&ip1)); + assert!(!state.blocked.contains(&ip2)); + } + + // Test active connections + { + let mut state = state.write().await; + state.active_connections.insert(ip1); + state.active_connections.insert(ip2); + assert_eq!(state.active_connections.len(), 2); + + state.active_connections.remove(&ip1); + assert_eq!(state.active_connections.len(), 1); + assert!(!state.active_connections.contains(&ip1)); + assert!(state.active_connections.contains(&ip2)); + } + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..c925565 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,168 @@ +use std::collections::HashMap; +use std::hash::Hasher; + +// Find end of HTTP headers +pub fn find_header_end(data: &[u8]) -> Option { + data.windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|pos| pos + 4) +} + +// Extract path from raw request data +pub fn extract_path_from_request(data: &[u8]) -> Option<&str> { + // Get first line from request + let first_line = data + .split(|&b| b == b'\r' || b == b'\n') + .next() + .filter(|line| !line.is_empty())?; + + // Split by spaces and ensure we have at least 3 parts (METHOD PATH VERSION) + let parts: Vec<&[u8]> = first_line.split(|&b| b == b' ').collect(); + if parts.len() < 3 || !parts[2].starts_with(b"HTTP/") { + return None; + } + + // Return the path (second element) + std::str::from_utf8(parts[1]).ok() +} + +// Extract header value from raw request data +pub fn extract_header_value(data: &[u8], header_name: &str) -> Option { + let data_str = std::str::from_utf8(data).ok()?; + let header_prefix = format!("{header_name}: ").to_lowercase(); + + for line in data_str.lines() { + let line_lower = line.to_lowercase(); + if line_lower.starts_with(&header_prefix) { + return Some(line[header_prefix.len()..].trim().to_string()); + } + } + None +} + +// Extract all headers from request data +pub fn extract_all_headers(data: &[u8]) -> HashMap { + let mut headers = HashMap::new(); + + if let Ok(data_str) = std::str::from_utf8(data) { + let mut lines = data_str.lines(); + + // Skip the request line + let _ = lines.next(); + + // Parse headers until empty line + for line in lines { + if line.is_empty() { + break; + } + + if let Some(colon_pos) = line.find(':') { + let key = line[..colon_pos].trim().to_lowercase(); + let value = line[colon_pos + 1..].trim().to_string(); + headers.insert(key, value); + } + } + } + + headers +} + +// Determine response type based on request path +pub fn choose_response_type(path: &str) -> &'static str { + if path.contains("phpunit") || path.contains("eval") { + "php_exploit" + } else if path.contains("wp-") { + "wordpress" + } else if path.contains("api") { + "api" + } else { + "generic" + } +} + +// Get current timestamp in seconds +pub fn get_timestamp() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +// Create a unique session ID for tracking a connection +pub fn generate_session_id(ip: &str, user_agent: &str) -> String { + let timestamp = get_timestamp(); + let random = rand::random::(); + + // XXX: Is this fast enough for our case? I don't think hashing is a huge + // bottleneck, but it's worth revisiting in the future to see if there is + // an objectively faster algorithm that we can try. + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + std::hash::Hash::hash(&format!("{ip}_{user_agent}_{timestamp}"), &mut hasher); + let hash = hasher.finish(); + + format!("SID_{hash:x}_{random:x}") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_header_end() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content"; + assert_eq!(find_header_end(data), Some(55)); + + let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n"; + assert_eq!(find_header_end(incomplete), None); + } + + #[test] + fn test_extract_path_from_request() { + let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert_eq!(extract_path_from_request(data), Some("/index.html")); + + let bad_data = b"INVALID DATA"; + assert_eq!(extract_path_from_request(bad_data), None); + } + + #[test] + fn test_extract_header_value() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n"; + assert_eq!( + extract_header_value(data, "user-agent"), + Some("TestBot/1.0".to_string()) + ); + assert_eq!( + extract_header_value(data, "Host"), + Some("example.com".to_string()) + ); + assert_eq!(extract_header_value(data, "nonexistent"), None); + } + + #[test] + fn test_extract_all_headers() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\nAccept: */*\r\n\r\n"; + let headers = extract_all_headers(data); + + assert_eq!(headers.len(), 3); + assert_eq!(headers.get("host").unwrap(), "example.com"); + assert_eq!(headers.get("user-agent").unwrap(), "TestBot/1.0"); + assert_eq!(headers.get("accept").unwrap(), "*/*"); + } + + #[test] + fn test_choose_response_type() { + assert_eq!( + choose_response_type("/vendor/phpunit/whatever"), + "php_exploit" + ); + assert_eq!( + choose_response_type("/path/to/eval-stdin.php"), + "php_exploit" + ); + assert_eq!(choose_response_type("/wp-admin/login.php"), "wordpress"); + assert_eq!(choose_response_type("/wp-login.php"), "wordpress"); + assert_eq!(choose_response_type("/api/v1/users"), "api"); + assert_eq!(choose_response_type("/index.html"), "generic"); + } +}