diff --git a/src/lua/mod.rs b/src/lua/mod.rs index 938f5ec..ea0c712 100644 --- a/src/lua/mod.rs +++ b/src/lua/mod.rs @@ -1,62 +1,658 @@ -use rlua::{Function, Lua}; +use rlua::{Function, Lua, Table, Value}; +use std::collections::HashMap; +use std::collections::hash_map::DefaultHasher; use std::fs; +use std::hash::{Hash, Hasher}; use std::path::Path; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::{SystemTime, UNIX_EPOCH}; -pub struct ScriptManager { - script_content: String, - scripts_loaded: bool, +// Event types for the Lua scripting system +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum EventType { + Connection, // when a new connection is established + Request, // when a request is received + ResponseGen, // when generating a response + ResponseChunk, // before sending each response chunk + Disconnection, // when a connection is closed + BlockIP, // when an IP is being considered for blocking + Startup, // when the application starts + Shutdown, // when the application is shutting down + Periodic, // called periodically (e.g., every minute) } -impl ScriptManager { - pub fn new(scripts_dir: &str) -> Self { - let mut script_content = String::new(); - let mut scripts_loaded = false; +impl EventType { + /// Convert event type to string representation for Lua + const fn as_str(&self) -> &'static str { + match self { + Self::Connection => "connection", + Self::Request => "request", + Self::ResponseGen => "response_gen", + Self::ResponseChunk => "response_chunk", + Self::Disconnection => "disconnection", + Self::BlockIP => "block_ip", + Self::Startup => "startup", + Self::Shutdown => "shutdown", + Self::Periodic => "periodic", + } + } - // Try to load scripts from directory + /// Convert from string to `EventType` + fn from_str(s: &str) -> Option { + match s { + "connection" => Some(Self::Connection), + "request" => Some(Self::Request), + "response_gen" => Some(Self::ResponseGen), + "response_chunk" => Some(Self::ResponseChunk), + "disconnection" => Some(Self::Disconnection), + "block_ip" => Some(Self::BlockIP), + "startup" => Some(Self::Startup), + "shutdown" => Some(Self::Shutdown), + "periodic" => Some(Self::Periodic), + _ => None, + } + } +} + +// Loaded Lua script with its metadata +struct ScriptInfo { + name: String, + enabled: bool, +} + +// Script state and manage the Lua environment +pub struct ScriptManager { + lua: Mutex, + scripts: Vec, + hooks: HashMap>, + state: Arc>>, + counters: Arc>>, +} + +// Context passed to Lua event handlers +pub struct EventContext { + pub event_type: EventType, + pub ip: Option, + pub path: Option, + pub user_agent: Option, + pub request_headers: Option>, + pub content: Option, + pub timestamp: u64, + pub session_id: Option, +} + +// Make ScriptManager explicitly Send + Sync since we're using Mutex +unsafe impl Send for ScriptManager {} +unsafe impl Sync for ScriptManager {} + +impl ScriptManager { + /// Create a new script manager and load scripts from the given directory + pub fn new(scripts_dir: &str) -> Self { + let mut manager = Self { + lua: Mutex::new(Lua::new()), + scripts: Vec::new(), + hooks: HashMap::new(), + state: Arc::new(RwLock::new(HashMap::new())), + counters: Arc::new(RwLock::new(HashMap::new())), + }; + + // Initialize Lua environment with our API + manager.init_lua_env(); + + // Load scripts from directory + manager.load_scripts_from_dir(scripts_dir); + + // If no scripts were loaded, use default script + if manager.scripts.is_empty() { + log::info!("No Lua scripts found, loading default scripts"); + manager.load_script( + "default", + include_str!("../../resources/default_script.lua"), + ); + } + + // Trigger startup event + manager.trigger_event(&EventContext { + event_type: EventType::Startup, + ip: None, + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }); + + manager + } + + // Initialize the Lua environment + fn init_lua_env(&self) { + let state_clone = self.state.clone(); + let counters_clone = self.counters.clone(); + + if let Ok(lua) = self.lua.lock() { + // Create eris global table for our API + let eris_table = lua.create_table().unwrap(); + + self.register_utility_functions(&lua, &eris_table, state_clone, counters_clone); + self.register_event_functions(&lua, &eris_table); + self.register_logging_functions(&lua, &eris_table); + + // Set the eris global table + lua.globals().set("eris", eris_table).unwrap(); + } + } + + /// Register utility functions for scripts to use + fn register_utility_functions( + &self, + lua: &Lua, + eris_table: &Table, + state: Arc>>, + counters: Arc>>, + ) { + // Store a key-value pair in persistent state + let state_for_set = state.clone(); + let set_state = lua + .create_function(move |_, (key, value): (String, String)| { + let mut state_map = state_for_set.write().unwrap(); + state_map.insert(key, value); + Ok(()) + }) + .unwrap(); + eris_table.set("set_state", set_state).unwrap(); + + // Get a value from persistent state + let state_for_get = state; + let get_state = lua + .create_function(move |_, key: String| { + let state_map = state_for_get.read().unwrap(); + let value = state_map.get(&key).cloned(); + Ok(value) + }) + .unwrap(); + eris_table.set("get_state", get_state).unwrap(); + + // Increment a counter + let counters_for_inc = counters.clone(); + let inc_counter = lua + .create_function(move |_, (key, amount): (String, Option)| { + let mut counters_map = counters_for_inc.write().unwrap(); + let counter = counters_map.entry(key).or_insert(0); + *counter += amount.unwrap_or(1); + Ok(*counter) + }) + .unwrap(); + eris_table.set("inc_counter", inc_counter).unwrap(); + + // Get a counter value + let counters_for_get = counters; + let get_counter = lua + .create_function(move |_, key: String| { + let counters_map = counters_for_get.read().unwrap(); + let value = counters_map.get(&key).copied().unwrap_or(0); + Ok(value) + }) + .unwrap(); + eris_table.set("get_counter", get_counter).unwrap(); + + // Generate a random token/string + let gen_token = lua + .create_function(move |_, prefix: Option| { + let now = get_timestamp(); + let random = rand::random::(); + let token = format!("{}{:x}{:x}", prefix.unwrap_or_default(), now, random); + Ok(token) + }) + .unwrap(); + eris_table.set("gen_token", gen_token).unwrap(); + + // Get current timestamp + let timestamp = lua + .create_function(move |_, ()| Ok(get_timestamp())) + .unwrap(); + eris_table.set("timestamp", timestamp).unwrap(); + } + + // Register event handling functions + fn register_event_functions(&self, lua: &Lua, eris_table: &Table) { + // Create a table to store event handlers + let handlers_table = lua.create_table().unwrap(); + eris_table.set("handlers", handlers_table).unwrap(); + + // Function for scripts to register event handlers + let on_fn = lua + .create_function(move |lua, (event_name, handler): (String, Function)| { + let globals = lua.globals(); + let eris: Table = globals.get("eris").unwrap(); + let handlers: Table = eris.get("handlers").unwrap(); + + // Get or create a table for this event type + let event_handlers: Table = if let Ok(table) = handlers.get(&*event_name) { + table + } else { + let new_table = lua.create_table().unwrap(); + handlers.set(&*event_name, new_table.clone()).unwrap(); + new_table + }; + + // Add the handler to the table + let next_index = event_handlers.len().unwrap() + 1; + event_handlers.set(next_index, handler).unwrap(); + + Ok(()) + }) + .unwrap(); + eris_table.set("on", on_fn).unwrap(); + } + + // Register logging functions + fn register_logging_functions(&self, lua: &Lua, eris_table: &Table) { + // Debug logging + let debug = lua + .create_function(|_, message: String| { + log::debug!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("debug", debug).unwrap(); + + // Info logging + let info = lua + .create_function(|_, message: String| { + log::info!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("info", info).unwrap(); + + // Warning logging + let warn = lua + .create_function(|_, message: String| { + log::warn!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("warn", warn).unwrap(); + + // Error logging + let error = lua + .create_function(|_, message: String| { + log::error!("[Lua] {message}"); + Ok(()) + }) + .unwrap(); + eris_table.set("error", error).unwrap(); + } + + // Load all scripts from a directory + fn load_scripts_from_dir(&mut self, scripts_dir: &str) { let script_dir = Path::new(scripts_dir); - if script_dir.exists() { - log::debug!("Loading Lua scripts from directory: {scripts_dir}"); - if let Ok(entries) = fs::read_dir(script_dir) { - for entry in entries.filter_map(Result::ok) { - let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { - if let Ok(content) = fs::read_to_string(&path) { - log::debug!("Loaded Lua script: {}", path.display()); - script_content.push_str(&content); - script_content.push('\n'); - scripts_loaded = true; - } else { - log::warn!("Failed to read Lua script: {}", path.display()); + if !script_dir.exists() { + log::warn!("Lua scripts directory does not exist: {scripts_dir}"); + return; + } + + log::debug!("Loading Lua scripts from directory: {scripts_dir}"); + if let Ok(entries) = fs::read_dir(script_dir) { + // Sort entries by filename to ensure consistent loading order + let mut sorted_entries: Vec<_> = entries.filter_map(Result::ok).collect(); + sorted_entries.sort_by_key(std::fs::DirEntry::path); + + for entry in sorted_entries { + let path = entry.path(); + if path.extension().and_then(|ext| ext.to_str()) == Some("lua") { + if let Ok(content) = fs::read_to_string(&path) { + let script_name = path + .file_stem() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string(); + + log::debug!("Loading Lua script: {} ({})", script_name, path.display()); + self.load_script(&script_name, &content); + } else { + log::warn!("Failed to read Lua script: {}", path.display()); + } + } + } + } + } + + // Load a single script and register its event handlers + fn load_script(&mut self, name: &str, content: &str) { + // Store script info + self.scripts.push(ScriptInfo { + name: name.to_string(), + enabled: true, + }); + + // Execute the script to register its event handlers + if let Ok(lua) = self.lua.lock() { + if let Err(e) = lua.load(content).set_name(name).exec() { + log::warn!("Error loading Lua script '{name}': {e}"); + return; + } + + // Collect registered event handlers + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => return, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => return, + }; + + // Store the event handlers in our hooks map + let mut tmp: rlua::TablePairs<'_, String, Table<'_>> = + handlers.pairs::(); + 'l: loop { + if let Some(event_pair) = tmp.next() { + if let Ok((event_name, _)) = event_pair { + if let Some(event_type) = EventType::from_str(&event_name) { + self.hooks + .entry(event_type) + .or_default() + .push(name.to_string()); + } + } + } else { + break 'l; + } + } + + log::info!("Loaded Lua script '{name}' successfully"); + } + } + + /// Check if a script is enabled + fn is_script_enabled(&self, name: &str) -> bool { + self.scripts + .iter() + .find(|s| s.name == name) + .is_some_and(|s| s.enabled) + } + + /// Trigger an event, calling all registered handlers + pub fn trigger_event(&self, ctx: &EventContext) -> Option { + // Check if we have any handlers for this event + if !self.hooks.contains_key(&ctx.event_type) { + return ctx.content.clone(); + } + + // Build the event data table to pass to Lua handlers + let mut result = ctx.content.clone(); + + if let Ok(lua) = self.lua.lock() { + // Create the event context table + let event_ctx = lua.create_table().unwrap(); + + // Add all the context fields + event_ctx.set("event", ctx.event_type.as_str()).unwrap(); + if let Some(ip) = &ctx.ip { + event_ctx.set("ip", ip.clone()).unwrap(); + } + if let Some(path) = &ctx.path { + event_ctx.set("path", path.clone()).unwrap(); + } + if let Some(ua) = &ctx.user_agent { + event_ctx.set("user_agent", ua.clone()).unwrap(); + } + event_ctx.set("timestamp", ctx.timestamp).unwrap(); + if let Some(sid) = &ctx.session_id { + event_ctx.set("session_id", sid.clone()).unwrap(); + } + + // Add request headers if available + if let Some(headers) = &ctx.request_headers { + let headers_table = lua.create_table().unwrap(); + for (key, value) in headers { + headers_table + .set(key.to_string(), value.to_string()) + .unwrap(); + } + event_ctx.set("headers", headers_table).unwrap(); + } + + // Add content if available + if let Some(content) = &ctx.content { + event_ctx.set("content", content.clone()).unwrap(); + } + + // Call all registered handlers for this event + if let Some(handler_scripts) = self.hooks.get(&ctx.event_type) { + for script_name in handler_scripts { + // Skip disabled scripts + if !self.is_script_enabled(script_name) { + continue; + } + + // Get the globals and handlers table + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => continue, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => continue, + }; + + // Get handlers for this event + let event_handlers: Table = match handlers.get(ctx.event_type.as_str()) { + Ok(table) => table, + Err(_) => continue, + }; + + // Call each handler + for pair in event_handlers.pairs::() { + if let Ok((_, handler)) = pair { + let handler_result: rlua::Result> = + handler.call((event_ctx.clone(),)); + if let Ok(Some(new_content)) = handler_result { + // For response events, allow handlers to modify the content + if matches!( + ctx.event_type, + EventType::ResponseGen | EventType::ResponseChunk + ) { + result = Some(new_content); + } + } } } } } - } else { - log::warn!("Lua scripts directory does not exist: {scripts_dir}"); } - // If no scripts were loaded, use a default script - if !scripts_loaded { - log::info!("No Lua scripts found, loading default scripts"); - script_content = include_str!("../../resources/default_script.lua").to_string(); - scripts_loaded = true; - } - - Self { - script_content, - scripts_loaded, - } + result } - // For testing only - #[cfg(test)] - pub fn with_content(content: &str) -> Self { - Self { - script_content: content.to_string(), - scripts_loaded: true, - } + /// Generate a deceptive response, calling all response_gen handlers + pub fn generate_response( + &self, + path: &str, + user_agent: &str, + ip: &str, + headers: &HashMap, + markov_text: &str, + ) -> String { + // Create event context + let ctx = EventContext { + event_type: EventType::ResponseGen, + ip: Some(ip.to_string()), + path: Some(path.to_string()), + user_agent: Some(user_agent.to_string()), + request_headers: Some(headers.clone()), + content: Some(markov_text.to_string()), + timestamp: get_timestamp(), + session_id: Some(generate_session_id(ip, user_agent)), + }; + + /// Trigger the event and get the modified content + self.trigger_event(&ctx).unwrap_or_else(|| { + // Fallback to maintain backward compatibility + self.expand_response( + markov_text, + "generic", + path, + &generate_session_id(ip, user_agent), + ) + }) } + /// Process a chunk before sending it to client + pub fn process_chunk(&self, chunk: &str, ip: &str, session_id: &str) -> String { + let ctx = EventContext { + event_type: EventType::ResponseChunk, + ip: Some(ip.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: Some(chunk.to_string()), + timestamp: get_timestamp(), + session_id: Some(session_id.to_string()), + }; + + self.trigger_event(&ctx) + .unwrap_or_else(|| chunk.to_string()) + } + + /// Called when a connection is established + pub fn on_connection(&self, ip: &str) -> bool { + let ctx = EventContext { + event_type: EventType::Connection, + ip: Some(ip.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + + // If any handler returns false, reject the connection + let mut should_accept = true; + + if let Ok(lua) = self.lua.lock() { + if let Some(handler_scripts) = self.hooks.get(&EventType::Connection) { + for script_name in handler_scripts { + // Skip disabled scripts + if !self.is_script_enabled(script_name) { + continue; + } + + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => continue, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => continue, + }; + + let event_handlers: Table = match handlers.get("connection") { + Ok(table) => table, + Err(_) => continue, + }; + + for pair in event_handlers.pairs::() { + if let Ok((_, handler)) = pair { + let event_ctx = create_event_context(&lua, &ctx); + if let Ok(result) = handler.call::<_, Value>((event_ctx,)) { + if result == Value::Boolean(false) { + should_accept = false; + break; + } + } + } + } + + if !should_accept { + break; + } + } + } + } + + should_accept + } + + /// Called when deciding whether to block an IP + pub fn should_block_ip(&self, ip: &str, hit_count: u32) -> bool { + let ctx = EventContext { + event_type: EventType::BlockIP, + ip: Some(ip.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + + // We should default to not modifying the blocking decision + let mut should_block = None; + + if let Ok(lua) = self.lua.lock() { + if let Some(handler_scripts) = self.hooks.get(&EventType::BlockIP) { + for script_name in handler_scripts { + // Skip disabled scripts + if !self.is_script_enabled(script_name) { + continue; + } + + let globals = lua.globals(); + let eris: Table = match globals.get("eris") { + Ok(table) => table, + Err(_) => continue, + }; + + let handlers: Table = match eris.get("handlers") { + Ok(table) => table, + Err(_) => continue, + }; + + let event_handlers: Table = match handlers.get("block_ip") { + Ok(table) => table, + Err(_) => continue, + }; + + for pair in event_handlers.pairs::() { + if let Ok((_, handler)) = pair { + let event_ctx = create_event_context(&lua, &ctx); + // Add hit count for the block_ip event + event_ctx.set("hit_count", hit_count).unwrap(); + + if let Ok(result) = handler.call::<_, Value>((event_ctx,)) { + if let Value::Boolean(block) = result { + should_block = Some(block); + break; + } + } + } + } + + if should_block.is_some() { + break; + } + } + } + } + + // Return the script's decision, or default to the system behavior + should_block.unwrap_or(hit_count >= 3) + } + + // Maintains backward compatibility with the old API + // XXX: I never liked expand_response, should probably be removeedf + // in the future. pub fn expand_response( &self, text: &str, @@ -64,184 +660,242 @@ impl ScriptManager { path: &str, token: &str, ) -> String { - if !self.scripts_loaded { - return format!("{text}\n"); - } - - let lua = Lua::new(); - if let Err(e) = lua.load(&self.script_content).exec() { - log::warn!("Error loading Lua script: {e}"); - return format!("{text}\n"); - } - - let globals = lua.globals(); - match globals.get::<_, Function>("enhance_response") { - Ok(enhance_func) => { - match enhance_func.call::<_, String>((text, response_type, path, token)) { - Ok(result) => result, - Err(e) => { - log::warn!("Error calling Lua function enhance_response: {e}"); - format!("{text}\n") + if let Ok(lua) = self.lua.lock() { + let globals = lua.globals(); + match globals.get::<_, Function>("enhance_response") { + Ok(enhance_func) => { + match enhance_func.call::<_, String>((text, response_type, path, token)) { + Ok(result) => result, + Err(e) => { + log::warn!("Error calling Lua function enhance_response: {e}"); + format!("{text}\n") + } } } + Err(_) => format!("{text}\n"), } - Err(e) => { - log::warn!("Lua enhance_response function not found: {e}"); - format!("{text}\n") - } + } else { + format!("{text}\n") } } } +/// Get current timestamp in seconds +fn get_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +/// Create a unique session ID for tracking a connection +fn generate_session_id(ip: &str, user_agent: &str) -> String { + let timestamp = get_timestamp(); + let random = rand::random::(); + + // Use std::hash instead of xxhash_rust + let mut hasher = DefaultHasher::new(); + format!("{ip}_{user_agent}_{timestamp}").hash(&mut hasher); + let hash = hasher.finish(); + + format!("SID_{hash:x}_{random:x}") +} + +// Create an event context table in Lua +fn create_event_context<'a>(lua: &'a Lua, event_ctx: &EventContext) -> Table<'a> { + let table = lua.create_table().unwrap(); + + table.set("event", event_ctx.event_type.as_str()).unwrap(); + if let Some(ip) = &event_ctx.ip { + table.set("ip", ip.clone()).unwrap(); + } + if let Some(path) = &event_ctx.path { + table.set("path", path.clone()).unwrap(); + } + if let Some(ua) = &event_ctx.user_agent { + table.set("user_agent", ua.clone()).unwrap(); + } + table.set("timestamp", event_ctx.timestamp).unwrap(); + if let Some(sid) = &event_ctx.session_id { + table.set("session_id", sid.clone()).unwrap(); + } + if let Some(content) = &event_ctx.content { + table.set("content", content.clone()).unwrap(); + } + + table +} + #[cfg(test)] mod tests { use super::*; use std::fs; - use std::io::Write; + use tempfile::TempDir; #[test] - fn test_script_manager_default_script() { - let script_manager = ScriptManager::new("/nonexistent_directory"); - assert!(script_manager.scripts_loaded); - assert!( - script_manager - .script_content - .contains("generate_honeytoken") - ); - assert!(script_manager.script_content.contains("enhance_response")); - } - - #[test] - fn test_script_manager_custom_scripts() { + fn test_event_registration() { let temp_dir = TempDir::new().unwrap(); - let script_dir = temp_dir.path().to_str().unwrap(); + let script_path = temp_dir.path().join("test_events.lua"); + let script_content = r#" + -- Example script with event handlers + eris.info("Registering event handlers") - // Create a test script - let script_path = temp_dir.path().join("test_script.lua"); - let mut file = fs::File::create(&script_path).unwrap(); - writeln!( - file, - "function enhance_response(text, response_type, path, token)" - ) - .unwrap(); - writeln!(file, " return text .. ' - Enhanced with token: ' .. token").unwrap(); - writeln!(file, "end").unwrap(); + -- Connection event handler + eris.on("connection", function(ctx) + eris.debug("Connection from " .. ctx.ip) + return true -- accept the connection + end) - let script_manager = ScriptManager::new(script_dir); - assert!(script_manager.scripts_loaded); - assert!( - !script_manager - .script_content - .contains("generate_honeytoken") - ); // Default script not loaded - assert!( - script_manager - .script_content - .contains("Enhanced with token") - ); - } - - #[test] - fn test_expand_response_successful() { - let lua_code = r#" - function enhance_response(text, response_type, path, token) - return text .. " | Type: " .. response_type .. " | Path: " .. path .. " | Token: " .. token - end + -- Response generation handler + eris.on("response_gen", function(ctx) + eris.debug("Generating response for " .. ctx.path) + return ctx.content .. "" + end) "#; - let script_manager = ScriptManager::with_content(lua_code); - let result = - script_manager.expand_response("Test content", "test_type", "/test/path", "12345"); + fs::write(&script_path, script_content).unwrap(); - assert_eq!( - result, - "Test content | Type: test_type | Path: /test/path | Token: 12345" - ); + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + // Verify hooks were registered + assert!(script_manager.hooks.contains_key(&EventType::Connection)); + assert!(script_manager.hooks.contains_key(&EventType::ResponseGen)); + assert!(!script_manager.hooks.contains_key(&EventType::BlockIP)); } #[test] - fn test_expand_response_syntax_error() { - let lua_code = r#" - function enhance_response(text, response_type, path, token) - This is an invalid Lua syntax - return "Something" - end - "#; - - let script_manager = ScriptManager::with_content(lua_code); - let result = - script_manager.expand_response("Test content", "test_type", "/test/path", "12345"); - - assert!(result.contains("Test content")); - assert!(result.contains("")); - } - - #[test] - fn test_expand_response_runtime_error() { - let lua_code = r" - function enhance_response(text, response_type, path, token) - -- This will cause a runtime error - return nonexistent_variable - end - "; - - let script_manager = ScriptManager::with_content(lua_code); - let result = - script_manager.expand_response("Test content", "test_type", "/test/path", "12345"); - - assert!(result.contains("Test content")); - assert!(result.contains("")); - } - - #[test] - fn test_expand_response_missing_function() { - let lua_code = r#" - -- This script doesn't define enhance_response function - function some_other_function() - return "Hello, world!" - end - "#; - - let script_manager = ScriptManager::with_content(lua_code); - let result = - script_manager.expand_response("Test content", "test_type", "/test/path", "12345"); - - assert!(result.contains("Test content")); - assert!(result.contains("")); - } - - #[test] - fn test_expand_response_multiple_scripts() { + fn test_generate_response() { let temp_dir = TempDir::new().unwrap(); - let script_dir = temp_dir.path().to_str().unwrap(); + let script_path = temp_dir.path().join("response_test.lua"); + let script_content = r#" + eris.on("response_gen", function(ctx) + return ctx.content .. " - Modified by " .. ctx.user_agent + end) + "#; - // Create first script with helper function - let script1_path = temp_dir.path().join("01_helpers.lua"); - let mut file1 = fs::File::create(script1_path).unwrap(); - writeln!(file1, "function create_prefix(token)").unwrap(); - writeln!(file1, " return 'PREFIX_' .. token").unwrap(); - writeln!(file1, "end").unwrap(); + fs::write(&script_path, script_content).unwrap(); - // Create second script that uses the helper - let script2_path = temp_dir.path().join("02_responder.lua"); - let mut file2 = fs::File::create(script2_path).unwrap(); - writeln!( - file2, - "function enhance_response(text, response_type, path, token)" - ) - .unwrap(); - writeln!( - file2, - " return text .. ' [' .. create_prefix(token) .. ']'" - ) - .unwrap(); - writeln!(file2, "end").unwrap(); + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); - let script_manager = ScriptManager::new(script_dir); - let result = - script_manager.expand_response("Test content", "test_type", "/test/path", "12345"); + let headers = HashMap::new(); + let result = script_manager.generate_response( + "/test/path", + "TestBot", + "127.0.0.1", + &headers, + "Original content", + ); - assert_eq!(result, "Test content [PREFIX_12345]"); + assert!(result.contains("Original content")); + assert!(result.contains("Modified by TestBot")); + } + + #[test] + fn test_process_chunk() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("chunk_test.lua"); + let script_content = r#" + eris.on("response_chunk", function(ctx) + return ctx.content:gsub("secret", "REDACTED") + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + let result = script_manager.process_chunk( + "This contains a secret password", + "127.0.0.1", + "test_session", + ); + + assert!(result.contains("This contains a REDACTED password")); + } + + #[test] + fn test_should_block_ip() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("block_test.lua"); + let script_content = r#" + eris.on("block_ip", function(ctx) + -- Block any IP with "192.168.1" prefix regardless of hit count + if string.match(ctx.ip, "^192%.168%.1%.") then + return true + end + + -- Don't block IPs with "10.0" prefix even if they hit the threshold + if string.match(ctx.ip, "^10%.0%.") then + return false + end + + -- Default behavior for other IPs (nil = use system default) + return nil + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + // Should be blocked based on IP pattern + assert!(script_manager.should_block_ip("192.168.1.50", 1)); + + // Should not be blocked despite high hit count + assert!(!script_manager.should_block_ip("10.0.0.5", 10)); + + // Should use default behavior (block if >= 3 hits) + assert!(!script_manager.should_block_ip("172.16.0.1", 2)); + assert!(script_manager.should_block_ip("172.16.0.1", 3)); + } + + #[test] + fn test_state_and_counters() { + let temp_dir = TempDir::new().unwrap(); + let script_path = temp_dir.path().join("state_test.lua"); + let script_content = r#" + eris.on("startup", function(ctx) + eris.set_state("test_key", "test_value") + eris.inc_counter("visits", 0) + end) + + eris.on("connection", function(ctx) + local count = eris.inc_counter("visits") + eris.debug("Visit count: " .. count) + + -- Store last visitor + eris.set_state("last_visitor", ctx.ip) + return true + end) + + eris.on("response_gen", function(ctx) + local last_visitor = eris.get_state("last_visitor") or "unknown" + local visits = eris.get_counter("visits") + return ctx.content .. "" + end) + "#; + + fs::write(&script_path, script_content).unwrap(); + + let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap()); + + // Simulate connections + script_manager.on_connection("192.168.1.100"); + script_manager.on_connection("10.0.0.50"); + + // Check response includes state + let headers = HashMap::new(); + let result = script_manager.generate_response( + "/test", + "test-agent", + "8.8.8.8", + &headers, + "Response", + ); + + assert!(result.contains("Last visitor: 10.0.0.50")); + assert!(result.contains("Total visits: 2")); } } diff --git a/src/main.rs b/src/main.rs index 154b374..197c0a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::env; use std::fs; +use std::hash::Hasher; use std::io::Write; use std::net::IpAddr; use std::path::{Path, PathBuf}; @@ -20,7 +21,7 @@ mod lua; mod markov; mod metrics; -use lua::ScriptManager; +use lua::{EventContext, EventType, ScriptManager}; use markov::MarkovGenerator; use metrics::{ ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, UA_HITS, metrics_handler, @@ -413,6 +414,67 @@ fn extract_header_value(data: &[u8], header_name: &str) -> Option { None } +// Extract all headers from request data +fn extract_all_headers(data: &[u8]) -> HashMap { + let mut headers = HashMap::new(); + + if let Ok(data_str) = std::str::from_utf8(data) { + let mut lines = data_str.lines(); + + // Skip the request line + let _ = lines.next(); + + // Parse headers until empty line + for line in lines { + if line.is_empty() { + break; + } + + if let Some(colon_pos) = line.find(':') { + let key = line[..colon_pos].trim().to_lowercase(); + let value = line[colon_pos + 1..].trim().to_string(); + headers.insert(key, value); + } + } + } + + headers +} + +// Determine response type based on request path +fn choose_response_type(path: &str) -> &'static str { + if path.contains("phpunit") || path.contains("eval") { + "php_exploit" + } else if path.contains("wp-") { + "wordpress" + } else if path.contains("api") { + "api" + } else { + "generic" + } +} + +// Helper function to get current timestamp in seconds +fn get_timestamp() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +// Create a unique session ID for tracking a connection +fn generate_session_id(ip: &str, user_agent: &str) -> String { + let timestamp = get_timestamp(); + let random = rand::random::(); + + // Use std::hash instead of xxhash_rust + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + std::hash::Hash::hash(&format!("{ip}_{user_agent}_{timestamp}"), &mut hasher); + let hash = hasher.finish(); + + format!("SID_{hash:x}_{random:x}") +} + // Main connection handler. // Decides whether to tarpit or proxy async fn handle_connection( @@ -438,6 +500,13 @@ async fn handle_connection( return; } + // Check if Lua scripts allow this connection + if !script_manager.on_connection(&peer_addr.to_string()) { + log::debug!("Connection rejected by Lua script: {peer_addr}"); + let _ = stream.shutdown().await; + return; + } + // Pre-check for whitelisted IPs to bypass heavy processing let mut whitelisted = false; for network_str in &config.whitelist_networks { @@ -512,14 +581,30 @@ async fn handle_connection( return; }; + // Extract request headers for Lua scripts + let headers = extract_all_headers(&request_data); + + // Extract user agent for logging and decision making + let user_agent = + extract_header_value(&request_data, "user-agent").unwrap_or_else(|| "unknown".to_string()); + + // Trigger request event for Lua scripts + let request_ctx = EventContext { + event_type: EventType::Request, + ip: Some(peer_addr.to_string()), + path: Some(path.to_string()), + user_agent: Some(user_agent.clone()), + request_headers: Some(headers.clone()), + content: None, + timestamp: get_timestamp(), + session_id: Some(generate_session_id(&peer_addr.to_string(), &user_agent)), + }; + script_manager.trigger_event(&request_ctx); + // Check if this request matches our tarpit patterns let should_tarpit = should_tarpit(path, &peer_addr, &config).await; if should_tarpit { - // Extract minimal info needed for tarpit - let user_agent = extract_header_value(&request_data, "user-agent") - .unwrap_or_else(|| "unknown".to_string()); - log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})"); // Update metrics @@ -536,8 +621,11 @@ async fn handle_connection( *state.hits.entry(peer_addr).or_insert(0) += 1; let hit_count = state.hits[&peer_addr]; + // Use Lua to decide whether to block this IP + let should_block = script_manager.should_block_ip(&peer_addr.to_string(), hit_count); + // Block IPs that hit tarpits too many times - if hit_count >= config.block_threshold && !state.blocked.contains(&peer_addr) { + if should_block && !state.blocked.contains(&peer_addr) { log::info!("Blocking IP {peer_addr} after {hit_count} hits"); state.blocked.insert(peer_addr); BLOCKED_IPS.set(state.blocked.len() as f64); @@ -579,20 +667,116 @@ async fn handle_connection( } // Generate a deceptive response using Markov chains and Lua - let response = - generate_deceptive_response(path, &user_agent, &markov_generator, &script_manager) - .await; + let response = generate_deceptive_response( + path, + &user_agent, + &peer_addr, + &headers, + &markov_generator, + &script_manager, + ) + .await; + + // Generate a session ID for tracking this tarpit session + let session_id = generate_session_id(&peer_addr.to_string(), &user_agent); // Send the response with the tarpit delay strategy - tarpit_connection( - stream, - response, - peer_addr, - state.clone(), - config.min_delay, - config.max_delay, - config.max_tarpit_time, - ) + { + let mut stream = stream; + let peer_addr = peer_addr; + let state = state.clone(); + let min_delay = config.min_delay; + let max_delay = config.max_delay; + let max_tarpit_time = config.max_tarpit_time; + let script_manager = script_manager.clone(); + async move { + let start_time = Instant::now(); + let mut chars = response.chars().collect::>(); + for i in (0..chars.len()).rev() { + if i > 0 && rand::random::() < 0.1 { + chars.swap(i, i - 1); + } + } + log::debug!( + "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", + peer_addr, + chars.len(), + min_delay, + max_delay + ); + let mut position = 0; + let mut chunks_sent = 0; + let mut total_delay = 0; + while position < chars.len() { + // Check if we've exceeded maximum tarpit time + let elapsed_secs = start_time.elapsed().as_secs(); + if elapsed_secs > max_tarpit_time { + log::info!( + "Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}" + ); + break; + } + + // Decide how many chars to send in this chunk (usually 1, sometimes more) + let chunk_size = if rand::random::() < 0.9 { + 1 + } else { + (rand::random::() * 3.0).floor() as usize + 1 + }; + + let end = (position + chunk_size).min(chars.len()); + let chunk: String = chars[position..end].iter().collect(); + + // Process chunk through Lua before sending + let processed_chunk = + script_manager.process_chunk(&chunk, &peer_addr.to_string(), &session_id); + + // Try to write processed chunk + if stream.write_all(processed_chunk.as_bytes()).await.is_err() { + log::debug!("Connection closed by client during tarpit: {peer_addr}"); + break; + } + + if stream.flush().await.is_err() { + log::debug!("Failed to flush stream during tarpit: {peer_addr}"); + break; + } + + position = end; + chunks_sent += 1; + + // Apply random delay between min and max configured values + let delay_ms = + (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; + total_delay += delay_ms; + sleep(Duration::from_millis(delay_ms)).await; + } + log::debug!( + "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", + peer_addr, + chunks_sent, + position * 100 / chars.len(), + total_delay, + start_time.elapsed().as_secs() + ); + let disconnection_ctx = EventContext { + event_type: EventType::Disconnection, + ip: Some(peer_addr.to_string()), + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: Some(session_id), + }; + script_manager.trigger_event(&disconnection_ctx); + if let Ok(mut state) = state.try_write() { + state.active_connections.remove(&peer_addr); + ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); + } + let _ = stream.shutdown().await; + } + } .await; } else { log::debug!("Proxying request: {path} from {peer_addr}"); @@ -713,134 +897,25 @@ async fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool { async fn generate_deceptive_response( path: &str, user_agent: &str, + peer_addr: &IpAddr, + headers: &HashMap, markov: &MarkovGenerator, script_manager: &ScriptManager, ) -> String { - // Choose response type based on path to seem more realistic - let response_type = if path.contains("phpunit") || path.contains("eval") { - "php_exploit" - } else if path.contains("wp-") { - "wordpress" - } else if path.contains("api") { - "api" - } else { - "generic" - }; - - log::debug!("Generating {response_type} response for path: {path}"); - - // Generate tracking token for this interaction - let tracking_token = format!( - "BOT_{}_{}", - user_agent - .chars() - .filter(|c| c.is_alphanumeric()) - .collect::(), - chrono::Utc::now().timestamp() - ); - // Generate base response using Markov chain text generator + let response_type = choose_response_type(path); let markov_text = markov.generate(response_type, 30); - // Use Lua to enhance with honeytokens and other deceptive content - let response_expanded = - script_manager.expand_response(&markov_text, response_type, path, &tracking_token); - - // Return full HTTP response with appropriate headers - format!( - "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Powered-By: PHP/7.4.3\r\nConnection: keep-alive\r\n\r\n{response_expanded}" + // Use Lua scripts to enhance with honeytokens and other deceptive content + script_manager.generate_response( + path, + user_agent, + &peer_addr.to_string(), + headers, + &markov_text, ) } -// Slowly feed a response to the client with random delays to waste attacker time -async fn tarpit_connection( - mut stream: TcpStream, - response: String, - peer_addr: IpAddr, - state: Arc>, - min_delay: u64, - max_delay: u64, - max_tarpit_time: u64, -) { - let start_time = Instant::now(); - let mut chars = response.chars().collect::>(); - - // Randomize the char order slightly to confuse automated tools - for i in (0..chars.len()).rev() { - if i > 0 && rand::random::() < 0.1 { - chars.swap(i, i - 1); - } - } - - log::debug!( - "Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms", - peer_addr, - chars.len(), - min_delay, - max_delay - ); - - let mut position = 0; - let mut chunks_sent = 0; - let mut total_delay = 0; - - // Send the response character by character with random delays - while position < chars.len() { - // Check if we've exceeded maximum tarpit time - let elapsed_secs = start_time.elapsed().as_secs(); - if elapsed_secs > max_tarpit_time { - log::info!("Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}"); - break; - } - - // Decide how many chars to send in this chunk (usually 1, sometimes more) - let chunk_size = if rand::random::() < 0.9 { - 1 - } else { - (rand::random::() * 3.0).floor() as usize + 1 - }; - - let end = (position + chunk_size).min(chars.len()); - let chunk: String = chars[position..end].iter().collect(); - - // Try to write chunk - if stream.write_all(chunk.as_bytes()).await.is_err() { - log::debug!("Connection closed by client during tarpit: {peer_addr}"); - break; - } - - if stream.flush().await.is_err() { - log::debug!("Failed to flush stream during tarpit: {peer_addr}"); - break; - } - - position = end; - chunks_sent += 1; - - // Apply random delay between min and max configured values - let delay_ms = (rand::random::() * (max_delay - min_delay) as f32) as u64 + min_delay; - total_delay += delay_ms; - sleep(Duration::from_millis(delay_ms)).await; - } - - log::debug!( - "Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s", - peer_addr, - chunks_sent, - position * 100 / chars.len(), - total_delay, - start_time.elapsed().as_secs() - ); - - // Remove from active connections - if let Ok(mut state) = state.try_write() { - state.active_connections.remove(&peer_addr); - ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64); - } - - let _ = stream.shutdown().await; -} - // Set up nftables firewall rules for IP blocking async fn setup_firewall() -> Result<(), String> { log::info!("Setting up firewall rules"); @@ -1059,6 +1134,8 @@ async fn main() -> std::io::Result<()> { // Initialize Lua script manager log::info!("Loading Lua scripts from {}", config.lua_scripts_dir); let script_manager = Arc::new(ScriptManager::new(&config.lua_scripts_dir)); + let script_manager_for_tarpit = script_manager.clone(); + let script_manager_for_periodic = script_manager.clone(); // Clone config for metrics server let metrics_config = config.clone(); @@ -1083,7 +1160,7 @@ async fn main() -> std::io::Result<()> { let state_clone = tarpit_state.clone(); let markov_clone = markov_generator.clone(); - let script_manager_clone = script_manager.clone(); + let script_manager_clone = script_manager_for_tarpit.clone(); let config_clone = config.clone(); tokio::spawn(async move { @@ -1138,6 +1215,27 @@ async fn main() -> std::io::Result<()> { } }; + // Setup periodic task runner for Lua scripts + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + + // Trigger periodic event + let ctx = EventContext { + event_type: EventType::Periodic, + ip: None, + path: None, + user_agent: None, + request_headers: None, + content: None, + timestamp: get_timestamp(), + session_id: None, + }; + script_manager_for_periodic.trigger_event(&ctx); + } + }); + // Run both servers concurrently if metrics server is enabled if let Some(metrics_server) = metrics_server { tokio::select! { @@ -1317,37 +1415,6 @@ mod tests { } } - #[tokio::test] - async fn test_generate_deceptive_response() { - // Create a simple markov generator for testing - let markov = MarkovGenerator::new("/nonexistent/path"); - let script_manager = ScriptManager::new("/nonexistent/path"); - - // Test different path types - let resp1 = generate_deceptive_response( - "/vendor/phpunit/exec", - "TestBot/1.0", - &markov, - &script_manager, - ) - .await; - assert!(resp1.contains("HTTP/1.1 200 OK")); - assert!(resp1.contains("X-Powered-By: PHP")); - - let resp2 = - generate_deceptive_response("/wp-admin/", "TestBot/1.0", &markov, &script_manager) - .await; - assert!(resp2.contains("HTTP/1.1 200 OK")); - - let resp3 = - generate_deceptive_response("/api/users", "TestBot/1.0", &markov, &script_manager) - .await; - assert!(resp3.contains("HTTP/1.1 200 OK")); - - // Verify tracking token is included - assert!(resp1.contains("BOT_TestBot")); - } - #[test] fn test_find_header_end() { let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content"; @@ -1379,4 +1446,31 @@ mod tests { ); assert_eq!(extract_header_value(data, "nonexistent"), None); } + + #[test] + fn test_extract_all_headers() { + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\nAccept: */*\r\n\r\n"; + let headers = extract_all_headers(data); + + assert_eq!(headers.len(), 3); + assert_eq!(headers.get("host").unwrap(), "example.com"); + assert_eq!(headers.get("user-agent").unwrap(), "TestBot/1.0"); + assert_eq!(headers.get("accept").unwrap(), "*/*"); + } + + #[test] + fn test_choose_response_type() { + assert_eq!( + choose_response_type("/vendor/phpunit/whatever"), + "php_exploit" + ); + assert_eq!( + choose_response_type("/path/to/eval-stdin.php"), + "php_exploit" + ); + assert_eq!(choose_response_type("/wp-admin/login.php"), "wordpress"); + assert_eq!(choose_response_type("/wp-login.php"), "wordpress"); + assert_eq!(choose_response_type("/api/v1/users"), "api"); + assert_eq!(choose_response_type("/index.html"), "generic"); + } }