diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml new file mode 100644 index 0000000..d3ce8ab --- /dev/null +++ b/.github/workflows/tag.yml @@ -0,0 +1,31 @@ +name: Tag latest version + +on: + workflow_dispatch: + push: + branches: [ main ] + +concurrency: tag + +jobs: + tag-release: + runs-on: ubuntu-latest + steps: + - uses: cachix/install-nix-action@master + with: + github_access_token: ${{ secrets.GITHUB_TOKEN }} + + - name: Checkout + uses: actions/checkout@v4 + + - name: Read version + run: | + echo -n "_version=v" >> "$GITHUB_ENV" + nix run nixpkgs#fq -- -r ".package.version" Cargo.toml >> "$GITHUB_ENV" + cat "$GITHUB_ENV" + + - name: Tag + run: | + set -x + git tag $version + git push --tags || : diff --git a/Cargo.lock b/Cargo.lock index 3ab4b04..2eaaf07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -414,9 +414,7 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", - "js-sys", "num-traits", - "wasm-bindgen", "windows-link", ] @@ -442,7 +440,6 @@ version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2" dependencies = [ - "anstream", "anstyle", "clap_lex", "strsim", @@ -620,7 +617,7 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "eris" -version = "0.1.0" +version = "1.0.0" dependencies = [ "actix-web", "chrono", @@ -633,10 +630,13 @@ dependencies = [ "prometheus 0.14.0", "prometheus_exporter", "rand", + "regex", "rlua", "serde", "serde_json", + "tempfile", "tokio", + "toml", ] [[package]] @@ -646,9 +646,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "flate2" version = "1.1.1" @@ -1580,7 +1586,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1633,6 +1639,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1740,6 +1755,19 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "tempfile" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -1876,6 +1904,47 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" + [[package]] name = "tracing" version = "0.1.41" @@ -2187,6 +2256,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e27d6ad3dac991091e4d35de9ba2d2d00647c5d0fc26c5496dee55984ae111b" +dependencies = [ + "memchr", +] + [[package]] name = "winsafe" version = "0.0.19" diff --git a/Cargo.toml b/Cargo.toml index 2292db4..42cc24b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "eris" -version = "0.1.0" +version = "1.0.0" edition = "2024" [dependencies] -actix-web = "4.3.1" -clap = { version = "4.3", features = ["derive"] } -chrono = "0.4.24" +actix-web = { version = "4.3.1" } +chrono = { version = "0.4.41", default-features = false, features = ["std", "clock"] } +clap = { version = "4.5", default-features = false, features = ["std", "derive", "help", "usage", "suggestions"] } futures = "0.3.28" ipnetwork = "0.21.1" lazy_static = "1.4.0" @@ -19,3 +19,6 @@ serde_json = "1.0.96" tokio = { version = "1.28.0", features = ["full"] } log = "0.4.27" env_logger = "0.11.8" +tempfile = "3.19.1" +regex = "1.11.1" +toml = "0.8.22" diff --git a/nix/package.nix b/nix/package.nix index b227e1a..aa6e975 100644 --- a/nix/package.nix +++ b/nix/package.nix @@ -19,6 +19,7 @@ in fileset = fs.unions [ (fs.fileFilter (file: builtins.any file.hasExt ["rs"]) (s + /src)) (s + /contrib) + (s + /resources) lockfile cargoToml ]; diff --git a/resources/default_script.lua b/resources/default_script.lua new file mode 100644 index 0000000..f2f51fd --- /dev/null +++ b/resources/default_script.lua @@ -0,0 +1,210 @@ +--[[ +Eris Default Script + +This script demonstrates how to use the Eris Lua API to customize +the tarpit's behavior, and will be loaded by default if no other +scripts are loaded. + +Available events: +- connection: When a new connection is established +- request: When a request is received +- response_gen: When generating a response +- response_chunk: Before sending each response chunk +- disconnection: When a connection is closed +- block_ip: When an IP is being considered for blocking +- startup: When the application starts +- shutdown: When the application is shutting down +- periodic: Called periodically + +API Functions: +- eris.debug(message): Log a debug message +- eris.info(message): Log an info message +- eris.warn(message): Log a warning message +- eris.error(message): Log an error message +- eris.set_state(key, value): Store persistent state +- eris.get_state(key): Retrieve persistent state +- eris.inc_counter(key, [amount]): Increment a counter +- eris.get_counter(key): Get a counter value +- eris.gen_token([prefix]): Generate a unique token +- eris.timestamp(): Get current Unix timestamp +--]] + +-- Called when the application starts +eris.on("startup", function(ctx) + eris.info("Initializing default script") + + -- Initialize counters + eris.inc_counter("total_connections", 0) + eris.inc_counter("total_responses", 0) + eris.inc_counter("blocked_ips", 0) + + -- Initialize banned keywords + eris.set_state("banned_keywords", "eval,exec,system,shell," + .. "\n" + .. "\n" + elseif ctx.path:find("phpunit") or ctx.path:find("eval") then + -- For PHP exploit attempts + -- Turns out you can just google "PHP error log" and search random online forums where people + -- dump their service logs in full. + enhanced_content = enhanced_content + .. "\nPHP Notice: Undefined variable: _SESSION in /var/www/html/includes/core.php on line 58\n" + .. "Warning: file_get_contents(): Filename cannot be empty in /var/www/html/vendor/autoload.php on line 23\n" + .. "Token: " + .. token + .. "\n" + elseif ctx.path:find("api") then + -- For API requests + local fake_api_key = + string.format("ak_%x%x%x", math.random(1000, 9999), math.random(1000, 9999), math.random(1000, 9999)) + + enhanced_content = enhanced_content + .. "{\n" + .. ' "status": "warning",\n' + .. ' "message": "Test API environment detected",\n' + .. ' "debug_token": "' + .. token + .. '",\n' + .. ' "api_key": "' + .. fake_api_key + .. '"\n' + .. "}\n" + else + -- For other requests + enhanced_content = enhanced_content + .. "\n" + .. "\n" + .. "\n" + end + + -- Track which honeytokens were sent to which IP + local honeytokens = eris.get_state("honeytokens") or "{}" + local ht_table = {} + + -- This is a simplistic approach - in a real script, you'd want to use + -- a proper JSON library to handle this correctly + if honeytokens ~= "{}" then + -- Simple parsing of the stored data + for ip, tok in honeytokens:gmatch('"([^"]+)":"([^"]+)"') do + ht_table[ip] = tok + end + end + + ht_table[ctx.ip] = token + + -- Convert back to a simple JSON-like string + local new_tokens = "{" + for ip, tok in pairs(ht_table) do + if new_tokens ~= "{" then + new_tokens = new_tokens .. "," + end + new_tokens = new_tokens .. '"' .. ip .. '":"' .. tok .. '"' + end + new_tokens = new_tokens .. "}" + + eris.set_state("honeytokens", new_tokens) + + return enhanced_content +end) + +-- Called before sending each chunk of a response +eris.on("response_chunk", function(ctx) + -- This can be used to alter individual chunks for more deceptive behavior + -- For example, to simulate a slow, unreliable server + + -- 5% chance of "corrupting" a chunk to confuse scanners + if math.random(1, 100) <= 5 then + local chunk = ctx.content + if #chunk > 10 then + local pos = math.random(1, #chunk - 5) + chunk = chunk:sub(1, pos) .. string.char(math.random(32, 126)) .. chunk:sub(pos + 2) + end + return chunk + end + + return ctx.content +end) + +-- Called when deciding whether to block an IP +eris.on("block_ip", function(ctx) + -- You can override the default blocking logic + + -- Check for potential attackers using specific patterns + local banned_keywords = eris.get_state("banned_keywords") or "" + local user_agent = ctx.user_agent or "" + + -- Check if user agent contains highly suspicious patterns + for keyword in banned_keywords:gmatch("[^,]+") do + if user_agent:lower():find(keyword:lower()) then + eris.info("Blocking IP " .. ctx.ip .. " due to suspicious user agent: " .. keyword) + eris.inc_counter("blocked_ips") + return true -- Force block + end + end + + -- For demonstration, we'll be more lenient with 10.x IPs + if ctx.ip:match("^10%.") then + -- Only block if they've hit us many times + return ctx.hit_count >= 5 + end + + -- Default to the system's threshold-based decision + return nil +end) + +-- The enhance_response is now legacy, and I never liked it anyway. Though let's add it here +-- for the sake of backwards compatibility. +function enhance_response(text, response_type, path, token) + local enhanced = text + + -- Add token as a comment + if response_type == "php_exploit" then + enhanced = enhanced .. "\n/* Token: " .. token .. " */\n" + elseif response_type == "wordpress" then + enhanced = enhanced .. "\n\n" + elseif response_type == "api" then + enhanced = enhanced:gsub('"status": "[^"]+"', '"status": "warning"') + enhanced = enhanced:gsub('"message": "[^"]+"', '"message": "API token: ' .. token .. '"') + else + enhanced = enhanced .. "\n\n" + end + + return enhanced +end diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..74a1571 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,613 @@ +use clap::Parser; +use ipnetwork::IpNetwork; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::env; +use std::fs; +use std::net::IpAddr; +use std::path::{Path, PathBuf}; + +// Command-line arguments using clap +#[derive(Parser, Debug, Clone)] +#[clap( + author, + version, + about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" +)] +pub struct Args { + #[clap( + long, + default_value = "0.0.0.0:8888", + help = "Address and port to listen for incoming HTTP requests (format: ip:port)" + )] + pub listen_addr: String, + + #[clap( + long, + default_value = "0.0.0.0:9100", + help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)" + )] + pub metrics_addr: String, + + #[clap(long, help = "Disable Prometheus metrics server completely")] + pub disable_metrics: bool, + + #[clap( + long, + default_value = "127.0.0.1:80", + help = "Backend server address to proxy legitimate requests to (format: ip:port)" + )] + pub backend_addr: String, + + #[clap( + long, + default_value = "1000", + help = "Minimum delay in milliseconds between chunks sent to attacker" + )] + pub min_delay: u64, + + #[clap( + long, + default_value = "15000", + help = "Maximum delay in milliseconds between chunks sent to attacker" + )] + pub max_delay: u64, + + #[clap( + long, + default_value = "600", + help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" + )] + pub max_tarpit_time: u64, + + #[clap( + long, + default_value = "3", + help = "Number of hits to honeypot patterns before permanently blocking an IP" + )] + pub block_threshold: u32, + + #[clap( + long, + help = "Base directory for all application data (overrides XDG directory structure)" + )] + pub base_dir: Option, + + #[clap( + long, + help = "Path to configuration file (JSON or TOML, overrides command line options)" + )] + pub config_file: Option, + + #[clap( + long, + default_value = "info", + help = "Log level: trace, debug, info, warn, error" + )] + pub log_level: String, + + #[clap( + long, + default_value = "pretty", + help = "Log format: plain, pretty, json, pretty-json" + )] + pub log_format: String, + + #[clap(long, help = "Enable rate limiting for connections from the same IP")] + pub rate_limit_enabled: bool, + + #[clap(long, default_value = "60", help = "Rate limit window in seconds")] + pub rate_limit_window: u64, + + #[clap( + long, + default_value = "30", + help = "Maximum number of connections allowed per IP in the rate limit window" + )] + pub rate_limit_max: usize, + + #[clap( + long, + default_value = "100", + help = "Connection attempts threshold before considering for IP blocking" + )] + pub rate_limit_block_threshold: usize, + + #[clap( + long, + help = "Send a 429 response for rate limited connections instead of dropping connection" + )] + pub rate_limit_slow_response: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +pub enum LogFormat { + Plain, + #[default] + Pretty, + Json, + PrettyJson, +} + +// Trap pattern structure. It can be either a plain string +// regex to catch more advanced patterns necessitated by +// more sophisticated crawlers. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum TrapPattern { + Plain(String), + Regex { pattern: String, regex: bool }, +} + +impl TrapPattern { + pub fn as_plain(value: &str) -> Self { + Self::Plain(value.to_string()) + } + + pub fn as_regex(value: &str) -> Self { + Self::Regex { + pattern: value.to_string(), + regex: true, + } + } + + pub fn matches(&self, path: &str) -> bool { + match self { + Self::Plain(pattern) => path.contains(pattern), + Self::Regex { + pattern, + regex: true, + } => { + if let Ok(re) = Regex::new(pattern) { + re.is_match(path) + } else { + false + } + } + _ => false, + } + } +} + +// Configuration structure +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Config { + pub listen_addr: String, + pub metrics_addr: String, + pub disable_metrics: bool, + pub backend_addr: String, + pub min_delay: u64, + pub max_delay: u64, + pub max_tarpit_time: u64, + pub block_threshold: u32, + pub trap_patterns: Vec, + pub whitelist_networks: Vec, + pub markov_corpora_dir: String, + pub lua_scripts_dir: String, + pub data_dir: String, + pub config_dir: String, + pub cache_dir: String, + pub log_format: LogFormat, + pub rate_limit_enabled: bool, + pub rate_limit_window_seconds: u64, + pub rate_limit_max_connections: usize, + pub rate_limit_block_threshold: usize, + pub rate_limit_slow_response: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + listen_addr: "0.0.0.0:8888".to_string(), + metrics_addr: "0.0.0.0:9100".to_string(), + disable_metrics: false, + backend_addr: "127.0.0.1:80".to_string(), + min_delay: 1000, + max_delay: 15000, + max_tarpit_time: 600, + block_threshold: 3, + trap_patterns: vec![ + // Basic attack patterns as plain strings + TrapPattern::as_plain("/vendor/phpunit"), + TrapPattern::as_plain("eval-stdin.php"), + TrapPattern::as_plain("/wp-admin"), + TrapPattern::as_plain("/wp-login.php"), + TrapPattern::as_plain("/xmlrpc.php"), + TrapPattern::as_plain("/phpMyAdmin"), + TrapPattern::as_plain("/solr/"), + TrapPattern::as_plain("/.env"), + TrapPattern::as_plain("/config"), + TrapPattern::as_plain("/actuator/"), + // More aggressive patterns for various PHP exploits. + // XXX: I dedicate this entire section to that one single crawler + // that has been scanning my entire network, hitting 403s left and right + // but not giving up, and coming back the next day at the same time to + // scan the same paths over and over. Kudos to you, random crawler. + TrapPattern::as_regex(r"/.*phpunit.*eval-stdin\.php"), + TrapPattern::as_regex(r"/index\.php\?s=/index/\\think\\app/invokefunction"), + TrapPattern::as_regex(r".*%ADd\+auto_prepend_file%3dphp://input.*"), + TrapPattern::as_regex(r".*%ADd\+allow_url_include%3d1.*"), + TrapPattern::as_regex(r".*/wp-content/plugins/.*\.php"), + TrapPattern::as_regex(r".*/wp-content/themes/.*\.php"), + TrapPattern::as_regex(r".*eval\(.*\).*"), + TrapPattern::as_regex(r".*/adminer\.php.*"), + TrapPattern::as_regex(r".*/admin\.php.*"), + TrapPattern::as_regex(r".*/administrator/.*"), + TrapPattern::as_regex(r".*/wp-json/.*"), + TrapPattern::as_regex(r".*/api/.*\.php.*"), + TrapPattern::as_regex(r".*/cgi-bin/.*"), + TrapPattern::as_regex(r".*/owa/.*"), + TrapPattern::as_regex(r".*/ecp/.*"), + TrapPattern::as_regex(r".*/webshell\.php.*"), + TrapPattern::as_regex(r".*/shell\.php.*"), + TrapPattern::as_regex(r".*/cmd\.php.*"), + TrapPattern::as_regex(r".*/struts.*"), + ], + whitelist_networks: vec![ + "192.168.0.0/16".to_string(), + "10.0.0.0/8".to_string(), + "172.16.0.0/12".to_string(), + "127.0.0.0/8".to_string(), + ], + markov_corpora_dir: "./corpora".to_string(), + lua_scripts_dir: "./scripts".to_string(), + data_dir: "./data".to_string(), + config_dir: "./conf".to_string(), + cache_dir: "./cache".to_string(), + log_format: LogFormat::Pretty, + rate_limit_enabled: true, + rate_limit_window_seconds: 60, + rate_limit_max_connections: 30, + rate_limit_block_threshold: 100, + rate_limit_slow_response: true, + } + } +} + +// Gets standard XDG directory paths for config, data and cache. +// XXX: This could be "simplified" by using the Dirs crate, but I can't +// really justify pulling a library for something I can handle in less +// than 30 lines. Unless cross-platform becomes necessary, the below +// implementation is good enough. For alternative platforms, we can simply +// enhance the current implementation as needed. +pub fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) { + let config_home = env::var_os("XDG_CONFIG_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".config") + }); + + let data_home = env::var_os("XDG_DATA_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".local").join("share") + }); + + let cache_home = env::var_os("XDG_CACHE_HOME") + .map(PathBuf::from) + .unwrap_or_else(|| { + let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); + home.join(".cache") + }); + + let config_dir = config_home.join("eris"); + let data_dir = data_home.join("eris"); + let cache_dir = cache_home.join("eris"); + + (config_dir, data_dir, cache_dir) +} + +impl Config { + // Create configuration from command-line args. We'll be falling back to this + // when the configuration is invalid, so it must be validated more strictly. + pub fn from_args(args: &Args) -> Self { + let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir { + let base_str = base_dir.to_string_lossy().to_string(); + ( + format!("{base_str}/conf"), + format!("{base_str}/data"), + format!("{base_str}/cache"), + ) + } else { + let (c, d, cache) = get_xdg_dirs(); + ( + c.to_string_lossy().to_string(), + d.to_string_lossy().to_string(), + cache.to_string_lossy().to_string(), + ) + }; + + Self { + listen_addr: args.listen_addr.clone(), + metrics_addr: args.metrics_addr.clone(), + disable_metrics: args.disable_metrics, + backend_addr: args.backend_addr.clone(), + min_delay: args.min_delay, + max_delay: args.max_delay, + max_tarpit_time: args.max_tarpit_time, + block_threshold: args.block_threshold, + markov_corpora_dir: format!("{data_dir}/corpora"), + lua_scripts_dir: format!("{data_dir}/scripts"), + data_dir, + config_dir, + cache_dir, + log_format: LogFormat::Pretty, + rate_limit_enabled: args.rate_limit_enabled, + rate_limit_window_seconds: args.rate_limit_window, + rate_limit_max_connections: args.rate_limit_max, + rate_limit_block_threshold: args.rate_limit_block_threshold, + rate_limit_slow_response: args.rate_limit_slow_response, + ..Default::default() + } + } + + // Load configuration from a file (JSON or TOML) + pub fn load_from_file(path: &Path) -> std::io::Result { + let content = fs::read_to_string(path)?; + + let extension = path + .extension() + .map(|ext| ext.to_string_lossy().to_lowercase()) + .unwrap_or_default(); + + let config = match extension.as_str() { + "toml" => toml::from_str(&content).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to parse TOML: {e}"), + ) + })?, + _ => { + // Default to JSON for any other extension + serde_json::from_str(&content).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to parse JSON: {e}"), + ) + })? + } + }; + + Ok(config) + } + + // Save configuration to a file (JSON or TOML) + pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + + let extension = path + .extension() + .map(|ext| ext.to_string_lossy().to_lowercase()) + .unwrap_or_default(); + + let content = match extension.as_str() { + "toml" => toml::to_string_pretty(self).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to serialize to TOML: {e}"), + ) + })?, + _ => { + // Default to JSON for any other extension + serde_json::to_string_pretty(self).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to serialize to JSON: {e}"), + ) + })? + } + }; + + fs::write(path, content)?; + Ok(()) + } + + // Create required directories if they don't exist + pub fn ensure_dirs_exist(&self) -> std::io::Result<()> { + let dirs = [ + &self.markov_corpora_dir, + &self.lua_scripts_dir, + &self.data_dir, + &self.config_dir, + &self.cache_dir, + ]; + + for dir in dirs { + fs::create_dir_all(dir)?; + log::debug!("Created directory: {dir}"); + } + + Ok(()) + } +} + +// Decide if a request should be tarpitted based on path and IP +pub fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { + // Check whitelist IPs first to avoid unnecessary pattern matching + for network_str in &config.whitelist_networks { + if let Ok(network) = network_str.parse::() { + if network.contains(*ip) { + return false; + } + } + } + + // Use pattern matching based on the trap pattern type. It can be + // a plain string or regex. + for pattern in &config.trap_patterns { + if pattern.matches(path) { + return true; + } + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + + #[test] + fn test_config_from_args() { + let args = Args { + listen_addr: "127.0.0.1:8080".to_string(), + metrics_addr: "127.0.0.1:9000".to_string(), + disable_metrics: true, + backend_addr: "127.0.0.1:8081".to_string(), + min_delay: 500, + max_delay: 10000, + max_tarpit_time: 300, + block_threshold: 5, + base_dir: Some(PathBuf::from("/tmp/eris")), + config_file: None, + log_level: "debug".to_string(), + log_format: "pretty".to_string(), + rate_limit_enabled: true, + rate_limit_window: 30, + rate_limit_max: 20, + rate_limit_block_threshold: 50, + rate_limit_slow_response: true, + }; + + let config = Config::from_args(&args); + assert_eq!(config.listen_addr, "127.0.0.1:8080"); + assert_eq!(config.metrics_addr, "127.0.0.1:9000"); + assert!(config.disable_metrics); + assert_eq!(config.backend_addr, "127.0.0.1:8081"); + assert_eq!(config.min_delay, 500); + assert_eq!(config.max_delay, 10000); + assert_eq!(config.max_tarpit_time, 300); + assert_eq!(config.block_threshold, 5); + assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora"); + assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts"); + assert_eq!(config.data_dir, "/tmp/eris/data"); + assert_eq!(config.config_dir, "/tmp/eris/conf"); + assert_eq!(config.cache_dir, "/tmp/eris/cache"); + assert!(config.rate_limit_enabled); + assert_eq!(config.rate_limit_window_seconds, 30); + assert_eq!(config.rate_limit_max_connections, 20); + assert_eq!(config.rate_limit_block_threshold, 50); + assert!(config.rate_limit_slow_response); + } + + #[test] + fn test_trap_pattern_matching() { + // Test plain string pattern + let plain = TrapPattern::as_plain("phpunit"); + assert!(plain.matches("path/to/phpunit/test")); + assert!(!plain.matches("path/to/something/else")); + + // Test regex pattern + let regex = TrapPattern::as_regex(r".*eval-stdin\.php.*"); + assert!(regex.matches("/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php")); + assert!(regex.matches("/tests/eval-stdin.php?param")); + assert!(!regex.matches("/normal/path")); + + // Test invalid regex pattern (should return false) + let invalid = TrapPattern::Regex { + pattern: "(invalid[regex".to_string(), + regex: true, + }; + assert!(!invalid.matches("anything")); + } + + #[tokio::test] + async fn test_should_tarpit() { + let config = Config::default(); + + // Test trap patterns + assert!(should_tarpit( + "/vendor/phpunit/whatever", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + assert!(should_tarpit( + "/wp-admin/login.php", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + assert!(should_tarpit( + "/.env", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + + // Test whitelist networks + assert!(!should_tarpit( + "/wp-admin/login.php", + &IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + &config + )); + assert!(!should_tarpit( + "/vendor/phpunit/whatever", + &IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + &config + )); + + // Test legitimate paths + assert!(!should_tarpit( + "/index.html", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + assert!(!should_tarpit( + "/images/logo.png", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + + // Test regex patterns + assert!(should_tarpit( + "/index.php?s=/index/\\think\\app/invokefunction&function=call_user_func_array&vars[0]=md5&vars[1][]=Hello", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + + assert!(should_tarpit( + "/hello.world?%ADd+allow_url_include%3d1+%ADd+auto_prepend_file%3dphp://input", + &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + &config + )); + } + + #[test] + fn test_config_file_formats() { + // Create temporary JSON config file + let temp_dir = std::env::temp_dir(); + let json_path = temp_dir.join("temp_config.json"); + let toml_path = temp_dir.join("temp_config.toml"); + + let config = Config::default(); + + // Test JSON serialization and deserialization + config.save_to_file(&json_path).unwrap(); + let loaded_json = Config::load_from_file(&json_path).unwrap(); + assert_eq!(loaded_json.listen_addr, config.listen_addr); + assert_eq!(loaded_json.min_delay, config.min_delay); + assert_eq!(loaded_json.rate_limit_enabled, config.rate_limit_enabled); + assert_eq!( + loaded_json.rate_limit_max_connections, + config.rate_limit_max_connections + ); + + // Test TOML serialization and deserialization + config.save_to_file(&toml_path).unwrap(); + let loaded_toml = Config::load_from_file(&toml_path).unwrap(); + assert_eq!(loaded_toml.listen_addr, config.listen_addr); + assert_eq!(loaded_toml.min_delay, config.min_delay); + assert_eq!(loaded_toml.rate_limit_enabled, config.rate_limit_enabled); + assert_eq!( + loaded_toml.rate_limit_max_connections, + config.rate_limit_max_connections + ); + + // Clean up + let _ = std::fs::remove_file(json_path); + let _ = std::fs::remove_file(toml_path); + } +} diff --git a/src/firewall.rs b/src/firewall.rs new file mode 100644 index 0000000..73f36e5 --- /dev/null +++ b/src/firewall.rs @@ -0,0 +1,128 @@ +use tokio::process::Command; + +// Set up nftables firewall rules for IP blocking +pub async fn setup_firewall() -> Result<(), String> { + log::info!("Setting up firewall rules"); + + // Check if nft command exists + let nft_exists = Command::new("which") + .arg("nft") + .output() + .await + .map(|output| output.status.success()) + .unwrap_or(false); + + if !nft_exists { + log::warn!("nft command not found. Firewall rules will not be set up."); + return Ok(()); + } + + // Create table if it doesn't exist + let output = Command::new("nft") + .args(["list", "table", "inet", "filter"]) + .output() + .await; + + match output { + Ok(output) => { + if !output.status.success() { + log::info!("Creating nftables table"); + let result = Command::new("nft") + .args(["create", "table", "inet", "filter"]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create nftables table: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if nftables table exists: {e}"); + log::info!("Will try to create it anyway"); + let result = Command::new("nft") + .args(["create", "table", "inet", "filter"]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create nftables table: {e}")); + } + } + } + + // Create blacklist set if it doesn't exist + let output = Command::new("nft") + .args(["list", "set", "inet", "filter", "eris_blacklist"]) + .output() + .await; + + match output { + Ok(output) => { + if !output.status.success() { + log::info!("Creating eris_blacklist set"); + let result = Command::new("nft") + .args([ + "create", + "set", + "inet", + "filter", + "eris_blacklist", + "{ type ipv4_addr; flags interval; }", + ]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to create blacklist set: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if blacklist set exists: {e}"); + return Err(format!("Failed to check if blacklist set exists: {e}")); + } + } + + // Add rule to drop traffic from blacklisted IPs + let output = Command::new("nft") + .args(["list", "chain", "inet", "filter", "input"]) + .output() + .await; + + // Check if our rule already exists + match output { + Ok(output) => { + let rule_exists = String::from_utf8_lossy(&output.stdout) + .contains("ip saddr @eris_blacklist counter drop"); + + if !rule_exists { + log::info!("Adding drop rule for blacklisted IPs"); + let result = Command::new("nft") + .args([ + "add", + "rule", + "inet", + "filter", + "input", + "ip saddr @eris_blacklist", + "counter", + "drop", + ]) + .output() + .await; + + if let Err(e) = result { + return Err(format!("Failed to add firewall rule: {e}")); + } + } + } + Err(e) => { + log::warn!("Failed to check if firewall rule exists: {e}"); + return Err(format!("Failed to check if firewall rule exists: {e}")); + } + } + + log::info!("Firewall setup complete"); + Ok(()) +} diff --git a/src/lua/mod.rs b/src/lua/mod.rs new file mode 100644 index 0000000..fc96382 --- /dev/null +++ b/src/lua/mod.rs @@ -0,0 +1,901 @@ +use rlua::{Function, Lua, Table, Value}; +use std::collections::HashMap; +use std::collections::hash_map::DefaultHasher; +use std::fs; +use std::hash::{Hash, Hasher}; +use std::path::Path; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::{SystemTime, UNIX_EPOCH}; + +// Event types for the Lua scripting system +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum EventType { + Connection, // when a new connection is established + Request, // when a request is received + ResponseGen, // when generating a response + ResponseChunk, // before sending each response chunk + Disconnection, // when a connection is closed + BlockIP, // when an IP is being considered for blocking + Startup, // when the application starts + Shutdown, // when the application is shutting down + Periodic, // called periodically (e.g., every minute) +} + +impl EventType { + /// Convert event type to string representation for Lua + const fn as_str(&self) -> &'static str { + match self { + Self::Connection => "connection", + Self::Request => "request", + Self::ResponseGen => "response_gen", + Self::ResponseChunk => "response_chunk", + Self::Disconnection => "disconnection", + Self::BlockIP => "block_ip", + Self::Startup => "startup", + Self::Shutdown => "shutdown", + Self::Periodic => "periodic", + } + } + + /// Convert from string to `EventType` + fn from_str(s: &str) -> Option { + match s { + "connection" => Some(Self::Connection), + "request" => Some(Self::Request), + "response_gen" => Some(Self::ResponseGen), + "response_chunk" => Some(Self::ResponseChunk), + "disconnection" => Some(Self::Disconnection), + "block_ip" => Some(Self::BlockIP), + "startup" => Some(Self::Startup), + "shutdown" => Some(Self::Shutdown), + "periodic" => Some(Self::Periodic), + _ => None, + } + } +} + +// Loaded Lua script with its metadata +struct ScriptInfo { + name: String, + enabled: bool, +} + +// Script state and manage the Lua environment +pub struct ScriptManager { + lua: Mutex, + scripts: Vec, + hooks: HashMap>, + state: Arc>>, + counters: Arc>>, +} + +// Context passed to Lua event handlers +pub struct EventContext { + pub event_type: EventType, + pub ip: Option, + pub path: Option, + pub user_agent: Option, + pub request_headers: Option>, + pub content: Option, + pub timestamp: u64, + pub session_id: Option, +} + +// Make ScriptManager explicitly Send + Sync since we're using Mutex +unsafe impl Send for ScriptManager {} +unsafe impl Sync for ScriptManager {} + +impl ScriptManager { + /// Create a new script manager and load scripts from the given directory + pub fn new(scripts_dir: &str) -> Self { + let mut manager = Self { + lua: Mutex::new(Lua::new()), + scripts: Vec::new(), + hooks: HashMap::new(), + state: Arc::new(RwLock::new(HashMap::new())), + counters: Arc::new(RwLock::new(HashMap::new())), + }; + + // Initialize Lua environment with our API + manager.init_lua_env(); + + // Load scripts from directory + manager.load_scripts_from_dir(scripts_dir); + + // If no scripts were loaded, use default script + if manager.scripts.is_empty() { + log::info!("No Lua scripts found, loading default scripts"); + manager.load_script( + "default", + include_str!("../../resources/default_script.lua"), + ); + } + + // Trigger startup event + manager.trigger_event(&EventContext { + event_type: EventType::Startup, + ip: None, + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }); + + manager + } + + // Initialize the Lua environment + fn init_lua_env(&self) { + let state_clone = self.state.clone(); + let counters_clone = self.counters.clone(); + + if let Ok(lua) = self.lua.lock() { + // Create eris global table for our API + let eris_table = lua.create_table().unwrap(); + + self.register_utility_functions(&lua, &eris_table, state_clone, counters_clone); + self.register_event_functions(&lua, &eris_table); + self.register_logging_functions(&lua, &eris_table); + + // Set the eris global table + lua.globals().set("eris", eris_table).unwrap(); + } + } + + /// Register utility functions for scripts to use + fn register_utility_functions( + &self, + lua: &Lua, + eris_table: &Table, + state: Arc>>, + counters: Arc>>, + ) { + // Store a key-value pair in persistent state + let state_for_set = state.clone(); + let set_state = lua + .create_function(move |_, (key, value): (String, String)| { + let mut state_map = state_for_set.write().unwrap(); + state_map.insert(key, value); + Ok(()) + }) + .unwrap(); + eris_table.set("set_state", set_state).unwrap(); + + // Get a value from persistent state + let state_for_get = state; + let get_state = lua + .create_function(move |_, key: String| { + let state_map = state_for_get.read().unwrap(); + let value = state_map.get(&key).cloned(); + Ok(value) + }) + .unwrap(); + eris_table.set("get_state", get_state).unwrap(); + + // Increment a counter + let counters_for_inc = counters.clone(); + let inc_counter = lua + .create_function(move |_, (key, amount): (String, Option)| { + let mut counters_map = counters_for_inc.write().unwrap(); + let counter = counters_map.entry(key).or_insert(0); + *counter += amount.unwrap_or(1); + Ok(*counter) + }) + .unwrap(); + eris_table.set("inc_counter", inc_counter).unwrap(); + + // Get a counter value + let counters_for_get = counters; + let get_counter = lua + .create_function(move |_, key: String| { + let counters_map = counters_for_get.read().unwrap(); + let value = counters_map.get(&key).copied().unwrap_or(0); + Ok(value) + }) + .unwrap(); + eris_table.set("get_counter", get_counter).unwrap(); + + // Generate a random token/string + let gen_token = lua + .create_function(move |_, prefix: Option| { + let now = get_timestamp(); + let random = rand::random::(); + let token = format!("{}{:x}{:x}", prefix.unwrap_or_default(), now, random); + Ok(token) + }) + .unwrap(); + eris_table.set("gen_token", gen_token).unwrap(); + + // Get current timestamp + let timestamp = lua + .create_function(move |_, ()| Ok(get_timestamp())) + .unwrap(); + eris_table.set("timestamp", timestamp).unwrap(); + } + + // Register event handling functions + fn register_event_functions(&self, lua: &Lua, eris_table: &Table) { + // Create a table to store event handlers + let handlers_table = lua.create_table().unwrap(); + eris_table.set("handlers", handlers_table).unwrap(); + + // Function for scripts to register event handlers + let on_fn = lua + .create_function(move |lua, (event_name, handler): (String, Function)| { + let globals = lua.globals(); + let eris: Table = globals.get("eris").unwrap(); + let handlers: Table = eris.get("handlers").unwrap(); + + // Get or create a table for this event type + let event_handlers: Table = if let Ok(table) = handlers.get(&*event_name) { + table + } else { + let new_table = lua.create_table().unwrap(); + handlers.set(&*event_name, new_table.clone()).unwrap(); + new_table + }; + + // Add the handler to the table + let next_index = event_handlers.len().unwrap() + 1; + event_handlers.set(next_index, handler).unwrap(); + + Ok(()) + }) + .unwrap(); + eris_table.set("on", on_fn).unwrap(); + } + + // Register logging functions + fn register_logging_functions(&self, lua: &Lua, eris_table: &Table) { + // Debug logging + let debug = lua + .create_function(|_, message: String| { + log::debug!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("debug", debug).unwrap(); + + // Info logging + let info = lua + .create_function(|_, message: String| { + log::info!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("info", info).unwrap(); + + // Warning logging + let warn = lua + .create_function(|_, message: String| { + log::warn!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("warn", warn).unwrap(); + + // Error logging + let error = lua + .create_function(|_, message: String| { + log::error!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("error", error).unwrap(); + } + + // Load all scripts from a directory + fn load_scripts_from_dir(&mut self, scripts_dir: &str) { + let script_dir = Path::new(scripts_dir); + if !script_dir.exists() { + log::warn!("Lua scripts directory does not exist: {scripts_dir}"); + return; + } + + log::debug!("Loading Lua scripts from directory: {scripts_dir}"); + if let Ok(entries) = fs::read_dir(script_dir) { + // Sort entries by filename to ensure consistent loading order + let mut sorted_entries: Vec<_> = entries.filter_map(Result::ok).collect(); + sorted_entries.sort_by_key(std::fs::DirEntry::path); + + for entry in sorted_entries { + let path = entry.path(); + if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { + if let Ok(content) = fs::read_to_string(&path) { + let script_name = path + .file_stem() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string(); + + log::debug!("Loading Lua script: {} ({})", script_name, path.display()); + self.load_script(&script_name, &content); + } else { + log::warn!("Failed to read Lua script: {}", path.display()); + } + } + } + } + } + + // Load a single script and register its event handlers + fn load_script(&mut self, name: &str, content: &str) { + // Store script info + self.scripts.push(ScriptInfo { + name: name.to_string(), + enabled: true, + }); + + // Execute the script to register its event handlers + if let Ok(lua) = self.lua.lock() { + if let Err(e) = lua.load(content).set_name(name).exec() { + log::warn!("Error loading Lua script '{name}': {e}"); + return; + } + + // Collect registered event handlers + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => return, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => return, + }; + + // Store the event handlers in our hooks map + let mut tmp: rlua::TablePairs<'_, String, Table<'_>> = + handlers.pairs::(); + 'l: loop { + if let Some(event_pair) = tmp.next() { + if let Ok((event_name, _)) = event_pair { + if let Some(event_type) = EventType::from_str(&event_name) { + self.hooks + .entry(event_type) + .or_default() + .push(name.to_string()); + } + } + } else { + break 'l; + } + } + + log::info!("Loaded Lua script '{name}' successfully"); + } + } + + /// Check if a script is enabled + fn is_script_enabled(&self, name: &str) -> bool { + self.scripts + .iter() + .find(|s| s.name == name) + .is_some_and(|s| s.enabled) + } + + /// Trigger an event, calling all registered handlers + pub fn trigger_event(&self, ctx: &EventContext) -> Option { + // Check if we have any handlers for this event + if !self.hooks.contains_key(&ctx.event_type) { + return ctx.content.clone(); + } + + // Build the event data table to pass to Lua handlers + let mut result = ctx.content.clone(); + + if let Ok(lua) = self.lua.lock() { + // Create the event context table + let event_ctx = lua.create_table().unwrap(); + + // Add all the context fields + event_ctx.set("event", ctx.event_type.as_str()).unwrap(); + if let Some(ip) = &ctx.ip { + event_ctx.set("ip", ip.clone()).unwrap(); + } + if let Some(path) = &ctx.path { + event_ctx.set("path", path.clone()).unwrap(); + } + if let Some(ua) = &ctx.user_agent { + event_ctx.set("user_agent", ua.clone()).unwrap(); + } + event_ctx.set("timestamp", ctx.timestamp).unwrap(); + if let Some(sid) = &ctx.session_id { + event_ctx.set("session_id", sid.clone()).unwrap(); + } + + // Add request headers if available + if let Some(headers) = &ctx.request_headers { + let headers_table = lua.create_table().unwrap(); + for (key, value) in headers { + headers_table + .set(key.to_string(), value.to_string()) + .unwrap(); + } + event_ctx.set("headers", headers_table).unwrap(); + } + + // Add content if available + if let Some(content) = &ctx.content { + event_ctx.set("content", content.clone()).unwrap(); + } + + // Call all registered handlers for this event + if let Some(handler_scripts) = self.hooks.get(&ctx.event_type) { + for script_name in handler_scripts { + // Skip disabled scripts + if !self.is_script_enabled(script_name) { + continue; + } + + // Get the globals and handlers table + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => continue, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => continue, + }; + + // Get handlers for this event + let event_handlers: Table = match handlers.get(ctx.event_type.as_str()) { + Ok(table) => table, + Err(_) => continue, + }; + + // Call each handler + for pair in event_handlers.pairs::() { + if let Ok((_, handler)) = pair { + let handler_result: rlua::Result> = + handler.call((event_ctx.clone(),)); + if let Ok(Some(new_content)) = handler_result { + // For response events, allow handlers to modify the content + if matches!( + ctx.event_type, + EventType::ResponseGen | EventType::ResponseChunk + ) { + result = Some(new_content); + } + } + } + } + } + } + } + + result + } + + /// Generate a deceptive response, calling all `response_gen` handlers + pub fn generate_response( + &self, + path: &str, + user_agent: &str, + ip: &str, + headers: &HashMap, + markov_text: &str, + ) -> String { + // Create event context + let ctx = EventContext { + event_type: EventType::ResponseGen, + ip: Some(ip.to_string()), + path: Some(path.to_string()), + user_agent: Some(user_agent.to_string()), + request_headers: Some(headers.clone()), + content: Some(markov_text.to_string()), + timestamp: get_timestamp(), + session_id: Some(generate_session_id(ip, user_agent)), + }; + + /// Trigger the event and get the modified content + self.trigger_event(&ctx).unwrap_or_else(|| { + // Fallback to maintain backward compatibility + self.expand_response( + markov_text, + "generic", + path, + &generate_session_id(ip, user_agent), + ) + }) + } + + /// Process a chunk before sending it to client + pub fn process_chunk(&self, chunk: &str, ip: &str, session_id: &str) -> String { + let ctx = EventContext { + event_type: EventType::ResponseChunk, + ip: Some(ip.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: Some(chunk.to_string()), + timestamp: get_timestamp(), + session_id: Some(session_id.to_string()), + }; + + self.trigger_event(&ctx) + .unwrap_or_else(|| chunk.to_string()) + } + + /// Called when a connection is established + pub fn on_connection(&self, ip: &str) -> bool { + let ctx = EventContext { + event_type: EventType::Connection, + ip: Some(ip.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + + // If any handler returns false, reject the connection + let mut should_accept = true; + + if let Ok(lua) = self.lua.lock() { + if let Some(handler_scripts) = self.hooks.get(&EventType::Connection) { + for script_name in handler_scripts { + // Skip disabled scripts + if !self.is_script_enabled(script_name) { + continue; + } + + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => continue, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => continue, + }; + + let event_handlers: Table = match handlers.get("connection") { + Ok(table) => table, + Err(_) => continue, + }; + + for pair in event_handlers.pairs::() { + if let Ok((_, handler)) = pair { + let event_ctx = create_event_context(&lua, &ctx); + if let Ok(result) = handler.call::<_, Value>((event_ctx,)) { + if result == Value::Boolean(false) { + should_accept = false; + break; + } + } + } + } + + if !should_accept { + break; + } + } + } + } + + should_accept + } + + /// Called when deciding whether to block an IP + pub fn should_block_ip(&self, ip: &str, hit_count: u32) -> bool { + let ctx = EventContext { + event_type: EventType::BlockIP, + ip: Some(ip.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + + // We should default to not modifying the blocking decision + let mut should_block = None; + + if let Ok(lua) = self.lua.lock() { + if let Some(handler_scripts) = self.hooks.get(&EventType::BlockIP) { + for script_name in handler_scripts { + // Skip disabled scripts + if !self.is_script_enabled(script_name) { + continue; + } + + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => continue, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => continue, + }; + + let event_handlers: Table = match handlers.get("block_ip") { + Ok(table) => table, + Err(_) => continue, + }; + + for pair in event_handlers.pairs::() { + if let Ok((_, handler)) = pair { + let event_ctx = create_event_context(&lua, &ctx); + // Add hit count for the block_ip event + event_ctx.set("hit_count", hit_count).unwrap(); + + if let Ok(result) = handler.call::<_, Value>((event_ctx,)) { + if let Value::Boolean(block) = result { + should_block = Some(block); + break; + } + } + } + } + + if should_block.is_some() { + break; + } + } + } + } + + // Return the script's decision, or default to the system behavior + should_block.unwrap_or(hit_count >= 3) + } + + // Maintains backward compatibility with the old API + // XXX: I never liked expand_response, should probably be removeedf + // in the future. + pub fn expand_response( + &self, + text: &str, + response_type: &str, + path: &str, + token: &str, + ) -> String { + if let Ok(lua) = self.lua.lock() { + let globals = lua.globals(); + match globals.get::<_, Function>("enhance_response") { + Ok(enhance_func) => { + match enhance_func.call::<_, String>((text, response_type, path, token)) { + Ok(result) => result, + Err(e) => { + log::warn!("Error calling Lua function enhance_response: {e}"); + format!("{text}\n") + } + } + } + Err(_) => format!("{text}\n"), + } + } else { + format!("{text}\n") + } + } +} + +/// Get current timestamp in seconds +fn get_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +/// Create a unique session ID for tracking a connection +fn generate_session_id(ip: &str, user_agent: &str) -> String { + let timestamp = get_timestamp(); + let random = rand::random::(); + + // Use std::hash instead of xxhash_rust + let mut hasher = DefaultHasher::new(); + format!("{ip}_{user_agent}_{timestamp}").hash(&mut hasher); + let hash = hasher.finish(); + + format!("SID_{hash:x}_{random:x}") +} + +// Create an event context table in Lua +fn create_event_context<'a>(lua: &'a Lua, event_ctx: &EventContext) -> Table<'a> { + let table = lua.create_table().unwrap(); + + table.set("event", event_ctx.event_type.as_str()).unwrap(); + if let Some(ip) = &event_ctx.ip { + table.set("ip", ip.clone()).unwrap(); + } + if let Some(path) = &event_ctx.path { + table.set("path", path.clone()).unwrap(); + } + if let Some(ua) = &event_ctx.user_agent { + table.set("user_agent", ua.clone()).unwrap(); + } + table.set("timestamp", event_ctx.timestamp).unwrap(); + if let Some(sid) = &event_ctx.session_id { + table.set("session_id", sid.clone()).unwrap(); + } + if let Some(content) = &event_ctx.content { + table.set("content", content.clone()).unwrap(); + } + + table +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + use tempfile::TempDir; + + #[test] + fn test_event_registration() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("test_events.lua"); + let script_content = r#" + -- Example script with event handlers + eris.info("Registering event handlers") + + -- Connection event handler + eris.on("connection", function(ctx) + eris.debug("Connection from " .. ctx.ip) + return true -- accept the connection + end) + + -- Response generation handler + eris.on("response_gen", function(ctx) + eris.debug("Generating response for " .. ctx.path) + return ctx.content .. "" + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + // Verify hooks were registered + assert!(script_manager.hooks.contains_key(&EventType::Connection)); + assert!(script_manager.hooks.contains_key(&EventType::ResponseGen)); + assert!(!script_manager.hooks.contains_key(&EventType::BlockIP)); + } + + #[test] + fn test_generate_response() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("response_test.lua"); + let script_content = r#" + eris.on("response_gen", function(ctx) + return ctx.content .. " - Modified by " .. ctx.user_agent + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + let headers = HashMap::new(); + let result = script_manager.generate_response( + "/test/path", + "TestBot", + "127.0.0.1", + &headers, + "Original content", + ); + + assert!(result.contains("Original content")); + assert!(result.contains("Modified by TestBot")); + } + + #[test] + fn test_process_chunk() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("chunk_test.lua"); + let script_content = r#" + eris.on("response_chunk", function(ctx) + return ctx.content:gsub("secret", "REDACTED") + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + let result = script_manager.process_chunk( + "This contains a secret password", + "127.0.0.1", + "test_session", + ); + + assert!(result.contains("This contains a REDACTED password")); + } + + #[test] + fn test_should_block_ip() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("block_test.lua"); + let script_content = r#" + eris.on("block_ip", function(ctx) + -- Block any IP with "192.168.1" prefix regardless of hit count + if string.match(ctx.ip, "^192%.168%.1%.") then + return true + end + + -- Don't block IPs with "10.0" prefix even if they hit the threshold + if string.match(ctx.ip, "^10%.0%.") then + return false + end + + -- Default behavior for other IPs (nil = use system default) + return nil + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + // Should be blocked based on IP pattern + assert!(script_manager.should_block_ip("192.168.1.50", 1)); + + // Should not be blocked despite high hit count + assert!(!script_manager.should_block_ip("10.0.0.5", 10)); + + // Should use default behavior (block if >= 3 hits) + assert!(!script_manager.should_block_ip("172.16.0.1", 2)); + assert!(script_manager.should_block_ip("172.16.0.1", 3)); + } + + #[test] + fn test_state_and_counters() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("state_test.lua"); + let script_content = r#" + eris.on("startup", function(ctx) + eris.set_state("test_key", "test_value") + eris.inc_counter("visits", 0) + end) + + eris.on("connection", function(ctx) + local count = eris.inc_counter("visits") + eris.debug("Visit count: " .. count) + + -- Store last visitor + eris.set_state("last_visitor", ctx.ip) + return true + end) + + eris.on("response_gen", function(ctx) + local last_visitor = eris.get_state("last_visitor") or "unknown" + local visits = eris.get_counter("visits") + return ctx.content .. "" + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + // Simulate connections + script_manager.on_connection("192.168.1.100"); + script_manager.on_connection("10.0.0.50"); + + // Check response includes state + let headers = HashMap::new(); + let result = script_manager.generate_response( + "/test", + "test-agent", + "8.8.8.8", + &headers, + "Response", + ); + + assert!(result.contains("Last visitor: 10.0.0.50")); + assert!(result.contains("Total visits: 2")); + } +} diff --git a/src/main.rs b/src/main.rs index 2471ced..c8d80a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,1067 +1,106 @@ use actix_web::{App, HttpResponse, HttpServer, web}; use clap::Parser; -use ipnetwork::IpNetwork; -use rlua::{Function, Lua}; -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; -use std::env; use std::fs; -use std::io::Write; -use std::net::IpAddr; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::process::Command; +use std::time::Duration; +use tokio::net::TcpListener; use tokio::sync::RwLock; -use tokio::time::sleep; +mod config; +mod firewall; +mod lua; mod markov; mod metrics; +mod network; +mod state; +mod utils; +use config::{Args, Config}; +use lua::{EventContext, EventType, ScriptManager}; use markov::MarkovGenerator; -use metrics::{ - ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, - status_handler, -}; - -// Command-line arguments using clap -#[derive(Parser, Debug, Clone)] -#[clap( - author, - version, - about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers" -)] -struct Args { - #[clap( - long, - default_value = "0.0.0.0:8888", - help = "Address and port to listen for incoming HTTP requests (format: ip:port)" - )] - listen_addr: String, - - #[clap( - long, - default_value = "0.0.0.0:9100", - help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)" - )] - metrics_addr: String, - - #[clap(long, help = "Disable Prometheus metrics server completely")] - disable_metrics: bool, - - #[clap( - long, - default_value = "127.0.0.1:80", - help = "Backend server address to proxy legitimate requests to (format: ip:port)" - )] - backend_addr: String, - - #[clap( - long, - default_value = "1000", - help = "Minimum delay in milliseconds between chunks sent to attacker" - )] - min_delay: u64, - - #[clap( - long, - default_value = "15000", - help = "Maximum delay in milliseconds between chunks sent to attacker" - )] - max_delay: u64, - - #[clap( - long, - default_value = "600", - help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting" - )] - max_tarpit_time: u64, - - #[clap( - long, - default_value = "3", - help = "Number of hits to honeypot patterns before permanently blocking an IP" - )] - block_threshold: u32, - - #[clap( - long, - help = "Base directory for all application data (overrides XDG directory structure)" - )] - base_dir: Option, - - #[clap( - long, - help = "Path to JSON configuration file (overrides command line options)" - )] - config_file: Option, - - #[clap( - long, - default_value = "info", - help = "Log level: trace, debug, info, warn, error" - )] - log_level: String, -} - -// Configuration structure -#[derive(Clone, Debug, Deserialize, Serialize)] -struct Config { - listen_addr: String, - metrics_addr: String, - disable_metrics: bool, - backend_addr: String, - min_delay: u64, - max_delay: u64, - max_tarpit_time: u64, - block_threshold: u32, - trap_patterns: Vec, - whitelist_networks: Vec, - markov_corpora_dir: String, - lua_scripts_dir: String, - data_dir: String, - config_dir: String, - cache_dir: String, -} - -impl Default for Config { - fn default() -> Self { - Self { - listen_addr: "0.0.0.0:8888".to_string(), - metrics_addr: "0.0.0.0:9100".to_string(), - disable_metrics: false, - backend_addr: "127.0.0.1:80".to_string(), - min_delay: 1000, - max_delay: 15000, - max_tarpit_time: 600, - block_threshold: 3, - trap_patterns: vec![ - "/vendor/phpunit".to_string(), - "eval-stdin.php".to_string(), - "/wp-admin".to_string(), - "/wp-login.php".to_string(), - "/xmlrpc.php".to_string(), - "/phpMyAdmin".to_string(), - "/solr/".to_string(), - "/.env".to_string(), - "/config".to_string(), - "/api/".to_string(), - "/actuator/".to_string(), - ], - whitelist_networks: vec![ - "192.168.0.0/16".to_string(), - "10.0.0.0/8".to_string(), - "172.16.0.0/12".to_string(), - "127.0.0.0/8".to_string(), - ], - markov_corpora_dir: "./corpora".to_string(), - lua_scripts_dir: "./scripts".to_string(), - data_dir: "./data".to_string(), - config_dir: "./conf".to_string(), - cache_dir: "./cache".to_string(), - } - } -} - -// Gets standard XDG directory paths for config, data and cache -fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) { - let config_home = env::var_os("XDG_CONFIG_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".config") - }); - - let data_home = env::var_os("XDG_DATA_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".local").join("share") - }); - - let cache_home = env::var_os("XDG_CACHE_HOME") - .map(PathBuf::from) - .unwrap_or_else(|| { - let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from); - home.join(".cache") - }); - - let config_dir = config_home.join("eris"); - let data_dir = data_home.join("eris"); - let cache_dir = cache_home.join("eris"); - - (config_dir, data_dir, cache_dir) -} - -impl Config { - // Create configuration from command-line args - fn from_args(args: &Args) -> Self { - let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir { - let base_str = base_dir.to_string_lossy().to_string(); - ( - format!("{base_str}/conf"), - format!("{base_str}/data"), - format!("{base_str}/cache"), - ) - } else { - let (c, d, cache) = get_xdg_dirs(); - ( - c.to_string_lossy().to_string(), - d.to_string_lossy().to_string(), - cache.to_string_lossy().to_string(), - ) - }; - - Self { - listen_addr: args.listen_addr.clone(), - metrics_addr: args.metrics_addr.clone(), - disable_metrics: args.disable_metrics, - backend_addr: args.backend_addr.clone(), - min_delay: args.min_delay, - max_delay: args.max_delay, - max_tarpit_time: args.max_tarpit_time, - block_threshold: args.block_threshold, - markov_corpora_dir: format!("{data_dir}/corpora"), - lua_scripts_dir: format!("{data_dir}/scripts"), - data_dir, - config_dir, - cache_dir, - ..Default::default() - } - } - - // Load configuration from a JSON file - fn load_from_file(path: &Path) -> std::io::Result { - let content = fs::read_to_string(path)?; - let config = serde_json::from_str(&content)?; - Ok(config) - } - - // Save configuration to a JSON file - fn save_to_file(&self, path: &Path) -> std::io::Result<()> { - if let Some(parent) = path.parent() { - fs::create_dir_all(parent)?; - } - let content = serde_json::to_string_pretty(self)?; - fs::write(path, content)?; - Ok(()) - } - - // Create required directories if they don't exist - fn ensure_dirs_exist(&self) -> std::io::Result<()> { - let dirs = [ - &self.markov_corpora_dir, - &self.lua_scripts_dir, - &self.data_dir, - &self.config_dir, - &self.cache_dir, - ]; - - for dir in dirs { - fs::create_dir_all(dir)?; - log::debug!("Created directory: {dir}"); - } - - Ok(()) - } -} - -// State of bots/IPs hitting the honeypot -#[derive(Clone, Debug)] -struct BotState { - hits: HashMap, - blocked: HashSet, - active_connections: HashSet, - data_dir: String, - cache_dir: String, -} - -impl BotState { - fn new(data_dir: &str, cache_dir: &str) -> Self { - Self { - hits: HashMap::new(), - blocked: HashSet::new(), - active_connections: HashSet::new(), - data_dir: data_dir.to_string(), - cache_dir: cache_dir.to_string(), - } - } - - // Load previous state from disk - fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { - let mut state = Self::new(data_dir, cache_dir); - let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); - - if let Ok(content) = fs::read_to_string(&blocked_ips_file) { - let mut loaded = 0; - for line in content.lines() { - if let Ok(ip) = line.parse::() { - state.blocked.insert(ip); - loaded += 1; - } - } - log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); - } else { - log::info!("No blocked IPs file found at {blocked_ips_file}"); - } - - // Check for temporary hit counter cache - let hit_cache_file = format!("{cache_dir}/hit_counters.json"); - if let Ok(content) = fs::read_to_string(&hit_cache_file) { - if let Ok(hit_map) = serde_json::from_str::>(&content) { - for (ip_str, count) in hit_map { - if let Ok(ip) = ip_str.parse::() { - state.hits.insert(ip, count); - } - } - log::info!("Loaded hit counters for {} IPs", state.hits.len()); - } - } - - BLOCKED_IPS.set(state.blocked.len() as f64); - state - } - - // Persist state to disk for later reloading - fn save_to_disk(&self) { - // Save blocked IPs - if let Err(e) = fs::create_dir_all(&self.data_dir) { - log::error!("Failed to create data directory: {e}"); - return; - } - - let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); - - match fs::File::create(&blocked_ips_file) { - Ok(mut file) => { - let mut count = 0; - for ip in &self.blocked { - if writeln!(file, "{ip}").is_ok() { - count += 1; - } - } - log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); - } - Err(e) => { - log::error!("Failed to create blocked IPs file: {e}"); - } - } - - // Save hit counters to cache - if let Err(e) = fs::create_dir_all(&self.cache_dir) { - log::error!("Failed to create cache directory: {e}"); - return; - } - - let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); - let mut hit_map = HashMap::new(); - for (ip, count) in &self.hits { - hit_map.insert(ip.to_string(), *count); - } - - match fs::File::create(&hit_cache_file) { - Ok(file) => { - if let Err(e) = serde_json::to_writer(file, &hit_map) { - log::error!("Failed to write hit counters to cache: {e}"); - } else { - log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); - } - } - Err(e) => { - log::error!("Failed to create hit counter cache file: {e}"); - } - } - } -} - -// Lua scripts for response generation and customization -struct ScriptManager { - script_content: String, - scripts_loaded: bool, -} - -impl ScriptManager { - fn new(scripts_dir: &str) -> Self { - let mut script_content = String::new(); - let mut scripts_loaded = false; - - // Try to load scripts from directory - let script_dir = Path::new(scripts_dir); - if script_dir.exists() { - log::debug!("Loading Lua scripts from directory: {scripts_dir}"); - if let Ok(entries) = fs::read_dir(script_dir) { - for entry in entries { - if let Ok(entry) = entry { - let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { - if let Ok(content) = fs::read_to_string(&path) { - log::debug!("Loaded Lua script: {}", path.display()); - script_content.push_str(&content); - script_content.push('\n'); - scripts_loaded = true; - } else { - log::warn!("Failed to read Lua script: {}", path.display()); - } - } - } - } - } - } else { - log::warn!("Lua scripts directory does not exist: {scripts_dir}"); - } - - // If no scripts were loaded, use a default script - if !scripts_loaded { - log::info!("No Lua scripts found, loading default scripts"); - script_content = r#" - function generate_honeytoken(token) - local token_types = {"API_KEY", "AUTH_TOKEN", "SESSION_ID", "SECRET_KEY"} - local prefix = token_types[math.random(#token_types)] - local suffix = string.format("%08x", math.random(0xffffff)) - return prefix .. "_" .. token .. "_" .. suffix - end - - function enhance_response(text, response_type, path, token) - local result = text - local honeytoken = generate_honeytoken(token) - - -- Add some fake sensitive data - result = result .. "\n" - result = result .. "\n
Server ID: " .. token .. "
" - - return result - end - "# - .to_string(); - scripts_loaded = true; - } - - Self { - script_content, - scripts_loaded, - } - } - - // Lua is a powerful configuration language we can use to expand functionality of - // Eris, e.g., with fake tokens or honeytrap content. - fn expand_response(&self, text: &str, response_type: &str, path: &str, token: &str) -> String { - if !self.scripts_loaded { - return format!("{text}\n"); - } - - let lua = Lua::new(); - if let Err(e) = lua.load(&self.script_content).exec() { - log::warn!("Error loading Lua script: {e}"); - return format!("{text}\n"); - } - - let globals = lua.globals(); - match globals.get::<_, Function>("enhance_response") { - Ok(enhance_func) => { - match enhance_func.call::<_, String>((text, response_type, path, token)) { - Ok(result) => result, - Err(e) => { - log::warn!("Error calling Lua function enhance_response: {e}"); - format!("{text}\n") - } - } - } - Err(e) => { - log::warn!("Lua enhance_response function not found: {e}"); - format!("{text}\n") - } - } - } -} - -// Main connection handler - decides whether to tarpit or proxy -async fn handle_connection( - mut stream: TcpStream, - config: Arc, - state: Arc>, - markov_generator: Arc, - script_manager: Arc, -) { - let peer_addr = match stream.peer_addr() { - Ok(addr) => addr.ip(), - Err(e) => { - log::debug!("Failed to get peer address: {e}"); - return; - } - }; - - log::debug!("New connection from: {peer_addr}"); - - // Check if IP is already blocked - if state.read().await.blocked.contains(&peer_addr) { - log::debug!("Rejected connection from blocked IP: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - - // Read the HTTP request - let mut buffer = [0; 8192]; - let mut request_data = Vec::new(); - - // Read with timeout to prevent hanging - let read_fut = async { - loop { - match stream.read(&mut buffer).await { - Ok(0) => break, - Ok(n) => { - request_data.extend_from_slice(&buffer[..n]); - // Stop reading at empty line, this is the end of HTTP headers - if request_data.len() > 2 && &request_data[request_data.len() - 2..] == b"\r\n" - { - break; - } - } - Err(e) => { - log::debug!("Error reading from stream: {e}"); - break; - } - } - } - }; - - let timeout_fut = sleep(Duration::from_secs(5)); - - tokio::select! { - () = read_fut => {}, - () = timeout_fut => { - log::debug!("Connection timeout from: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - } - - // Parse the request - let request_str = String::from_utf8_lossy(&request_data); - let request_lines: Vec<&str> = request_str.lines().collect(); - - if request_lines.is_empty() { - log::debug!("Empty request from: {peer_addr}"); - let _ = stream.shutdown().await; - return; - } - - // Parse request line - let request_parts: Vec<&str> = request_lines[0].split_whitespace().collect(); - if request_parts.len() < 3 { - log::debug!("Malformed request from {}: {}", peer_addr, request_lines[0]); - let _ = stream.shutdown().await; - return; - } - - let method = request_parts[0]; - let path = request_parts[1]; - let protocol = request_parts[2]; - - log::debug!("Request: {method} {path} {protocol} from {peer_addr}"); - - // Parse headers - let mut headers = HashMap::new(); - for line in &request_lines[1..] { - if line.is_empty() { - break; - } - - if let Some(idx) = line.find(':') { - let key = line[..idx].trim(); - let value = line[idx + 1..].trim(); - headers.insert(key, value.to_string()); - } - } - - let user_agent = headers - .get("user-agent") - .cloned() - .unwrap_or_else(|| "unknown".to_string()); - - // Check if this request matches our tarpit patterns - let should_tarpit = should_tarpit(path, &peer_addr, &config).await; - - if should_tarpit { - log::info!("Tarpit triggered: {method} {path} from {peer_addr} (UA: {user_agent})"); - - // Update metrics - HITS_COUNTER.inc(); - PATH_HITS.with_label_values(&[path]).inc(); - UA_HITS.with_label_values(&[&user_agent]).inc(); - - // Update state and check for blocking threshold - { - let mut state = state.write().await; - state.active_connections.insert(peer_addr); - ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); - - *state.hits.entry(peer_addr).or_insert(0) += 1; - let hit_count = state.hits[&peer_addr]; - log::debug!("Hit count for {peer_addr}: {hit_count}"); - - // Block IPs that hit tarpits too many times - if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) { - log::info!("Blocking IP {peer_addr} after {hit_count} hits"); - state.blocked.insert(peer_addr); - BLOCKED_IPS.set(state.blocked.len() as f64); - state.save_to_disk(); - - // Try to add to firewall - let peer_addr_str = peer_addr.to_string(); - tokio::spawn(async move { - log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); - match Command::new("nft") - .args([ - "add", - "element", - "inet", - "filter", - "eris_blacklist", - "{", - &peer_addr_str, - "}", - ]) - .output() - .await - { - Ok(output) => { - if !output.status.success() { - log::warn!( - "Failed to add IP {} to firewall: {}", - peer_addr_str, - String::from_utf8_lossy(&output.stderr) - ); - } - } - Err(e) => { - log::warn!("Failed to execute nft command: {e}"); - } - } - }); - } - } - - // Generate a deceptive response using Markov chains and Lua - let response = - generate_deceptive_response(path, &user_agent, &markov_generator, &script_manager) - .await; - - // Send the response with the tarpit delay strategy - tarpit_connection( - stream, - response, - peer_addr, - state.clone(), - config.min_delay, - config.max_delay, - config.max_tarpit_time, - ) - .await; - } else { - log::debug!("Proxying request: {method} {path} from {peer_addr}"); - - // Proxy non-matching requests to the actual backend - proxy_to_backend( - stream, - method, - path, - protocol, - &headers, - &config.backend_addr, - ) - .await; - } -} - -// Determine if a request should be tarpitted based on path and IP -async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { - // Don't tarpit whitelisted IPs (internal networks, etc) - for network_str in &config.whitelist_networks { - if let Ok(network) = network_str.parse::() { - if network.contains(*ip) { - log::debug!("IP {ip} is in whitelist network {network_str}"); - return false; - } - } - } - - // Check if the request path matches any of our trap patterns - for pattern in &config.trap_patterns { - if path.contains(pattern) { - log::debug!("Path '{path}' matches trap pattern '{pattern}'"); - return true; - } - } - - // No trap patterns matched - false -} - -// Generate a deceptive HTTP response that appears legitimate -async fn generate_deceptive_response( - path: &str, - user_agent: &str, - markov: &MarkovGenerator, - script_manager: &ScriptManager, -) -> String { - // Choose response type based on path to seem more realistic - let response_type = if path.contains("phpunit") || path.contains("eval") { - "php_exploit" - } else if path.contains("wp-") { - "wordpress" - } else if path.contains("api") { - "api" - } else { - "generic" - }; - - log::debug!("Generating {response_type} response for path: {path}"); - - // Generate tracking token for this interaction - let tracking_token = format!( - "BOT_{}_{}", - user_agent - .chars() - .filter(|c| c.is_alphanumeric()) - .collect::(), - chrono::Utc::now().timestamp() - ); - - // Generate base response using Markov chain text generator - let markov_text = markov.generate(response_type, 30); - - // Use Lua to enhance with honeytokens and other deceptive content - let enhanced = - script_manager.expand_response(&markov_text, response_type, path, &tracking_token); - - // Return full HTTP response with appropriate headers - format!( - "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{enhanced}" - ) -} - -// Slowly feed a response to the client with random delays to waste attacker time -async fn tarpit_connection( - mut stream: TcpStream, - response: String, - peer_addr: IpAddr, - state: Arc>, - min_delay: u64, - max_delay: u64, - max_tarpit_time: u64, -) { - let start_time = Instant::now(); - let mut chars = response.chars().collect::>(); - - // Randomize the char order slightly to confuse automated tools - for i in (0..chars.len()).rev() { - if i > 0 && rand::random::() < 0.1 { - chars.swap(i, i - 1); - } - } - - log::debug!( - "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", - peer_addr, - chars.len(), - min_delay, - max_delay - ); - - let mut position = 0; - let mut chunks_sent = 0; - let mut total_delay = 0; - - // Send the response character by character with random delays - while position < chars.len() { - // Check if we've exceeded maximum tarpit time - let elapsed_secs = start_time.elapsed().as_secs(); - if elapsed_secs > max_tarpit_time { - log::info!("Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}"); - break; - } - - // Decide how many chars to send in this chunk (usually 1, sometimes more) - let chunk_size = if rand::random::() < 0.9 { - 1 - } else { - (rand::random::() * 3.0).floor() as usize + 1 - }; - - let end = (position + chunk_size).min(chars.len()); - let chunk: String = chars[position..end].iter().collect(); - - // Try to write chunk - if stream.write_all(chunk.as_bytes()).await.is_err() { - log::debug!("Connection closed by client during tarpit: {peer_addr}"); - break; - } - - if stream.flush().await.is_err() { - log::debug!("Failed to flush stream during tarpit: {peer_addr}"); - break; - } - - position = end; - chunks_sent += 1; - - // Apply random delay between min and max configured values - let delay_ms = (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; - total_delay += delay_ms; - sleep(Duration::from_millis(delay_ms)).await; - } - - log::debug!( - "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", - peer_addr, - chunks_sent, - position * 100 / chars.len(), - total_delay, - start_time.elapsed().as_secs() - ); - - // Remove from active connections - if let Ok(mut state) = state.try_write() { - state.active_connections.remove(&peer_addr); - ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); - } - - let _ = stream.shutdown().await; -} - -// Forward a legitimate request to the real backend server -async fn proxy_to_backend( - mut client_stream: TcpStream, - method: &str, - path: &str, - protocol: &str, - headers: &HashMap<&str, String>, - backend_addr: &str, -) { - // Connect to backend server - let server_stream = match TcpStream::connect(backend_addr).await { - Ok(stream) => stream, - Err(e) => { - log::warn!("Failed to connect to backend {backend_addr}: {e}"); - let _ = client_stream.shutdown().await; - return; - } - }; - - log::debug!("Connected to backend server at {backend_addr}"); - - // Forward the original request - let mut request = format!("{method} {path} {protocol}\r\n"); - for (key, value) in headers { - request.push_str(&format!("{key}: {value}\r\n")); - } - request.push_str("\r\n"); - - let mut server_stream = server_stream; - if server_stream.write_all(request.as_bytes()).await.is_err() { - log::debug!("Failed to write request to backend server"); - let _ = client_stream.shutdown().await; - return; - } - - // Set up bidirectional forwarding between client and backend - let (mut client_read, mut client_write) = client_stream.split(); - let (mut server_read, mut server_write) = server_stream.split(); - - // Client -> Server - let client_to_server = async { - let mut buf = [0; 8192]; - let mut bytes_forwarded = 0; - - loop { - match client_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if server_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); - }; - - // Server -> Client - let server_to_client = async { - let mut buf = [0; 8192]; - let mut bytes_forwarded = 0; - - loop { - match server_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - bytes_forwarded += n; - if client_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - Err(_) => break, - } - } - - log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); - }; - - // Run both directions concurrently - tokio::select! { - () = client_to_server => {}, - () = server_to_client => {}, - } - - log::debug!("Proxy connection completed"); -} - -// Set up nftables firewall rules for IP blocking -async fn setup_firewall() -> Result<(), String> { - log::info!("Setting up firewall rules"); - - // Check if nft command exists - let nft_exists = Command::new("which") - .arg("nft") - .output() - .await - .map(|output| output.status.success()) - .unwrap_or(false); - - if !nft_exists { - log::warn!("nft command not found. Firewall rules will not be set up."); - return Ok(()); - } - - // Create table if it doesn't exist - let output = Command::new("nft") - .args(["list", "table", "inet", "filter"]) - .output() - .await; - - match output { - Ok(output) => { - if !output.status.success() { - log::info!("Creating nftables table"); - let result = Command::new("nft") - .args(["create", "table", "inet", "filter"]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if nftables table exists: {e}"); - log::info!("Will try to create it anyway"); - let result = Command::new("nft") - .args(["create", "table", "inet", "filter"]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create nftables table: {e}")); - } - } - } - - // Create blacklist set if it doesn't exist - let output = Command::new("nft") - .args(["list", "set", "inet", "filter", "eris_blacklist"]) - .output() - .await; - - match output { - Ok(output) => { - if !output.status.success() { - log::info!("Creating eris_blacklist set"); - let result = Command::new("nft") - .args([ - "create", - "set", - "inet", - "filter", - "eris_blacklist", - "{ type ipv4_addr; flags interval; }", - ]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to create blacklist set: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if blacklist set exists: {e}"); - return Err(format!("Failed to check if blacklist set exists: {e}")); - } - } - - // Add rule to drop traffic from blacklisted IPs - let output = Command::new("nft") - .args(["list", "chain", "inet", "filter", "input"]) - .output() - .await; - - // Check if our rule already exists - match output { - Ok(output) => { - let rule_exists = String::from_utf8_lossy(&output.stdout) - .contains("ip saddr @eris_blacklist counter drop"); - - if !rule_exists { - log::info!("Adding drop rule for blacklisted IPs"); - let result = Command::new("nft") - .args([ - "add", - "rule", - "inet", - "filter", - "input", - "ip saddr @eris_blacklist", - "counter", - "drop", - ]) - .output() - .await; - - if let Err(e) = result { - return Err(format!("Failed to add firewall rule: {e}")); - } - } - } - Err(e) => { - log::warn!("Failed to check if firewall rule exists: {e}"); - return Err(format!("Failed to check if firewall rule exists: {e}")); - } - } - - log::info!("Firewall setup complete"); - Ok(()) -} +use metrics::{metrics_handler, status_handler}; +use network::handle_connection; +use state::BotState; +use utils::get_timestamp; #[actix_web::main] async fn main() -> std::io::Result<()> { // Parse command line arguments let args = Args::parse(); - // Initialize the logger - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(&args.log_level)) - .format_timestamp_millis() - .init(); + // Determine log format from args + let log_format = match args.log_format.to_lowercase().as_str() { + "json" => config::LogFormat::Json, + "pretty-json" => config::LogFormat::PrettyJson, + "plain" => config::LogFormat::Plain, + _ => config::LogFormat::Pretty, + }; - log::info!("Starting eris tarpit system"); + // Initialize the logger with proper formatting + let env = env_logger::Env::default().default_filter_or(&args.log_level); + let mut builder = env_logger::Builder::from_env(env); + + match log_format { + config::LogFormat::Plain => { + builder.format_timestamp_millis().init(); + } + config::LogFormat::Pretty => { + builder + .format(|buf, record| { + use std::io::Write; + let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f"); + writeln!( + buf, + "[{}] {} [{}] {}", + timestamp, + record.level(), + record.target(), + record.args() + ) + }) + .init(); + } + config::LogFormat::Json => { + builder + .format(|buf, record| { + use std::io::Write; + let json = serde_json::json!({ + "timestamp": chrono::Local::now().to_rfc3339(), + "level": record.level().to_string(), + "target": record.target(), + "message": record.args().to_string(), + "module_path": record.module_path(), + "file": record.file(), + "line": record.line(), + }); + writeln!(buf, "{}", json) + }) + .init(); + } + config::LogFormat::PrettyJson => { + builder + .format(|buf, record| { + use std::io::Write; + let json = serde_json::json!({ + "timestamp": chrono::Local::now().to_rfc3339(), + "level": record.level().to_string(), + "target": record.target(), + "message": record.args().to_string(), + "module_path": record.module_path(), + "file": record.file(), + "line": record.line(), + }); + writeln!(buf, "{}", serde_json::to_string_pretty(&json).unwrap()) + }) + .init(); + } + } + + log::info!("Starting Eris tarpit system"); // Load configuration - let config = if let Some(config_path) = &args.config_file { + let mut config = if let Some(config_path) = &args.config_file { log::info!("Loading configuration from {config_path:?}"); match Config::load_from_file(config_path) { Ok(cfg) => { @@ -1079,6 +118,11 @@ async fn main() -> std::io::Result<()> { Config::from_args(&args) }; + // Log format from the command line needs to be preserved + if args.config_file.is_none() { + config.log_format = log_format; + } + // Ensure required directories exist match config.ensure_dirs_exist() { Ok(()) => log::info!("Directory setup completed"), @@ -1093,12 +137,23 @@ async fn main() -> std::io::Result<()> { if let Err(e) = fs::create_dir_all(&config.config_dir) { log::warn!("Failed to create config directory: {e}"); } else { - let config_path = Path::new(&config.config_dir).join("config.json"); - if !config_path.exists() { - if let Err(e) = config.save_to_file(&config_path) { - log::warn!("Failed to save default configuration: {e}"); + // Save both JSON and TOML versions of the config for user reference + let config_path_json = Path::new(&config.config_dir).join("config.json"); + let config_path_toml = Path::new(&config.config_dir).join("config.toml"); + + if !config_path_json.exists() { + if let Err(e) = config.save_to_file(&config_path_json) { + log::warn!("Failed to save JSON configuration: {e}"); } else { - log::info!("Saved default configuration to {config_path:?}"); + log::info!("Saved JSON configuration to {config_path_json:?}"); + } + } + + if !config_path_toml.exists() { + if let Err(e) = config.save_to_file(&config_path_toml) { + log::warn!("Failed to save TOML configuration: {e}"); + } else { + log::info!("Saved TOML configuration to {config_path_toml:?}"); } } } @@ -1114,7 +169,7 @@ async fn main() -> std::io::Result<()> { let config = Arc::new(config); // Setup firewall rules for IP blocking - match setup_firewall().await { + match firewall::setup_firewall().await { Ok(()) => {} Err(e) => { log::warn!("Failed to set up firewall rules: {e}"); @@ -1139,6 +194,8 @@ async fn main() -> std::io::Result<()> { // Initialize Lua script manager log::info!("Loading Lua scripts from {}", config.lua_scripts_dir); let script_manager = Arc::new(ScriptManager::new(&config.lua_scripts_dir)); + let script_manager_for_tarpit = script_manager.clone(); + let script_manager_for_periodic = script_manager.clone(); // Clone config for metrics server let metrics_config = config.clone(); @@ -1163,7 +220,7 @@ async fn main() -> std::io::Result<()> { let state_clone = tarpit_state.clone(); let markov_clone = markov_generator.clone(); - let script_manager_clone = script_manager.clone(); + let script_manager_clone = script_manager_for_tarpit.clone(); let config_clone = config.clone(); tokio::spawn(async move { @@ -1218,6 +275,27 @@ async fn main() -> std::io::Result<()> { } }; + // Setup periodic task runner for Lua scripts + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + + // Trigger periodic event + let ctx = EventContext { + event_type: EventType::Periodic, + ip: None, + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + script_manager_for_periodic.trigger_event(&ctx); + } + }); + // Run both servers concurrently if metrics server is enabled if let Some(metrics_server) = metrics_server { tokio::select! { @@ -1257,186 +335,3 @@ async fn main() -> std::io::Result<()> { } } } - -#[cfg(test)] -mod tests { - use super::*; - use std::net::{IpAddr, Ipv4Addr}; - use tokio::sync::RwLock; - - #[test] - fn test_config_from_args() { - let args = Args { - listen_addr: "127.0.0.1:8080".to_string(), - metrics_addr: "127.0.0.1:9000".to_string(), - disable_metrics: true, - backend_addr: "127.0.0.1:8081".to_string(), - min_delay: 500, - max_delay: 10000, - max_tarpit_time: 300, - block_threshold: 5, - base_dir: Some(PathBuf::from("/tmp/eris")), - config_file: None, - log_level: "debug".to_string(), - }; - - let config = Config::from_args(&args); - assert_eq!(config.listen_addr, "127.0.0.1:8080"); - assert_eq!(config.metrics_addr, "127.0.0.1:9000"); - assert!(config.disable_metrics); - assert_eq!(config.backend_addr, "127.0.0.1:8081"); - assert_eq!(config.min_delay, 500); - assert_eq!(config.max_delay, 10000); - assert_eq!(config.max_tarpit_time, 300); - assert_eq!(config.block_threshold, 5); - assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora"); - assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts"); - assert_eq!(config.data_dir, "/tmp/eris/data"); - assert_eq!(config.config_dir, "/tmp/eris/conf"); - assert_eq!(config.cache_dir, "/tmp/eris/cache"); - } - - #[tokio::test] - async fn test_should_tarpit() { - let config = Config::default(); - - // Test trap patterns - assert!( - should_tarpit( - "/vendor/phpunit/whatever", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - assert!( - should_tarpit( - "/wp-admin/login.php", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - assert!(should_tarpit("/.env", &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), &config).await); - - // Test whitelist networks - assert!( - !should_tarpit( - "/wp-admin/login.php", - &IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - &config - ) - .await - ); - assert!( - !should_tarpit( - "/vendor/phpunit/whatever", - &IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), - &config - ) - .await - ); - - // Test legitimate paths - assert!( - !should_tarpit( - "/index.html", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - assert!( - !should_tarpit( - "/images/logo.png", - &IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), - &config - ) - .await - ); - } - - #[test] - fn test_script_manager_default_script() { - let script_manager = ScriptManager::new("/nonexistent_directory"); - assert!(script_manager.scripts_loaded); - assert!( - script_manager - .script_content - .contains("generate_honeytoken") - ); - assert!(script_manager.script_content.contains("enhance_response")); - } - - #[tokio::test] - async fn test_bot_state() { - let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache"); - let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); - let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)); - - let state = Arc::new(RwLock::new(state)); - - // Test hit counter - { - let mut state = state.write().await; - *state.hits.entry(ip1).or_insert(0) += 1; - *state.hits.entry(ip1).or_insert(0) += 1; - *state.hits.entry(ip2).or_insert(0) += 1; - - assert_eq!(*state.hits.get(&ip1).unwrap(), 2); - assert_eq!(*state.hits.get(&ip2).unwrap(), 1); - } - - // Test blocking - { - let mut state = state.write().await; - state.blocked.insert(ip1); - assert!(state.blocked.contains(&ip1)); - assert!(!state.blocked.contains(&ip2)); - } - - // Test active connections - { - let mut state = state.write().await; - state.active_connections.insert(ip1); - state.active_connections.insert(ip2); - assert_eq!(state.active_connections.len(), 2); - - state.active_connections.remove(&ip1); - assert_eq!(state.active_connections.len(), 1); - assert!(!state.active_connections.contains(&ip1)); - assert!(state.active_connections.contains(&ip2)); - } - } - - #[tokio::test] - async fn test_generate_deceptive_response() { - // Create a simple markov generator for testing - let markov = MarkovGenerator::new("/nonexistent/path"); - let script_manager = ScriptManager::new("/nonexistent/path"); - - // Test different path types - let resp1 = generate_deceptive_response( - "/vendor/phpunit/exec", - "TestBot/1.0", - &markov, - &script_manager, - ) - .await; - assert!(resp1.contains("HTTP/1.1 200 OK")); - assert!(resp1.contains("X-Powered-By: PHP")); - - let resp2 = - generate_deceptive_response("/wp-admin/", "TestBot/1.0", &markov, &script_manager) - .await; - assert!(resp2.contains("HTTP/1.1 200 OK")); - - let resp3 = - generate_deceptive_response("/api/users", "TestBot/1.0", &markov, &script_manager) - .await; - assert!(resp3.contains("HTTP/1.1 200 OK")); - - // Verify tracking token is included - assert!(resp1.contains("BOT_TestBot")); - } -} diff --git a/src/markov.rs b/src/markov.rs index b640dac..05960cb 100644 --- a/src/markov.rs +++ b/src/markov.rs @@ -103,7 +103,7 @@ impl MarkovGenerator { let path = Path::new(corpus_dir); if path.exists() && path.is_dir() { if let Ok(entries) = fs::read_dir(path) { - for entry in entries { + entries.for_each(|entry| { if let Ok(entry) = entry { let file_path = entry.path(); if let Some(file_name) = file_path.file_stem() { @@ -120,7 +120,7 @@ impl MarkovGenerator { } } } - } + }); } } diff --git a/src/metrics.rs b/src/metrics.rs index 48bceb9..44c3fb4 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -24,6 +24,11 @@ lazy_static! { register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap(); pub static ref UA_HITS: CounterVec = register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap(); + pub static ref RATE_LIMITED_CONNECTIONS: Counter = register_counter!( + "eris_rate_limited_total", + "Number of connections rejected due to rate limiting" + ) + .unwrap(); } // Prometheus metrics endpoint @@ -76,6 +81,7 @@ mod tests { UA_HITS.with_label_values(&["TestBot/1.0"]).inc(); BLOCKED_IPS.set(5.0); ACTIVE_CONNECTIONS.set(3.0); + RATE_LIMITED_CONNECTIONS.inc(); // Create test app let app = @@ -96,6 +102,7 @@ mod tests { assert!(body_str.contains("eris_ua_hits_total")); assert!(body_str.contains("eris_blocked_ips")); assert!(body_str.contains("eris_active_connections")); + assert!(body_str.contains("eris_rate_limited_total")); } #[actix_web::test] diff --git a/src/network/mod.rs b/src/network/mod.rs new file mode 100644 index 0000000..0fd1bfa --- /dev/null +++ b/src/network/mod.rs @@ -0,0 +1,504 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::process::Command; +use tokio::sync::RwLock; +use tokio::time::sleep; + +use crate::config::Config; +use crate::lua::{EventContext, EventType, ScriptManager}; +use crate::markov::MarkovGenerator; +use crate::metrics::{ + ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, RATE_LIMITED_CONNECTIONS, UA_HITS, +}; +use crate::state::BotState; +use crate::utils::{ + choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request, + find_header_end, generate_session_id, get_timestamp, +}; + +mod rate_limiter; +use rate_limiter::RateLimiter; + +// Global rate limiter instance. +// Default is 30 connections per IP in a 60 second window +// XXX: This might add overhead of the proxy, e.g. NGINX already implements +// rate limiting. Though I don't think we have a way of knowing if the middleman +// we are handing the connections to (from the same middleman in some cases) has +// rate limiting. +lazy_static::lazy_static! { + static ref RATE_LIMITER: RateLimiter = RateLimiter::new(60, 30); +} + +// Main connection handler. +// Decides whether to tarpit or proxy +pub async fn handle_connection( + mut stream: TcpStream, + config: Arc, + state: Arc>, + markov_generator: Arc, + script_manager: Arc, +) { + // Get peer information + let peer_addr: IpAddr = match stream.peer_addr() { + Ok(addr) => addr.ip(), + Err(e) => { + log::debug!("Failed to get peer address: {e}"); + return; + } + }; + + // Check for blocked IPs to avoid any processing + if state.read().await.blocked.contains(&peer_addr) { + log::debug!("Rejected connection from blocked IP: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + + // Apply rate limiting before any further processing + if config.rate_limit_enabled && !RATE_LIMITER.check_rate_limit(peer_addr).await { + log::info!("Rate limited connection from {peer_addr}"); + RATE_LIMITED_CONNECTIONS.inc(); + + // Optionally, add the IP to a temporary block list + // if it's constantly hitting the rate limit + let connection_count = RATE_LIMITER.get_connection_count(&peer_addr); + if connection_count > config.rate_limit_block_threshold { + log::warn!( + "IP {peer_addr} exceeding rate limit with {connection_count} connection attempts, considering for blocking" + ); + + // Trigger a blocked event for Lua scripts + let rate_limit_ctx = EventContext { + event_type: EventType::BlockIP, + ip: Some(peer_addr.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + script_manager.trigger_event(&rate_limit_ctx); + } + + // Either send a slow response or just close connection + if config.rate_limit_slow_response { + // Send a simple 429 Too Many Requests respons. If the bots actually respected + // HTTP error codes, the internet would be a mildly better place. + let response = "HTTP/1.1 429 Too Many Requests\r\n\ + Content-Type: text/plain\r\n\ + Retry-After: 60\r\n\ + Connection: close\r\n\ + \r\n\ + Rate limit exceeded. Please try again later."; + + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.flush().await; + } + + let _ = stream.shutdown().await; + return; + } + + // Check if Lua scripts allow this connection + if !script_manager.on_connection(&peer_addr.to_string()) { + log::debug!("Connection rejected by Lua script: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + + // Pre-check for whitelisted IPs to bypass heavy processing + let mut whitelisted = false; + for network_str in &config.whitelist_networks { + if let Ok(network) = network_str.parse::() { + if network.contains(peer_addr) { + whitelisted = true; + break; + } + } + } + + // Read buffer + let mut buffer = vec![0; 8192]; + let mut request_data = Vec::with_capacity(8192); + let mut header_end_pos = 0; + + // Read with timeout to prevent hanging resource load ops. + let read_fut = async { + loop { + match stream.read(&mut buffer).await { + Ok(0) => break, + Ok(n) => { + let new_data = &buffer[..n]; + request_data.extend_from_slice(new_data); + + // Look for end of headers + if header_end_pos == 0 { + if let Some(pos) = find_header_end(&request_data) { + header_end_pos = pos; + break; + } + } + + // Avoid excessive buffering + if request_data.len() > 32768 { + break; + } + } + Err(e) => { + log::debug!("Error reading from stream: {e}"); + break; + } + } + } + }; + + let timeout_fut = sleep(Duration::from_secs(3)); + + tokio::select! { + () = read_fut => {}, + () = timeout_fut => { + log::debug!("Connection timeout from: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + } + + // Fast path for whitelisted IPs. Skip full parsing and speed up "approved" + // connections automatically. + if whitelisted { + log::debug!("Whitelisted IP {peer_addr} - using fast proxy path"); + proxy_fast_path(stream, request_data, &config.backend_addr).await; + return; + } + + // Parse minimally to extract the path + let path = if let Some(p) = extract_path_from_request(&request_data) { + p + } else { + log::debug!("Invalid request from {peer_addr}"); + let _ = stream.shutdown().await; + return; + }; + + // Extract request headers for Lua scripts + let headers = extract_all_headers(&request_data); + + // Extract user agent for logging and decision making + let user_agent = + extract_header_value(&request_data, "user-agent").unwrap_or_else(|| "unknown".to_string()); + + // Trigger request event for Lua scripts + let request_ctx = EventContext { + event_type: EventType::Request, + ip: Some(peer_addr.to_string()), + path: Some(path.to_string()), + user_agent: Some(user_agent.clone()), + request_headers: Some(headers.clone()), + content: None, + timestamp: get_timestamp(), + session_id: Some(generate_session_id(&peer_addr.to_string(), &user_agent)), + }; + script_manager.trigger_event(&request_ctx); + + // Check if this request matches our tarpit patterns + let should_tarpit = crate::config::should_tarpit(path, &peer_addr, &config); + + if should_tarpit { + log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})"); + + // Update metrics + HITS_COUNTER.inc(); + PATH_HITS.with_label_values(&[path]).inc(); + UA_HITS.with_label_values(&[&user_agent]).inc(); + + // Update state and check for blocking threshold + { + let mut state = state.write().await; + state.active_connections.insert(peer_addr); + ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); + + *state.hits.entry(peer_addr).or_insert(0) += 1; + let hit_count = state.hits[&peer_addr]; + + // Use Lua to decide whether to block this IP + let should_block = script_manager.should_block_ip(&peer_addr.to_string(), hit_count); + + // Block IPs that hit tarpits too many times + if should_block && !state.blocked.contains(&peer_addr) { + log::info!("Blocking IP {peer_addr} after {hit_count} hits"); + state.blocked.insert(peer_addr); + BLOCKED_IPS.set(state.blocked.len() as f64); + state.save_to_disk(); + + // Do firewall blocking in background + let peer_addr_str = peer_addr.to_string(); + tokio::spawn(async move { + log::debug!("Adding IP {peer_addr_str} to firewall blacklist"); + match Command::new("nft") + .args([ + "add", + "element", + "inet", + "filter", + "eris_blacklist", + "{", + &peer_addr_str, + "}", + ]) + .output() + .await + { + Ok(output) => { + if !output.status.success() { + log::warn!( + "Failed to add IP {} to firewall: {}", + peer_addr_str, + String::from_utf8_lossy(&output.stderr) + ); + } + } + Err(e) => { + log::warn!("Failed to execute nft command: {e}"); + } + } + }); + } + } + + // Generate a deceptive response using Markov chains and Lua + let response = generate_deceptive_response( + path, + &user_agent, + &peer_addr, + &headers, + &markov_generator, + &script_manager, + ) + .await; + + // Generate a session ID for tracking this tarpit session + let session_id = generate_session_id(&peer_addr.to_string(), &user_agent); + + // Send the response with the tarpit delay strategy + { + let mut stream = stream; + let peer_addr = peer_addr; + let state = state.clone(); + let min_delay = config.min_delay; + let max_delay = config.max_delay; + let max_tarpit_time = config.max_tarpit_time; + let script_manager = script_manager.clone(); + async move { + let start_time = Instant::now(); + let mut chars = response.chars().collect::>(); + for i in (0..chars.len()).rev() { + if i > 0 && rand::random::() < 0.1 { + chars.swap(i, i - 1); + } + } + log::debug!( + "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", + peer_addr, + chars.len(), + min_delay, + max_delay + ); + let mut position = 0; + let mut chunks_sent = 0; + let mut total_delay = 0; + while position < chars.len() { + // Check if we've exceeded maximum tarpit time + let elapsed_secs = start_time.elapsed().as_secs(); + if elapsed_secs > max_tarpit_time { + log::info!( + "Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}" + ); + break; + } + + // Decide how many chars to send in this chunk (usually 1, sometimes more) + let chunk_size = if rand::random::() < 0.9 { + 1 + } else { + (rand::random::() * 3.0).floor() as usize + 1 + }; + + let end = (position + chunk_size).min(chars.len()); + let chunk: String = chars[position..end].iter().collect(); + + // Process chunk through Lua before sending + let processed_chunk = + script_manager.process_chunk(&chunk, &peer_addr.to_string(), &session_id); + + // Try to write processed chunk + if stream.write_all(processed_chunk.as_bytes()).await.is_err() { + log::debug!("Connection closed by client during tarpit: {peer_addr}"); + break; + } + + if stream.flush().await.is_err() { + log::debug!("Failed to flush stream during tarpit: {peer_addr}"); + break; + } + + position = end; + chunks_sent += 1; + + // Apply random delay between min and max configured values + let delay_ms = + (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; + total_delay += delay_ms; + sleep(Duration::from_millis(delay_ms)).await; + } + log::debug!( + "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", + peer_addr, + chunks_sent, + position * 100 / chars.len(), + total_delay, + start_time.elapsed().as_secs() + ); + let disconnection_ctx = EventContext { + event_type: EventType::Disconnection, + ip: Some(peer_addr.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: Some(session_id), + }; + script_manager.trigger_event(&disconnection_ctx); + if let Ok(mut state) = state.try_write() { + state.active_connections.remove(&peer_addr); + ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); + } + let _ = stream.shutdown().await; + } + } + .await; + } else { + log::debug!("Proxying request: {path} from {peer_addr}"); + proxy_fast_path(stream, request_data, &config.backend_addr).await; + } +} + +// Forward a legitimate request to the real backend server +pub async fn proxy_fast_path( + mut client_stream: TcpStream, + request_data: Vec, + backend_addr: &str, +) { + // Connect to backend server + let server_stream = match TcpStream::connect(backend_addr).await { + Ok(stream) => stream, + Err(e) => { + log::warn!("Failed to connect to backend {backend_addr}: {e}"); + let _ = client_stream.shutdown().await; + return; + } + }; + + // Set TCP_NODELAY for both streams before splitting them + if let Err(e) = client_stream.set_nodelay(true) { + log::debug!("Failed to set TCP_NODELAY on client stream: {e}"); + } + + let mut server_stream = server_stream; + if let Err(e) = server_stream.set_nodelay(true) { + log::debug!("Failed to set TCP_NODELAY on server stream: {e}"); + } + + // Forward the original request bytes directly without parsing + if server_stream.write_all(&request_data).await.is_err() { + log::debug!("Failed to write request to backend server"); + let _ = client_stream.shutdown().await; + return; + } + + // Now split the streams for concurrent reading/writing + let (mut client_read, mut client_write) = client_stream.split(); + let (mut server_read, mut server_write) = server_stream.split(); + + // 32KB buffer + let buf_size = 32768; + + // Client -> Server + let client_to_server = async { + let mut buf = vec![0; buf_size]; + let mut bytes_forwarded = 0; + + loop { + match client_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if server_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + // Ensure everything is sent + let _ = server_write.flush().await; + log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes"); + }; + + // Server -> Client + let server_to_client = async { + let mut buf = vec![0; buf_size]; + let mut bytes_forwarded = 0; + + loop { + match server_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + bytes_forwarded += n; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + Err(_) => break, + } + } + + // Ensure everything is sent + let _ = client_write.flush().await; + log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes"); + }; + + // Run both directions concurrently + tokio::join!(client_to_server, server_to_client); + log::debug!("Fast proxy connection completed"); +} + +// Generate a deceptive HTTP response that appears legitimate +pub async fn generate_deceptive_response( + path: &str, + user_agent: &str, + peer_addr: &IpAddr, + headers: &HashMap, + markov: &MarkovGenerator, + script_manager: &ScriptManager, +) -> String { + // Generate base response using Markov chain text generator + let response_type = choose_response_type(path); + let markov_text = markov.generate(response_type, 30); + + // Use Lua scripts to enhance with honeytokens and other deceptive content + script_manager.generate_response( + path, + user_agent, + &peer_addr.to_string(), + headers, + &markov_text, + ) +} diff --git a/src/network/rate_limiter.rs b/src/network/rate_limiter.rs new file mode 100644 index 0000000..6150793 --- /dev/null +++ b/src/network/rate_limiter.rs @@ -0,0 +1,85 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; + +pub struct RateLimiter { + connections: Arc>>>, + window_seconds: u64, + max_connections: usize, + cleanup_interval: Duration, + last_cleanup: Instant, +} + +impl RateLimiter { + pub fn new(window_seconds: u64, max_connections: usize) -> Self { + Self { + connections: Arc::new(Mutex::new(HashMap::new())), + window_seconds, + max_connections, + cleanup_interval: Duration::from_secs(60), + last_cleanup: Instant::now(), + } + } + + pub async fn check_rate_limit(&self, ip: IpAddr) -> bool { + let now = Instant::now(); + let window = Duration::from_secs(self.window_seconds); + + let mut connections = self.connections.lock().await; + + // Periodically clean up old entries across all IPs + if now.duration_since(self.last_cleanup) > self.cleanup_interval { + self.cleanup_old_entries(&mut connections, now, window); + } + + // Clean up old entries for this specific IP + if let Some(times) = connections.get_mut(&ip) { + times.retain(|time| now.duration_since(*time) < window); + + // Check if rate limit exceeded + if times.len() >= self.max_connections { + log::debug!("Rate limit exceeded for IP: {}", ip); + return false; + } + + // Add new connection time + times.push(now); + } else { + connections.insert(ip, vec![now]); + } + + true + } + + fn cleanup_old_entries( + &self, + connections: &mut HashMap>, + now: Instant, + window: Duration, + ) { + let mut empty_keys = Vec::new(); + + for (ip, times) in connections.iter_mut() { + times.retain(|time| now.duration_since(*time) < window); + if times.is_empty() { + empty_keys.push(*ip); + } + } + + // Remove empty entries + for ip in empty_keys { + connections.remove(&ip); + } + } + + pub fn get_connection_count(&self, ip: &IpAddr) -> usize { + if let Ok(connections) = self.connections.try_lock() { + if let Some(times) = connections.get(ip) { + return times.len(); + } + } + 0 + } +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..99e5369 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,166 @@ +use std::collections::{HashMap, HashSet}; +use std::fs; +use std::io::Write; +use std::net::IpAddr; + +use crate::metrics::BLOCKED_IPS; + +// State of bots/IPs hitting the honeypot +#[derive(Clone, Debug)] +pub struct BotState { + pub hits: HashMap, + pub blocked: HashSet, + pub active_connections: HashSet, + pub data_dir: String, + pub cache_dir: String, +} + +impl BotState { + pub fn new(data_dir: &str, cache_dir: &str) -> Self { + Self { + hits: HashMap::new(), + blocked: HashSet::new(), + active_connections: HashSet::new(), + data_dir: data_dir.to_string(), + cache_dir: cache_dir.to_string(), + } + } + + // Load previous state from disk + pub fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self { + let mut state = Self::new(data_dir, cache_dir); + let blocked_ips_file = format!("{data_dir}/blocked_ips.txt"); + + if let Ok(content) = fs::read_to_string(&blocked_ips_file) { + let mut loaded = 0; + for line in content.lines() { + if let Ok(ip) = line.parse::() { + state.blocked.insert(ip); + loaded += 1; + } + } + log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}"); + } else { + log::info!("No blocked IPs file found at {blocked_ips_file}"); + } + + // Check for temporary hit counter cache + let hit_cache_file = format!("{cache_dir}/hit_counters.json"); + if let Ok(content) = fs::read_to_string(&hit_cache_file) { + if let Ok(hit_map) = serde_json::from_str::>(&content) { + for (ip_str, count) in hit_map { + if let Ok(ip) = ip_str.parse::() { + state.hits.insert(ip, count); + } + } + log::info!("Loaded hit counters for {} IPs", state.hits.len()); + } + } + + BLOCKED_IPS.set(state.blocked.len() as f64); + state + } + + // Persist state to disk for later reloading + pub fn save_to_disk(&self) { + // Save blocked IPs + if let Err(e) = fs::create_dir_all(&self.data_dir) { + log::error!("Failed to create data directory: {e}"); + return; + } + + let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir); + + match fs::File::create(&blocked_ips_file) { + Ok(mut file) => { + let mut count = 0; + for ip in &self.blocked { + if writeln!(file, "{ip}").is_ok() { + count += 1; + } + } + log::info!("Saved {count} blocked IPs to {blocked_ips_file}"); + } + Err(e) => { + log::error!("Failed to create blocked IPs file: {e}"); + } + } + + // Save hit counters to cache + if let Err(e) = fs::create_dir_all(&self.cache_dir) { + log::error!("Failed to create cache directory: {e}"); + return; + } + + let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir); + let hit_map: HashMap = self + .hits + .iter() + .map(|(ip, count)| (ip.to_string(), *count)) + .collect(); + + match fs::File::create(&hit_cache_file) { + Ok(file) => { + if let Err(e) = serde_json::to_writer(file, &hit_map) { + log::error!("Failed to write hit counters to cache: {e}"); + } else { + log::debug!("Saved hit counters for {} IPs to cache", hit_map.len()); + } + } + Err(e) => { + log::error!("Failed to create hit counter cache file: {e}"); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + use std::sync::Arc; + use tokio::sync::RwLock; + + #[tokio::test] + async fn test_bot_state() { + let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache"); + let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)); + + let state = Arc::new(RwLock::new(state)); + + // Test hit counter + { + let mut state = state.write().await; + *state.hits.entry(ip1).or_insert(0) += 1; + *state.hits.entry(ip1).or_insert(0) += 1; + *state.hits.entry(ip2).or_insert(0) += 1; + + assert_eq!(*state.hits.get(&ip1).unwrap(), 2); + assert_eq!(*state.hits.get(&ip2).unwrap(), 1); + } + + // Test blocking + { + let mut state = state.write().await; + state.blocked.insert(ip1); + assert!(state.blocked.contains(&ip1)); + assert!(!state.blocked.contains(&ip2)); + drop(state); + } + + // Test active connections + { + let mut state = state.write().await; + state.active_connections.insert(ip1); + state.active_connections.insert(ip2); + assert_eq!(state.active_connections.len(), 2); + + state.active_connections.remove(&ip1); + assert_eq!(state.active_connections.len(), 1); + assert!(!state.active_connections.contains(&ip1)); + assert!(state.active_connections.contains(&ip2)); + drop(state); + } + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..c925565 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,168 @@ +use std::collections::HashMap; +use std::hash::Hasher; + +// Find end of HTTP headers +pub fn find_header_end(data: &[u8]) -> Option { + data.windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|pos| pos + 4) +} + +// Extract path from raw request data +pub fn extract_path_from_request(data: &[u8]) -> Option<&str> { + // Get first line from request + let first_line = data + .split(|&b| b == b'\r' || b == b'\n') + .next() + .filter(|line| !line.is_empty())?; + + // Split by spaces and ensure we have at least 3 parts (METHOD PATH VERSION) + let parts: Vec<&[u8]> = first_line.split(|&b| b == b' ').collect(); + if parts.len() < 3 || !parts[2].starts_with(b"HTTP/") { + return None; + } + + // Return the path (second element) + std::str::from_utf8(parts[1]).ok() +} + +// Extract header value from raw request data +pub fn extract_header_value(data: &[u8], header_name: &str) -> Option { + let data_str = std::str::from_utf8(data).ok()?; + let header_prefix = format!("{header_name}: ").to_lowercase(); + + for line in data_str.lines() { + let line_lower = line.to_lowercase(); + if line_lower.starts_with(&header_prefix) { + return Some(line[header_prefix.len()..].trim().to_string()); + } + } + None +} + +// Extract all headers from request data +pub fn extract_all_headers(data: &[u8]) -> HashMap { + let mut headers = HashMap::new(); + + if let Ok(data_str) = std::str::from_utf8(data) { + let mut lines = data_str.lines(); + + // Skip the request line + let _ = lines.next(); + + // Parse headers until empty line + for line in lines { + if line.is_empty() { + break; + } + + if let Some(colon_pos) = line.find(':') { + let key = line[..colon_pos].trim().to_lowercase(); + let value = line[colon_pos + 1..].trim().to_string(); + headers.insert(key, value); + } + } + } + + headers +} + +// Determine response type based on request path +pub fn choose_response_type(path: &str) -> &'static str { + if path.contains("phpunit") || path.contains("eval") { + "php_exploit" + } else if path.contains("wp-") { + "wordpress" + } else if path.contains("api") { + "api" + } else { + "generic" + } +} + +// Get current timestamp in seconds +pub fn get_timestamp() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +// Create a unique session ID for tracking a connection +pub fn generate_session_id(ip: &str, user_agent: &str) -> String { + let timestamp = get_timestamp(); + let random = rand::random::(); + + // XXX: Is this fast enough for our case? I don't think hashing is a huge + // bottleneck, but it's worth revisiting in the future to see if there is + // an objectively faster algorithm that we can try. + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + std::hash::Hash::hash(&format!("{ip}_{user_agent}_{timestamp}"), &mut hasher); + let hash = hasher.finish(); + + format!("SID_{hash:x}_{random:x}") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_header_end() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content"; + assert_eq!(find_header_end(data), Some(55)); + + let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n"; + assert_eq!(find_header_end(incomplete), None); + } + + #[test] + fn test_extract_path_from_request() { + let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert_eq!(extract_path_from_request(data), Some("/index.html")); + + let bad_data = b"INVALID DATA"; + assert_eq!(extract_path_from_request(bad_data), None); + } + + #[test] + fn test_extract_header_value() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n"; + assert_eq!( + extract_header_value(data, "user-agent"), + Some("TestBot/1.0".to_string()) + ); + assert_eq!( + extract_header_value(data, "Host"), + Some("example.com".to_string()) + ); + assert_eq!(extract_header_value(data, "nonexistent"), None); + } + + #[test] + fn test_extract_all_headers() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\nAccept: */*\r\n\r\n"; + let headers = extract_all_headers(data); + + assert_eq!(headers.len(), 3); + assert_eq!(headers.get("host").unwrap(), "example.com"); + assert_eq!(headers.get("user-agent").unwrap(), "TestBot/1.0"); + assert_eq!(headers.get("accept").unwrap(), "*/*"); + } + + #[test] + fn test_choose_response_type() { + assert_eq!( + choose_response_type("/vendor/phpunit/whatever"), + "php_exploit" + ); + assert_eq!( + choose_response_type("/path/to/eval-stdin.php"), + "php_exploit" + ); + assert_eq!(choose_response_type("/wp-admin/login.php"), "wordpress"); + assert_eq!(choose_response_type("/wp-login.php"), "wordpress"); + assert_eq!(choose_response_type("/api/v1/users"), "api"); + assert_eq!(choose_response_type("/index.html"), "generic"); + } +}