diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml deleted file mode 100644 index d3ce8ab..0000000 --- a/.github/workflows/tag.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: Tag latest version - -on: - workflow_dispatch: - push: - branches: [ main ] - -concurrency: tag - -jobs: - tag-release: - runs-on: ubuntu-latest - steps: - - uses: cachix/install-nix-action@master - with: - github_access_token: ${{ secrets.GITHUB_TOKEN }} - - - name: Checkout - uses: actions/checkout@v4 - - - name: Read version - run: | - echo -n "_version=v" >> "$GITHUB_ENV" - nix run nixpkgs#fq -- -r ".package.version" Cargo.toml >> "$GITHUB_ENV" - cat "$GITHUB_ENV" - - - name: Tag - run: | - set -x - git tag $version - git push --tags || : diff --git a/Cargo.lock b/Cargo.lock index 2eaaf07..3ab4b04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -414,7 +414,9 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", "windows-link", ] @@ -440,6 +442,7 @@ version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2" dependencies = [ + "anstream", "anstyle", "clap_lex", "strsim", @@ -617,7 +620,7 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "eris" -version = "1.0.0" +version = "0.1.0" dependencies = [ "actix-web", "chrono", @@ -630,13 +633,10 @@ dependencies = [ "prometheus 0.14.0", "prometheus_exporter", "rand", - "regex", "rlua", "serde", "serde_json", - "tempfile", "tokio", - "toml", ] [[package]] @@ -646,15 +646,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] -[[package]] -name = "fastrand" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" - [[package]] name = "flate2" version = "1.1.1" @@ -1586,7 +1580,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -1639,15 +1633,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_spanned" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" -dependencies = [ - "serde", -] - [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1755,19 +1740,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "tempfile" -version = "3.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" -dependencies = [ - "fastrand", - "getrandom", - "once_cell", - "rustix", - "windows-sys 0.59.0", -] - [[package]] name = "thiserror" version = "1.0.69" @@ -1904,47 +1876,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "toml" -version = "0.8.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" -dependencies = [ - "indexmap", - "serde", - "serde_spanned", - "toml_datetime", - "toml_write", - "winnow", -] - -[[package]] -name = "toml_write" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" - [[package]] name = "tracing" version = "0.1.41" @@ -2256,15 +2187,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "winnow" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e27d6ad3dac991091e4d35de9ba2d2d00647c5d0fc26c5496dee55984ae111b" -dependencies = [ - "memchr", -] - [[package]] name = "winsafe" version = "0.0.19" diff --git a/Cargo.toml b/Cargo.toml index 42cc24b..2292db4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "eris" -version = "1.0.0" +version = "0.1.0" edition = "2024" [dependencies] -actix-web = { version = "4.3.1" } -chrono = { version = "0.4.41", default-features = false, features = ["std", "clock"] } -clap = { version = "4.5", default-features = false, features = ["std", "derive", "help", "usage", "suggestions"] } +actix-web = "4.3.1" +clap = { version = "4.3", features = ["derive"] } +chrono = "0.4.24" futures = "0.3.28" ipnetwork = "0.21.1" lazy_static = "1.4.0" @@ -19,6 +19,3 @@ serde_json = "1.0.96" tokio = { version = "1.28.0", features = ["full"] } log = "0.4.27" env_logger = "0.11.8" -tempfile = "3.19.1" -regex = "1.11.1" -toml = "0.8.22" diff --git a/nix/package.nix b/nix/package.nix index aa6e975..b227e1a 100644 --- a/nix/package.nix +++ b/nix/package.nix @@ -19,7 +19,6 @@ in fileset = fs.unions [ (fs.fileFilter (file: builtins.any file.hasExt ["rs"]) (s + /src)) (s + /contrib) - (s + /resources) lockfile cargoToml ]; diff --git a/resources/default_script.lua b/resources/default_script.lua deleted file mode 100644 index f2f51fd..0000000 --- a/resources/default_script.lua +++ /dev/null @@ -1,210 +0,0 @@ ---[[ -Eris Default Script - -This script demonstrates how to use the Eris Lua API to customize -the tarpit's behavior, and will be loaded by default if no other -scripts are loaded. - -Available events: -- connection: When a new connection is established -- request: When a request is received -- response_gen: When generating a response -- response_chunk: Before sending each response chunk -- disconnection: When a connection is closed -- block_ip: When an IP is being considered for blocking -- startup: When the application starts -- shutdown: When the application is shutting down -- periodic: Called periodically - -API Functions: -- eris.debug(message): Log a debug message -- eris.info(message): Log an info message -- eris.warn(message): Log a warning message -- eris.error(message): Log an error message -- eris.set_state(key, value): Store persistent state -- eris.get_state(key): Retrieve persistent state -- eris.inc_counter(key, [amount]): Increment a counter -- eris.get_counter(key): Get a counter value -- eris.gen_token([prefix]): Generate a unique token -- eris.timestamp(): Get current Unix timestamp ---]] - --- Called when the application starts -eris.on("startup", function(ctx) - eris.info("Initializing default script") - - -- Initialize counters - eris.inc_counter("total_connections", 0) - eris.inc_counter("total_responses", 0) - eris.inc_counter("blocked_ips", 0) - - -- Initialize banned keywords - eris.set_state("banned_keywords", "eval,exec,system,shell," - .. "\n" - .. "\n" - elseif ctx.path:find("phpunit") or ctx.path:find("eval") then - -- For PHP exploit attempts - -- Turns out you can just google "PHP error log" and search random online forums where people - -- dump their service logs in full. - enhanced_content = enhanced_content - .. "\nPHP Notice: Undefined variable: _SESSION in /var/www/html/includes/core.php on line 58\n" - .. "Warning: file_get_contents(): Filename cannot be empty in /var/www/html/vendor/autoload.php on line 23\n" - .. "Token: " - .. token - .. "\n" - elseif ctx.path:find("api") then - -- For API requests - local fake_api_key = - string.format("ak_%x%x%x", math.random(1000, 9999), math.random(1000, 9999), math.random(1000, 9999)) - - enhanced_content = enhanced_content - .. "{\n" - .. ' "status": "warning",\n' - .. ' "message": "Test API environment detected",\n' - .. ' "debug_token": "' - .. token - .. '",\n' - .. ' "api_key": "' - .. fake_api_key - .. '"\n' - .. "}\n" - else - -- For other requests - enhanced_content = enhanced_content - .. "\n" - .. "\n" - .. "\n" - end - - -- Track which honeytokens were sent to which IP - local honeytokens = eris.get_state("honeytokens") or "{}" - local ht_table = {} - - -- This is a simplistic approach - in a real script, you'd want to use - -- a proper JSON library to handle this correctly - if honeytokens ~= "{}" then - -- Simple parsing of the stored data - for ip, tok in honeytokens:gmatch('"([^"]+)":"([^"]+)"') do - ht_table[ip] = tok - end - end - - ht_table[ctx.ip] = token - - -- Convert back to a simple JSON-like string - local new_tokens = "{" - for ip, tok in pairs(ht_table) do - if new_tokens ~= "{" then - new_tokens = new_tokens .. "," - end - new_tokens = new_tokens .. '"' .. ip .. '":"' .. tok .. '"' - end - new_tokens = new_tokens .. "}" - - eris.set_state("honeytokens", new_tokens) - - return enhanced_content -end) - --- Called before sending each chunk of a response -eris.on("response_chunk", function(ctx) - -- This can be used to alter individual chunks for more deceptive behavior - -- For example, to simulate a slow, unreliable server - - -- 5% chance of "corrupting" a chunk to confuse scanners - if math.random(1, 100) <= 5 then - local chunk = ctx.content - if #chunk > 10 then - local pos = math.random(1, #chunk - 5) - chunk = chunk:sub(1, pos) .. string.char(math.random(32, 126)) .. chunk:sub(pos + 2) - end - return chunk - end - - return ctx.content -end) - --- Called when deciding whether to block an IP -eris.on("block_ip", function(ctx) - -- You can override the default blocking logic - - -- Check for potential attackers using specific patterns - local banned_keywords = eris.get_state("banned_keywords") or "" - local user_agent = ctx.user_agent or "" - - -- Check if user agent contains highly suspicious patterns - for keyword in banned_keywords:gmatch("[^,]+") do - if user_agent:lower():find(keyword:lower()) then - eris.info("Blocking IP " .. ctx.ip .. " due to suspicious user agent: " .. keyword) - eris.inc_counter("blocked_ips") - return true -- Force block - end - end - - -- For demonstration, we'll be more lenient with 10.x IPs - if ctx.ip:match("^10%.") then - -- Only block if they've hit us many times - return ctx.hit_count >= 5 - end - - -- Default to the system's threshold-based decision - return nil -end) - --- The enhance_response is now legacy, and I never liked it anyway. Though let's add it here --- for the sake of backwards compatibility. -function enhance_response(text, response_type, path, token) - local enhanced = text - - -- Add token as a comment - if response_type == "php_exploit" then - enhanced = enhanced .. "\n/* Token: " .. token .. " */\n" - elseif response_type == "wordpress" then - enhanced = enhanced .. "\n\n" - elseif response_type == "api" then - enhanced = enhanced:gsub('"status": "[^"]+"', '"status": "warning"') - enhanced = enhanced:gsub('"message": "[^"]+"', '"message": "API token: ' .. token .. '"') - else - enhanced = enhanced .. "\n\n" - end - - return enhanced -end diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index 74a1571..0000000 --- a/src/config.rs +++ /dev/null @@ -1,613 +0,0 @@ -use clap::Parser; -use ipnetwork::IpNetwork; -use regex::Regex; -use serde::{Deserialize, Serialize}; -use std::env; -use std::fs; -use std::net::IpAddr; -use std::path::{Path, PathBuf}; - -// Command-line arguments using clap -#[derive(Parser, Debug, Clone)] -#[clap( - author, - version, - about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" -)] -pub struct Args { - #[clap( - long, - default_value = "0.0.0.0:8888", - help = "Address and port to listen for incoming HTTP requests (format: ip:port)" - )] - pub listen_addr: String, - - #[clap( - long, - default_value = "0.0.0.0:9100", - help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)" - )] - pub metrics_addr: String, - - #[clap(long, help = "Disable Prometheus metrics server completely")] - pub disable_metrics: bool, - - #[clap( - long, - default_value = "127.0.0.1:80", - help = "Backend server address to proxy legitimate requests to (format: ip:port)" - )] - pub backend_addr: String, - - #[clap( - long, - default_value = "1000", - help = "Minimum delay in milliseconds between chunks sent to attacker" - )] - pub min_delay: u64, - - #[clap( - long, - default_value = "15000", - help = "Maximum delay in milliseconds between chunks sent to attacker" - )] - pub max_delay: u64, - - #[clap( - long, - default_value = "600", - help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" - )] - pub max_tarpit_time: u64, - - #[clap( - long, - default_value = "3", - help = "Number of hits to honeypot patterns before permanently blocking an IP" - )] - pub block_threshold: u32, - - #[clap( - long, - help = "Base directory for all application data (overrides XDG directory structure)" - )] - pub base_dir: Option, - - #[clap( - long, - help = "Path to configuration file (JSON or TOML, overrides command line options)" - )] - pub config_file: Option, - - #[clap( - long, - default_value = "info", - help = "Log level: trace, debug, info, warn, error" - )] - pub log_level: String, - - #[clap( - long, - default_value = "pretty", - help = "Log format: plain, pretty, json, pretty-json" - )] - pub log_format: String, - - #[clap(long, help = "Enable rate limiting for connections from the same IP")] - pub rate_limit_enabled: bool, - - #[clap(long, default_value = "60", help = "Rate limit window in seconds")] - pub rate_limit_window: u64, - - #[clap( - long, - default_value = "30", - help = "Maximum number of connections allowed per IP in the rate limit window" - )] - pub rate_limit_max: usize, - - #[clap( - long, - default_value = "100", - help = "Connection attempts threshold before considering for IP blocking" - )] - pub rate_limit_block_threshold: usize, - - #[clap( - long, - help = "Send a 429 response for rate limited connections instead of dropping connection" - )] - pub rate_limit_slow_response: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] -pub enum LogFormat { - Plain, - #[default] - Pretty, - Json, - PrettyJson, -} - -// Trap pattern structure. It can be either a plain string -// regex to catch more advanced patterns necessitated by -// more sophisticated crawlers. -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub enum TrapPattern { - Plain(String), - Regex { pattern: String, regex: bool }, -} - -impl TrapPattern { - pub fn as_plain(value: &str) -> Self { - Self::Plain(value.to_string()) - } - - pub fn as_regex(value: &str) -> Self { - Self::Regex { - pattern: value.to_string(), - regex: true, - } - } - - pub fn matches(&self, path: &str) -> bool { - match self { - Self::Plain(pattern) => path.contains(pattern), - Self::Regex { - pattern, - regex: true, - } => { - if let Ok(re) = Regex::new(pattern) { - re.is_match(path) - } else { - false - } - } - _ => false, - } - } -} - -// Configuration structure -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Config { - pub listen_addr: String, - pub metrics_addr: String, - pub disable_metrics: bool, - pub backend_addr: String, - pub min_delay: u64, - pub max_delay: u64, - pub max_tarpit_time: u64, - pub block_threshold: u32, - pub trap_patterns: Vec, - pub whitelist_networks: Vec, - pub markov_corpora_dir: String, - pub lua_scripts_dir: String, - pub data_dir: String, - pub config_dir: String, - pub cache_dir: String, - pub log_format: LogFormat, - pub rate_limit_enabled: bool, - pub rate_limit_window_seconds: u64, - pub rate_limit_max_connections: usize, - pub rate_limit_block_threshold: usize, - pub rate_limit_slow_response: bool, -} - -impl Default for Config { - fn default() -> Self { - Self { - listen_addr: "0.0.0.0:8888".to_string(), - metrics_addr: "0.0.0.0:9100".to_string(), - disable_metrics: false, - backend_addr: "127.0.0.1:80".to_string(), - min_delay: 1000, - max_delay: 15000, - max_tarpit_time: 600, - block_threshold: 3, - trap_patterns: vec![ - // Basic attack patterns as plain strings - TrapPattern::as_plain("/vendor/phpunit"), - TrapPattern::as_plain("eval-stdin.php"), - TrapPattern::as_plain("/wp-admin"), - TrapPattern::as_plain("/wp-login.php"), - TrapPattern::as_plain("/xmlrpc.php"), - TrapPattern::as_plain("/phpMyAdmin"), - TrapPattern::as_plain("/solr/"), - TrapPattern::as_plain("/.env"), - TrapPattern::as_plain("/config"), - TrapPattern::as_plain("/actuator/"), - // More aggressive patterns for various PHP exploits. - // XXX: I dedicate this entire section to that one single crawler - // that has been scanning my entire network, hitting 403s left and right - // but not giving up, and coming back the next day at the same time to - // scan the same paths over and over. Kudos to you, random crawler. - TrapPattern::as_regex(r"/.*phpunit.*eval-stdin\.php"), - TrapPattern::as_regex(r"/index\.php\?s=/index/\\think\\app/invokefunction"), - TrapPattern::as_regex(r".*%ADd\+auto_prepend_file%3dphp://input.*"), - TrapPattern::as_regex(r".*%ADd\+allow_url_include%3d1.*"), - TrapPattern::as_regex(r".*/wp-content/plugins/.*\.php"), - TrapPattern::as_regex(r".*/wp-content/themes/.*\.php"), - TrapPattern::as_regex(r".*eval\(.*\).*"), - TrapPattern::as_regex(r".*/adminer\.php.*"), - TrapPattern::as_regex(r".*/admin\.php.*"), - TrapPattern::as_regex(r".*/administrator/.*"), - TrapPattern::as_regex(r".*/wp-json/.*"), - TrapPattern::as_regex(r".*/api/.*\.php.*"), - TrapPattern::as_regex(r".*/cgi-bin/.*"), - TrapPattern::as_regex(r".*/owa/.*"), - TrapPattern::as_regex(r".*/ecp/.*"), - TrapPattern::as_regex(r".*/webshell\.php.*"), - TrapPattern::as_regex(r".*/shell\.php.*"), - TrapPattern::as_regex(r".*/cmd\.php.*"), - TrapPattern::as_regex(r".*/struts.*"), - ], - whitelist_networks: vec![ - "192.168.0.0/16".to_string(), - "10.0.0.0/8".to_string(), - "172.16.0.0/12".to_string(), - "127.0.0.0/8".to_string(), - ], - markov_corpora_dir: "./corpora".to_string(), - lua_scripts_dir: "./scripts".to_string(), - data_dir: "./data".to_string(), - config_dir: "./conf".to_string(), - cache_dir: "./cache".to_string(), - log_format: LogFormat::Pretty, - rate_limit_enabled: true, - rate_limit_window_seconds: 60, - rate_limit_max_connections: 30, - rate_limit_block_threshold: 100, - rate_limit_slow_response: true, - } - } -} - -// Gets standard XDG directory paths for config, data and cache. -// XXX: This could be "simplified" by using the Dirs crate, but I can't -// really justify pulling a library for something I can handle in less -// than 30 lines. Unless cross-platform becomes necessary, the below -// implementation is good enough. For alternative platforms, we can simply -// enhance the current implementation as needed. -pub fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) { - let config_home = env::var_os("XDG_CONFIG_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".config") - }); - - let data_home = env::var_os("XDG_DATA_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".local").join("share") - }); - - let cache_home = env::var_os("XDG_CACHE_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".cache") - }); - - let config_dir = config_home.join("eris"); - let data_dir = data_home.join("eris"); - let cache_dir = cache_home.join("eris"); - - (config_dir, data_dir, cache_dir) -} - -impl Config { - // Create configuration from command-line args. We'll be falling back to this - // when the configuration is invalid, so it must be validated more strictly. - pub fn from_args(args: &Args) -> Self { - let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir { - let base_str = base_dir.to_string_lossy().to_string(); - ( - format!("{base_str}/conf"), - format!("{base_str}/data"), - format!("{base_str}/cache"), - ) - } else { - let (c, d, cache) = get_xdg_dirs(); - ( - c.to_string_lossy().to_string(), - d.to_string_lossy().to_string(), - cache.to_string_lossy().to_string(), - ) - }; - - Self { - listen_addr: args.listen_addr.clone(), - metrics_addr: args.metrics_addr.clone(), - disable_metrics: args.disable_metrics, - backend_addr: args.backend_addr.clone(), - min_delay: args.min_delay, - max_delay: args.max_delay, - max_tarpit_time: args.max_tarpit_time, - block_threshold: args.block_threshold, - markov_corpora_dir: format!("{data_dir}/corpora"), - lua_scripts_dir: format!("{data_dir}/scripts"), - data_dir, - config_dir, - cache_dir, - log_format: LogFormat::Pretty, - rate_limit_enabled: args.rate_limit_enabled, - rate_limit_window_seconds: args.rate_limit_window, - rate_limit_max_connections: args.rate_limit_max, - rate_limit_block_threshold: args.rate_limit_block_threshold, - rate_limit_slow_response: args.rate_limit_slow_response, - ..Default::default() - } - } - - // Load configuration from a file (JSON or TOML) - pub fn load_from_file(path: &Path) -> std::io::Result { - let content = fs::read_to_string(path)?; - - let extension = path - .extension() - .map(|ext| ext.to_string_lossy().to_lowercase()) - .unwrap_or_default(); - - let config = match extension.as_str() { - "toml" => toml::from_str(&content).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to parse TOML: {e}"), - ) - })?, - _ => { - // Default to JSON for any other extension - serde_json::from_str(&content).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to parse JSON: {e}"), - ) - })? - } - }; - - Ok(config) - } - - // Save configuration to a file (JSON or TOML) - pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> { - if let Some(parent) = path.parent() { - fs::create_dir_all(parent)?; - } - - let extension = path - .extension() - .map(|ext| ext.to_string_lossy().to_lowercase()) - .unwrap_or_default(); - - let content = match extension.as_str() { - "toml" => toml::to_string_pretty(self).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to serialize to TOML: {e}"), - ) - })?, - _ => { - // Default to JSON for any other extension - serde_json::to_string_pretty(self).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Failed to serialize to JSON: {e}"), - ) - })? - } - }; - - fs::write(path, content)?; - Ok(()) - } - - // Create required directories if they don't exist - pub fn ensure_dirs_exist(&self) -> std::io::Result<()> { - let dirs = [ - &self.markov_corpora_dir, - &self.lua_scripts_dir, - &self.data_dir, - &self.config_dir, - &self.cache_dir, - ]; - - for dir in dirs { - fs::create_dir_all(dir)?; - log::debug!("Created directory: {dir}"); - } - - Ok(()) - } -} - -// Decide if a request should be tarpitted based on path and IP -pub fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { - // Check whitelist IPs first to avoid unnecessary pattern matching - for network_str in &config.whitelist_networks { - if let Ok(network) = network_str.parse::() { - if network.contains(*ip) { - return false; - } - } - } - - // Use pattern matching based on the trap pattern type. It can be - // a plain string or regex. - for pattern in &config.trap_patterns { - if pattern.matches(path) { - return true; - } - } - - false -} - -#[cfg(test)] -mod tests { - use super::*; - use std::net::{IpAddr, Ipv4Addr}; - - #[test] - fn test_config_from_args() { - let args = Args { - listen_addr: "127.0.0.1:8080".to_string(), - metrics_addr: "127.0.0.1:9000".to_string(), - disable_metrics: true, - backend_addr: "127.0.0.1:8081".to_string(), - min_delay: 500, - max_delay: 10000, - max_tarpit_time: 300, - block_threshold: 5, - base_dir: Some(PathBuf::from("/tmp/eris")), - config_file: None, - log_level: "debug".to_string(), - log_format: "pretty".to_string(), - rate_limit_enabled: true, - rate_limit_window: 30, - rate_limit_max: 20, - rate_limit_block_threshold: 50, - rate_limit_slow_response: true, - }; - - let config = Config::from_args(&args); - assert_eq!(config.listen_addr, "127.0.0.1:8080"); - assert_eq!(config.metrics_addr, "127.0.0.1:9000"); - assert!(config.disable_metrics); - assert_eq!(config.backend_addr, "127.0.0.1:8081"); - assert_eq!(config.min_delay, 500); - assert_eq!(config.max_delay, 10000); - assert_eq!(config.max_tarpit_time, 300); - assert_eq!(config.block_threshold, 5); - assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora"); - assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts"); - assert_eq!(config.data_dir, "/tmp/eris/data"); - assert_eq!(config.config_dir, "/tmp/eris/conf"); - assert_eq!(config.cache_dir, "/tmp/eris/cache"); - assert!(config.rate_limit_enabled); - assert_eq!(config.rate_limit_window_seconds, 30); - assert_eq!(config.rate_limit_max_connections, 20); - assert_eq!(config.rate_limit_block_threshold, 50); - assert!(config.rate_limit_slow_response); - } - - #[test] - fn test_trap_pattern_matching() { - // Test plain string pattern - let plain = TrapPattern::as_plain("phpunit"); - assert!(plain.matches("path/to/phpunit/test")); - assert!(!plain.matches("path/to/something/else")); - - // Test regex pattern - let regex = TrapPattern::as_regex(r".*eval-stdin\.php.*"); - assert!(regex.matches("/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php")); - assert!(regex.matches("/tests/eval-stdin.php?param")); - assert!(!regex.matches("/normal/path")); - - // Test invalid regex pattern (should return false) - let invalid = TrapPattern::Regex { - pattern: "(invalid[regex".to_string(), - regex: true, - }; - assert!(!invalid.matches("anything")); - } - - #[tokio::test] - async fn test_should_tarpit() { - let config = Config::default(); - - // Test trap patterns - assert!(should_tarpit( - "/vendor/phpunit/whatever", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - )); - assert!(should_tarpit( - "/wp-admin/login.php", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - )); - assert!(should_tarpit( - "/.env", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - )); - - // Test whitelist networks - assert!(!should_tarpit( - "/wp-admin/login.php", - &IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - &config - )); - assert!(!should_tarpit( - "/vendor/phpunit/whatever", - &IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), - &config - )); - - // Test legitimate paths - assert!(!should_tarpit( - "/index.html", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - )); - assert!(!should_tarpit( - "/images/logo.png", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - )); - - // Test regex patterns - assert!(should_tarpit( - "/index.php?s=/index/\\think\\app/invokefunction&function=call_user_func_array&vars[0]=md5&vars[1][]=Hello", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - )); - - assert!(should_tarpit( - "/hello.world?%ADd+allow_url_include%3d1+%ADd+auto_prepend_file%3dphp://input", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - )); - } - - #[test] - fn test_config_file_formats() { - // Create temporary JSON config file - let temp_dir = std::env::temp_dir(); - let json_path = temp_dir.join("temp_config.json"); - let toml_path = temp_dir.join("temp_config.toml"); - - let config = Config::default(); - - // Test JSON serialization and deserialization - config.save_to_file(&json_path).unwrap(); - let loaded_json = Config::load_from_file(&json_path).unwrap(); - assert_eq!(loaded_json.listen_addr, config.listen_addr); - assert_eq!(loaded_json.min_delay, config.min_delay); - assert_eq!(loaded_json.rate_limit_enabled, config.rate_limit_enabled); - assert_eq!( - loaded_json.rate_limit_max_connections, - config.rate_limit_max_connections - ); - - // Test TOML serialization and deserialization - config.save_to_file(&toml_path).unwrap(); - let loaded_toml = Config::load_from_file(&toml_path).unwrap(); - assert_eq!(loaded_toml.listen_addr, config.listen_addr); - assert_eq!(loaded_toml.min_delay, config.min_delay); - assert_eq!(loaded_toml.rate_limit_enabled, config.rate_limit_enabled); - assert_eq!( - loaded_toml.rate_limit_max_connections, - config.rate_limit_max_connections - ); - - // Clean up - let _ = std::fs::remove_file(json_path); - let _ = std::fs::remove_file(toml_path); - } -} diff --git a/src/firewall.rs b/src/firewall.rs deleted file mode 100644 index 73f36e5..0000000 --- a/src/firewall.rs +++ /dev/null @@ -1,128 +0,0 @@ -use tokio::process::Command; - -// Set up nftables firewall rules for IP blocking -pub async fn setup_firewall() -> Result<(), String> { - log::info!("Setting up firewall rules"); - - // Check if nft command exists - let nft_exists = Command::new("which") - .arg("nft") - .output() - .await - .map(|output| output.status.success()) - .unwrap_or(false); - - if !nft_exists { - log::warn!("nft command not found. Firewall rules will not be set up."); - return Ok(()); - } - - // Create table if it doesn't exist - let output = Command::new("nft") - .args(["list", "table", "inet", "filter"]) - .output() - .await; - - match output { - Ok(output) => { - if !output.status.success() { - log::info!("Creating nftables table"); - let result = Command::new("nft") - .args(["create", "table", "inet", "filter"]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if nftables table exists: {e}"); - log::info!("Will try to create it anyway"); - let result = Command::new("nft") - .args(["create", "table", "inet", "filter"]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); - } - } - } - - // Create blacklist set if it doesn't exist - let output = Command::new("nft") - .args(["list", "set", "inet", "filter", "eris_blacklist"]) - .output() - .await; - - match output { - Ok(output) => { - if !output.status.success() { - log::info!("Creating eris_blacklist set"); - let result = Command::new("nft") - .args([ - "create", - "set", - "inet", - "filter", - "eris_blacklist", - "{ type ipv4_addr; flags interval; }", - ]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create blacklist set: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if blacklist set exists: {e}"); - return Err(format!("Failed to check if blacklist set exists: {e}")); - } - } - - // Add rule to drop traffic from blacklisted IPs - let output = Command::new("nft") - .args(["list", "chain", "inet", "filter", "input"]) - .output() - .await; - - // Check if our rule already exists - match output { - Ok(output) => { - let rule_exists = String::from_utf8_lossy(&output.stdout) - .contains("ip saddr @eris_blacklist counter drop"); - - if !rule_exists { - log::info!("Adding drop rule for blacklisted IPs"); - let result = Command::new("nft") - .args([ - "add", - "rule", - "inet", - "filter", - "input", - "ip saddr @eris_blacklist", - "counter", - "drop", - ]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to add firewall rule: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if firewall rule exists: {e}"); - return Err(format!("Failed to check if firewall rule exists: {e}")); - } - } - - log::info!("Firewall setup complete"); - Ok(()) -} diff --git a/src/lua/mod.rs b/src/lua/mod.rs deleted file mode 100644 index fc96382..0000000 --- a/src/lua/mod.rs +++ /dev/null @@ -1,901 +0,0 @@ -use rlua::{Function, Lua, Table, Value}; -use std::collections::HashMap; -use std::collections::hash_map::DefaultHasher; -use std::fs; -use std::hash::{Hash, Hasher}; -use std::path::Path; -use std::sync::{Arc, Mutex, RwLock}; -use std::time::{SystemTime, UNIX_EPOCH}; - -// Event types for the Lua scripting system -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum EventType { - Connection, // when a new connection is established - Request, // when a request is received - ResponseGen, // when generating a response - ResponseChunk, // before sending each response chunk - Disconnection, // when a connection is closed - BlockIP, // when an IP is being considered for blocking - Startup, // when the application starts - Shutdown, // when the application is shutting down - Periodic, // called periodically (e.g., every minute) -} - -impl EventType { - /// Convert event type to string representation for Lua - const fn as_str(&self) -> &'static str { - match self { - Self::Connection => "connection", - Self::Request => "request", - Self::ResponseGen => "response_gen", - Self::ResponseChunk => "response_chunk", - Self::Disconnection => "disconnection", - Self::BlockIP => "block_ip", - Self::Startup => "startup", - Self::Shutdown => "shutdown", - Self::Periodic => "periodic", - } - } - - /// Convert from string to `EventType` - fn from_str(s: &str) -> Option { - match s { - "connection" => Some(Self::Connection), - "request" => Some(Self::Request), - "response_gen" => Some(Self::ResponseGen), - "response_chunk" => Some(Self::ResponseChunk), - "disconnection" => Some(Self::Disconnection), - "block_ip" => Some(Self::BlockIP), - "startup" => Some(Self::Startup), - "shutdown" => Some(Self::Shutdown), - "periodic" => Some(Self::Periodic), - _ => None, - } - } -} - -// Loaded Lua script with its metadata -struct ScriptInfo { - name: String, - enabled: bool, -} - -// Script state and manage the Lua environment -pub struct ScriptManager { - lua: Mutex, - scripts: Vec, - hooks: HashMap>, - state: Arc>>, - counters: Arc>>, -} - -// Context passed to Lua event handlers -pub struct EventContext { - pub event_type: EventType, - pub ip: Option, - pub path: Option, - pub user_agent: Option, - pub request_headers: Option>, - pub content: Option, - pub timestamp: u64, - pub session_id: Option, -} - -// Make ScriptManager explicitly Send + Sync since we're using Mutex -unsafe impl Send for ScriptManager {} -unsafe impl Sync for ScriptManager {} - -impl ScriptManager { - /// Create a new script manager and load scripts from the given directory - pub fn new(scripts_dir: &str) -> Self { - let mut manager = Self { - lua: Mutex::new(Lua::new()), - scripts: Vec::new(), - hooks: HashMap::new(), - state: Arc::new(RwLock::new(HashMap::new())), - counters: Arc::new(RwLock::new(HashMap::new())), - }; - - // Initialize Lua environment with our API - manager.init_lua_env(); - - // Load scripts from directory - manager.load_scripts_from_dir(scripts_dir); - - // If no scripts were loaded, use default script - if manager.scripts.is_empty() { - log::info!("No Lua scripts found, loading default scripts"); - manager.load_script( - "default", - include_str!("../../resources/default_script.lua"), - ); - } - - // Trigger startup event - manager.trigger_event(&EventContext { - event_type: EventType::Startup, - ip: None, - path: None, - user_agent: None, - request_headers: None, - content: None, - timestamp: get_timestamp(), - session_id: None, - }); - - manager - } - - // Initialize the Lua environment - fn init_lua_env(&self) { - let state_clone = self.state.clone(); - let counters_clone = self.counters.clone(); - - if let Ok(lua) = self.lua.lock() { - // Create eris global table for our API - let eris_table = lua.create_table().unwrap(); - - self.register_utility_functions(&lua, &eris_table, state_clone, counters_clone); - self.register_event_functions(&lua, &eris_table); - self.register_logging_functions(&lua, &eris_table); - - // Set the eris global table - lua.globals().set("eris", eris_table).unwrap(); - } - } - - /// Register utility functions for scripts to use - fn register_utility_functions( - &self, - lua: &Lua, - eris_table: &Table, - state: Arc>>, - counters: Arc>>, - ) { - // Store a key-value pair in persistent state - let state_for_set = state.clone(); - let set_state = lua - .create_function(move |_, (key, value): (String, String)| { - let mut state_map = state_for_set.write().unwrap(); - state_map.insert(key, value); - Ok(()) - }) - .unwrap(); - eris_table.set("set_state", set_state).unwrap(); - - // Get a value from persistent state - let state_for_get = state; - let get_state = lua - .create_function(move |_, key: String| { - let state_map = state_for_get.read().unwrap(); - let value = state_map.get(&key).cloned(); - Ok(value) - }) - .unwrap(); - eris_table.set("get_state", get_state).unwrap(); - - // Increment a counter - let counters_for_inc = counters.clone(); - let inc_counter = lua - .create_function(move |_, (key, amount): (String, Option)| { - let mut counters_map = counters_for_inc.write().unwrap(); - let counter = counters_map.entry(key).or_insert(0); - *counter += amount.unwrap_or(1); - Ok(*counter) - }) - .unwrap(); - eris_table.set("inc_counter", inc_counter).unwrap(); - - // Get a counter value - let counters_for_get = counters; - let get_counter = lua - .create_function(move |_, key: String| { - let counters_map = counters_for_get.read().unwrap(); - let value = counters_map.get(&key).copied().unwrap_or(0); - Ok(value) - }) - .unwrap(); - eris_table.set("get_counter", get_counter).unwrap(); - - // Generate a random token/string - let gen_token = lua - .create_function(move |_, prefix: Option| { - let now = get_timestamp(); - let random = rand::random::(); - let token = format!("{}{:x}{:x}", prefix.unwrap_or_default(), now, random); - Ok(token) - }) - .unwrap(); - eris_table.set("gen_token", gen_token).unwrap(); - - // Get current timestamp - let timestamp = lua - .create_function(move |_, ()| Ok(get_timestamp())) - .unwrap(); - eris_table.set("timestamp", timestamp).unwrap(); - } - - // Register event handling functions - fn register_event_functions(&self, lua: &Lua, eris_table: &Table) { - // Create a table to store event handlers - let handlers_table = lua.create_table().unwrap(); - eris_table.set("handlers", handlers_table).unwrap(); - - // Function for scripts to register event handlers - let on_fn = lua - .create_function(move |lua, (event_name, handler): (String, Function)| { - let globals = lua.globals(); - let eris: Table = globals.get("eris").unwrap(); - let handlers: Table = eris.get("handlers").unwrap(); - - // Get or create a table for this event type - let event_handlers: Table = if let Ok(table) = handlers.get(&*event_name) { - table - } else { - let new_table = lua.create_table().unwrap(); - handlers.set(&*event_name, new_table.clone()).unwrap(); - new_table - }; - - // Add the handler to the table - let next_index = event_handlers.len().unwrap() + 1; - event_handlers.set(next_index, handler).unwrap(); - - Ok(()) - }) - .unwrap(); - eris_table.set("on", on_fn).unwrap(); - } - - // Register logging functions - fn register_logging_functions(&self, lua: &Lua, eris_table: &Table) { - // Debug logging - let debug = lua - .create_function(|_, message: String| { - log::debug!("[Lua] {message}"); - Ok(()) - }) - .unwrap(); - eris_table.set("debug", debug).unwrap(); - - // Info logging - let info = lua - .create_function(|_, message: String| { - log::info!("[Lua] {message}"); - Ok(()) - }) - .unwrap(); - eris_table.set("info", info).unwrap(); - - // Warning logging - let warn = lua - .create_function(|_, message: String| { - log::warn!("[Lua] {message}"); - Ok(()) - }) - .unwrap(); - eris_table.set("warn", warn).unwrap(); - - // Error logging - let error = lua - .create_function(|_, message: String| { - log::error!("[Lua] {message}"); - Ok(()) - }) - .unwrap(); - eris_table.set("error", error).unwrap(); - } - - // Load all scripts from a directory - fn load_scripts_from_dir(&mut self, scripts_dir: &str) { - let script_dir = Path::new(scripts_dir); - if !script_dir.exists() { - log::warn!("Lua scripts directory does not exist: {scripts_dir}"); - return; - } - - log::debug!("Loading Lua scripts from directory: {scripts_dir}"); - if let Ok(entries) = fs::read_dir(script_dir) { - // Sort entries by filename to ensure consistent loading order - let mut sorted_entries: Vec<_> = entries.filter_map(Result::ok).collect(); - sorted_entries.sort_by_key(std::fs::DirEntry::path); - - for entry in sorted_entries { - let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { - if let Ok(content) = fs::read_to_string(&path) { - let script_name = path - .file_stem() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string(); - - log::debug!("Loading Lua script: {} ({})", script_name, path.display()); - self.load_script(&script_name, &content); - } else { - log::warn!("Failed to read Lua script: {}", path.display()); - } - } - } - } - } - - // Load a single script and register its event handlers - fn load_script(&mut self, name: &str, content: &str) { - // Store script info - self.scripts.push(ScriptInfo { - name: name.to_string(), - enabled: true, - }); - - // Execute the script to register its event handlers - if let Ok(lua) = self.lua.lock() { - if let Err(e) = lua.load(content).set_name(name).exec() { - log::warn!("Error loading Lua script '{name}': {e}"); - return; - } - - // Collect registered event handlers - let globals = lua.globals(); - let eris: Table = match globals.get("eris") { - Ok(table) => table, - Err(_) => return, - }; - - let handlers: Table = match eris.get("handlers") { - Ok(table) => table, - Err(_) => return, - }; - - // Store the event handlers in our hooks map - let mut tmp: rlua::TablePairs<'_, String, Table<'_>> = - handlers.pairs::(); - 'l: loop { - if let Some(event_pair) = tmp.next() { - if let Ok((event_name, _)) = event_pair { - if let Some(event_type) = EventType::from_str(&event_name) { - self.hooks - .entry(event_type) - .or_default() - .push(name.to_string()); - } - } - } else { - break 'l; - } - } - - log::info!("Loaded Lua script '{name}' successfully"); - } - } - - /// Check if a script is enabled - fn is_script_enabled(&self, name: &str) -> bool { - self.scripts - .iter() - .find(|s| s.name == name) - .is_some_and(|s| s.enabled) - } - - /// Trigger an event, calling all registered handlers - pub fn trigger_event(&self, ctx: &EventContext) -> Option { - // Check if we have any handlers for this event - if !self.hooks.contains_key(&ctx.event_type) { - return ctx.content.clone(); - } - - // Build the event data table to pass to Lua handlers - let mut result = ctx.content.clone(); - - if let Ok(lua) = self.lua.lock() { - // Create the event context table - let event_ctx = lua.create_table().unwrap(); - - // Add all the context fields - event_ctx.set("event", ctx.event_type.as_str()).unwrap(); - if let Some(ip) = &ctx.ip { - event_ctx.set("ip", ip.clone()).unwrap(); - } - if let Some(path) = &ctx.path { - event_ctx.set("path", path.clone()).unwrap(); - } - if let Some(ua) = &ctx.user_agent { - event_ctx.set("user_agent", ua.clone()).unwrap(); - } - event_ctx.set("timestamp", ctx.timestamp).unwrap(); - if let Some(sid) = &ctx.session_id { - event_ctx.set("session_id", sid.clone()).unwrap(); - } - - // Add request headers if available - if let Some(headers) = &ctx.request_headers { - let headers_table = lua.create_table().unwrap(); - for (key, value) in headers { - headers_table - .set(key.to_string(), value.to_string()) - .unwrap(); - } - event_ctx.set("headers", headers_table).unwrap(); - } - - // Add content if available - if let Some(content) = &ctx.content { - event_ctx.set("content", content.clone()).unwrap(); - } - - // Call all registered handlers for this event - if let Some(handler_scripts) = self.hooks.get(&ctx.event_type) { - for script_name in handler_scripts { - // Skip disabled scripts - if !self.is_script_enabled(script_name) { - continue; - } - - // Get the globals and handlers table - let globals = lua.globals(); - let eris: Table = match globals.get("eris") { - Ok(table) => table, - Err(_) => continue, - }; - - let handlers: Table = match eris.get("handlers") { - Ok(table) => table, - Err(_) => continue, - }; - - // Get handlers for this event - let event_handlers: Table = match handlers.get(ctx.event_type.as_str()) { - Ok(table) => table, - Err(_) => continue, - }; - - // Call each handler - for pair in event_handlers.pairs::() { - if let Ok((_, handler)) = pair { - let handler_result: rlua::Result> = - handler.call((event_ctx.clone(),)); - if let Ok(Some(new_content)) = handler_result { - // For response events, allow handlers to modify the content - if matches!( - ctx.event_type, - EventType::ResponseGen | EventType::ResponseChunk - ) { - result = Some(new_content); - } - } - } - } - } - } - } - - result - } - - /// Generate a deceptive response, calling all `response_gen` handlers - pub fn generate_response( - &self, - path: &str, - user_agent: &str, - ip: &str, - headers: &HashMap, - markov_text: &str, - ) -> String { - // Create event context - let ctx = EventContext { - event_type: EventType::ResponseGen, - ip: Some(ip.to_string()), - path: Some(path.to_string()), - user_agent: Some(user_agent.to_string()), - request_headers: Some(headers.clone()), - content: Some(markov_text.to_string()), - timestamp: get_timestamp(), - session_id: Some(generate_session_id(ip, user_agent)), - }; - - /// Trigger the event and get the modified content - self.trigger_event(&ctx).unwrap_or_else(|| { - // Fallback to maintain backward compatibility - self.expand_response( - markov_text, - "generic", - path, - &generate_session_id(ip, user_agent), - ) - }) - } - - /// Process a chunk before sending it to client - pub fn process_chunk(&self, chunk: &str, ip: &str, session_id: &str) -> String { - let ctx = EventContext { - event_type: EventType::ResponseChunk, - ip: Some(ip.to_string()), - path: None, - user_agent: None, - request_headers: None, - content: Some(chunk.to_string()), - timestamp: get_timestamp(), - session_id: Some(session_id.to_string()), - }; - - self.trigger_event(&ctx) - .unwrap_or_else(|| chunk.to_string()) - } - - /// Called when a connection is established - pub fn on_connection(&self, ip: &str) -> bool { - let ctx = EventContext { - event_type: EventType::Connection, - ip: Some(ip.to_string()), - path: None, - user_agent: None, - request_headers: None, - content: None, - timestamp: get_timestamp(), - session_id: None, - }; - - // If any handler returns false, reject the connection - let mut should_accept = true; - - if let Ok(lua) = self.lua.lock() { - if let Some(handler_scripts) = self.hooks.get(&EventType::Connection) { - for script_name in handler_scripts { - // Skip disabled scripts - if !self.is_script_enabled(script_name) { - continue; - } - - let globals = lua.globals(); - let eris: Table = match globals.get("eris") { - Ok(table) => table, - Err(_) => continue, - }; - - let handlers: Table = match eris.get("handlers") { - Ok(table) => table, - Err(_) => continue, - }; - - let event_handlers: Table = match handlers.get("connection") { - Ok(table) => table, - Err(_) => continue, - }; - - for pair in event_handlers.pairs::() { - if let Ok((_, handler)) = pair { - let event_ctx = create_event_context(&lua, &ctx); - if let Ok(result) = handler.call::<_, Value>((event_ctx,)) { - if result == Value::Boolean(false) { - should_accept = false; - break; - } - } - } - } - - if !should_accept { - break; - } - } - } - } - - should_accept - } - - /// Called when deciding whether to block an IP - pub fn should_block_ip(&self, ip: &str, hit_count: u32) -> bool { - let ctx = EventContext { - event_type: EventType::BlockIP, - ip: Some(ip.to_string()), - path: None, - user_agent: None, - request_headers: None, - content: None, - timestamp: get_timestamp(), - session_id: None, - }; - - // We should default to not modifying the blocking decision - let mut should_block = None; - - if let Ok(lua) = self.lua.lock() { - if let Some(handler_scripts) = self.hooks.get(&EventType::BlockIP) { - for script_name in handler_scripts { - // Skip disabled scripts - if !self.is_script_enabled(script_name) { - continue; - } - - let globals = lua.globals(); - let eris: Table = match globals.get("eris") { - Ok(table) => table, - Err(_) => continue, - }; - - let handlers: Table = match eris.get("handlers") { - Ok(table) => table, - Err(_) => continue, - }; - - let event_handlers: Table = match handlers.get("block_ip") { - Ok(table) => table, - Err(_) => continue, - }; - - for pair in event_handlers.pairs::() { - if let Ok((_, handler)) = pair { - let event_ctx = create_event_context(&lua, &ctx); - // Add hit count for the block_ip event - event_ctx.set("hit_count", hit_count).unwrap(); - - if let Ok(result) = handler.call::<_, Value>((event_ctx,)) { - if let Value::Boolean(block) = result { - should_block = Some(block); - break; - } - } - } - } - - if should_block.is_some() { - break; - } - } - } - } - - // Return the script's decision, or default to the system behavior - should_block.unwrap_or(hit_count >= 3) - } - - // Maintains backward compatibility with the old API - // XXX: I never liked expand_response, should probably be removeedf - // in the future. - pub fn expand_response( - &self, - text: &str, - response_type: &str, - path: &str, - token: &str, - ) -> String { - if let Ok(lua) = self.lua.lock() { - let globals = lua.globals(); - match globals.get::<_, Function>("enhance_response") { - Ok(enhance_func) => { - match enhance_func.call::<_, String>((text, response_type, path, token)) { - Ok(result) => result, - Err(e) => { - log::warn!("Error calling Lua function enhance_response: {e}"); - format!("{text}\n") - } - } - } - Err(_) => format!("{text}\n"), - } - } else { - format!("{text}\n") - } - } -} - -/// Get current timestamp in seconds -fn get_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs() -} - -/// Create a unique session ID for tracking a connection -fn generate_session_id(ip: &str, user_agent: &str) -> String { - let timestamp = get_timestamp(); - let random = rand::random::(); - - // Use std::hash instead of xxhash_rust - let mut hasher = DefaultHasher::new(); - format!("{ip}_{user_agent}_{timestamp}").hash(&mut hasher); - let hash = hasher.finish(); - - format!("SID_{hash:x}_{random:x}") -} - -// Create an event context table in Lua -fn create_event_context<'a>(lua: &'a Lua, event_ctx: &EventContext) -> Table<'a> { - let table = lua.create_table().unwrap(); - - table.set("event", event_ctx.event_type.as_str()).unwrap(); - if let Some(ip) = &event_ctx.ip { - table.set("ip", ip.clone()).unwrap(); - } - if let Some(path) = &event_ctx.path { - table.set("path", path.clone()).unwrap(); - } - if let Some(ua) = &event_ctx.user_agent { - table.set("user_agent", ua.clone()).unwrap(); - } - table.set("timestamp", event_ctx.timestamp).unwrap(); - if let Some(sid) = &event_ctx.session_id { - table.set("session_id", sid.clone()).unwrap(); - } - if let Some(content) = &event_ctx.content { - table.set("content", content.clone()).unwrap(); - } - - table -} - -#[cfg(test)] -mod tests { - use super::*; - use std::fs; - - use tempfile::TempDir; - - #[test] - fn test_event_registration() { - let temp_dir = TempDir::new().unwrap(); - let script_path = temp_dir.path().join("test_events.lua"); - let script_content = r#" - -- Example script with event handlers - eris.info("Registering event handlers") - - -- Connection event handler - eris.on("connection", function(ctx) - eris.debug("Connection from " .. ctx.ip) - return true -- accept the connection - end) - - -- Response generation handler - eris.on("response_gen", function(ctx) - eris.debug("Generating response for " .. ctx.path) - return ctx.content .. "" - end) - "#; - - fs::write(&script_path, script_content).unwrap(); - - let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); - - // Verify hooks were registered - assert!(script_manager.hooks.contains_key(&EventType::Connection)); - assert!(script_manager.hooks.contains_key(&EventType::ResponseGen)); - assert!(!script_manager.hooks.contains_key(&EventType::BlockIP)); - } - - #[test] - fn test_generate_response() { - let temp_dir = TempDir::new().unwrap(); - let script_path = temp_dir.path().join("response_test.lua"); - let script_content = r#" - eris.on("response_gen", function(ctx) - return ctx.content .. " - Modified by " .. ctx.user_agent - end) - "#; - - fs::write(&script_path, script_content).unwrap(); - - let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); - - let headers = HashMap::new(); - let result = script_manager.generate_response( - "/test/path", - "TestBot", - "127.0.0.1", - &headers, - "Original content", - ); - - assert!(result.contains("Original content")); - assert!(result.contains("Modified by TestBot")); - } - - #[test] - fn test_process_chunk() { - let temp_dir = TempDir::new().unwrap(); - let script_path = temp_dir.path().join("chunk_test.lua"); - let script_content = r#" - eris.on("response_chunk", function(ctx) - return ctx.content:gsub("secret", "REDACTED") - end) - "#; - - fs::write(&script_path, script_content).unwrap(); - - let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); - - let result = script_manager.process_chunk( - "This contains a secret password", - "127.0.0.1", - "test_session", - ); - - assert!(result.contains("This contains a REDACTED password")); - } - - #[test] - fn test_should_block_ip() { - let temp_dir = TempDir::new().unwrap(); - let script_path = temp_dir.path().join("block_test.lua"); - let script_content = r#" - eris.on("block_ip", function(ctx) - -- Block any IP with "192.168.1" prefix regardless of hit count - if string.match(ctx.ip, "^192%.168%.1%.") then - return true - end - - -- Don't block IPs with "10.0" prefix even if they hit the threshold - if string.match(ctx.ip, "^10%.0%.") then - return false - end - - -- Default behavior for other IPs (nil = use system default) - return nil - end) - "#; - - fs::write(&script_path, script_content).unwrap(); - - let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); - - // Should be blocked based on IP pattern - assert!(script_manager.should_block_ip("192.168.1.50", 1)); - - // Should not be blocked despite high hit count - assert!(!script_manager.should_block_ip("10.0.0.5", 10)); - - // Should use default behavior (block if >= 3 hits) - assert!(!script_manager.should_block_ip("172.16.0.1", 2)); - assert!(script_manager.should_block_ip("172.16.0.1", 3)); - } - - #[test] - fn test_state_and_counters() { - let temp_dir = TempDir::new().unwrap(); - let script_path = temp_dir.path().join("state_test.lua"); - let script_content = r#" - eris.on("startup", function(ctx) - eris.set_state("test_key", "test_value") - eris.inc_counter("visits", 0) - end) - - eris.on("connection", function(ctx) - local count = eris.inc_counter("visits") - eris.debug("Visit count: " .. count) - - -- Store last visitor - eris.set_state("last_visitor", ctx.ip) - return true - end) - - eris.on("response_gen", function(ctx) - local last_visitor = eris.get_state("last_visitor") or "unknown" - local visits = eris.get_counter("visits") - return ctx.content .. "" - end) - "#; - - fs::write(&script_path, script_content).unwrap(); - - let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); - - // Simulate connections - script_manager.on_connection("192.168.1.100"); - script_manager.on_connection("10.0.0.50"); - - // Check response includes state - let headers = HashMap::new(); - let result = script_manager.generate_response( - "/test", - "test-agent", - "8.8.8.8", - &headers, - "Response", - ); - - assert!(result.contains("Last visitor: 10.0.0.50")); - assert!(result.contains("Total visits: 2")); - } -} diff --git a/src/main.rs b/src/main.rs index c8d80a6..2471ced 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,106 +1,1067 @@ use actix_web::{App, HttpResponse, HttpServer, web}; use clap::Parser; +use ipnetwork::IpNetwork; +use rlua::{Function, Lua}; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use std::env; use std::fs; -use std::path::Path; +use std::io::Write; +use std::net::IpAddr; +use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::Duration; -use tokio::net::TcpListener; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::process::Command; use tokio::sync::RwLock; +use tokio::time::sleep; -mod config; -mod firewall; -mod lua; mod markov; mod metrics; -mod network; -mod state; -mod utils; -use config::{Args, Config}; -use lua::{EventContext, EventType, ScriptManager}; use markov::MarkovGenerator; -use metrics::{metrics_handler, status_handler}; -use network::handle_connection; -use state::BotState; -use utils::get_timestamp; +use metrics::{ + ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, + status_handler, +}; + +// Command-line arguments using clap +#[derive(Parser, Debug, Clone)] +#[clap( + author, + version, + about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" +)] +struct Args { + #[clap( + long, + default_value = "0.0.0.0:8888", + help = "Address and port to listen for incoming HTTP requests (format: ip:port)" + )] + listen_addr: String, + + #[clap( + long, + default_value = "0.0.0.0:9100", + help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)" + )] + metrics_addr: String, + + #[clap(long, help = "Disable Prometheus metrics server completely")] + disable_metrics: bool, + + #[clap( + long, + default_value = "127.0.0.1:80", + help = "Backend server address to proxy legitimate requests to (format: ip:port)" + )] + backend_addr: String, + + #[clap( + long, + default_value = "1000", + help = "Minimum delay in milliseconds between chunks sent to attacker" + )] + min_delay: u64, + + #[clap( + long, + default_value = "15000", + help = "Maximum delay in milliseconds between chunks sent to attacker" + )] + max_delay: u64, + + #[clap( + long, + default_value = "600", + help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" + )] + max_tarpit_time: u64, + + #[clap( + long, + default_value = "3", + help = "Number of hits to honeypot patterns before permanently blocking an IP" + )] + block_threshold: u32, + + #[clap( + long, + help = "Base directory for all application data (overrides XDG directory structure)" + )] + base_dir: Option, + + #[clap( + long, + help = "Path to JSON configuration file (overrides command line options)" + )] + config_file: Option, + + #[clap( + long, + default_value = "info", + help = "Log level: trace, debug, info, warn, error" + )] + log_level: String, +} + +// Configuration structure +#[derive(Clone, Debug, Deserialize, Serialize)] +struct Config { + listen_addr: String, + metrics_addr: String, + disable_metrics: bool, + backend_addr: String, + min_delay: u64, + max_delay: u64, + max_tarpit_time: u64, + block_threshold: u32, + trap_patterns: Vec, + whitelist_networks: Vec, + markov_corpora_dir: String, + lua_scripts_dir: String, + data_dir: String, + config_dir: String, + cache_dir: String, +} + +impl Default for Config { + fn default() -> Self { + Self { + listen_addr: "0.0.0.0:8888".to_string(), + metrics_addr: "0.0.0.0:9100".to_string(), + disable_metrics: false, + backend_addr: "127.0.0.1:80".to_string(), + min_delay: 1000, + max_delay: 15000, + max_tarpit_time: 600, + block_threshold: 3, + trap_patterns: vec![ + "/vendor/phpunit".to_string(), + "eval-stdin.php".to_string(), + "/wp-admin".to_string(), + "/wp-login.php".to_string(), + "/xmlrpc.php".to_string(), + "/phpMyAdmin".to_string(), + "/solr/".to_string(), + "/.env".to_string(), + "/config".to_string(), + "/api/".to_string(), + "/actuator/".to_string(), + ], + whitelist_networks: vec![ + "192.168.0.0/16".to_string(), + "10.0.0.0/8".to_string(), + "172.16.0.0/12".to_string(), + "127.0.0.0/8".to_string(), + ], + markov_corpora_dir: "./corpora".to_string(), + lua_scripts_dir: "./scripts".to_string(), + data_dir: "./data".to_string(), + config_dir: "./conf".to_string(), + cache_dir: "./cache".to_string(), + } + } +} + +// Gets standard XDG directory paths for config, data and cache +fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) { + let config_home = env::var_os("XDG_CONFIG_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".config") + }); + + let data_home = env::var_os("XDG_DATA_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".local").join("share") + }); + + let cache_home = env::var_os("XDG_CACHE_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".cache") + }); + + let config_dir = config_home.join("eris"); + let data_dir = data_home.join("eris"); + let cache_dir = cache_home.join("eris"); + + (config_dir, data_dir, cache_dir) +} + +impl Config { + // Create configuration from command-line args + fn from_args(args: &Args) -> Self { + let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir { + let base_str = base_dir.to_string_lossy().to_string(); + ( + format!("{base_str}/conf"), + format!("{base_str}/data"), + format!("{base_str}/cache"), + ) + } else { + let (c, d, cache) = get_xdg_dirs(); + ( + c.to_string_lossy().to_string(), + d.to_string_lossy().to_string(), + cache.to_string_lossy().to_string(), + ) + }; + + Self { + listen_addr: args.listen_addr.clone(), + metrics_addr: args.metrics_addr.clone(), + disable_metrics: args.disable_metrics, + backend_addr: args.backend_addr.clone(), + min_delay: args.min_delay, + max_delay: args.max_delay, + max_tarpit_time: args.max_tarpit_time, + block_threshold: args.block_threshold, + markov_corpora_dir: format!("{data_dir}/corpora"), + lua_scripts_dir: format!("{data_dir}/scripts"), + data_dir, + config_dir, + cache_dir, + ..Default::default() + } + } + + // Load configuration from a JSON file + fn load_from_file(path: &Path) -> std::io::Result { + let content = fs::read_to_string(path)?; + let config = serde_json::from_str(&content)?; + Ok(config) + } + + // Save configuration to a JSON file + fn save_to_file(&self, path: &Path) -> std::io::Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + let content = serde_json::to_string_pretty(self)?; + fs::write(path, content)?; + Ok(()) + } + + // Create required directories if they don't exist + fn ensure_dirs_exist(&self) -> std::io::Result<()> { + let dirs = [ + &self.markov_corpora_dir, + &self.lua_scripts_dir, + &self.data_dir, + &self.config_dir, + &self.cache_dir, + ]; + + for dir in dirs { + fs::create_dir_all(dir)?; + log::debug!("Created directory: {dir}"); + } + + Ok(()) + } +} + +// State of bots/IPs hitting the honeypot +#[derive(Clone, Debug)] +struct BotState { + hits: HashMap, + blocked: HashSet, + active_connections: HashSet, + data_dir: String, + cache_dir: String, +} + +impl BotState { + fn new(data_dir: &str, cache_dir: &str) -> Self { + Self { + hits: HashMap::new(), + blocked: HashSet::new(), + active_connections: HashSet::new(), + data_dir: data_dir.to_string(), + cache_dir: cache_dir.to_string(), + } + } + + // Load previous state from disk + fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { + let mut state = Self::new(data_dir, cache_dir); + let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); + + if let Ok(content) = fs::read_to_string(&blocked_ips_file) { + let mut loaded = 0; + for line in content.lines() { + if let Ok(ip) = line.parse::() { + state.blocked.insert(ip); + loaded += 1; + } + } + log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); + } else { + log::info!("No blocked IPs file found at {blocked_ips_file}"); + } + + // Check for temporary hit counter cache + let hit_cache_file = format!("{cache_dir}/hit_counters.json"); + if let Ok(content) = fs::read_to_string(&hit_cache_file) { + if let Ok(hit_map) = serde_json::from_str::>(&content) { + for (ip_str, count) in hit_map { + if let Ok(ip) = ip_str.parse::() { + state.hits.insert(ip, count); + } + } + log::info!("Loaded hit counters for {} IPs", state.hits.len()); + } + } + + BLOCKED_IPS.set(state.blocked.len() as f64); + state + } + + // Persist state to disk for later reloading + fn save_to_disk(&self) { + // Save blocked IPs + if let Err(e) = fs::create_dir_all(&self.data_dir) { + log::error!("Failed to create data directory: {e}"); + return; + } + + let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); + + match fs::File::create(&blocked_ips_file) { + Ok(mut file) => { + let mut count = 0; + for ip in &self.blocked { + if writeln!(file, "{ip}").is_ok() { + count += 1; + } + } + log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); + } + Err(e) => { + log::error!("Failed to create blocked IPs file: {e}"); + } + } + + // Save hit counters to cache + if let Err(e) = fs::create_dir_all(&self.cache_dir) { + log::error!("Failed to create cache directory: {e}"); + return; + } + + let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); + let mut hit_map = HashMap::new(); + for (ip, count) in &self.hits { + hit_map.insert(ip.to_string(), *count); + } + + match fs::File::create(&hit_cache_file) { + Ok(file) => { + if let Err(e) = serde_json::to_writer(file, &hit_map) { + log::error!("Failed to write hit counters to cache: {e}"); + } else { + log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); + } + } + Err(e) => { + log::error!("Failed to create hit counter cache file: {e}"); + } + } + } +} + +// Lua scripts for response generation and customization +struct ScriptManager { + script_content: String, + scripts_loaded: bool, +} + +impl ScriptManager { + fn new(scripts_dir: &str) -> Self { + let mut script_content = String::new(); + let mut scripts_loaded = false; + + // Try to load scripts from directory + let script_dir = Path::new(scripts_dir); + if script_dir.exists() { + log::debug!("Loading Lua scripts from directory: {scripts_dir}"); + if let Ok(entries) = fs::read_dir(script_dir) { + for entry in entries { + if let Ok(entry) = entry { + let path = entry.path(); + if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { + if let Ok(content) = fs::read_to_string(&path) { + log::debug!("Loaded Lua script: {}", path.display()); + script_content.push_str(&content); + script_content.push('\n'); + scripts_loaded = true; + } else { + log::warn!("Failed to read Lua script: {}", path.display()); + } + } + } + } + } + } else { + log::warn!("Lua scripts directory does not exist: {scripts_dir}"); + } + + // If no scripts were loaded, use a default script + if !scripts_loaded { + log::info!("No Lua scripts found, loading default scripts"); + script_content = r#" + function generate_honeytoken(token) + local token_types = {"API_KEY", "AUTH_TOKEN", "SESSION_ID", "SECRET_KEY"} + local prefix = token_types[math.random(#token_types)] + local suffix = string.format("%08x", math.random(0xffffff)) + return prefix .. "_" .. token .. "_" .. suffix + end + + function enhance_response(text, response_type, path, token) + local result = text + local honeytoken = generate_honeytoken(token) + + -- Add some fake sensitive data + result = result .. "\n" + result = result .. "\n
Server ID: " .. token .. "
" + + return result + end + "# + .to_string(); + scripts_loaded = true; + } + + Self { + script_content, + scripts_loaded, + } + } + + // Lua is a powerful configuration language we can use to expand functionality of + // Eris, e.g., with fake tokens or honeytrap content. + fn expand_response(&self, text: &str, response_type: &str, path: &str, token: &str) -> String { + if !self.scripts_loaded { + return format!("{text}\n"); + } + + let lua = Lua::new(); + if let Err(e) = lua.load(&self.script_content).exec() { + log::warn!("Error loading Lua script: {e}"); + return format!("{text}\n"); + } + + let globals = lua.globals(); + match globals.get::<_, Function>("enhance_response") { + Ok(enhance_func) => { + match enhance_func.call::<_, String>((text, response_type, path, token)) { + Ok(result) => result, + Err(e) => { + log::warn!("Error calling Lua function enhance_response: {e}"); + format!("{text}\n") + } + } + } + Err(e) => { + log::warn!("Lua enhance_response function not found: {e}"); + format!("{text}\n") + } + } + } +} + +// Main connection handler - decides whether to tarpit or proxy +async fn handle_connection( + mut stream: TcpStream, + config: Arc, + state: Arc>, + markov_generator: Arc, + script_manager: Arc, +) { + let peer_addr = match stream.peer_addr() { + Ok(addr) => addr.ip(), + Err(e) => { + log::debug!("Failed to get peer address: {e}"); + return; + } + }; + + log::debug!("New connection from: {peer_addr}"); + + // Check if IP is already blocked + if state.read().await.blocked.contains(&peer_addr) { + log::debug!("Rejected connection from blocked IP: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + + // Read the HTTP request + let mut buffer = [0; 8192]; + let mut request_data = Vec::new(); + + // Read with timeout to prevent hanging + let read_fut = async { + loop { + match stream.read(&mut buffer).await { + Ok(0) => break, + Ok(n) => { + request_data.extend_from_slice(&buffer[..n]); + // Stop reading at empty line, this is the end of HTTP headers + if request_data.len() > 2 && &request_data[request_data.len() - 2..] == b"\r\n" + { + break; + } + } + Err(e) => { + log::debug!("Error reading from stream: {e}"); + break; + } + } + } + }; + + let timeout_fut = sleep(Duration::from_secs(5)); + + tokio::select! { + () = read_fut => {}, + () = timeout_fut => { + log::debug!("Connection timeout from: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + } + + // Parse the request + let request_str = String::from_utf8_lossy(&request_data); + let request_lines: Vec<&str> = request_str.lines().collect(); + + if request_lines.is_empty() { + log::debug!("Empty request from: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + + // Parse request line + let request_parts: Vec<&str> = request_lines[0].split_whitespace().collect(); + if request_parts.len() < 3 { + log::debug!("Malformed request from {}: {}", peer_addr, request_lines[0]); + let _ = stream.shutdown().await; + return; + } + + let method = request_parts[0]; + let path = request_parts[1]; + let protocol = request_parts[2]; + + log::debug!("Request: {method} {path} {protocol} from {peer_addr}"); + + // Parse headers + let mut headers = HashMap::new(); + for line in &request_lines[1..] { + if line.is_empty() { + break; + } + + if let Some(idx) = line.find(':') { + let key = line[..idx].trim(); + let value = line[idx + 1..].trim(); + headers.insert(key, value.to_string()); + } + } + + let user_agent = headers + .get("user-agent") + .cloned() + .unwrap_or_else(|| "unknown".to_string()); + + // Check if this request matches our tarpit patterns + let should_tarpit = should_tarpit(path, &peer_addr, &config).await; + + if should_tarpit { + log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})"); + + // Update metrics + HITS_COUNTER.inc(); + PATH_HITS.with_label_values(&[path]).inc(); + UA_HITS.with_label_values(&[&user_agent]).inc(); + + // Update state and check for blocking threshold + { + let mut state = state.write().await; + state.active_connections.insert(peer_addr); + ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); + + *state.hits.entry(peer_addr).or_insert(0) += 1; + let hit_count = state.hits[&peer_addr]; + log::debug!("Hit count for {peer_addr}: {hit_count}"); + + // Block IPs that hit tarpits too many times + if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) { + log::info!("Blocking IP {peer_addr} after {hit_count} hits"); + state.blocked.insert(peer_addr); + BLOCKED_IPS.set(state.blocked.len() as f64); + state.save_to_disk(); + + // Try to add to firewall + let peer_addr_str = peer_addr.to_string(); + tokio::spawn(async move { + log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); + match Command::new("nft") + .args([ + "add", + "element", + "inet", + "filter", + "eris_blacklist", + "{", + &peer_addr_str, + "}", + ]) + .output() + .await + { + Ok(output) => { + if !output.status.success() { + log::warn!( + "Failed to add IP {} to firewall: {}", + peer_addr_str, + String::from_utf8_lossy(&output.stderr) + ); + } + } + Err(e) => { + log::warn!("Failed to execute nft command: {e}"); + } + } + }); + } + } + + // Generate a deceptive response using Markov chains and Lua + let response = + generate_deceptive_response(path, &user_agent, &markov_generator, &script_manager) + .await; + + // Send the response with the tarpit delay strategy + tarpit_connection( + stream, + response, + peer_addr, + state.clone(), + config.min_delay, + config.max_delay, + config.max_tarpit_time, + ) + .await; + } else { + log::debug!("Proxying request: {method} {path} from {peer_addr}"); + + // Proxy non-matching requests to the actual backend + proxy_to_backend( + stream, + method, + path, + protocol, + &headers, + &config.backend_addr, + ) + .await; + } +} + +// Determine if a request should be tarpitted based on path and IP +async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { + // Don't tarpit whitelisted IPs (internal networks, etc) + for network_str in &config.whitelist_networks { + if let Ok(network) = network_str.parse::() { + if network.contains(*ip) { + log::debug!("IP {ip} is in whitelist network {network_str}"); + return false; + } + } + } + + // Check if the request path matches any of our trap patterns + for pattern in &config.trap_patterns { + if path.contains(pattern) { + log::debug!("Path '{path}' matches trap pattern '{pattern}'"); + return true; + } + } + + // No trap patterns matched + false +} + +// Generate a deceptive HTTP response that appears legitimate +async fn generate_deceptive_response( + path: &str, + user_agent: &str, + markov: &MarkovGenerator, + script_manager: &ScriptManager, +) -> String { + // Choose response type based on path to seem more realistic + let response_type = if path.contains("phpunit") || path.contains("eval") { + "php_exploit" + } else if path.contains("wp-") { + "wordpress" + } else if path.contains("api") { + "api" + } else { + "generic" + }; + + log::debug!("Generating {response_type} response for path: {path}"); + + // Generate tracking token for this interaction + let tracking_token = format!( + "BOT_{}_{}", + user_agent + .chars() + .filter(|c| c.is_alphanumeric()) + .collect::(), + chrono::Utc::now().timestamp() + ); + + // Generate base response using Markov chain text generator + let markov_text = markov.generate(response_type, 30); + + // Use Lua to enhance with honeytokens and other deceptive content + let enhanced = + script_manager.expand_response(&markov_text, response_type, path, &tracking_token); + + // Return full HTTP response with appropriate headers + format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{enhanced}" + ) +} + +// Slowly feed a response to the client with random delays to waste attacker time +async fn tarpit_connection( + mut stream: TcpStream, + response: String, + peer_addr: IpAddr, + state: Arc>, + min_delay: u64, + max_delay: u64, + max_tarpit_time: u64, +) { + let start_time = Instant::now(); + let mut chars = response.chars().collect::>(); + + // Randomize the char order slightly to confuse automated tools + for i in (0..chars.len()).rev() { + if i > 0 && rand::random::() < 0.1 { + chars.swap(i, i - 1); + } + } + + log::debug!( + "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", + peer_addr, + chars.len(), + min_delay, + max_delay + ); + + let mut position = 0; + let mut chunks_sent = 0; + let mut total_delay = 0; + + // Send the response character by character with random delays + while position < chars.len() { + // Check if we've exceeded maximum tarpit time + let elapsed_secs = start_time.elapsed().as_secs(); + if elapsed_secs > max_tarpit_time { + log::info!("Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}"); + break; + } + + // Decide how many chars to send in this chunk (usually 1, sometimes more) + let chunk_size = if rand::random::() < 0.9 { + 1 + } else { + (rand::random::() * 3.0).floor() as usize + 1 + }; + + let end = (position + chunk_size).min(chars.len()); + let chunk: String = chars[position..end].iter().collect(); + + // Try to write chunk + if stream.write_all(chunk.as_bytes()).await.is_err() { + log::debug!("Connection closed by client during tarpit: {peer_addr}"); + break; + } + + if stream.flush().await.is_err() { + log::debug!("Failed to flush stream during tarpit: {peer_addr}"); + break; + } + + position = end; + chunks_sent += 1; + + // Apply random delay between min and max configured values + let delay_ms = (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; + total_delay += delay_ms; + sleep(Duration::from_millis(delay_ms)).await; + } + + log::debug!( + "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", + peer_addr, + chunks_sent, + position * 100 / chars.len(), + total_delay, + start_time.elapsed().as_secs() + ); + + // Remove from active connections + if let Ok(mut state) = state.try_write() { + state.active_connections.remove(&peer_addr); + ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); + } + + let _ = stream.shutdown().await; +} + +// Forward a legitimate request to the real backend server +async fn proxy_to_backend( + mut client_stream: TcpStream, + method: &str, + path: &str, + protocol: &str, + headers: &HashMap<&str, String>, + backend_addr: &str, +) { + // Connect to backend server + let server_stream = match TcpStream::connect(backend_addr).await { + Ok(stream) => stream, + Err(e) => { + log::warn!("Failed to connect to backend {backend_addr}: {e}"); + let _ = client_stream.shutdown().await; + return; + } + }; + + log::debug!("Connected to backend server at {backend_addr}"); + + // Forward the original request + let mut request = format!("{method} {path} {protocol}\r\n"); + for (key, value) in headers { + request.push_str(&format!("{key}: {value}\r\n")); + } + request.push_str("\r\n"); + + let mut server_stream = server_stream; + if server_stream.write_all(request.as_bytes()).await.is_err() { + log::debug!("Failed to write request to backend server"); + let _ = client_stream.shutdown().await; + return; + } + + // Set up bidirectional forwarding between client and backend + let (mut client_read, mut client_write) = client_stream.split(); + let (mut server_read, mut server_write) = server_stream.split(); + + // Client -> Server + let client_to_server = async { + let mut buf = [0; 8192]; + let mut bytes_forwarded = 0; + + loop { + match client_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if server_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); + }; + + // Server -> Client + let server_to_client = async { + let mut buf = [0; 8192]; + let mut bytes_forwarded = 0; + + loop { + match server_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); + }; + + // Run both directions concurrently + tokio::select! { + () = client_to_server => {}, + () = server_to_client => {}, + } + + log::debug!("Proxy connection completed"); +} + +// Set up nftables firewall rules for IP blocking +async fn setup_firewall() -> Result<(), String> { + log::info!("Setting up firewall rules"); + + // Check if nft command exists + let nft_exists = Command::new("which") + .arg("nft") + .output() + .await + .map(|output| output.status.success()) + .unwrap_or(false); + + if !nft_exists { + log::warn!("nft command not found. Firewall rules will not be set up."); + return Ok(()); + } + + // Create table if it doesn't exist + let output = Command::new("nft") + .args(["list", "table", "inet", "filter"]) + .output() + .await; + + match output { + Ok(output) => { + if !output.status.success() { + log::info!("Creating nftables table"); + let result = Command::new("nft") + .args(["create", "table", "inet", "filter"]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create nftables table: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if nftables table exists: {e}"); + log::info!("Will try to create it anyway"); + let result = Command::new("nft") + .args(["create", "table", "inet", "filter"]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create nftables table: {e}")); + } + } + } + + // Create blacklist set if it doesn't exist + let output = Command::new("nft") + .args(["list", "set", "inet", "filter", "eris_blacklist"]) + .output() + .await; + + match output { + Ok(output) => { + if !output.status.success() { + log::info!("Creating eris_blacklist set"); + let result = Command::new("nft") + .args([ + "create", + "set", + "inet", + "filter", + "eris_blacklist", + "{ type ipv4_addr; flags interval; }", + ]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create blacklist set: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if blacklist set exists: {e}"); + return Err(format!("Failed to check if blacklist set exists: {e}")); + } + } + + // Add rule to drop traffic from blacklisted IPs + let output = Command::new("nft") + .args(["list", "chain", "inet", "filter", "input"]) + .output() + .await; + + // Check if our rule already exists + match output { + Ok(output) => { + let rule_exists = String::from_utf8_lossy(&output.stdout) + .contains("ip saddr @eris_blacklist counter drop"); + + if !rule_exists { + log::info!("Adding drop rule for blacklisted IPs"); + let result = Command::new("nft") + .args([ + "add", + "rule", + "inet", + "filter", + "input", + "ip saddr @eris_blacklist", + "counter", + "drop", + ]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to add firewall rule: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if firewall rule exists: {e}"); + return Err(format!("Failed to check if firewall rule exists: {e}")); + } + } + + log::info!("Firewall setup complete"); + Ok(()) +} #[actix_web::main] async fn main() -> std::io::Result<()> { // Parse command line arguments let args = Args::parse(); - // Determine log format from args - let log_format = match args.log_format.to_lowercase().as_str() { - "json" => config::LogFormat::Json, - "pretty-json" => config::LogFormat::PrettyJson, - "plain" => config::LogFormat::Plain, - _ => config::LogFormat::Pretty, - }; + // Initialize the logger + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(&args.log_level)) + .format_timestamp_millis() + .init(); - // Initialize the logger with proper formatting - let env = env_logger::Env::default().default_filter_or(&args.log_level); - let mut builder = env_logger::Builder::from_env(env); - - match log_format { - config::LogFormat::Plain => { - builder.format_timestamp_millis().init(); - } - config::LogFormat::Pretty => { - builder - .format(|buf, record| { - use std::io::Write; - let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f"); - writeln!( - buf, - "[{}] {} [{}] {}", - timestamp, - record.level(), - record.target(), - record.args() - ) - }) - .init(); - } - config::LogFormat::Json => { - builder - .format(|buf, record| { - use std::io::Write; - let json = serde_json::json!({ - "timestamp": chrono::Local::now().to_rfc3339(), - "level": record.level().to_string(), - "target": record.target(), - "message": record.args().to_string(), - "module_path": record.module_path(), - "file": record.file(), - "line": record.line(), - }); - writeln!(buf, "{}", json) - }) - .init(); - } - config::LogFormat::PrettyJson => { - builder - .format(|buf, record| { - use std::io::Write; - let json = serde_json::json!({ - "timestamp": chrono::Local::now().to_rfc3339(), - "level": record.level().to_string(), - "target": record.target(), - "message": record.args().to_string(), - "module_path": record.module_path(), - "file": record.file(), - "line": record.line(), - }); - writeln!(buf, "{}", serde_json::to_string_pretty(&json).unwrap()) - }) - .init(); - } - } - - log::info!("Starting Eris tarpit system"); + log::info!("Starting eris tarpit system"); // Load configuration - let mut config = if let Some(config_path) = &args.config_file { + let config = if let Some(config_path) = &args.config_file { log::info!("Loading configuration from {config_path:?}"); match Config::load_from_file(config_path) { Ok(cfg) => { @@ -118,11 +1079,6 @@ async fn main() -> std::io::Result<()> { Config::from_args(&args) }; - // Log format from the command line needs to be preserved - if args.config_file.is_none() { - config.log_format = log_format; - } - // Ensure required directories exist match config.ensure_dirs_exist() { Ok(()) => log::info!("Directory setup completed"), @@ -137,23 +1093,12 @@ async fn main() -> std::io::Result<()> { if let Err(e) = fs::create_dir_all(&config.config_dir) { log::warn!("Failed to create config directory: {e}"); } else { - // Save both JSON and TOML versions of the config for user reference - let config_path_json = Path::new(&config.config_dir).join("config.json"); - let config_path_toml = Path::new(&config.config_dir).join("config.toml"); - - if !config_path_json.exists() { - if let Err(e) = config.save_to_file(&config_path_json) { - log::warn!("Failed to save JSON configuration: {e}"); + let config_path = Path::new(&config.config_dir).join("config.json"); + if !config_path.exists() { + if let Err(e) = config.save_to_file(&config_path) { + log::warn!("Failed to save default configuration: {e}"); } else { - log::info!("Saved JSON configuration to {config_path_json:?}"); - } - } - - if !config_path_toml.exists() { - if let Err(e) = config.save_to_file(&config_path_toml) { - log::warn!("Failed to save TOML configuration: {e}"); - } else { - log::info!("Saved TOML configuration to {config_path_toml:?}"); + log::info!("Saved default configuration to {config_path:?}"); } } } @@ -169,7 +1114,7 @@ async fn main() -> std::io::Result<()> { let config = Arc::new(config); // Setup firewall rules for IP blocking - match firewall::setup_firewall().await { + match setup_firewall().await { Ok(()) => {} Err(e) => { log::warn!("Failed to set up firewall rules: {e}"); @@ -194,8 +1139,6 @@ async fn main() -> std::io::Result<()> { // Initialize Lua script manager log::info!("Loading Lua scripts from {}", config.lua_scripts_dir); let script_manager = Arc::new(ScriptManager::new(&config.lua_scripts_dir)); - let script_manager_for_tarpit = script_manager.clone(); - let script_manager_for_periodic = script_manager.clone(); // Clone config for metrics server let metrics_config = config.clone(); @@ -220,7 +1163,7 @@ async fn main() -> std::io::Result<()> { let state_clone = tarpit_state.clone(); let markov_clone = markov_generator.clone(); - let script_manager_clone = script_manager_for_tarpit.clone(); + let script_manager_clone = script_manager.clone(); let config_clone = config.clone(); tokio::spawn(async move { @@ -275,27 +1218,6 @@ async fn main() -> std::io::Result<()> { } }; - // Setup periodic task runner for Lua scripts - tokio::spawn(async move { - let mut interval = tokio::time::interval(Duration::from_secs(60)); - loop { - interval.tick().await; - - // Trigger periodic event - let ctx = EventContext { - event_type: EventType::Periodic, - ip: None, - path: None, - user_agent: None, - request_headers: None, - content: None, - timestamp: get_timestamp(), - session_id: None, - }; - script_manager_for_periodic.trigger_event(&ctx); - } - }); - // Run both servers concurrently if metrics server is enabled if let Some(metrics_server) = metrics_server { tokio::select! { @@ -335,3 +1257,186 @@ async fn main() -> std::io::Result<()> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + use tokio::sync::RwLock; + + #[test] + fn test_config_from_args() { + let args = Args { + listen_addr: "127.0.0.1:8080".to_string(), + metrics_addr: "127.0.0.1:9000".to_string(), + disable_metrics: true, + backend_addr: "127.0.0.1:8081".to_string(), + min_delay: 500, + max_delay: 10000, + max_tarpit_time: 300, + block_threshold: 5, + base_dir: Some(PathBuf::from("/tmp/eris")), + config_file: None, + log_level: "debug".to_string(), + }; + + let config = Config::from_args(&args); + assert_eq!(config.listen_addr, "127.0.0.1:8080"); + assert_eq!(config.metrics_addr, "127.0.0.1:9000"); + assert!(config.disable_metrics); + assert_eq!(config.backend_addr, "127.0.0.1:8081"); + assert_eq!(config.min_delay, 500); + assert_eq!(config.max_delay, 10000); + assert_eq!(config.max_tarpit_time, 300); + assert_eq!(config.block_threshold, 5); + assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora"); + assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts"); + assert_eq!(config.data_dir, "/tmp/eris/data"); + assert_eq!(config.config_dir, "/tmp/eris/conf"); + assert_eq!(config.cache_dir, "/tmp/eris/cache"); + } + + #[tokio::test] + async fn test_should_tarpit() { + let config = Config::default(); + + // Test trap patterns + assert!( + should_tarpit( + "/vendor/phpunit/whatever", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + ) + .await + ); + assert!( + should_tarpit( + "/wp-admin/login.php", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + ) + .await + ); + assert!(should_tarpit("/.env", &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), &config).await); + + // Test whitelist networks + assert!( + !should_tarpit( + "/wp-admin/login.php", + &IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + &config + ) + .await + ); + assert!( + !should_tarpit( + "/vendor/phpunit/whatever", + &IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + &config + ) + .await + ); + + // Test legitimate paths + assert!( + !should_tarpit( + "/index.html", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + ) + .await + ); + assert!( + !should_tarpit( + "/images/logo.png", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + ) + .await + ); + } + + #[test] + fn test_script_manager_default_script() { + let script_manager = ScriptManager::new("/nonexistent_directory"); + assert!(script_manager.scripts_loaded); + assert!( + script_manager + .script_content + .contains("generate_honeytoken") + ); + assert!(script_manager.script_content.contains("enhance_response")); + } + + #[tokio::test] + async fn test_bot_state() { + let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache"); + let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)); + + let state = Arc::new(RwLock::new(state)); + + // Test hit counter + { + let mut state = state.write().await; + *state.hits.entry(ip1).or_insert(0) += 1; + *state.hits.entry(ip1).or_insert(0) += 1; + *state.hits.entry(ip2).or_insert(0) += 1; + + assert_eq!(*state.hits.get(&ip1).unwrap(), 2); + assert_eq!(*state.hits.get(&ip2).unwrap(), 1); + } + + // Test blocking + { + let mut state = state.write().await; + state.blocked.insert(ip1); + assert!(state.blocked.contains(&ip1)); + assert!(!state.blocked.contains(&ip2)); + } + + // Test active connections + { + let mut state = state.write().await; + state.active_connections.insert(ip1); + state.active_connections.insert(ip2); + assert_eq!(state.active_connections.len(), 2); + + state.active_connections.remove(&ip1); + assert_eq!(state.active_connections.len(), 1); + assert!(!state.active_connections.contains(&ip1)); + assert!(state.active_connections.contains(&ip2)); + } + } + + #[tokio::test] + async fn test_generate_deceptive_response() { + // Create a simple markov generator for testing + let markov = MarkovGenerator::new("/nonexistent/path"); + let script_manager = ScriptManager::new("/nonexistent/path"); + + // Test different path types + let resp1 = generate_deceptive_response( + "/vendor/phpunit/exec", + "TestBot/1.0", + &markov, + &script_manager, + ) + .await; + assert!(resp1.contains("HTTP/1.1 200 OK")); + assert!(resp1.contains("X-Powered-By: PHP")); + + let resp2 = + generate_deceptive_response("/wp-admin/", "TestBot/1.0", &markov, &script_manager) + .await; + assert!(resp2.contains("HTTP/1.1 200 OK")); + + let resp3 = + generate_deceptive_response("/api/users", "TestBot/1.0", &markov, &script_manager) + .await; + assert!(resp3.contains("HTTP/1.1 200 OK")); + + // Verify tracking token is included + assert!(resp1.contains("BOT_TestBot")); + } +} diff --git a/src/markov.rs b/src/markov.rs index 05960cb..b640dac 100644 --- a/src/markov.rs +++ b/src/markov.rs @@ -103,7 +103,7 @@ impl MarkovGenerator { let path = Path::new(corpus_dir); if path.exists() && path.is_dir() { if let Ok(entries) = fs::read_dir(path) { - entries.for_each(|entry| { + for entry in entries { if let Ok(entry) = entry { let file_path = entry.path(); if let Some(file_name) = file_path.file_stem() { @@ -120,7 +120,7 @@ impl MarkovGenerator { } } } - }); + } } } diff --git a/src/metrics.rs b/src/metrics.rs index 44c3fb4..48bceb9 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -24,11 +24,6 @@ lazy_static! { register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap(); pub static ref UA_HITS: CounterVec = register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap(); - pub static ref RATE_LIMITED_CONNECTIONS: Counter = register_counter!( - "eris_rate_limited_total", - "Number of connections rejected due to rate limiting" - ) - .unwrap(); } // Prometheus metrics endpoint @@ -81,7 +76,6 @@ mod tests { UA_HITS.with_label_values(&["TestBot/1.0"]).inc(); BLOCKED_IPS.set(5.0); ACTIVE_CONNECTIONS.set(3.0); - RATE_LIMITED_CONNECTIONS.inc(); // Create test app let app = @@ -102,7 +96,6 @@ mod tests { assert!(body_str.contains("eris_ua_hits_total")); assert!(body_str.contains("eris_blocked_ips")); assert!(body_str.contains("eris_active_connections")); - assert!(body_str.contains("eris_rate_limited_total")); } #[actix_web::test] diff --git a/src/network/mod.rs b/src/network/mod.rs deleted file mode 100644 index 0fd1bfa..0000000 --- a/src/network/mod.rs +++ /dev/null @@ -1,504 +0,0 @@ -use std::collections::HashMap; -use std::net::IpAddr; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio::process::Command; -use tokio::sync::RwLock; -use tokio::time::sleep; - -use crate::config::Config; -use crate::lua::{EventContext, EventType, ScriptManager}; -use crate::markov::MarkovGenerator; -use crate::metrics::{ - ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, RATE_LIMITED_CONNECTIONS, UA_HITS, -}; -use crate::state::BotState; -use crate::utils::{ - choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request, - find_header_end, generate_session_id, get_timestamp, -}; - -mod rate_limiter; -use rate_limiter::RateLimiter; - -// Global rate limiter instance. -// Default is 30 connections per IP in a 60 second window -// XXX: This might add overhead of the proxy, e.g. NGINX already implements -// rate limiting. Though I don't think we have a way of knowing if the middleman -// we are handing the connections to (from the same middleman in some cases) has -// rate limiting. -lazy_static::lazy_static! { - static ref RATE_LIMITER: RateLimiter = RateLimiter::new(60, 30); -} - -// Main connection handler. -// Decides whether to tarpit or proxy -pub async fn handle_connection( - mut stream: TcpStream, - config: Arc, - state: Arc>, - markov_generator: Arc, - script_manager: Arc, -) { - // Get peer information - let peer_addr: IpAddr = match stream.peer_addr() { - Ok(addr) => addr.ip(), - Err(e) => { - log::debug!("Failed to get peer address: {e}"); - return; - } - }; - - // Check for blocked IPs to avoid any processing - if state.read().await.blocked.contains(&peer_addr) { - log::debug!("Rejected connection from blocked IP: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - - // Apply rate limiting before any further processing - if config.rate_limit_enabled && !RATE_LIMITER.check_rate_limit(peer_addr).await { - log::info!("Rate limited connection from {peer_addr}"); - RATE_LIMITED_CONNECTIONS.inc(); - - // Optionally, add the IP to a temporary block list - // if it's constantly hitting the rate limit - let connection_count = RATE_LIMITER.get_connection_count(&peer_addr); - if connection_count > config.rate_limit_block_threshold { - log::warn!( - "IP {peer_addr} exceeding rate limit with {connection_count} connection attempts, considering for blocking" - ); - - // Trigger a blocked event for Lua scripts - let rate_limit_ctx = EventContext { - event_type: EventType::BlockIP, - ip: Some(peer_addr.to_string()), - path: None, - user_agent: None, - request_headers: None, - content: None, - timestamp: get_timestamp(), - session_id: None, - }; - script_manager.trigger_event(&rate_limit_ctx); - } - - // Either send a slow response or just close connection - if config.rate_limit_slow_response { - // Send a simple 429 Too Many Requests respons. If the bots actually respected - // HTTP error codes, the internet would be a mildly better place. - let response = "HTTP/1.1 429 Too Many Requests\r\n\ - Content-Type: text/plain\r\n\ - Retry-After: 60\r\n\ - Connection: close\r\n\ - \r\n\ - Rate limit exceeded. Please try again later."; - - let _ = stream.write_all(response.as_bytes()).await; - let _ = stream.flush().await; - } - - let _ = stream.shutdown().await; - return; - } - - // Check if Lua scripts allow this connection - if !script_manager.on_connection(&peer_addr.to_string()) { - log::debug!("Connection rejected by Lua script: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - - // Pre-check for whitelisted IPs to bypass heavy processing - let mut whitelisted = false; - for network_str in &config.whitelist_networks { - if let Ok(network) = network_str.parse::() { - if network.contains(peer_addr) { - whitelisted = true; - break; - } - } - } - - // Read buffer - let mut buffer = vec![0; 8192]; - let mut request_data = Vec::with_capacity(8192); - let mut header_end_pos = 0; - - // Read with timeout to prevent hanging resource load ops. - let read_fut = async { - loop { - match stream.read(&mut buffer).await { - Ok(0) => break, - Ok(n) => { - let new_data = &buffer[..n]; - request_data.extend_from_slice(new_data); - - // Look for end of headers - if header_end_pos == 0 { - if let Some(pos) = find_header_end(&request_data) { - header_end_pos = pos; - break; - } - } - - // Avoid excessive buffering - if request_data.len() > 32768 { - break; - } - } - Err(e) => { - log::debug!("Error reading from stream: {e}"); - break; - } - } - } - }; - - let timeout_fut = sleep(Duration::from_secs(3)); - - tokio::select! { - () = read_fut => {}, - () = timeout_fut => { - log::debug!("Connection timeout from: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - } - - // Fast path for whitelisted IPs. Skip full parsing and speed up "approved" - // connections automatically. - if whitelisted { - log::debug!("Whitelisted IP {peer_addr} - using fast proxy path"); - proxy_fast_path(stream, request_data, &config.backend_addr).await; - return; - } - - // Parse minimally to extract the path - let path = if let Some(p) = extract_path_from_request(&request_data) { - p - } else { - log::debug!("Invalid request from {peer_addr}"); - let _ = stream.shutdown().await; - return; - }; - - // Extract request headers for Lua scripts - let headers = extract_all_headers(&request_data); - - // Extract user agent for logging and decision making - let user_agent = - extract_header_value(&request_data, "user-agent").unwrap_or_else(|| "unknown".to_string()); - - // Trigger request event for Lua scripts - let request_ctx = EventContext { - event_type: EventType::Request, - ip: Some(peer_addr.to_string()), - path: Some(path.to_string()), - user_agent: Some(user_agent.clone()), - request_headers: Some(headers.clone()), - content: None, - timestamp: get_timestamp(), - session_id: Some(generate_session_id(&peer_addr.to_string(), &user_agent)), - }; - script_manager.trigger_event(&request_ctx); - - // Check if this request matches our tarpit patterns - let should_tarpit = crate::config::should_tarpit(path, &peer_addr, &config); - - if should_tarpit { - log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})"); - - // Update metrics - HITS_COUNTER.inc(); - PATH_HITS.with_label_values(&[path]).inc(); - UA_HITS.with_label_values(&[&user_agent]).inc(); - - // Update state and check for blocking threshold - { - let mut state = state.write().await; - state.active_connections.insert(peer_addr); - ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); - - *state.hits.entry(peer_addr).or_insert(0) += 1; - let hit_count = state.hits[&peer_addr]; - - // Use Lua to decide whether to block this IP - let should_block = script_manager.should_block_ip(&peer_addr.to_string(), hit_count); - - // Block IPs that hit tarpits too many times - if should_block && !state.blocked.contains(&peer_addr) { - log::info!("Blocking IP {peer_addr} after {hit_count} hits"); - state.blocked.insert(peer_addr); - BLOCKED_IPS.set(state.blocked.len() as f64); - state.save_to_disk(); - - // Do firewall blocking in background - let peer_addr_str = peer_addr.to_string(); - tokio::spawn(async move { - log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); - match Command::new("nft") - .args([ - "add", - "element", - "inet", - "filter", - "eris_blacklist", - "{", - &peer_addr_str, - "}", - ]) - .output() - .await - { - Ok(output) => { - if !output.status.success() { - log::warn!( - "Failed to add IP {} to firewall: {}", - peer_addr_str, - String::from_utf8_lossy(&output.stderr) - ); - } - } - Err(e) => { - log::warn!("Failed to execute nft command: {e}"); - } - } - }); - } - } - - // Generate a deceptive response using Markov chains and Lua - let response = generate_deceptive_response( - path, - &user_agent, - &peer_addr, - &headers, - &markov_generator, - &script_manager, - ) - .await; - - // Generate a session ID for tracking this tarpit session - let session_id = generate_session_id(&peer_addr.to_string(), &user_agent); - - // Send the response with the tarpit delay strategy - { - let mut stream = stream; - let peer_addr = peer_addr; - let state = state.clone(); - let min_delay = config.min_delay; - let max_delay = config.max_delay; - let max_tarpit_time = config.max_tarpit_time; - let script_manager = script_manager.clone(); - async move { - let start_time = Instant::now(); - let mut chars = response.chars().collect::>(); - for i in (0..chars.len()).rev() { - if i > 0 && rand::random::() < 0.1 { - chars.swap(i, i - 1); - } - } - log::debug!( - "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", - peer_addr, - chars.len(), - min_delay, - max_delay - ); - let mut position = 0; - let mut chunks_sent = 0; - let mut total_delay = 0; - while position < chars.len() { - // Check if we've exceeded maximum tarpit time - let elapsed_secs = start_time.elapsed().as_secs(); - if elapsed_secs > max_tarpit_time { - log::info!( - "Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}" - ); - break; - } - - // Decide how many chars to send in this chunk (usually 1, sometimes more) - let chunk_size = if rand::random::() < 0.9 { - 1 - } else { - (rand::random::() * 3.0).floor() as usize + 1 - }; - - let end = (position + chunk_size).min(chars.len()); - let chunk: String = chars[position..end].iter().collect(); - - // Process chunk through Lua before sending - let processed_chunk = - script_manager.process_chunk(&chunk, &peer_addr.to_string(), &session_id); - - // Try to write processed chunk - if stream.write_all(processed_chunk.as_bytes()).await.is_err() { - log::debug!("Connection closed by client during tarpit: {peer_addr}"); - break; - } - - if stream.flush().await.is_err() { - log::debug!("Failed to flush stream during tarpit: {peer_addr}"); - break; - } - - position = end; - chunks_sent += 1; - - // Apply random delay between min and max configured values - let delay_ms = - (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; - total_delay += delay_ms; - sleep(Duration::from_millis(delay_ms)).await; - } - log::debug!( - "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", - peer_addr, - chunks_sent, - position * 100 / chars.len(), - total_delay, - start_time.elapsed().as_secs() - ); - let disconnection_ctx = EventContext { - event_type: EventType::Disconnection, - ip: Some(peer_addr.to_string()), - path: None, - user_agent: None, - request_headers: None, - content: None, - timestamp: get_timestamp(), - session_id: Some(session_id), - }; - script_manager.trigger_event(&disconnection_ctx); - if let Ok(mut state) = state.try_write() { - state.active_connections.remove(&peer_addr); - ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); - } - let _ = stream.shutdown().await; - } - } - .await; - } else { - log::debug!("Proxying request: {path} from {peer_addr}"); - proxy_fast_path(stream, request_data, &config.backend_addr).await; - } -} - -// Forward a legitimate request to the real backend server -pub async fn proxy_fast_path( - mut client_stream: TcpStream, - request_data: Vec, - backend_addr: &str, -) { - // Connect to backend server - let server_stream = match TcpStream::connect(backend_addr).await { - Ok(stream) => stream, - Err(e) => { - log::warn!("Failed to connect to backend {backend_addr}: {e}"); - let _ = client_stream.shutdown().await; - return; - } - }; - - // Set TCP_NODELAY for both streams before splitting them - if let Err(e) = client_stream.set_nodelay(true) { - log::debug!("Failed to set TCP_NODELAY on client stream: {e}"); - } - - let mut server_stream = server_stream; - if let Err(e) = server_stream.set_nodelay(true) { - log::debug!("Failed to set TCP_NODELAY on server stream: {e}"); - } - - // Forward the original request bytes directly without parsing - if server_stream.write_all(&request_data).await.is_err() { - log::debug!("Failed to write request to backend server"); - let _ = client_stream.shutdown().await; - return; - } - - // Now split the streams for concurrent reading/writing - let (mut client_read, mut client_write) = client_stream.split(); - let (mut server_read, mut server_write) = server_stream.split(); - - // 32KB buffer - let buf_size = 32768; - - // Client -> Server - let client_to_server = async { - let mut buf = vec![0; buf_size]; - let mut bytes_forwarded = 0; - - loop { - match client_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if server_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - // Ensure everything is sent - let _ = server_write.flush().await; - log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); - }; - - // Server -> Client - let server_to_client = async { - let mut buf = vec![0; buf_size]; - let mut bytes_forwarded = 0; - - loop { - match server_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if client_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - // Ensure everything is sent - let _ = client_write.flush().await; - log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); - }; - - // Run both directions concurrently - tokio::join!(client_to_server, server_to_client); - log::debug!("Fast proxy connection completed"); -} - -// Generate a deceptive HTTP response that appears legitimate -pub async fn generate_deceptive_response( - path: &str, - user_agent: &str, - peer_addr: &IpAddr, - headers: &HashMap, - markov: &MarkovGenerator, - script_manager: &ScriptManager, -) -> String { - // Generate base response using Markov chain text generator - let response_type = choose_response_type(path); - let markov_text = markov.generate(response_type, 30); - - // Use Lua scripts to enhance with honeytokens and other deceptive content - script_manager.generate_response( - path, - user_agent, - &peer_addr.to_string(), - headers, - &markov_text, - ) -} diff --git a/src/network/rate_limiter.rs b/src/network/rate_limiter.rs deleted file mode 100644 index 6150793..0000000 --- a/src/network/rate_limiter.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::collections::HashMap; -use std::net::IpAddr; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::sync::Mutex; - -pub struct RateLimiter { - connections: Arc>>>, - window_seconds: u64, - max_connections: usize, - cleanup_interval: Duration, - last_cleanup: Instant, -} - -impl RateLimiter { - pub fn new(window_seconds: u64, max_connections: usize) -> Self { - Self { - connections: Arc::new(Mutex::new(HashMap::new())), - window_seconds, - max_connections, - cleanup_interval: Duration::from_secs(60), - last_cleanup: Instant::now(), - } - } - - pub async fn check_rate_limit(&self, ip: IpAddr) -> bool { - let now = Instant::now(); - let window = Duration::from_secs(self.window_seconds); - - let mut connections = self.connections.lock().await; - - // Periodically clean up old entries across all IPs - if now.duration_since(self.last_cleanup) > self.cleanup_interval { - self.cleanup_old_entries(&mut connections, now, window); - } - - // Clean up old entries for this specific IP - if let Some(times) = connections.get_mut(&ip) { - times.retain(|time| now.duration_since(*time) < window); - - // Check if rate limit exceeded - if times.len() >= self.max_connections { - log::debug!("Rate limit exceeded for IP: {}", ip); - return false; - } - - // Add new connection time - times.push(now); - } else { - connections.insert(ip, vec![now]); - } - - true - } - - fn cleanup_old_entries( - &self, - connections: &mut HashMap>, - now: Instant, - window: Duration, - ) { - let mut empty_keys = Vec::new(); - - for (ip, times) in connections.iter_mut() { - times.retain(|time| now.duration_since(*time) < window); - if times.is_empty() { - empty_keys.push(*ip); - } - } - - // Remove empty entries - for ip in empty_keys { - connections.remove(&ip); - } - } - - pub fn get_connection_count(&self, ip: &IpAddr) -> usize { - if let Ok(connections) = self.connections.try_lock() { - if let Some(times) = connections.get(ip) { - return times.len(); - } - } - 0 - } -} diff --git a/src/state.rs b/src/state.rs deleted file mode 100644 index 99e5369..0000000 --- a/src/state.rs +++ /dev/null @@ -1,166 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::fs; -use std::io::Write; -use std::net::IpAddr; - -use crate::metrics::BLOCKED_IPS; - -// State of bots/IPs hitting the honeypot -#[derive(Clone, Debug)] -pub struct BotState { - pub hits: HashMap, - pub blocked: HashSet, - pub active_connections: HashSet, - pub data_dir: String, - pub cache_dir: String, -} - -impl BotState { - pub fn new(data_dir: &str, cache_dir: &str) -> Self { - Self { - hits: HashMap::new(), - blocked: HashSet::new(), - active_connections: HashSet::new(), - data_dir: data_dir.to_string(), - cache_dir: cache_dir.to_string(), - } - } - - // Load previous state from disk - pub fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { - let mut state = Self::new(data_dir, cache_dir); - let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); - - if let Ok(content) = fs::read_to_string(&blocked_ips_file) { - let mut loaded = 0; - for line in content.lines() { - if let Ok(ip) = line.parse::() { - state.blocked.insert(ip); - loaded += 1; - } - } - log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); - } else { - log::info!("No blocked IPs file found at {blocked_ips_file}"); - } - - // Check for temporary hit counter cache - let hit_cache_file = format!("{cache_dir}/hit_counters.json"); - if let Ok(content) = fs::read_to_string(&hit_cache_file) { - if let Ok(hit_map) = serde_json::from_str::>(&content) { - for (ip_str, count) in hit_map { - if let Ok(ip) = ip_str.parse::() { - state.hits.insert(ip, count); - } - } - log::info!("Loaded hit counters for {} IPs", state.hits.len()); - } - } - - BLOCKED_IPS.set(state.blocked.len() as f64); - state - } - - // Persist state to disk for later reloading - pub fn save_to_disk(&self) { - // Save blocked IPs - if let Err(e) = fs::create_dir_all(&self.data_dir) { - log::error!("Failed to create data directory: {e}"); - return; - } - - let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); - - match fs::File::create(&blocked_ips_file) { - Ok(mut file) => { - let mut count = 0; - for ip in &self.blocked { - if writeln!(file, "{ip}").is_ok() { - count += 1; - } - } - log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); - } - Err(e) => { - log::error!("Failed to create blocked IPs file: {e}"); - } - } - - // Save hit counters to cache - if let Err(e) = fs::create_dir_all(&self.cache_dir) { - log::error!("Failed to create cache directory: {e}"); - return; - } - - let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); - let hit_map: HashMap = self - .hits - .iter() - .map(|(ip, count)| (ip.to_string(), *count)) - .collect(); - - match fs::File::create(&hit_cache_file) { - Ok(file) => { - if let Err(e) = serde_json::to_writer(file, &hit_map) { - log::error!("Failed to write hit counters to cache: {e}"); - } else { - log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); - } - } - Err(e) => { - log::error!("Failed to create hit counter cache file: {e}"); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::net::{IpAddr, Ipv4Addr}; - use std::sync::Arc; - use tokio::sync::RwLock; - - #[tokio::test] - async fn test_bot_state() { - let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache"); - let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); - let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)); - - let state = Arc::new(RwLock::new(state)); - - // Test hit counter - { - let mut state = state.write().await; - *state.hits.entry(ip1).or_insert(0) += 1; - *state.hits.entry(ip1).or_insert(0) += 1; - *state.hits.entry(ip2).or_insert(0) += 1; - - assert_eq!(*state.hits.get(&ip1).unwrap(), 2); - assert_eq!(*state.hits.get(&ip2).unwrap(), 1); - } - - // Test blocking - { - let mut state = state.write().await; - state.blocked.insert(ip1); - assert!(state.blocked.contains(&ip1)); - assert!(!state.blocked.contains(&ip2)); - drop(state); - } - - // Test active connections - { - let mut state = state.write().await; - state.active_connections.insert(ip1); - state.active_connections.insert(ip2); - assert_eq!(state.active_connections.len(), 2); - - state.active_connections.remove(&ip1); - assert_eq!(state.active_connections.len(), 1); - assert!(!state.active_connections.contains(&ip1)); - assert!(state.active_connections.contains(&ip2)); - drop(state); - } - } -} diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index c925565..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,168 +0,0 @@ -use std::collections::HashMap; -use std::hash::Hasher; - -// Find end of HTTP headers -pub fn find_header_end(data: &[u8]) -> Option { - data.windows(4) - .position(|window| window == b"\r\n\r\n") - .map(|pos| pos + 4) -} - -// Extract path from raw request data -pub fn extract_path_from_request(data: &[u8]) -> Option<&str> { - // Get first line from request - let first_line = data - .split(|&b| b == b'\r' || b == b'\n') - .next() - .filter(|line| !line.is_empty())?; - - // Split by spaces and ensure we have at least 3 parts (METHOD PATH VERSION) - let parts: Vec<&[u8]> = first_line.split(|&b| b == b' ').collect(); - if parts.len() < 3 || !parts[2].starts_with(b"HTTP/") { - return None; - } - - // Return the path (second element) - std::str::from_utf8(parts[1]).ok() -} - -// Extract header value from raw request data -pub fn extract_header_value(data: &[u8], header_name: &str) -> Option { - let data_str = std::str::from_utf8(data).ok()?; - let header_prefix = format!("{header_name}: ").to_lowercase(); - - for line in data_str.lines() { - let line_lower = line.to_lowercase(); - if line_lower.starts_with(&header_prefix) { - return Some(line[header_prefix.len()..].trim().to_string()); - } - } - None -} - -// Extract all headers from request data -pub fn extract_all_headers(data: &[u8]) -> HashMap { - let mut headers = HashMap::new(); - - if let Ok(data_str) = std::str::from_utf8(data) { - let mut lines = data_str.lines(); - - // Skip the request line - let _ = lines.next(); - - // Parse headers until empty line - for line in lines { - if line.is_empty() { - break; - } - - if let Some(colon_pos) = line.find(':') { - let key = line[..colon_pos].trim().to_lowercase(); - let value = line[colon_pos + 1..].trim().to_string(); - headers.insert(key, value); - } - } - } - - headers -} - -// Determine response type based on request path -pub fn choose_response_type(path: &str) -> &'static str { - if path.contains("phpunit") || path.contains("eval") { - "php_exploit" - } else if path.contains("wp-") { - "wordpress" - } else if path.contains("api") { - "api" - } else { - "generic" - } -} - -// Get current timestamp in seconds -pub fn get_timestamp() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() -} - -// Create a unique session ID for tracking a connection -pub fn generate_session_id(ip: &str, user_agent: &str) -> String { - let timestamp = get_timestamp(); - let random = rand::random::(); - - // XXX: Is this fast enough for our case? I don't think hashing is a huge - // bottleneck, but it's worth revisiting in the future to see if there is - // an objectively faster algorithm that we can try. - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - std::hash::Hash::hash(&format!("{ip}_{user_agent}_{timestamp}"), &mut hasher); - let hash = hasher.finish(); - - format!("SID_{hash:x}_{random:x}") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_find_header_end() { - let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content"; - assert_eq!(find_header_end(data), Some(55)); - - let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n"; - assert_eq!(find_header_end(incomplete), None); - } - - #[test] - fn test_extract_path_from_request() { - let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n"; - assert_eq!(extract_path_from_request(data), Some("/index.html")); - - let bad_data = b"INVALID DATA"; - assert_eq!(extract_path_from_request(bad_data), None); - } - - #[test] - fn test_extract_header_value() { - let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n"; - assert_eq!( - extract_header_value(data, "user-agent"), - Some("TestBot/1.0".to_string()) - ); - assert_eq!( - extract_header_value(data, "Host"), - Some("example.com".to_string()) - ); - assert_eq!(extract_header_value(data, "nonexistent"), None); - } - - #[test] - fn test_extract_all_headers() { - let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\nAccept: */*\r\n\r\n"; - let headers = extract_all_headers(data); - - assert_eq!(headers.len(), 3); - assert_eq!(headers.get("host").unwrap(), "example.com"); - assert_eq!(headers.get("user-agent").unwrap(), "TestBot/1.0"); - assert_eq!(headers.get("accept").unwrap(), "*/*"); - } - - #[test] - fn test_choose_response_type() { - assert_eq!( - choose_response_type("/vendor/phpunit/whatever"), - "php_exploit" - ); - assert_eq!( - choose_response_type("/path/to/eval-stdin.php"), - "php_exploit" - ); - assert_eq!(choose_response_type("/wp-admin/login.php"), "wordpress"); - assert_eq!(choose_response_type("/wp-login.php"), "wordpress"); - assert_eq!(choose_response_type("/api/v1/users"), "api"); - assert_eq!(choose_response_type("/index.html"), "generic"); - } -}