WIP: network: implement basic ratelimiting #1

Draft
NotAShelf wants to merge 16 commits from multi-site-proxy into main
15 changed files with 3037 additions and 1247 deletions

31
.github/workflows/tag.yml vendored Normal file
View file

@ -0,0 +1,31 @@
name: Tag latest version
on:
workflow_dispatch:
push:
branches: [ main ]
concurrency: tag
jobs:
tag-release:
runs-on: ubuntu-latest
steps:
- uses: cachix/install-nix-action@master
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- name: Checkout
uses: actions/checkout@v4
- name: Read version
run: |
echo -n "_version=v" >> "$GITHUB_ENV"
nix run nixpkgs#fq -- -r ".package.version" Cargo.toml >> "$GITHUB_ENV"
cat "$GITHUB_ENV"
- name: Tag
run: |
set -x
git tag $version
git push --tags || :

90
Cargo.lock generated
View file

@ -414,9 +414,7 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-link",
]
@ -442,7 +440,6 @@ version = "4.5.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
@ -620,7 +617,7 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "eris"
version = "0.1.0"
version = "1.0.0"
dependencies = [
"actix-web",
"chrono",
@ -633,10 +630,13 @@ dependencies = [
"prometheus 0.14.0",
"prometheus_exporter",
"rand",
"regex",
"rlua",
"serde",
"serde_json",
"tempfile",
"tokio",
"toml",
]
[[package]]
@ -646,9 +646,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "flate2"
version = "1.1.1"
@ -1580,7 +1586,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -1633,6 +1639,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_spanned"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@ -1740,6 +1755,19 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "tempfile"
version = "3.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf"
dependencies = [
"fastrand",
"getrandom",
"once_cell",
"rustix",
"windows-sys 0.59.0",
]
[[package]]
name = "thiserror"
version = "1.0.69"
@ -1876,6 +1904,47 @@ dependencies = [
"tokio",
]
[[package]]
name = "toml"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae"
dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"toml_edit",
]
[[package]]
name = "toml_datetime"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.22.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e"
dependencies = [
"indexmap",
"serde",
"serde_spanned",
"toml_datetime",
"toml_write",
"winnow",
]
[[package]]
name = "toml_write"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076"
[[package]]
name = "tracing"
version = "0.1.41"
@ -2187,6 +2256,15 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e27d6ad3dac991091e4d35de9ba2d2d00647c5d0fc26c5496dee55984ae111b"
dependencies = [
"memchr",
]
[[package]]
name = "winsafe"
version = "0.0.19"

View file

@ -1,12 +1,12 @@
[package]
name = "eris"
version = "0.1.0"
version = "1.0.0"
edition = "2024"
[dependencies]
actix-web = "4.3.1"
clap = { version = "4.3", features = ["derive"] }
chrono = "0.4.24"
actix-web = { version = "4.3.1" }
chrono = { version = "0.4.41", default-features = false, features = ["std", "clock"] }
clap = { version = "4.5", default-features = false, features = ["std", "derive", "help", "usage", "suggestions"] }
futures = "0.3.28"
ipnetwork = "0.21.1"
lazy_static = "1.4.0"
@ -19,3 +19,6 @@ serde_json = "1.0.96"
tokio = { version = "1.28.0", features = ["full"] }
log = "0.4.27"
env_logger = "0.11.8"
tempfile = "3.19.1"
regex = "1.11.1"
toml = "0.8.22"

View file

@ -19,6 +19,7 @@ in
fileset = fs.unions [
(fs.fileFilter (file: builtins.any file.hasExt ["rs"]) (s + /src))
(s + /contrib)
(s + /resources)
lockfile
cargoToml
];

View file

@ -0,0 +1,210 @@
--[[
Eris Default Script
This script demonstrates how to use the Eris Lua API to customize
the tarpit's behavior, and will be loaded by default if no other
scripts are loaded.
Available events:
- connection: When a new connection is established
- request: When a request is received
- response_gen: When generating a response
- response_chunk: Before sending each response chunk
- disconnection: When a connection is closed
- block_ip: When an IP is being considered for blocking
- startup: When the application starts
- shutdown: When the application is shutting down
- periodic: Called periodically
API Functions:
- eris.debug(message): Log a debug message
- eris.info(message): Log an info message
- eris.warn(message): Log a warning message
- eris.error(message): Log an error message
- eris.set_state(key, value): Store persistent state
- eris.get_state(key): Retrieve persistent state
- eris.inc_counter(key, [amount]): Increment a counter
- eris.get_counter(key): Get a counter value
- eris.gen_token([prefix]): Generate a unique token
- eris.timestamp(): Get current Unix timestamp
--]]
-- Called when the application starts
eris.on("startup", function(ctx)
eris.info("Initializing default script")
-- Initialize counters
eris.inc_counter("total_connections", 0)
eris.inc_counter("total_responses", 0)
eris.inc_counter("blocked_ips", 0)
-- Initialize banned keywords
eris.set_state("banned_keywords", "eval,exec,system,shell,<?php,/bin/bash")
end)
-- Called for each new connection
eris.on("connection", function(ctx)
eris.inc_counter("total_connections")
eris.debug("New connection from " .. ctx.ip)
-- You can reject connections by returning false
-- This example checks a blocklist
local blocklist = eris.get_state("manual_blocklist") or ""
if blocklist:find(ctx.ip) then
eris.info("Rejecting connection from manually blocked IP: " .. ctx.ip)
return false
end
return true -- accept the connection
end)
-- Called when generating a response
eris.on("response_gen", function(ctx)
eris.inc_counter("total_responses")
-- Generate a unique traceable token for this request
local token = eris.gen_token("ERIS-")
-- Add some believable but fake honeytokens based on the request path
local enhanced_content = ctx.content
if ctx.path:find("wp%-") then
-- For WordPress paths
enhanced_content = enhanced_content
.. "\n<!-- WordPress Debug: "
.. token
.. " -->"
.. "\n<!-- WP_HOME: http://stop.crawlingmysite.com/wordpress -->"
.. "\n<!-- DB_USER: wp_user_"
.. math.random(1000, 9999)
.. " -->"
elseif ctx.path:find("phpunit") or ctx.path:find("eval") then
-- For PHP exploit attempts
-- Turns out you can just google "PHP error log" and search random online forums where people
-- dump their service logs in full.
enhanced_content = enhanced_content
.. "\nPHP Notice: Undefined variable: _SESSION in /var/www/html/includes/core.php on line 58\n"
.. "Warning: file_get_contents(): Filename cannot be empty in /var/www/html/vendor/autoload.php on line 23\n"
.. "Token: "
.. token
.. "\n"
elseif ctx.path:find("api") then
-- For API requests
local fake_api_key =
string.format("ak_%x%x%x", math.random(1000, 9999), math.random(1000, 9999), math.random(1000, 9999))
enhanced_content = enhanced_content
.. "{\n"
.. ' "status": "warning",\n'
.. ' "message": "Test API environment detected",\n'
.. ' "debug_token": "'
.. token
.. '",\n'
.. ' "api_key": "'
.. fake_api_key
.. '"\n'
.. "}\n"
else
-- For other requests
enhanced_content = enhanced_content
.. "\n<!-- Server: Apache/2.4.41 (Ubuntu) -->"
.. "\n<!-- Debug-Token: "
.. token
.. " -->"
.. "\n<!-- Environment: staging -->"
end
-- Track which honeytokens were sent to which IP
local honeytokens = eris.get_state("honeytokens") or "{}"
local ht_table = {}
-- This is a simplistic approach - in a real script, you'd want to use
-- a proper JSON library to handle this correctly
if honeytokens ~= "{}" then
-- Simple parsing of the stored data
for ip, tok in honeytokens:gmatch('"([^"]+)":"([^"]+)"') do
ht_table[ip] = tok
end
end
ht_table[ctx.ip] = token
-- Convert back to a simple JSON-like string
local new_tokens = "{"
for ip, tok in pairs(ht_table) do
if new_tokens ~= "{" then
new_tokens = new_tokens .. ","
end
new_tokens = new_tokens .. '"' .. ip .. '":"' .. tok .. '"'
end
new_tokens = new_tokens .. "}"
eris.set_state("honeytokens", new_tokens)
return enhanced_content
end)
-- Called before sending each chunk of a response
eris.on("response_chunk", function(ctx)
-- This can be used to alter individual chunks for more deceptive behavior
-- For example, to simulate a slow, unreliable server
-- 5% chance of "corrupting" a chunk to confuse scanners
if math.random(1, 100) <= 5 then
local chunk = ctx.content
if #chunk > 10 then
local pos = math.random(1, #chunk - 5)
chunk = chunk:sub(1, pos) .. string.char(math.random(32, 126)) .. chunk:sub(pos + 2)
end
return chunk
end
return ctx.content
end)
-- Called when deciding whether to block an IP
eris.on("block_ip", function(ctx)
-- You can override the default blocking logic
-- Check for potential attackers using specific patterns
local banned_keywords = eris.get_state("banned_keywords") or ""
local user_agent = ctx.user_agent or ""
-- Check if user agent contains highly suspicious patterns
for keyword in banned_keywords:gmatch("[^,]+") do
if user_agent:lower():find(keyword:lower()) then
eris.info("Blocking IP " .. ctx.ip .. " due to suspicious user agent: " .. keyword)
eris.inc_counter("blocked_ips")
return true -- Force block
end
end
-- For demonstration, we'll be more lenient with 10.x IPs
if ctx.ip:match("^10%.") then
-- Only block if they've hit us many times
return ctx.hit_count >= 5
end
-- Default to the system's threshold-based decision
return nil
end)
-- The enhance_response is now legacy, and I never liked it anyway. Though let's add it here
-- for the sake of backwards compatibility.
function enhance_response(text, response_type, path, token)
local enhanced = text
-- Add token as a comment
if response_type == "php_exploit" then
enhanced = enhanced .. "\n/* Token: " .. token .. " */\n"
elseif response_type == "wordpress" then
enhanced = enhanced .. "\n<!-- WordPress Debug Token: " .. token .. " -->\n"
elseif response_type == "api" then
enhanced = enhanced:gsub('"status": "[^"]+"', '"status": "warning"')
enhanced = enhanced:gsub('"message": "[^"]+"', '"message": "API token: ' .. token .. '"')
else
enhanced = enhanced .. "\n<!-- Debug token: " .. token .. " -->\n"
end
return enhanced
end

613
src/config.rs Normal file
View file

@ -0,0 +1,613 @@
use clap::Parser;
use ipnetwork::IpNetwork;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::env;
use std::fs;
use std::net::IpAddr;
use std::path::{Path, PathBuf};
// Command-line arguments using clap
#[derive(Parser, Debug, Clone)]
#[clap(
author,
version,
about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers"
)]
pub struct Args {
#[clap(
long,
default_value = "0.0.0.0:8888",
help = "Address and port to listen for incoming HTTP requests (format: ip:port)"
)]
pub listen_addr: String,
#[clap(
long,
default_value = "0.0.0.0:9100",
help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)"
)]
pub metrics_addr: String,
#[clap(long, help = "Disable Prometheus metrics server completely")]
pub disable_metrics: bool,
#[clap(
long,
default_value = "127.0.0.1:80",
help = "Backend server address to proxy legitimate requests to (format: ip:port)"
)]
pub backend_addr: String,
#[clap(
long,
default_value = "1000",
help = "Minimum delay in milliseconds between chunks sent to attacker"
)]
pub min_delay: u64,
#[clap(
long,
default_value = "15000",
help = "Maximum delay in milliseconds between chunks sent to attacker"
)]
pub max_delay: u64,
#[clap(
long,
default_value = "600",
help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting"
)]
pub max_tarpit_time: u64,
#[clap(
long,
default_value = "3",
help = "Number of hits to honeypot patterns before permanently blocking an IP"
)]
pub block_threshold: u32,
#[clap(
long,
help = "Base directory for all application data (overrides XDG directory structure)"
)]
pub base_dir: Option<PathBuf>,
#[clap(
long,
help = "Path to configuration file (JSON or TOML, overrides command line options)"
)]
pub config_file: Option<PathBuf>,
#[clap(
long,
default_value = "info",
help = "Log level: trace, debug, info, warn, error"
)]
pub log_level: String,
#[clap(
long,
default_value = "pretty",
help = "Log format: plain, pretty, json, pretty-json"
)]
pub log_format: String,
#[clap(long, help = "Enable rate limiting for connections from the same IP")]
pub rate_limit_enabled: bool,
#[clap(long, default_value = "60", help = "Rate limit window in seconds")]
pub rate_limit_window: u64,
#[clap(
long,
default_value = "30",
help = "Maximum number of connections allowed per IP in the rate limit window"
)]
pub rate_limit_max: usize,
#[clap(
long,
default_value = "100",
help = "Connection attempts threshold before considering for IP blocking"
)]
pub rate_limit_block_threshold: usize,
#[clap(
long,
help = "Send a 429 response for rate limited connections instead of dropping connection"
)]
pub rate_limit_slow_response: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub enum LogFormat {
Plain,
#[default]
Pretty,
Json,
PrettyJson,
}
// Trap pattern structure. It can be either a plain string
// regex to catch more advanced patterns necessitated by
// more sophisticated crawlers.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum TrapPattern {
Plain(String),
Regex { pattern: String, regex: bool },
}
impl TrapPattern {
pub fn as_plain(value: &str) -> Self {
Self::Plain(value.to_string())
}
pub fn as_regex(value: &str) -> Self {
Self::Regex {
pattern: value.to_string(),
regex: true,
}
}
pub fn matches(&self, path: &str) -> bool {
match self {
Self::Plain(pattern) => path.contains(pattern),
Self::Regex {
pattern,
regex: true,
} => {
if let Ok(re) = Regex::new(pattern) {
re.is_match(path)
} else {
false
}
}
_ => false,
}
}
}
// Configuration structure
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Config {
pub listen_addr: String,
pub metrics_addr: String,
pub disable_metrics: bool,
pub backend_addr: String,
pub min_delay: u64,
pub max_delay: u64,
pub max_tarpit_time: u64,
pub block_threshold: u32,
pub trap_patterns: Vec<TrapPattern>,
pub whitelist_networks: Vec<String>,
pub markov_corpora_dir: String,
pub lua_scripts_dir: String,
pub data_dir: String,
pub config_dir: String,
pub cache_dir: String,
pub log_format: LogFormat,
pub rate_limit_enabled: bool,
pub rate_limit_window_seconds: u64,
pub rate_limit_max_connections: usize,
pub rate_limit_block_threshold: usize,
pub rate_limit_slow_response: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
listen_addr: "0.0.0.0:8888".to_string(),
metrics_addr: "0.0.0.0:9100".to_string(),
disable_metrics: false,
backend_addr: "127.0.0.1:80".to_string(),
min_delay: 1000,
max_delay: 15000,
max_tarpit_time: 600,
block_threshold: 3,
trap_patterns: vec![
// Basic attack patterns as plain strings
TrapPattern::as_plain("/vendor/phpunit"),
TrapPattern::as_plain("eval-stdin.php"),
TrapPattern::as_plain("/wp-admin"),
TrapPattern::as_plain("/wp-login.php"),
TrapPattern::as_plain("/xmlrpc.php"),
TrapPattern::as_plain("/phpMyAdmin"),
TrapPattern::as_plain("/solr/"),
TrapPattern::as_plain("/.env"),
TrapPattern::as_plain("/config"),
TrapPattern::as_plain("/actuator/"),
// More aggressive patterns for various PHP exploits.
// XXX: I dedicate this entire section to that one single crawler
// that has been scanning my entire network, hitting 403s left and right
// but not giving up, and coming back the next day at the same time to
// scan the same paths over and over. Kudos to you, random crawler.
TrapPattern::as_regex(r"/.*phpunit.*eval-stdin\.php"),
TrapPattern::as_regex(r"/index\.php\?s=/index/\\think\\app/invokefunction"),
TrapPattern::as_regex(r".*%ADd\+auto_prepend_file%3dphp://input.*"),
TrapPattern::as_regex(r".*%ADd\+allow_url_include%3d1.*"),
TrapPattern::as_regex(r".*/wp-content/plugins/.*\.php"),
TrapPattern::as_regex(r".*/wp-content/themes/.*\.php"),
TrapPattern::as_regex(r".*eval\(.*\).*"),
TrapPattern::as_regex(r".*/adminer\.php.*"),
TrapPattern::as_regex(r".*/admin\.php.*"),
TrapPattern::as_regex(r".*/administrator/.*"),
TrapPattern::as_regex(r".*/wp-json/.*"),
TrapPattern::as_regex(r".*/api/.*\.php.*"),
TrapPattern::as_regex(r".*/cgi-bin/.*"),
TrapPattern::as_regex(r".*/owa/.*"),
TrapPattern::as_regex(r".*/ecp/.*"),
TrapPattern::as_regex(r".*/webshell\.php.*"),
TrapPattern::as_regex(r".*/shell\.php.*"),
TrapPattern::as_regex(r".*/cmd\.php.*"),
TrapPattern::as_regex(r".*/struts.*"),
],
whitelist_networks: vec![
"192.168.0.0/16".to_string(),
"10.0.0.0/8".to_string(),
"172.16.0.0/12".to_string(),
"127.0.0.0/8".to_string(),
],
markov_corpora_dir: "./corpora".to_string(),
lua_scripts_dir: "./scripts".to_string(),
data_dir: "./data".to_string(),
config_dir: "./conf".to_string(),
cache_dir: "./cache".to_string(),
log_format: LogFormat::Pretty,
rate_limit_enabled: true,
rate_limit_window_seconds: 60,
rate_limit_max_connections: 30,
rate_limit_block_threshold: 100,
rate_limit_slow_response: true,
}
}
}
// Gets standard XDG directory paths for config, data and cache.
// XXX: This could be "simplified" by using the Dirs crate, but I can't
// really justify pulling a library for something I can handle in less
// than 30 lines. Unless cross-platform becomes necessary, the below
// implementation is good enough. For alternative platforms, we can simply
// enhance the current implementation as needed.
pub fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) {
let config_home = env::var_os("XDG_CONFIG_HOME")
.map(PathBuf::from)
.unwrap_or_else(|| {
let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from);
home.join(".config")
});
let data_home = env::var_os("XDG_DATA_HOME")
.map(PathBuf::from)
.unwrap_or_else(|| {
let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from);
home.join(".local").join("share")
});
let cache_home = env::var_os("XDG_CACHE_HOME")
.map(PathBuf::from)
.unwrap_or_else(|| {
let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from);
home.join(".cache")
});
let config_dir = config_home.join("eris");
let data_dir = data_home.join("eris");
let cache_dir = cache_home.join("eris");
(config_dir, data_dir, cache_dir)
}
impl Config {
// Create configuration from command-line args. We'll be falling back to this
// when the configuration is invalid, so it must be validated more strictly.
pub fn from_args(args: &Args) -> Self {
let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir {
let base_str = base_dir.to_string_lossy().to_string();
(
format!("{base_str}/conf"),
format!("{base_str}/data"),
format!("{base_str}/cache"),
)
} else {
let (c, d, cache) = get_xdg_dirs();
(
c.to_string_lossy().to_string(),
d.to_string_lossy().to_string(),
cache.to_string_lossy().to_string(),
)
};
Self {
listen_addr: args.listen_addr.clone(),
metrics_addr: args.metrics_addr.clone(),
disable_metrics: args.disable_metrics,
backend_addr: args.backend_addr.clone(),
min_delay: args.min_delay,
max_delay: args.max_delay,
max_tarpit_time: args.max_tarpit_time,
block_threshold: args.block_threshold,
markov_corpora_dir: format!("{data_dir}/corpora"),
lua_scripts_dir: format!("{data_dir}/scripts"),
data_dir,
config_dir,
cache_dir,
log_format: LogFormat::Pretty,
rate_limit_enabled: args.rate_limit_enabled,
rate_limit_window_seconds: args.rate_limit_window,
rate_limit_max_connections: args.rate_limit_max,
rate_limit_block_threshold: args.rate_limit_block_threshold,
rate_limit_slow_response: args.rate_limit_slow_response,
..Default::default()
}
}
// Load configuration from a file (JSON or TOML)
pub fn load_from_file(path: &Path) -> std::io::Result<Self> {
let content = fs::read_to_string(path)?;
let extension = path
.extension()
.map(|ext| ext.to_string_lossy().to_lowercase())
.unwrap_or_default();
let config = match extension.as_str() {
"toml" => toml::from_str(&content).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to parse TOML: {e}"),
)
})?,
_ => {
// Default to JSON for any other extension
serde_json::from_str(&content).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to parse JSON: {e}"),
)
})?
}
};
Ok(config)
}
// Save configuration to a file (JSON or TOML)
pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let extension = path
.extension()
.map(|ext| ext.to_string_lossy().to_lowercase())
.unwrap_or_default();
let content = match extension.as_str() {
"toml" => toml::to_string_pretty(self).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to serialize to TOML: {e}"),
)
})?,
_ => {
// Default to JSON for any other extension
serde_json::to_string_pretty(self).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to serialize to JSON: {e}"),
)
})?
}
};
fs::write(path, content)?;
Ok(())
}
// Create required directories if they don't exist
pub fn ensure_dirs_exist(&self) -> std::io::Result<()> {
let dirs = [
&self.markov_corpora_dir,
&self.lua_scripts_dir,
&self.data_dir,
&self.config_dir,
&self.cache_dir,
];
for dir in dirs {
fs::create_dir_all(dir)?;
log::debug!("Created directory: {dir}");
}
Ok(())
}
}
// Decide if a request should be tarpitted based on path and IP
pub fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool {
// Check whitelist IPs first to avoid unnecessary pattern matching
for network_str in &config.whitelist_networks {
if let Ok(network) = network_str.parse::<IpNetwork>() {
if network.contains(*ip) {
return false;
}
}
}
// Use pattern matching based on the trap pattern type. It can be
// a plain string or regex.
for pattern in &config.trap_patterns {
if pattern.matches(path) {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_config_from_args() {
let args = Args {
listen_addr: "127.0.0.1:8080".to_string(),
metrics_addr: "127.0.0.1:9000".to_string(),
disable_metrics: true,
backend_addr: "127.0.0.1:8081".to_string(),
min_delay: 500,
max_delay: 10000,
max_tarpit_time: 300,
block_threshold: 5,
base_dir: Some(PathBuf::from("/tmp/eris")),
config_file: None,
log_level: "debug".to_string(),
log_format: "pretty".to_string(),
rate_limit_enabled: true,
rate_limit_window: 30,
rate_limit_max: 20,
rate_limit_block_threshold: 50,
rate_limit_slow_response: true,
};
let config = Config::from_args(&args);
assert_eq!(config.listen_addr, "127.0.0.1:8080");
assert_eq!(config.metrics_addr, "127.0.0.1:9000");
assert!(config.disable_metrics);
assert_eq!(config.backend_addr, "127.0.0.1:8081");
assert_eq!(config.min_delay, 500);
assert_eq!(config.max_delay, 10000);
assert_eq!(config.max_tarpit_time, 300);
assert_eq!(config.block_threshold, 5);
assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora");
assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts");
assert_eq!(config.data_dir, "/tmp/eris/data");
assert_eq!(config.config_dir, "/tmp/eris/conf");
assert_eq!(config.cache_dir, "/tmp/eris/cache");
assert!(config.rate_limit_enabled);
assert_eq!(config.rate_limit_window_seconds, 30);
assert_eq!(config.rate_limit_max_connections, 20);
assert_eq!(config.rate_limit_block_threshold, 50);
assert!(config.rate_limit_slow_response);
}
#[test]
fn test_trap_pattern_matching() {
// Test plain string pattern
let plain = TrapPattern::as_plain("phpunit");
assert!(plain.matches("path/to/phpunit/test"));
assert!(!plain.matches("path/to/something/else"));
// Test regex pattern
let regex = TrapPattern::as_regex(r".*eval-stdin\.php.*");
assert!(regex.matches("/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php"));
assert!(regex.matches("/tests/eval-stdin.php?param"));
assert!(!regex.matches("/normal/path"));
// Test invalid regex pattern (should return false)
let invalid = TrapPattern::Regex {
pattern: "(invalid[regex".to_string(),
regex: true,
};
assert!(!invalid.matches("anything"));
}
#[tokio::test]
async fn test_should_tarpit() {
let config = Config::default();
// Test trap patterns
assert!(should_tarpit(
"/vendor/phpunit/whatever",
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
&config
));
assert!(should_tarpit(
"/wp-admin/login.php",
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
&config
));
assert!(should_tarpit(
"/.env",
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
&config
));
// Test whitelist networks
assert!(!should_tarpit(
"/wp-admin/login.php",
&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
&config
));
assert!(!should_tarpit(
"/vendor/phpunit/whatever",
&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
&config
));
// Test legitimate paths
assert!(!should_tarpit(
"/index.html",
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
&config
));
assert!(!should_tarpit(
"/images/logo.png",
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
&config
));
// Test regex patterns
assert!(should_tarpit(
"/index.php?s=/index/\\think\\app/invokefunction&function=call_user_func_array&vars[0]=md5&vars[1][]=Hello",
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
&config
));
assert!(should_tarpit(
"/hello.world?%ADd+allow_url_include%3d1+%ADd+auto_prepend_file%3dphp://input",
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
&config
));
}
#[test]
fn test_config_file_formats() {
// Create temporary JSON config file
let temp_dir = std::env::temp_dir();
let json_path = temp_dir.join("temp_config.json");
let toml_path = temp_dir.join("temp_config.toml");
let config = Config::default();
// Test JSON serialization and deserialization
config.save_to_file(&json_path).unwrap();
let loaded_json = Config::load_from_file(&json_path).unwrap();
assert_eq!(loaded_json.listen_addr, config.listen_addr);
assert_eq!(loaded_json.min_delay, config.min_delay);
assert_eq!(loaded_json.rate_limit_enabled, config.rate_limit_enabled);
assert_eq!(
loaded_json.rate_limit_max_connections,
config.rate_limit_max_connections
);
// Test TOML serialization and deserialization
config.save_to_file(&toml_path).unwrap();
let loaded_toml = Config::load_from_file(&toml_path).unwrap();
assert_eq!(loaded_toml.listen_addr, config.listen_addr);
assert_eq!(loaded_toml.min_delay, config.min_delay);
assert_eq!(loaded_toml.rate_limit_enabled, config.rate_limit_enabled);
assert_eq!(
loaded_toml.rate_limit_max_connections,
config.rate_limit_max_connections
);
// Clean up
let _ = std::fs::remove_file(json_path);
let _ = std::fs::remove_file(toml_path);
}
}

128
src/firewall.rs Normal file
View file

@ -0,0 +1,128 @@
use tokio::process::Command;
// Set up nftables firewall rules for IP blocking
pub async fn setup_firewall() -> Result<(), String> {
log::info!("Setting up firewall rules");
// Check if nft command exists
let nft_exists = Command::new("which")
.arg("nft")
.output()
.await
.map(|output| output.status.success())
.unwrap_or(false);
if !nft_exists {
log::warn!("nft command not found. Firewall rules will not be set up.");
return Ok(());
}
// Create table if it doesn't exist
let output = Command::new("nft")
.args(["list", "table", "inet", "filter"])
.output()
.await;
match output {
Ok(output) => {
if !output.status.success() {
log::info!("Creating nftables table");
let result = Command::new("nft")
.args(["create", "table", "inet", "filter"])
.output()
.await;
if let Err(e) = result {
return Err(format!("Failed to create nftables table: {e}"));
}
}
}
Err(e) => {
log::warn!("Failed to check if nftables table exists: {e}");
log::info!("Will try to create it anyway");
let result = Command::new("nft")
.args(["create", "table", "inet", "filter"])
.output()
.await;
if let Err(e) = result {
return Err(format!("Failed to create nftables table: {e}"));
}
}
}
// Create blacklist set if it doesn't exist
let output = Command::new("nft")
.args(["list", "set", "inet", "filter", "eris_blacklist"])
.output()
.await;
match output {
Ok(output) => {
if !output.status.success() {
log::info!("Creating eris_blacklist set");
let result = Command::new("nft")
.args([
"create",
"set",
"inet",
"filter",
"eris_blacklist",
"{ type ipv4_addr; flags interval; }",
])
.output()
.await;
if let Err(e) = result {
return Err(format!("Failed to create blacklist set: {e}"));
}
}
}
Err(e) => {
log::warn!("Failed to check if blacklist set exists: {e}");
return Err(format!("Failed to check if blacklist set exists: {e}"));
}
}
// Add rule to drop traffic from blacklisted IPs
let output = Command::new("nft")
.args(["list", "chain", "inet", "filter", "input"])
.output()
.await;
// Check if our rule already exists
match output {
Ok(output) => {
let rule_exists = String::from_utf8_lossy(&output.stdout)
.contains("ip saddr @eris_blacklist counter drop");
if !rule_exists {
log::info!("Adding drop rule for blacklisted IPs");
let result = Command::new("nft")
.args([
"add",
"rule",
"inet",
"filter",
"input",
"ip saddr @eris_blacklist",
"counter",
"drop",
])
.output()
.await;
if let Err(e) = result {
return Err(format!("Failed to add firewall rule: {e}"));
}
}
}
Err(e) => {
log::warn!("Failed to check if firewall rule exists: {e}");
return Err(format!("Failed to check if firewall rule exists: {e}"));
}
}
log::info!("Firewall setup complete");
Ok(())
}

901
src/lua/mod.rs Normal file
View file

@ -0,0 +1,901 @@
use rlua::{Function, Lua, Table, Value};
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
// Event types for the Lua scripting system
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EventType {
Connection, // when a new connection is established
Request, // when a request is received
ResponseGen, // when generating a response
ResponseChunk, // before sending each response chunk
Disconnection, // when a connection is closed
BlockIP, // when an IP is being considered for blocking
Startup, // when the application starts
Shutdown, // when the application is shutting down
Periodic, // called periodically (e.g., every minute)
}
impl EventType {
/// Convert event type to string representation for Lua
const fn as_str(&self) -> &'static str {
match self {
Self::Connection => "connection",
Self::Request => "request",
Self::ResponseGen => "response_gen",
Self::ResponseChunk => "response_chunk",
Self::Disconnection => "disconnection",
Self::BlockIP => "block_ip",
Self::Startup => "startup",
Self::Shutdown => "shutdown",
Self::Periodic => "periodic",
}
}
/// Convert from string to `EventType`
fn from_str(s: &str) -> Option<Self> {
match s {
"connection" => Some(Self::Connection),
"request" => Some(Self::Request),
"response_gen" => Some(Self::ResponseGen),
"response_chunk" => Some(Self::ResponseChunk),
"disconnection" => Some(Self::Disconnection),
"block_ip" => Some(Self::BlockIP),
"startup" => Some(Self::Startup),
"shutdown" => Some(Self::Shutdown),
"periodic" => Some(Self::Periodic),
_ => None,
}
}
}
// Loaded Lua script with its metadata
struct ScriptInfo {
name: String,
enabled: bool,
}
// Script state and manage the Lua environment
pub struct ScriptManager {
lua: Mutex<Lua>,
scripts: Vec<ScriptInfo>,
hooks: HashMap<EventType, Vec<String>>,
state: Arc<RwLock<HashMap<String, String>>>,
counters: Arc<RwLock<HashMap<String, i64>>>,
}
// Context passed to Lua event handlers
pub struct EventContext {
pub event_type: EventType,
pub ip: Option<String>,
pub path: Option<String>,
pub user_agent: Option<String>,
pub request_headers: Option<HashMap<String, String>>,
pub content: Option<String>,
pub timestamp: u64,
pub session_id: Option<String>,
}
// Make ScriptManager explicitly Send + Sync since we're using Mutex
unsafe impl Send for ScriptManager {}
unsafe impl Sync for ScriptManager {}
impl ScriptManager {
/// Create a new script manager and load scripts from the given directory
pub fn new(scripts_dir: &str) -> Self {
let mut manager = Self {
lua: Mutex::new(Lua::new()),
scripts: Vec::new(),
hooks: HashMap::new(),
state: Arc::new(RwLock::new(HashMap::new())),
counters: Arc::new(RwLock::new(HashMap::new())),
};
// Initialize Lua environment with our API
manager.init_lua_env();
// Load scripts from directory
manager.load_scripts_from_dir(scripts_dir);
// If no scripts were loaded, use default script
if manager.scripts.is_empty() {
log::info!("No Lua scripts found, loading default scripts");
manager.load_script(
"default",
include_str!("../../resources/default_script.lua"),
);
}
// Trigger startup event
manager.trigger_event(&EventContext {
event_type: EventType::Startup,
ip: None,
path: None,
user_agent: None,
request_headers: None,
content: None,
timestamp: get_timestamp(),
session_id: None,
});
manager
}
// Initialize the Lua environment
fn init_lua_env(&self) {
let state_clone = self.state.clone();
let counters_clone = self.counters.clone();
if let Ok(lua) = self.lua.lock() {
// Create eris global table for our API
let eris_table = lua.create_table().unwrap();
self.register_utility_functions(&lua, &eris_table, state_clone, counters_clone);
self.register_event_functions(&lua, &eris_table);
self.register_logging_functions(&lua, &eris_table);
// Set the eris global table
lua.globals().set("eris", eris_table).unwrap();
}
}
/// Register utility functions for scripts to use
fn register_utility_functions(
&self,
lua: &Lua,
eris_table: &Table,
state: Arc<RwLock<HashMap<String, String>>>,
counters: Arc<RwLock<HashMap<String, i64>>>,
) {
// Store a key-value pair in persistent state
let state_for_set = state.clone();
let set_state = lua
.create_function(move |_, (key, value): (String, String)| {
let mut state_map = state_for_set.write().unwrap();
state_map.insert(key, value);
Ok(())
})
.unwrap();
eris_table.set("set_state", set_state).unwrap();
// Get a value from persistent state
let state_for_get = state;
let get_state = lua
.create_function(move |_, key: String| {
let state_map = state_for_get.read().unwrap();
let value = state_map.get(&key).cloned();
Ok(value)
})
.unwrap();
eris_table.set("get_state", get_state).unwrap();
// Increment a counter
let counters_for_inc = counters.clone();
let inc_counter = lua
.create_function(move |_, (key, amount): (String, Option<i64>)| {
let mut counters_map = counters_for_inc.write().unwrap();
let counter = counters_map.entry(key).or_insert(0);
*counter += amount.unwrap_or(1);
Ok(*counter)
})
.unwrap();
eris_table.set("inc_counter", inc_counter).unwrap();
// Get a counter value
let counters_for_get = counters;
let get_counter = lua
.create_function(move |_, key: String| {
let counters_map = counters_for_get.read().unwrap();
let value = counters_map.get(&key).copied().unwrap_or(0);
Ok(value)
})
.unwrap();
eris_table.set("get_counter", get_counter).unwrap();
// Generate a random token/string
let gen_token = lua
.create_function(move |_, prefix: Option<String>| {
let now = get_timestamp();
let random = rand::random::<u32>();
let token = format!("{}{:x}{:x}", prefix.unwrap_or_default(), now, random);
Ok(token)
})
.unwrap();
eris_table.set("gen_token", gen_token).unwrap();
// Get current timestamp
let timestamp = lua
.create_function(move |_, ()| Ok(get_timestamp()))
.unwrap();
eris_table.set("timestamp", timestamp).unwrap();
}
// Register event handling functions
fn register_event_functions(&self, lua: &Lua, eris_table: &Table) {
// Create a table to store event handlers
let handlers_table = lua.create_table().unwrap();
eris_table.set("handlers", handlers_table).unwrap();
// Function for scripts to register event handlers
let on_fn = lua
.create_function(move |lua, (event_name, handler): (String, Function)| {
let globals = lua.globals();
let eris: Table = globals.get("eris").unwrap();
let handlers: Table = eris.get("handlers").unwrap();
// Get or create a table for this event type
let event_handlers: Table = if let Ok(table) = handlers.get(&*event_name) {
table
} else {
let new_table = lua.create_table().unwrap();
handlers.set(&*event_name, new_table.clone()).unwrap();
new_table
};
// Add the handler to the table
let next_index = event_handlers.len().unwrap() + 1;
event_handlers.set(next_index, handler).unwrap();
Ok(())
})
.unwrap();
eris_table.set("on", on_fn).unwrap();
}
// Register logging functions
fn register_logging_functions(&self, lua: &Lua, eris_table: &Table) {
// Debug logging
let debug = lua
.create_function(|_, message: String| {
log::debug!("[Lua] {message}");
Ok(())
})
.unwrap();
eris_table.set("debug", debug).unwrap();
// Info logging
let info = lua
.create_function(|_, message: String| {
log::info!("[Lua] {message}");
Ok(())
})
.unwrap();
eris_table.set("info", info).unwrap();
// Warning logging
let warn = lua
.create_function(|_, message: String| {
log::warn!("[Lua] {message}");
Ok(())
})
.unwrap();
eris_table.set("warn", warn).unwrap();
// Error logging
let error = lua
.create_function(|_, message: String| {
log::error!("[Lua] {message}");
Ok(())
})
.unwrap();
eris_table.set("error", error).unwrap();
}
// Load all scripts from a directory
fn load_scripts_from_dir(&mut self, scripts_dir: &str) {
let script_dir = Path::new(scripts_dir);
if !script_dir.exists() {
log::warn!("Lua scripts directory does not exist: {scripts_dir}");
return;
}
log::debug!("Loading Lua scripts from directory: {scripts_dir}");
if let Ok(entries) = fs::read_dir(script_dir) {
// Sort entries by filename to ensure consistent loading order
let mut sorted_entries: Vec<_> = entries.filter_map(Result::ok).collect();
sorted_entries.sort_by_key(std::fs::DirEntry::path);
for entry in sorted_entries {
let path = entry.path();
if path.extension().and_then(|ext| ext.to_str()) == Some("lua") {
if let Ok(content) = fs::read_to_string(&path) {
let script_name = path
.file_stem()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
log::debug!("Loading Lua script: {} ({})", script_name, path.display());
self.load_script(&script_name, &content);
} else {
log::warn!("Failed to read Lua script: {}", path.display());
}
}
}
}
}
// Load a single script and register its event handlers
fn load_script(&mut self, name: &str, content: &str) {
// Store script info
self.scripts.push(ScriptInfo {
name: name.to_string(),
enabled: true,
});
// Execute the script to register its event handlers
if let Ok(lua) = self.lua.lock() {
if let Err(e) = lua.load(content).set_name(name).exec() {
log::warn!("Error loading Lua script '{name}': {e}");
return;
}
// Collect registered event handlers
let globals = lua.globals();
let eris: Table = match globals.get("eris") {
Ok(table) => table,
Err(_) => return,
};
let handlers: Table = match eris.get("handlers") {
Ok(table) => table,
Err(_) => return,
};
// Store the event handlers in our hooks map
let mut tmp: rlua::TablePairs<'_, String, Table<'_>> =
handlers.pairs::<String, Table>();
'l: loop {
if let Some(event_pair) = tmp.next() {
if let Ok((event_name, _)) = event_pair {
if let Some(event_type) = EventType::from_str(&event_name) {
self.hooks
.entry(event_type)
.or_default()
.push(name.to_string());
}
}
} else {
break 'l;
}
}
log::info!("Loaded Lua script '{name}' successfully");
}
}
/// Check if a script is enabled
fn is_script_enabled(&self, name: &str) -> bool {
self.scripts
.iter()
.find(|s| s.name == name)
.is_some_and(|s| s.enabled)
}
/// Trigger an event, calling all registered handlers
pub fn trigger_event(&self, ctx: &EventContext) -> Option<String> {
// Check if we have any handlers for this event
if !self.hooks.contains_key(&ctx.event_type) {
return ctx.content.clone();
}
// Build the event data table to pass to Lua handlers
let mut result = ctx.content.clone();
if let Ok(lua) = self.lua.lock() {
// Create the event context table
let event_ctx = lua.create_table().unwrap();
// Add all the context fields
event_ctx.set("event", ctx.event_type.as_str()).unwrap();
if let Some(ip) = &ctx.ip {
event_ctx.set("ip", ip.clone()).unwrap();
}
if let Some(path) = &ctx.path {
event_ctx.set("path", path.clone()).unwrap();
}
if let Some(ua) = &ctx.user_agent {
event_ctx.set("user_agent", ua.clone()).unwrap();
}
event_ctx.set("timestamp", ctx.timestamp).unwrap();
if let Some(sid) = &ctx.session_id {
event_ctx.set("session_id", sid.clone()).unwrap();
}
// Add request headers if available
if let Some(headers) = &ctx.request_headers {
let headers_table = lua.create_table().unwrap();
for (key, value) in headers {
headers_table
.set(key.to_string(), value.to_string())
.unwrap();
}
event_ctx.set("headers", headers_table).unwrap();
}
// Add content if available
if let Some(content) = &ctx.content {
event_ctx.set("content", content.clone()).unwrap();
}
// Call all registered handlers for this event
if let Some(handler_scripts) = self.hooks.get(&ctx.event_type) {
for script_name in handler_scripts {
// Skip disabled scripts
if !self.is_script_enabled(script_name) {
continue;
}
// Get the globals and handlers table
let globals = lua.globals();
let eris: Table = match globals.get("eris") {
Ok(table) => table,
Err(_) => continue,
};
let handlers: Table = match eris.get("handlers") {
Ok(table) => table,
Err(_) => continue,
};
// Get handlers for this event
let event_handlers: Table = match handlers.get(ctx.event_type.as_str()) {
Ok(table) => table,
Err(_) => continue,
};
// Call each handler
for pair in event_handlers.pairs::<i64, Function>() {
if let Ok((_, handler)) = pair {
let handler_result: rlua::Result<Option<String>> =
handler.call((event_ctx.clone(),));
if let Ok(Some(new_content)) = handler_result {
// For response events, allow handlers to modify the content
if matches!(
ctx.event_type,
EventType::ResponseGen | EventType::ResponseChunk
) {
result = Some(new_content);
}
}
}
}
}
}
}
result
}
/// Generate a deceptive response, calling all `response_gen` handlers
pub fn generate_response(
&self,
path: &str,
user_agent: &str,
ip: &str,
headers: &HashMap<String, String>,
markov_text: &str,
) -> String {
// Create event context
let ctx = EventContext {
event_type: EventType::ResponseGen,
ip: Some(ip.to_string()),
path: Some(path.to_string()),
user_agent: Some(user_agent.to_string()),
request_headers: Some(headers.clone()),
content: Some(markov_text.to_string()),
timestamp: get_timestamp(),
session_id: Some(generate_session_id(ip, user_agent)),
};
/// Trigger the event and get the modified content
self.trigger_event(&ctx).unwrap_or_else(|| {
// Fallback to maintain backward compatibility
self.expand_response(
markov_text,
"generic",
path,
&generate_session_id(ip, user_agent),
)
})
}
/// Process a chunk before sending it to client
pub fn process_chunk(&self, chunk: &str, ip: &str, session_id: &str) -> String {
let ctx = EventContext {
event_type: EventType::ResponseChunk,
ip: Some(ip.to_string()),
path: None,
user_agent: None,
request_headers: None,
content: Some(chunk.to_string()),
timestamp: get_timestamp(),
session_id: Some(session_id.to_string()),
};
self.trigger_event(&ctx)
.unwrap_or_else(|| chunk.to_string())
}
/// Called when a connection is established
pub fn on_connection(&self, ip: &str) -> bool {
let ctx = EventContext {
event_type: EventType::Connection,
ip: Some(ip.to_string()),
path: None,
user_agent: None,
request_headers: None,
content: None,
timestamp: get_timestamp(),
session_id: None,
};
// If any handler returns false, reject the connection
let mut should_accept = true;
if let Ok(lua) = self.lua.lock() {
if let Some(handler_scripts) = self.hooks.get(&EventType::Connection) {
for script_name in handler_scripts {
// Skip disabled scripts
if !self.is_script_enabled(script_name) {
continue;
}
let globals = lua.globals();
let eris: Table = match globals.get("eris") {
Ok(table) => table,
Err(_) => continue,
};
let handlers: Table = match eris.get("handlers") {
Ok(table) => table,
Err(_) => continue,
};
let event_handlers: Table = match handlers.get("connection") {
Ok(table) => table,
Err(_) => continue,
};
for pair in event_handlers.pairs::<i64, Function>() {
if let Ok((_, handler)) = pair {
let event_ctx = create_event_context(&lua, &ctx);
if let Ok(result) = handler.call::<_, Value>((event_ctx,)) {
if result == Value::Boolean(false) {
should_accept = false;
break;
}
}
}
}
if !should_accept {
break;
}
}
}
}
should_accept
}
/// Called when deciding whether to block an IP
pub fn should_block_ip(&self, ip: &str, hit_count: u32) -> bool {
let ctx = EventContext {
event_type: EventType::BlockIP,
ip: Some(ip.to_string()),
path: None,
user_agent: None,
request_headers: None,
content: None,
timestamp: get_timestamp(),
session_id: None,
};
// We should default to not modifying the blocking decision
let mut should_block = None;
if let Ok(lua) = self.lua.lock() {
if let Some(handler_scripts) = self.hooks.get(&EventType::BlockIP) {
for script_name in handler_scripts {
// Skip disabled scripts
if !self.is_script_enabled(script_name) {
continue;
}
let globals = lua.globals();
let eris: Table = match globals.get("eris") {
Ok(table) => table,
Err(_) => continue,
};
let handlers: Table = match eris.get("handlers") {
Ok(table) => table,
Err(_) => continue,
};
let event_handlers: Table = match handlers.get("block_ip") {
Ok(table) => table,
Err(_) => continue,
};
for pair in event_handlers.pairs::<i64, Function>() {
if let Ok((_, handler)) = pair {
let event_ctx = create_event_context(&lua, &ctx);
// Add hit count for the block_ip event
event_ctx.set("hit_count", hit_count).unwrap();
if let Ok(result) = handler.call::<_, Value>((event_ctx,)) {
if let Value::Boolean(block) = result {
should_block = Some(block);
break;
}
}
}
}
if should_block.is_some() {
break;
}
}
}
}
// Return the script's decision, or default to the system behavior
should_block.unwrap_or(hit_count >= 3)
}
// Maintains backward compatibility with the old API
// XXX: I never liked expand_response, should probably be removeedf
// in the future.
pub fn expand_response(
&self,
text: &str,
response_type: &str,
path: &str,
token: &str,
) -> String {
if let Ok(lua) = self.lua.lock() {
let globals = lua.globals();
match globals.get::<_, Function>("enhance_response") {
Ok(enhance_func) => {
match enhance_func.call::<_, String>((text, response_type, path, token)) {
Ok(result) => result,
Err(e) => {
log::warn!("Error calling Lua function enhance_response: {e}");
format!("{text}\n<!-- Error calling Lua enhance_response -->")
}
}
}
Err(_) => format!("{text}\n<!-- Token: {token} -->"),
}
} else {
format!("{text}\n<!-- Token: {token} -->")
}
}
}
/// Get current timestamp in seconds
fn get_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
/// Create a unique session ID for tracking a connection
fn generate_session_id(ip: &str, user_agent: &str) -> String {
let timestamp = get_timestamp();
let random = rand::random::<u32>();
// Use std::hash instead of xxhash_rust
let mut hasher = DefaultHasher::new();
format!("{ip}_{user_agent}_{timestamp}").hash(&mut hasher);
let hash = hasher.finish();
format!("SID_{hash:x}_{random:x}")
}
// Create an event context table in Lua
fn create_event_context<'a>(lua: &'a Lua, event_ctx: &EventContext) -> Table<'a> {
let table = lua.create_table().unwrap();
table.set("event", event_ctx.event_type.as_str()).unwrap();
if let Some(ip) = &event_ctx.ip {
table.set("ip", ip.clone()).unwrap();
}
if let Some(path) = &event_ctx.path {
table.set("path", path.clone()).unwrap();
}
if let Some(ua) = &event_ctx.user_agent {
table.set("user_agent", ua.clone()).unwrap();
}
table.set("timestamp", event_ctx.timestamp).unwrap();
if let Some(sid) = &event_ctx.session_id {
table.set("session_id", sid.clone()).unwrap();
}
if let Some(content) = &event_ctx.content {
table.set("content", content.clone()).unwrap();
}
table
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_event_registration() {
let temp_dir = TempDir::new().unwrap();
let script_path = temp_dir.path().join("test_events.lua");
let script_content = r#"
-- Example script with event handlers
eris.info("Registering event handlers")
-- Connection event handler
eris.on("connection", function(ctx)
eris.debug("Connection from " .. ctx.ip)
return true -- accept the connection
end)
-- Response generation handler
eris.on("response_gen", function(ctx)
eris.debug("Generating response for " .. ctx.path)
return ctx.content .. "<!-- Enhanced by Lua -->"
end)
"#;
fs::write(&script_path, script_content).unwrap();
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
// Verify hooks were registered
assert!(script_manager.hooks.contains_key(&EventType::Connection));
assert!(script_manager.hooks.contains_key(&EventType::ResponseGen));
assert!(!script_manager.hooks.contains_key(&EventType::BlockIP));
}
#[test]
fn test_generate_response() {
let temp_dir = TempDir::new().unwrap();
let script_path = temp_dir.path().join("response_test.lua");
let script_content = r#"
eris.on("response_gen", function(ctx)
return ctx.content .. " - Modified by " .. ctx.user_agent
end)
"#;
fs::write(&script_path, script_content).unwrap();
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
let headers = HashMap::new();
let result = script_manager.generate_response(
"/test/path",
"TestBot",
"127.0.0.1",
&headers,
"Original content",
);
assert!(result.contains("Original content"));
assert!(result.contains("Modified by TestBot"));
}
#[test]
fn test_process_chunk() {
let temp_dir = TempDir::new().unwrap();
let script_path = temp_dir.path().join("chunk_test.lua");
let script_content = r#"
eris.on("response_chunk", function(ctx)
return ctx.content:gsub("secret", "REDACTED")
end)
"#;
fs::write(&script_path, script_content).unwrap();
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
let result = script_manager.process_chunk(
"This contains a secret password",
"127.0.0.1",
"test_session",
);
assert!(result.contains("This contains a REDACTED password"));
}
#[test]
fn test_should_block_ip() {
let temp_dir = TempDir::new().unwrap();
let script_path = temp_dir.path().join("block_test.lua");
let script_content = r#"
eris.on("block_ip", function(ctx)
-- Block any IP with "192.168.1" prefix regardless of hit count
if string.match(ctx.ip, "^192%.168%.1%.") then
return true
end
-- Don't block IPs with "10.0" prefix even if they hit the threshold
if string.match(ctx.ip, "^10%.0%.") then
return false
end
-- Default behavior for other IPs (nil = use system default)
return nil
end)
"#;
fs::write(&script_path, script_content).unwrap();
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
// Should be blocked based on IP pattern
assert!(script_manager.should_block_ip("192.168.1.50", 1));
// Should not be blocked despite high hit count
assert!(!script_manager.should_block_ip("10.0.0.5", 10));
// Should use default behavior (block if >= 3 hits)
assert!(!script_manager.should_block_ip("172.16.0.1", 2));
assert!(script_manager.should_block_ip("172.16.0.1", 3));
}
#[test]
fn test_state_and_counters() {
let temp_dir = TempDir::new().unwrap();
let script_path = temp_dir.path().join("state_test.lua");
let script_content = r#"
eris.on("startup", function(ctx)
eris.set_state("test_key", "test_value")
eris.inc_counter("visits", 0)
end)
eris.on("connection", function(ctx)
local count = eris.inc_counter("visits")
eris.debug("Visit count: " .. count)
-- Store last visitor
eris.set_state("last_visitor", ctx.ip)
return true
end)
eris.on("response_gen", function(ctx)
local last_visitor = eris.get_state("last_visitor") or "unknown"
local visits = eris.get_counter("visits")
return ctx.content .. "<!-- Last visitor: " .. last_visitor ..
", Total visits: " .. visits .. " -->"
end)
"#;
fs::write(&script_path, script_content).unwrap();
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
// Simulate connections
script_manager.on_connection("192.168.1.100");
script_manager.on_connection("10.0.0.50");
// Check response includes state
let headers = HashMap::new();
let result = script_manager.generate_response(
"/test",
"test-agent",
"8.8.8.8",
&headers,
"Response",
);
assert!(result.contains("Last visitor: 10.0.0.50"));
assert!(result.contains("Total visits: 2"));
}
}

File diff suppressed because it is too large Load diff

View file

@ -103,7 +103,7 @@ impl MarkovGenerator {
let path = Path::new(corpus_dir);
if path.exists() && path.is_dir() {
if let Ok(entries) = fs::read_dir(path) {
for entry in entries {
entries.for_each(|entry| {
if let Ok(entry) = entry {
let file_path = entry.path();
if let Some(file_name) = file_path.file_stem() {
@ -120,7 +120,7 @@ impl MarkovGenerator {
}
}
}
}
});
}
}

View file

@ -24,6 +24,11 @@ lazy_static! {
register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap();
pub static ref UA_HITS: CounterVec =
register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap();
pub static ref RATE_LIMITED_CONNECTIONS: Counter = register_counter!(
"eris_rate_limited_total",
"Number of connections rejected due to rate limiting"
)
.unwrap();
}
// Prometheus metrics endpoint
@ -76,6 +81,7 @@ mod tests {
UA_HITS.with_label_values(&["TestBot/1.0"]).inc();
BLOCKED_IPS.set(5.0);
ACTIVE_CONNECTIONS.set(3.0);
RATE_LIMITED_CONNECTIONS.inc();
// Create test app
let app =
@ -96,6 +102,7 @@ mod tests {
assert!(body_str.contains("eris_ua_hits_total"));
assert!(body_str.contains("eris_blocked_ips"));
assert!(body_str.contains("eris_active_connections"));
assert!(body_str.contains("eris_rate_limited_total"));
}
#[actix_web::test]

504
src/network/mod.rs Normal file
View file

@ -0,0 +1,504 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::process::Command;
use tokio::sync::RwLock;
use tokio::time::sleep;
use crate::config::Config;
use crate::lua::{EventContext, EventType, ScriptManager};
use crate::markov::MarkovGenerator;
use crate::metrics::{
ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, RATE_LIMITED_CONNECTIONS, UA_HITS,
};
use crate::state::BotState;
use crate::utils::{
choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request,
find_header_end, generate_session_id, get_timestamp,
};
mod rate_limiter;
use rate_limiter::RateLimiter;
// Global rate limiter instance.
// Default is 30 connections per IP in a 60 second window
// XXX: This might add overhead of the proxy, e.g. NGINX already implements
// rate limiting. Though I don't think we have a way of knowing if the middleman
// we are handing the connections to (from the same middleman in some cases) has
// rate limiting.
lazy_static::lazy_static! {
static ref RATE_LIMITER: RateLimiter = RateLimiter::new(60, 30);
}
// Main connection handler.
// Decides whether to tarpit or proxy
pub async fn handle_connection(
mut stream: TcpStream,
config: Arc<Config>,
state: Arc<RwLock<BotState>>,
markov_generator: Arc<MarkovGenerator>,
script_manager: Arc<ScriptManager>,
) {
// Get peer information
let peer_addr: IpAddr = match stream.peer_addr() {
Ok(addr) => addr.ip(),
Err(e) => {
log::debug!("Failed to get peer address: {e}");
return;
}
};
// Check for blocked IPs to avoid any processing
if state.read().await.blocked.contains(&peer_addr) {
log::debug!("Rejected connection from blocked IP: {peer_addr}");
let _ = stream.shutdown().await;
return;
}
// Apply rate limiting before any further processing
if config.rate_limit_enabled && !RATE_LIMITER.check_rate_limit(peer_addr).await {
log::info!("Rate limited connection from {peer_addr}");
RATE_LIMITED_CONNECTIONS.inc();
// Optionally, add the IP to a temporary block list
// if it's constantly hitting the rate limit
let connection_count = RATE_LIMITER.get_connection_count(&peer_addr);
if connection_count > config.rate_limit_block_threshold {
log::warn!(
"IP {peer_addr} exceeding rate limit with {connection_count} connection attempts, considering for blocking"
);
// Trigger a blocked event for Lua scripts
let rate_limit_ctx = EventContext {
event_type: EventType::BlockIP,
ip: Some(peer_addr.to_string()),
path: None,
user_agent: None,
request_headers: None,
content: None,
timestamp: get_timestamp(),
session_id: None,
};
script_manager.trigger_event(&rate_limit_ctx);
}
// Either send a slow response or just close connection
if config.rate_limit_slow_response {
// Send a simple 429 Too Many Requests respons. If the bots actually respected
// HTTP error codes, the internet would be a mildly better place.
let response = "HTTP/1.1 429 Too Many Requests\r\n\
Content-Type: text/plain\r\n\
Retry-After: 60\r\n\
Connection: close\r\n\
\r\n\
Rate limit exceeded. Please try again later.";
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
}
let _ = stream.shutdown().await;
return;
}
// Check if Lua scripts allow this connection
if !script_manager.on_connection(&peer_addr.to_string()) {
log::debug!("Connection rejected by Lua script: {peer_addr}");
let _ = stream.shutdown().await;
return;
}
// Pre-check for whitelisted IPs to bypass heavy processing
let mut whitelisted = false;
for network_str in &config.whitelist_networks {
if let Ok(network) = network_str.parse::<ipnetwork::IpNetwork>() {
if network.contains(peer_addr) {
whitelisted = true;
break;
}
}
}
// Read buffer
let mut buffer = vec![0; 8192];
let mut request_data = Vec::with_capacity(8192);
let mut header_end_pos = 0;
// Read with timeout to prevent hanging resource load ops.
let read_fut = async {
loop {
match stream.read(&mut buffer).await {
Ok(0) => break,
Ok(n) => {
let new_data = &buffer[..n];
request_data.extend_from_slice(new_data);
// Look for end of headers
if header_end_pos == 0 {
if let Some(pos) = find_header_end(&request_data) {
header_end_pos = pos;
break;
}
}
// Avoid excessive buffering
if request_data.len() > 32768 {
break;
}
}
Err(e) => {
log::debug!("Error reading from stream: {e}");
break;
}
}
}
};
let timeout_fut = sleep(Duration::from_secs(3));
tokio::select! {
() = read_fut => {},
() = timeout_fut => {
log::debug!("Connection timeout from: {peer_addr}");
let _ = stream.shutdown().await;
return;
}
}
// Fast path for whitelisted IPs. Skip full parsing and speed up "approved"
// connections automatically.
if whitelisted {
log::debug!("Whitelisted IP {peer_addr} - using fast proxy path");
proxy_fast_path(stream, request_data, &config.backend_addr).await;
return;
}
// Parse minimally to extract the path
let path = if let Some(p) = extract_path_from_request(&request_data) {
p
} else {
log::debug!("Invalid request from {peer_addr}");
let _ = stream.shutdown().await;
return;
};
// Extract request headers for Lua scripts
let headers = extract_all_headers(&request_data);
// Extract user agent for logging and decision making
let user_agent =
extract_header_value(&request_data, "user-agent").unwrap_or_else(|| "unknown".to_string());
// Trigger request event for Lua scripts
let request_ctx = EventContext {
event_type: EventType::Request,
ip: Some(peer_addr.to_string()),
path: Some(path.to_string()),
user_agent: Some(user_agent.clone()),
request_headers: Some(headers.clone()),
content: None,
timestamp: get_timestamp(),
session_id: Some(generate_session_id(&peer_addr.to_string(), &user_agent)),
};
script_manager.trigger_event(&request_ctx);
// Check if this request matches our tarpit patterns
let should_tarpit = crate::config::should_tarpit(path, &peer_addr, &config);
if should_tarpit {
log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})");
// Update metrics
HITS_COUNTER.inc();
PATH_HITS.with_label_values(&[path]).inc();
UA_HITS.with_label_values(&[&user_agent]).inc();
// Update state and check for blocking threshold
{
let mut state = state.write().await;
state.active_connections.insert(peer_addr);
ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64);
*state.hits.entry(peer_addr).or_insert(0) += 1;
let hit_count = state.hits[&peer_addr];
// Use Lua to decide whether to block this IP
let should_block = script_manager.should_block_ip(&peer_addr.to_string(), hit_count);
// Block IPs that hit tarpits too many times
if should_block && !state.blocked.contains(&peer_addr) {
log::info!("Blocking IP {peer_addr} after {hit_count} hits");
state.blocked.insert(peer_addr);
BLOCKED_IPS.set(state.blocked.len() as f64);
state.save_to_disk();
// Do firewall blocking in background
let peer_addr_str = peer_addr.to_string();
tokio::spawn(async move {
log::debug!("Adding IP {peer_addr_str} to firewall blacklist");
match Command::new("nft")
.args([
"add",
"element",
"inet",
"filter",
"eris_blacklist",
"{",
&peer_addr_str,
"}",
])
.output()
.await
{
Ok(output) => {
if !output.status.success() {
log::warn!(
"Failed to add IP {} to firewall: {}",
peer_addr_str,
String::from_utf8_lossy(&output.stderr)
);
}
}
Err(e) => {
log::warn!("Failed to execute nft command: {e}");
}
}
});
}
}
// Generate a deceptive response using Markov chains and Lua
let response = generate_deceptive_response(
path,
&user_agent,
&peer_addr,
&headers,
&markov_generator,
&script_manager,
)
.await;
// Generate a session ID for tracking this tarpit session
let session_id = generate_session_id(&peer_addr.to_string(), &user_agent);
// Send the response with the tarpit delay strategy
{
let mut stream = stream;
let peer_addr = peer_addr;
let state = state.clone();
let min_delay = config.min_delay;
let max_delay = config.max_delay;
let max_tarpit_time = config.max_tarpit_time;
let script_manager = script_manager.clone();
async move {
let start_time = Instant::now();
let mut chars = response.chars().collect::<Vec<_>>();
for i in (0..chars.len()).rev() {
if i > 0 && rand::random::<f32>() < 0.1 {
chars.swap(i, i - 1);
}
}
log::debug!(
"Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms",
peer_addr,
chars.len(),
min_delay,
max_delay
);
let mut position = 0;
let mut chunks_sent = 0;
let mut total_delay = 0;
while position < chars.len() {
// Check if we've exceeded maximum tarpit time
let elapsed_secs = start_time.elapsed().as_secs();
if elapsed_secs > max_tarpit_time {
log::info!(
"Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}"
);
break;
}
// Decide how many chars to send in this chunk (usually 1, sometimes more)
let chunk_size = if rand::random::<f32>() < 0.9 {
1
} else {
(rand::random::<f32>() * 3.0).floor() as usize + 1
};
let end = (position + chunk_size).min(chars.len());
let chunk: String = chars[position..end].iter().collect();
// Process chunk through Lua before sending
let processed_chunk =
script_manager.process_chunk(&chunk, &peer_addr.to_string(), &session_id);
// Try to write processed chunk
if stream.write_all(processed_chunk.as_bytes()).await.is_err() {
log::debug!("Connection closed by client during tarpit: {peer_addr}");
break;
}
if stream.flush().await.is_err() {
log::debug!("Failed to flush stream during tarpit: {peer_addr}");
break;
}
position = end;
chunks_sent += 1;
// Apply random delay between min and max configured values
let delay_ms =
(rand::random::<f32>() * (max_delay - min_delay) as f32) as u64 + min_delay;
total_delay += delay_ms;
sleep(Duration::from_millis(delay_ms)).await;
}
log::debug!(
"Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s",
peer_addr,
chunks_sent,
position * 100 / chars.len(),
total_delay,
start_time.elapsed().as_secs()
);
let disconnection_ctx = EventContext {
event_type: EventType::Disconnection,
ip: Some(peer_addr.to_string()),
path: None,
user_agent: None,
request_headers: None,
content: None,
timestamp: get_timestamp(),
session_id: Some(session_id),
};
script_manager.trigger_event(&disconnection_ctx);
if let Ok(mut state) = state.try_write() {
state.active_connections.remove(&peer_addr);
ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64);
}
let _ = stream.shutdown().await;
}
}
.await;
} else {
log::debug!("Proxying request: {path} from {peer_addr}");
proxy_fast_path(stream, request_data, &config.backend_addr).await;
}
}
// Forward a legitimate request to the real backend server
pub async fn proxy_fast_path(
mut client_stream: TcpStream,
request_data: Vec<u8>,
backend_addr: &str,
) {
// Connect to backend server
let server_stream = match TcpStream::connect(backend_addr).await {
Ok(stream) => stream,
Err(e) => {
log::warn!("Failed to connect to backend {backend_addr}: {e}");
let _ = client_stream.shutdown().await;
return;
}
};
// Set TCP_NODELAY for both streams before splitting them
if let Err(e) = client_stream.set_nodelay(true) {
log::debug!("Failed to set TCP_NODELAY on client stream: {e}");
}
let mut server_stream = server_stream;
if let Err(e) = server_stream.set_nodelay(true) {
log::debug!("Failed to set TCP_NODELAY on server stream: {e}");
}
// Forward the original request bytes directly without parsing
if server_stream.write_all(&request_data).await.is_err() {
log::debug!("Failed to write request to backend server");
let _ = client_stream.shutdown().await;
return;
}
// Now split the streams for concurrent reading/writing
let (mut client_read, mut client_write) = client_stream.split();
let (mut server_read, mut server_write) = server_stream.split();
// 32KB buffer
let buf_size = 32768;
// Client -> Server
let client_to_server = async {
let mut buf = vec![0; buf_size];
let mut bytes_forwarded = 0;
loop {
match client_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if server_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
// Ensure everything is sent
let _ = server_write.flush().await;
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
};
// Server -> Client
let server_to_client = async {
let mut buf = vec![0; buf_size];
let mut bytes_forwarded = 0;
loop {
match server_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
bytes_forwarded += n;
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
// Ensure everything is sent
let _ = client_write.flush().await;
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
};
// Run both directions concurrently
tokio::join!(client_to_server, server_to_client);
log::debug!("Fast proxy connection completed");
}
// Generate a deceptive HTTP response that appears legitimate
pub async fn generate_deceptive_response(
path: &str,
user_agent: &str,
peer_addr: &IpAddr,
headers: &HashMap<String, String>,
markov: &MarkovGenerator,
script_manager: &ScriptManager,
) -> String {
// Generate base response using Markov chain text generator
let response_type = choose_response_type(path);
let markov_text = markov.generate(response_type, 30);
// Use Lua scripts to enhance with honeytokens and other deceptive content
script_manager.generate_response(
path,
user_agent,
&peer_addr.to_string(),
headers,
&markov_text,
)
}

View file

@ -0,0 +1,85 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
pub struct RateLimiter {
connections: Arc<Mutex<HashMap<IpAddr, Vec<Instant>>>>,
window_seconds: u64,
max_connections: usize,
cleanup_interval: Duration,
last_cleanup: Instant,
}
impl RateLimiter {
pub fn new(window_seconds: u64, max_connections: usize) -> Self {
Self {
connections: Arc::new(Mutex::new(HashMap::new())),
window_seconds,
max_connections,
cleanup_interval: Duration::from_secs(60),
last_cleanup: Instant::now(),
}
}
pub async fn check_rate_limit(&self, ip: IpAddr) -> bool {
let now = Instant::now();
let window = Duration::from_secs(self.window_seconds);
let mut connections = self.connections.lock().await;
// Periodically clean up old entries across all IPs
if now.duration_since(self.last_cleanup) > self.cleanup_interval {
self.cleanup_old_entries(&mut connections, now, window);
}
// Clean up old entries for this specific IP
if let Some(times) = connections.get_mut(&ip) {
times.retain(|time| now.duration_since(*time) < window);
// Check if rate limit exceeded
if times.len() >= self.max_connections {
log::debug!("Rate limit exceeded for IP: {}", ip);
return false;
}
// Add new connection time
times.push(now);
} else {
connections.insert(ip, vec![now]);
}
true
}
fn cleanup_old_entries(
&self,
connections: &mut HashMap<IpAddr, Vec<Instant>>,
now: Instant,
window: Duration,
) {
let mut empty_keys = Vec::new();
for (ip, times) in connections.iter_mut() {
times.retain(|time| now.duration_since(*time) < window);
if times.is_empty() {
empty_keys.push(*ip);
}
}
// Remove empty entries
for ip in empty_keys {
connections.remove(&ip);
}
}
pub fn get_connection_count(&self, ip: &IpAddr) -> usize {
if let Ok(connections) = self.connections.try_lock() {
if let Some(times) = connections.get(ip) {
return times.len();
}
}
0
}
}

166
src/state.rs Normal file
View file

@ -0,0 +1,166 @@
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io::Write;
use std::net::IpAddr;
use crate::metrics::BLOCKED_IPS;
// State of bots/IPs hitting the honeypot
#[derive(Clone, Debug)]
pub struct BotState {
pub hits: HashMap<IpAddr, u32>,
pub blocked: HashSet<IpAddr>,
pub active_connections: HashSet<IpAddr>,
pub data_dir: String,
pub cache_dir: String,
}
impl BotState {
pub fn new(data_dir: &str, cache_dir: &str) -> Self {
Self {
hits: HashMap::new(),
blocked: HashSet::new(),
active_connections: HashSet::new(),
data_dir: data_dir.to_string(),
cache_dir: cache_dir.to_string(),
}
}
// Load previous state from disk
pub fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self {
let mut state = Self::new(data_dir, cache_dir);
let blocked_ips_file = format!("{data_dir}/blocked_ips.txt");
if let Ok(content) = fs::read_to_string(&blocked_ips_file) {
let mut loaded = 0;
for line in content.lines() {
if let Ok(ip) = line.parse::<IpAddr>() {
state.blocked.insert(ip);
loaded += 1;
}
}
log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}");
} else {
log::info!("No blocked IPs file found at {blocked_ips_file}");
}
// Check for temporary hit counter cache
let hit_cache_file = format!("{cache_dir}/hit_counters.json");
if let Ok(content) = fs::read_to_string(&hit_cache_file) {
if let Ok(hit_map) = serde_json::from_str::<HashMap<String, u32>>(&content) {
for (ip_str, count) in hit_map {
if let Ok(ip) = ip_str.parse::<IpAddr>() {
state.hits.insert(ip, count);
}
}
log::info!("Loaded hit counters for {} IPs", state.hits.len());
}
}
BLOCKED_IPS.set(state.blocked.len() as f64);
state
}
// Persist state to disk for later reloading
pub fn save_to_disk(&self) {
// Save blocked IPs
if let Err(e) = fs::create_dir_all(&self.data_dir) {
log::error!("Failed to create data directory: {e}");
return;
}
let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir);
match fs::File::create(&blocked_ips_file) {
Ok(mut file) => {
let mut count = 0;
for ip in &self.blocked {
if writeln!(file, "{ip}").is_ok() {
count += 1;
}
}
log::info!("Saved {count} blocked IPs to {blocked_ips_file}");
}
Err(e) => {
log::error!("Failed to create blocked IPs file: {e}");
}
}
// Save hit counters to cache
if let Err(e) = fs::create_dir_all(&self.cache_dir) {
log::error!("Failed to create cache directory: {e}");
return;
}
let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir);
let hit_map: HashMap<String, u32> = self
.hits
.iter()
.map(|(ip, count)| (ip.to_string(), *count))
.collect();
match fs::File::create(&hit_cache_file) {
Ok(file) => {
if let Err(e) = serde_json::to_writer(file, &hit_map) {
log::error!("Failed to write hit counters to cache: {e}");
} else {
log::debug!("Saved hit counters for {} IPs to cache", hit_map.len());
}
}
Err(e) => {
log::error!("Failed to create hit counter cache file: {e}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use tokio::sync::RwLock;
#[tokio::test]
async fn test_bot_state() {
let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache");
let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8));
let state = Arc::new(RwLock::new(state));
// Test hit counter
{
let mut state = state.write().await;
*state.hits.entry(ip1).or_insert(0) += 1;
*state.hits.entry(ip1).or_insert(0) += 1;
*state.hits.entry(ip2).or_insert(0) += 1;
assert_eq!(*state.hits.get(&ip1).unwrap(), 2);
assert_eq!(*state.hits.get(&ip2).unwrap(), 1);
}
// Test blocking
{
let mut state = state.write().await;
state.blocked.insert(ip1);
assert!(state.blocked.contains(&ip1));
assert!(!state.blocked.contains(&ip2));
drop(state);
}
// Test active connections
{
let mut state = state.write().await;
state.active_connections.insert(ip1);
state.active_connections.insert(ip2);
assert_eq!(state.active_connections.len(), 2);
state.active_connections.remove(&ip1);
assert_eq!(state.active_connections.len(), 1);
assert!(!state.active_connections.contains(&ip1));
assert!(state.active_connections.contains(&ip2));
drop(state);
}
}
}

168
src/utils.rs Normal file
View file

@ -0,0 +1,168 @@
use std::collections::HashMap;
use std::hash::Hasher;
// Find end of HTTP headers
pub fn find_header_end(data: &[u8]) -> Option<usize> {
data.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|pos| pos + 4)
}
// Extract path from raw request data
pub fn extract_path_from_request(data: &[u8]) -> Option<&str> {
// Get first line from request
let first_line = data
.split(|&b| b == b'\r' || b == b'\n')
.next()
.filter(|line| !line.is_empty())?;
// Split by spaces and ensure we have at least 3 parts (METHOD PATH VERSION)
let parts: Vec<&[u8]> = first_line.split(|&b| b == b' ').collect();
if parts.len() < 3 || !parts[2].starts_with(b"HTTP/") {
return None;
}
// Return the path (second element)
std::str::from_utf8(parts[1]).ok()
}
// Extract header value from raw request data
pub fn extract_header_value(data: &[u8], header_name: &str) -> Option<String> {
let data_str = std::str::from_utf8(data).ok()?;
let header_prefix = format!("{header_name}: ").to_lowercase();
for line in data_str.lines() {
let line_lower = line.to_lowercase();
if line_lower.starts_with(&header_prefix) {
return Some(line[header_prefix.len()..].trim().to_string());
}
}
None
}
// Extract all headers from request data
pub fn extract_all_headers(data: &[u8]) -> HashMap<String, String> {
let mut headers = HashMap::new();
if let Ok(data_str) = std::str::from_utf8(data) {
let mut lines = data_str.lines();
// Skip the request line
let _ = lines.next();
// Parse headers until empty line
for line in lines {
if line.is_empty() {
break;
}
if let Some(colon_pos) = line.find(':') {
let key = line[..colon_pos].trim().to_lowercase();
let value = line[colon_pos + 1..].trim().to_string();
headers.insert(key, value);
}
}
}
headers
}
// Determine response type based on request path
pub fn choose_response_type(path: &str) -> &'static str {
if path.contains("phpunit") || path.contains("eval") {
"php_exploit"
} else if path.contains("wp-") {
"wordpress"
} else if path.contains("api") {
"api"
} else {
"generic"
}
}
// Get current timestamp in seconds
pub fn get_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
// Create a unique session ID for tracking a connection
pub fn generate_session_id(ip: &str, user_agent: &str) -> String {
let timestamp = get_timestamp();
let random = rand::random::<u32>();
// XXX: Is this fast enough for our case? I don't think hashing is a huge
// bottleneck, but it's worth revisiting in the future to see if there is
// an objectively faster algorithm that we can try.
let mut hasher = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(&format!("{ip}_{user_agent}_{timestamp}"), &mut hasher);
let hash = hasher.finish();
format!("SID_{hash:x}_{random:x}")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_header_end() {
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content";
assert_eq!(find_header_end(data), Some(55));
let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n";
assert_eq!(find_header_end(incomplete), None);
}
#[test]
fn test_extract_path_from_request() {
let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n";
assert_eq!(extract_path_from_request(data), Some("/index.html"));
let bad_data = b"INVALID DATA";
assert_eq!(extract_path_from_request(bad_data), None);
}
#[test]
fn test_extract_header_value() {
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n";
assert_eq!(
extract_header_value(data, "user-agent"),
Some("TestBot/1.0".to_string())
);
assert_eq!(
extract_header_value(data, "Host"),
Some("example.com".to_string())
);
assert_eq!(extract_header_value(data, "nonexistent"), None);
}
#[test]
fn test_extract_all_headers() {
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\nAccept: */*\r\n\r\n";
let headers = extract_all_headers(data);
assert_eq!(headers.len(), 3);
assert_eq!(headers.get("host").unwrap(), "example.com");
assert_eq!(headers.get("user-agent").unwrap(), "TestBot/1.0");
assert_eq!(headers.get("accept").unwrap(), "*/*");
}
#[test]
fn test_choose_response_type() {
assert_eq!(
choose_response_type("/vendor/phpunit/whatever"),
"php_exploit"
);
assert_eq!(
choose_response_type("/path/to/eval-stdin.php"),
"php_exploit"
);
assert_eq!(choose_response_type("/wp-admin/login.php"), "wordpress");
assert_eq!(choose_response_type("/wp-login.php"), "wordpress");
assert_eq!(choose_response_type("/api/v1/users"), "api");
assert_eq!(choose_response_type("/index.html"), "generic");
}
}