WIP: network: implement basic ratelimiting #1
15 changed files with 3037 additions and 1247 deletions
31
.github/workflows/tag.yml
vendored
Normal file
31
.github/workflows/tag.yml
vendored
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
name: Tag latest version
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
concurrency: tag
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
tag-release:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: cachix/install-nix-action@master
|
||||||
|
with:
|
||||||
|
github_access_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Read version
|
||||||
|
run: |
|
||||||
|
echo -n "_version=v" >> "$GITHUB_ENV"
|
||||||
|
nix run nixpkgs#fq -- -r ".package.version" Cargo.toml >> "$GITHUB_ENV"
|
||||||
|
cat "$GITHUB_ENV"
|
||||||
|
|
||||||
|
- name: Tag
|
||||||
|
run: |
|
||||||
|
set -x
|
||||||
|
git tag $version
|
||||||
|
git push --tags || :
|
90
Cargo.lock
generated
90
Cargo.lock
generated
|
@ -414,9 +414,7 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"android-tzdata",
|
"android-tzdata",
|
||||||
"iana-time-zone",
|
"iana-time-zone",
|
||||||
"js-sys",
|
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"wasm-bindgen",
|
|
||||||
"windows-link",
|
"windows-link",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -442,7 +440,6 @@ version = "4.5.37"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2"
|
checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anstream",
|
|
||||||
"anstyle",
|
"anstyle",
|
||||||
"clap_lex",
|
"clap_lex",
|
||||||
"strsim",
|
"strsim",
|
||||||
|
@ -620,7 +617,7 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "eris"
|
name = "eris"
|
||||||
version = "0.1.0"
|
version = "1.0.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"actix-web",
|
"actix-web",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
@ -633,10 +630,13 @@ dependencies = [
|
||||||
"prometheus 0.14.0",
|
"prometheus 0.14.0",
|
||||||
"prometheus_exporter",
|
"prometheus_exporter",
|
||||||
"rand",
|
"rand",
|
||||||
|
"regex",
|
||||||
"rlua",
|
"rlua",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"tempfile",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"toml",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -646,9 +646,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
|
checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastrand"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flate2"
|
name = "flate2"
|
||||||
version = "1.1.1"
|
version = "1.1.1"
|
||||||
|
@ -1580,7 +1586,7 @@ dependencies = [
|
||||||
"errno",
|
"errno",
|
||||||
"libc",
|
"libc",
|
||||||
"linux-raw-sys",
|
"linux-raw-sys",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1633,6 +1639,15 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_spanned"
|
||||||
|
version = "0.6.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_urlencoded"
|
name = "serde_urlencoded"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
@ -1740,6 +1755,19 @@ dependencies = [
|
||||||
"syn 2.0.101",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tempfile"
|
||||||
|
version = "3.19.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf"
|
||||||
|
dependencies = [
|
||||||
|
"fastrand",
|
||||||
|
"getrandom",
|
||||||
|
"once_cell",
|
||||||
|
"rustix",
|
||||||
|
"windows-sys 0.59.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.69"
|
version = "1.0.69"
|
||||||
|
@ -1876,6 +1904,47 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml"
|
||||||
|
version = "0.8.22"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
"serde_spanned",
|
||||||
|
"toml_datetime",
|
||||||
|
"toml_edit",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_datetime"
|
||||||
|
version = "0.6.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_edit"
|
||||||
|
version = "0.22.26"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"serde",
|
||||||
|
"serde_spanned",
|
||||||
|
"toml_datetime",
|
||||||
|
"toml_write",
|
||||||
|
"winnow",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_write"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing"
|
name = "tracing"
|
||||||
version = "0.1.41"
|
version = "0.1.41"
|
||||||
|
@ -2187,6 +2256,15 @@ version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winnow"
|
||||||
|
version = "0.7.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9e27d6ad3dac991091e4d35de9ba2d2d00647c5d0fc26c5496dee55984ae111b"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winsafe"
|
name = "winsafe"
|
||||||
version = "0.0.19"
|
version = "0.0.19"
|
||||||
|
|
11
Cargo.toml
11
Cargo.toml
|
@ -1,12 +1,12 @@
|
||||||
[package]
|
[package]
|
||||||
name = "eris"
|
name = "eris"
|
||||||
version = "0.1.0"
|
version = "1.0.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
actix-web = "4.3.1"
|
actix-web = { version = "4.3.1" }
|
||||||
clap = { version = "4.3", features = ["derive"] }
|
chrono = { version = "0.4.41", default-features = false, features = ["std", "clock"] }
|
||||||
chrono = "0.4.24"
|
clap = { version = "4.5", default-features = false, features = ["std", "derive", "help", "usage", "suggestions"] }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
ipnetwork = "0.21.1"
|
ipnetwork = "0.21.1"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
|
@ -19,3 +19,6 @@ serde_json = "1.0.96"
|
||||||
tokio = { version = "1.28.0", features = ["full"] }
|
tokio = { version = "1.28.0", features = ["full"] }
|
||||||
log = "0.4.27"
|
log = "0.4.27"
|
||||||
env_logger = "0.11.8"
|
env_logger = "0.11.8"
|
||||||
|
tempfile = "3.19.1"
|
||||||
|
regex = "1.11.1"
|
||||||
|
toml = "0.8.22"
|
||||||
|
|
|
@ -19,6 +19,7 @@ in
|
||||||
fileset = fs.unions [
|
fileset = fs.unions [
|
||||||
(fs.fileFilter (file: builtins.any file.hasExt ["rs"]) (s + /src))
|
(fs.fileFilter (file: builtins.any file.hasExt ["rs"]) (s + /src))
|
||||||
(s + /contrib)
|
(s + /contrib)
|
||||||
|
(s + /resources)
|
||||||
lockfile
|
lockfile
|
||||||
cargoToml
|
cargoToml
|
||||||
];
|
];
|
||||||
|
|
210
resources/default_script.lua
Normal file
210
resources/default_script.lua
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
--[[
|
||||||
|
Eris Default Script
|
||||||
|
|
||||||
|
This script demonstrates how to use the Eris Lua API to customize
|
||||||
|
the tarpit's behavior, and will be loaded by default if no other
|
||||||
|
scripts are loaded.
|
||||||
|
|
||||||
|
Available events:
|
||||||
|
- connection: When a new connection is established
|
||||||
|
- request: When a request is received
|
||||||
|
- response_gen: When generating a response
|
||||||
|
- response_chunk: Before sending each response chunk
|
||||||
|
- disconnection: When a connection is closed
|
||||||
|
- block_ip: When an IP is being considered for blocking
|
||||||
|
- startup: When the application starts
|
||||||
|
- shutdown: When the application is shutting down
|
||||||
|
- periodic: Called periodically
|
||||||
|
|
||||||
|
API Functions:
|
||||||
|
- eris.debug(message): Log a debug message
|
||||||
|
- eris.info(message): Log an info message
|
||||||
|
- eris.warn(message): Log a warning message
|
||||||
|
- eris.error(message): Log an error message
|
||||||
|
- eris.set_state(key, value): Store persistent state
|
||||||
|
- eris.get_state(key): Retrieve persistent state
|
||||||
|
- eris.inc_counter(key, [amount]): Increment a counter
|
||||||
|
- eris.get_counter(key): Get a counter value
|
||||||
|
- eris.gen_token([prefix]): Generate a unique token
|
||||||
|
- eris.timestamp(): Get current Unix timestamp
|
||||||
|
--]]
|
||||||
|
|
||||||
|
-- Called when the application starts
|
||||||
|
eris.on("startup", function(ctx)
|
||||||
|
eris.info("Initializing default script")
|
||||||
|
|
||||||
|
-- Initialize counters
|
||||||
|
eris.inc_counter("total_connections", 0)
|
||||||
|
eris.inc_counter("total_responses", 0)
|
||||||
|
eris.inc_counter("blocked_ips", 0)
|
||||||
|
|
||||||
|
-- Initialize banned keywords
|
||||||
|
eris.set_state("banned_keywords", "eval,exec,system,shell,<?php,/bin/bash")
|
||||||
|
end)
|
||||||
|
|
||||||
|
-- Called for each new connection
|
||||||
|
eris.on("connection", function(ctx)
|
||||||
|
eris.inc_counter("total_connections")
|
||||||
|
eris.debug("New connection from " .. ctx.ip)
|
||||||
|
|
||||||
|
-- You can reject connections by returning false
|
||||||
|
-- This example checks a blocklist
|
||||||
|
local blocklist = eris.get_state("manual_blocklist") or ""
|
||||||
|
if blocklist:find(ctx.ip) then
|
||||||
|
eris.info("Rejecting connection from manually blocked IP: " .. ctx.ip)
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
return true -- accept the connection
|
||||||
|
end)
|
||||||
|
|
||||||
|
-- Called when generating a response
|
||||||
|
eris.on("response_gen", function(ctx)
|
||||||
|
eris.inc_counter("total_responses")
|
||||||
|
|
||||||
|
-- Generate a unique traceable token for this request
|
||||||
|
local token = eris.gen_token("ERIS-")
|
||||||
|
|
||||||
|
-- Add some believable but fake honeytokens based on the request path
|
||||||
|
local enhanced_content = ctx.content
|
||||||
|
|
||||||
|
if ctx.path:find("wp%-") then
|
||||||
|
-- For WordPress paths
|
||||||
|
enhanced_content = enhanced_content
|
||||||
|
.. "\n<!-- WordPress Debug: "
|
||||||
|
.. token
|
||||||
|
.. " -->"
|
||||||
|
.. "\n<!-- WP_HOME: http://stop.crawlingmysite.com/wordpress -->"
|
||||||
|
.. "\n<!-- DB_USER: wp_user_"
|
||||||
|
.. math.random(1000, 9999)
|
||||||
|
.. " -->"
|
||||||
|
elseif ctx.path:find("phpunit") or ctx.path:find("eval") then
|
||||||
|
-- For PHP exploit attempts
|
||||||
|
-- Turns out you can just google "PHP error log" and search random online forums where people
|
||||||
|
-- dump their service logs in full.
|
||||||
|
enhanced_content = enhanced_content
|
||||||
|
.. "\nPHP Notice: Undefined variable: _SESSION in /var/www/html/includes/core.php on line 58\n"
|
||||||
|
.. "Warning: file_get_contents(): Filename cannot be empty in /var/www/html/vendor/autoload.php on line 23\n"
|
||||||
|
.. "Token: "
|
||||||
|
.. token
|
||||||
|
.. "\n"
|
||||||
|
elseif ctx.path:find("api") then
|
||||||
|
-- For API requests
|
||||||
|
local fake_api_key =
|
||||||
|
string.format("ak_%x%x%x", math.random(1000, 9999), math.random(1000, 9999), math.random(1000, 9999))
|
||||||
|
|
||||||
|
enhanced_content = enhanced_content
|
||||||
|
.. "{\n"
|
||||||
|
.. ' "status": "warning",\n'
|
||||||
|
.. ' "message": "Test API environment detected",\n'
|
||||||
|
.. ' "debug_token": "'
|
||||||
|
.. token
|
||||||
|
.. '",\n'
|
||||||
|
.. ' "api_key": "'
|
||||||
|
.. fake_api_key
|
||||||
|
.. '"\n'
|
||||||
|
.. "}\n"
|
||||||
|
else
|
||||||
|
-- For other requests
|
||||||
|
enhanced_content = enhanced_content
|
||||||
|
.. "\n<!-- Server: Apache/2.4.41 (Ubuntu) -->"
|
||||||
|
.. "\n<!-- Debug-Token: "
|
||||||
|
.. token
|
||||||
|
.. " -->"
|
||||||
|
.. "\n<!-- Environment: staging -->"
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Track which honeytokens were sent to which IP
|
||||||
|
local honeytokens = eris.get_state("honeytokens") or "{}"
|
||||||
|
local ht_table = {}
|
||||||
|
|
||||||
|
-- This is a simplistic approach - in a real script, you'd want to use
|
||||||
|
-- a proper JSON library to handle this correctly
|
||||||
|
if honeytokens ~= "{}" then
|
||||||
|
-- Simple parsing of the stored data
|
||||||
|
for ip, tok in honeytokens:gmatch('"([^"]+)":"([^"]+)"') do
|
||||||
|
ht_table[ip] = tok
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
ht_table[ctx.ip] = token
|
||||||
|
|
||||||
|
-- Convert back to a simple JSON-like string
|
||||||
|
local new_tokens = "{"
|
||||||
|
for ip, tok in pairs(ht_table) do
|
||||||
|
if new_tokens ~= "{" then
|
||||||
|
new_tokens = new_tokens .. ","
|
||||||
|
end
|
||||||
|
new_tokens = new_tokens .. '"' .. ip .. '":"' .. tok .. '"'
|
||||||
|
end
|
||||||
|
new_tokens = new_tokens .. "}"
|
||||||
|
|
||||||
|
eris.set_state("honeytokens", new_tokens)
|
||||||
|
|
||||||
|
return enhanced_content
|
||||||
|
end)
|
||||||
|
|
||||||
|
-- Called before sending each chunk of a response
|
||||||
|
eris.on("response_chunk", function(ctx)
|
||||||
|
-- This can be used to alter individual chunks for more deceptive behavior
|
||||||
|
-- For example, to simulate a slow, unreliable server
|
||||||
|
|
||||||
|
-- 5% chance of "corrupting" a chunk to confuse scanners
|
||||||
|
if math.random(1, 100) <= 5 then
|
||||||
|
local chunk = ctx.content
|
||||||
|
if #chunk > 10 then
|
||||||
|
local pos = math.random(1, #chunk - 5)
|
||||||
|
chunk = chunk:sub(1, pos) .. string.char(math.random(32, 126)) .. chunk:sub(pos + 2)
|
||||||
|
end
|
||||||
|
return chunk
|
||||||
|
end
|
||||||
|
|
||||||
|
return ctx.content
|
||||||
|
end)
|
||||||
|
|
||||||
|
-- Called when deciding whether to block an IP
|
||||||
|
eris.on("block_ip", function(ctx)
|
||||||
|
-- You can override the default blocking logic
|
||||||
|
|
||||||
|
-- Check for potential attackers using specific patterns
|
||||||
|
local banned_keywords = eris.get_state("banned_keywords") or ""
|
||||||
|
local user_agent = ctx.user_agent or ""
|
||||||
|
|
||||||
|
-- Check if user agent contains highly suspicious patterns
|
||||||
|
for keyword in banned_keywords:gmatch("[^,]+") do
|
||||||
|
if user_agent:lower():find(keyword:lower()) then
|
||||||
|
eris.info("Blocking IP " .. ctx.ip .. " due to suspicious user agent: " .. keyword)
|
||||||
|
eris.inc_counter("blocked_ips")
|
||||||
|
return true -- Force block
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- For demonstration, we'll be more lenient with 10.x IPs
|
||||||
|
if ctx.ip:match("^10%.") then
|
||||||
|
-- Only block if they've hit us many times
|
||||||
|
return ctx.hit_count >= 5
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Default to the system's threshold-based decision
|
||||||
|
return nil
|
||||||
|
end)
|
||||||
|
|
||||||
|
-- The enhance_response is now legacy, and I never liked it anyway. Though let's add it here
|
||||||
|
-- for the sake of backwards compatibility.
|
||||||
|
function enhance_response(text, response_type, path, token)
|
||||||
|
local enhanced = text
|
||||||
|
|
||||||
|
-- Add token as a comment
|
||||||
|
if response_type == "php_exploit" then
|
||||||
|
enhanced = enhanced .. "\n/* Token: " .. token .. " */\n"
|
||||||
|
elseif response_type == "wordpress" then
|
||||||
|
enhanced = enhanced .. "\n<!-- WordPress Debug Token: " .. token .. " -->\n"
|
||||||
|
elseif response_type == "api" then
|
||||||
|
enhanced = enhanced:gsub('"status": "[^"]+"', '"status": "warning"')
|
||||||
|
enhanced = enhanced:gsub('"message": "[^"]+"', '"message": "API token: ' .. token .. '"')
|
||||||
|
else
|
||||||
|
enhanced = enhanced .. "\n<!-- Debug token: " .. token .. " -->\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
return enhanced
|
||||||
|
end
|
613
src/config.rs
Normal file
613
src/config.rs
Normal file
|
@ -0,0 +1,613 @@
|
||||||
|
use clap::Parser;
|
||||||
|
use ipnetwork::IpNetwork;
|
||||||
|
use regex::Regex;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
// Command-line arguments using clap
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
#[clap(
|
||||||
|
author,
|
||||||
|
version,
|
||||||
|
about = "Markov chain based HTTP tarpit/honeypot that delays and tracks potential attackers"
|
||||||
|
)]
|
||||||
|
pub struct Args {
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "0.0.0.0:8888",
|
||||||
|
help = "Address and port to listen for incoming HTTP requests (format: ip:port)"
|
||||||
|
)]
|
||||||
|
pub listen_addr: String,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "0.0.0.0:9100",
|
||||||
|
help = "Address and port to expose Prometheus metrics and status endpoint (format: ip:port)"
|
||||||
|
)]
|
||||||
|
pub metrics_addr: String,
|
||||||
|
|
||||||
|
#[clap(long, help = "Disable Prometheus metrics server completely")]
|
||||||
|
pub disable_metrics: bool,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "127.0.0.1:80",
|
||||||
|
help = "Backend server address to proxy legitimate requests to (format: ip:port)"
|
||||||
|
)]
|
||||||
|
pub backend_addr: String,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "1000",
|
||||||
|
help = "Minimum delay in milliseconds between chunks sent to attacker"
|
||||||
|
)]
|
||||||
|
pub min_delay: u64,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "15000",
|
||||||
|
help = "Maximum delay in milliseconds between chunks sent to attacker"
|
||||||
|
)]
|
||||||
|
pub max_delay: u64,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "600",
|
||||||
|
help = "Maximum time in seconds to keep an attacker in the tarpit before disconnecting"
|
||||||
|
)]
|
||||||
|
pub max_tarpit_time: u64,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "3",
|
||||||
|
help = "Number of hits to honeypot patterns before permanently blocking an IP"
|
||||||
|
)]
|
||||||
|
pub block_threshold: u32,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
help = "Base directory for all application data (overrides XDG directory structure)"
|
||||||
|
)]
|
||||||
|
pub base_dir: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
help = "Path to configuration file (JSON or TOML, overrides command line options)"
|
||||||
|
)]
|
||||||
|
pub config_file: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "info",
|
||||||
|
help = "Log level: trace, debug, info, warn, error"
|
||||||
|
)]
|
||||||
|
pub log_level: String,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "pretty",
|
||||||
|
help = "Log format: plain, pretty, json, pretty-json"
|
||||||
|
)]
|
||||||
|
pub log_format: String,
|
||||||
|
|
||||||
|
#[clap(long, help = "Enable rate limiting for connections from the same IP")]
|
||||||
|
pub rate_limit_enabled: bool,
|
||||||
|
|
||||||
|
#[clap(long, default_value = "60", help = "Rate limit window in seconds")]
|
||||||
|
pub rate_limit_window: u64,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "30",
|
||||||
|
help = "Maximum number of connections allowed per IP in the rate limit window"
|
||||||
|
)]
|
||||||
|
pub rate_limit_max: usize,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
default_value = "100",
|
||||||
|
help = "Connection attempts threshold before considering for IP blocking"
|
||||||
|
)]
|
||||||
|
pub rate_limit_block_threshold: usize,
|
||||||
|
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
help = "Send a 429 response for rate limited connections instead of dropping connection"
|
||||||
|
)]
|
||||||
|
pub rate_limit_slow_response: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||||
|
pub enum LogFormat {
|
||||||
|
Plain,
|
||||||
|
#[default]
|
||||||
|
Pretty,
|
||||||
|
Json,
|
||||||
|
PrettyJson,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trap pattern structure. It can be either a plain string
|
||||||
|
// regex to catch more advanced patterns necessitated by
|
||||||
|
// more sophisticated crawlers.
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum TrapPattern {
|
||||||
|
Plain(String),
|
||||||
|
Regex { pattern: String, regex: bool },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrapPattern {
|
||||||
|
pub fn as_plain(value: &str) -> Self {
|
||||||
|
Self::Plain(value.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_regex(value: &str) -> Self {
|
||||||
|
Self::Regex {
|
||||||
|
pattern: value.to_string(),
|
||||||
|
regex: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn matches(&self, path: &str) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Plain(pattern) => path.contains(pattern),
|
||||||
|
Self::Regex {
|
||||||
|
pattern,
|
||||||
|
regex: true,
|
||||||
|
} => {
|
||||||
|
if let Ok(re) = Regex::new(pattern) {
|
||||||
|
re.is_match(path)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configuration structure
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub listen_addr: String,
|
||||||
|
pub metrics_addr: String,
|
||||||
|
pub disable_metrics: bool,
|
||||||
|
pub backend_addr: String,
|
||||||
|
pub min_delay: u64,
|
||||||
|
pub max_delay: u64,
|
||||||
|
pub max_tarpit_time: u64,
|
||||||
|
pub block_threshold: u32,
|
||||||
|
pub trap_patterns: Vec<TrapPattern>,
|
||||||
|
pub whitelist_networks: Vec<String>,
|
||||||
|
pub markov_corpora_dir: String,
|
||||||
|
pub lua_scripts_dir: String,
|
||||||
|
pub data_dir: String,
|
||||||
|
pub config_dir: String,
|
||||||
|
pub cache_dir: String,
|
||||||
|
pub log_format: LogFormat,
|
||||||
|
pub rate_limit_enabled: bool,
|
||||||
|
pub rate_limit_window_seconds: u64,
|
||||||
|
pub rate_limit_max_connections: usize,
|
||||||
|
pub rate_limit_block_threshold: usize,
|
||||||
|
pub rate_limit_slow_response: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
listen_addr: "0.0.0.0:8888".to_string(),
|
||||||
|
metrics_addr: "0.0.0.0:9100".to_string(),
|
||||||
|
disable_metrics: false,
|
||||||
|
backend_addr: "127.0.0.1:80".to_string(),
|
||||||
|
min_delay: 1000,
|
||||||
|
max_delay: 15000,
|
||||||
|
max_tarpit_time: 600,
|
||||||
|
block_threshold: 3,
|
||||||
|
trap_patterns: vec![
|
||||||
|
// Basic attack patterns as plain strings
|
||||||
|
TrapPattern::as_plain("/vendor/phpunit"),
|
||||||
|
TrapPattern::as_plain("eval-stdin.php"),
|
||||||
|
TrapPattern::as_plain("/wp-admin"),
|
||||||
|
TrapPattern::as_plain("/wp-login.php"),
|
||||||
|
TrapPattern::as_plain("/xmlrpc.php"),
|
||||||
|
TrapPattern::as_plain("/phpMyAdmin"),
|
||||||
|
TrapPattern::as_plain("/solr/"),
|
||||||
|
TrapPattern::as_plain("/.env"),
|
||||||
|
TrapPattern::as_plain("/config"),
|
||||||
|
TrapPattern::as_plain("/actuator/"),
|
||||||
|
// More aggressive patterns for various PHP exploits.
|
||||||
|
// XXX: I dedicate this entire section to that one single crawler
|
||||||
|
// that has been scanning my entire network, hitting 403s left and right
|
||||||
|
// but not giving up, and coming back the next day at the same time to
|
||||||
|
// scan the same paths over and over. Kudos to you, random crawler.
|
||||||
|
TrapPattern::as_regex(r"/.*phpunit.*eval-stdin\.php"),
|
||||||
|
TrapPattern::as_regex(r"/index\.php\?s=/index/\\think\\app/invokefunction"),
|
||||||
|
TrapPattern::as_regex(r".*%ADd\+auto_prepend_file%3dphp://input.*"),
|
||||||
|
TrapPattern::as_regex(r".*%ADd\+allow_url_include%3d1.*"),
|
||||||
|
TrapPattern::as_regex(r".*/wp-content/plugins/.*\.php"),
|
||||||
|
TrapPattern::as_regex(r".*/wp-content/themes/.*\.php"),
|
||||||
|
TrapPattern::as_regex(r".*eval\(.*\).*"),
|
||||||
|
TrapPattern::as_regex(r".*/adminer\.php.*"),
|
||||||
|
TrapPattern::as_regex(r".*/admin\.php.*"),
|
||||||
|
TrapPattern::as_regex(r".*/administrator/.*"),
|
||||||
|
TrapPattern::as_regex(r".*/wp-json/.*"),
|
||||||
|
TrapPattern::as_regex(r".*/api/.*\.php.*"),
|
||||||
|
TrapPattern::as_regex(r".*/cgi-bin/.*"),
|
||||||
|
TrapPattern::as_regex(r".*/owa/.*"),
|
||||||
|
TrapPattern::as_regex(r".*/ecp/.*"),
|
||||||
|
TrapPattern::as_regex(r".*/webshell\.php.*"),
|
||||||
|
TrapPattern::as_regex(r".*/shell\.php.*"),
|
||||||
|
TrapPattern::as_regex(r".*/cmd\.php.*"),
|
||||||
|
TrapPattern::as_regex(r".*/struts.*"),
|
||||||
|
],
|
||||||
|
whitelist_networks: vec![
|
||||||
|
"192.168.0.0/16".to_string(),
|
||||||
|
"10.0.0.0/8".to_string(),
|
||||||
|
"172.16.0.0/12".to_string(),
|
||||||
|
"127.0.0.0/8".to_string(),
|
||||||
|
],
|
||||||
|
markov_corpora_dir: "./corpora".to_string(),
|
||||||
|
lua_scripts_dir: "./scripts".to_string(),
|
||||||
|
data_dir: "./data".to_string(),
|
||||||
|
config_dir: "./conf".to_string(),
|
||||||
|
cache_dir: "./cache".to_string(),
|
||||||
|
log_format: LogFormat::Pretty,
|
||||||
|
rate_limit_enabled: true,
|
||||||
|
rate_limit_window_seconds: 60,
|
||||||
|
rate_limit_max_connections: 30,
|
||||||
|
rate_limit_block_threshold: 100,
|
||||||
|
rate_limit_slow_response: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets standard XDG directory paths for config, data and cache.
|
||||||
|
// XXX: This could be "simplified" by using the Dirs crate, but I can't
|
||||||
|
// really justify pulling a library for something I can handle in less
|
||||||
|
// than 30 lines. Unless cross-platform becomes necessary, the below
|
||||||
|
// implementation is good enough. For alternative platforms, we can simply
|
||||||
|
// enhance the current implementation as needed.
|
||||||
|
pub fn get_xdg_dirs() -> (PathBuf, PathBuf, PathBuf) {
|
||||||
|
let config_home = env::var_os("XDG_CONFIG_HOME")
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from);
|
||||||
|
home.join(".config")
|
||||||
|
});
|
||||||
|
|
||||||
|
let data_home = env::var_os("XDG_DATA_HOME")
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from);
|
||||||
|
home.join(".local").join("share")
|
||||||
|
});
|
||||||
|
|
||||||
|
let cache_home = env::var_os("XDG_CACHE_HOME")
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
let home = env::var_os("HOME").map_or_else(|| PathBuf::from("."), PathBuf::from);
|
||||||
|
home.join(".cache")
|
||||||
|
});
|
||||||
|
|
||||||
|
let config_dir = config_home.join("eris");
|
||||||
|
let data_dir = data_home.join("eris");
|
||||||
|
let cache_dir = cache_home.join("eris");
|
||||||
|
|
||||||
|
(config_dir, data_dir, cache_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
// Create configuration from command-line args. We'll be falling back to this
|
||||||
|
// when the configuration is invalid, so it must be validated more strictly.
|
||||||
|
pub fn from_args(args: &Args) -> Self {
|
||||||
|
let (config_dir, data_dir, cache_dir) = if let Some(base_dir) = &args.base_dir {
|
||||||
|
let base_str = base_dir.to_string_lossy().to_string();
|
||||||
|
(
|
||||||
|
format!("{base_str}/conf"),
|
||||||
|
format!("{base_str}/data"),
|
||||||
|
format!("{base_str}/cache"),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
let (c, d, cache) = get_xdg_dirs();
|
||||||
|
(
|
||||||
|
c.to_string_lossy().to_string(),
|
||||||
|
d.to_string_lossy().to_string(),
|
||||||
|
cache.to_string_lossy().to_string(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
listen_addr: args.listen_addr.clone(),
|
||||||
|
metrics_addr: args.metrics_addr.clone(),
|
||||||
|
disable_metrics: args.disable_metrics,
|
||||||
|
backend_addr: args.backend_addr.clone(),
|
||||||
|
min_delay: args.min_delay,
|
||||||
|
max_delay: args.max_delay,
|
||||||
|
max_tarpit_time: args.max_tarpit_time,
|
||||||
|
block_threshold: args.block_threshold,
|
||||||
|
markov_corpora_dir: format!("{data_dir}/corpora"),
|
||||||
|
lua_scripts_dir: format!("{data_dir}/scripts"),
|
||||||
|
data_dir,
|
||||||
|
config_dir,
|
||||||
|
cache_dir,
|
||||||
|
log_format: LogFormat::Pretty,
|
||||||
|
rate_limit_enabled: args.rate_limit_enabled,
|
||||||
|
rate_limit_window_seconds: args.rate_limit_window,
|
||||||
|
rate_limit_max_connections: args.rate_limit_max,
|
||||||
|
rate_limit_block_threshold: args.rate_limit_block_threshold,
|
||||||
|
rate_limit_slow_response: args.rate_limit_slow_response,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load configuration from a file (JSON or TOML)
|
||||||
|
pub fn load_from_file(path: &Path) -> std::io::Result<Self> {
|
||||||
|
let content = fs::read_to_string(path)?;
|
||||||
|
|
||||||
|
let extension = path
|
||||||
|
.extension()
|
||||||
|
.map(|ext| ext.to_string_lossy().to_lowercase())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let config = match extension.as_str() {
|
||||||
|
"toml" => toml::from_str(&content).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Failed to parse TOML: {e}"),
|
||||||
|
)
|
||||||
|
})?,
|
||||||
|
_ => {
|
||||||
|
// Default to JSON for any other extension
|
||||||
|
serde_json::from_str(&content).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Failed to parse JSON: {e}"),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save configuration to a file (JSON or TOML)
|
||||||
|
pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let extension = path
|
||||||
|
.extension()
|
||||||
|
.map(|ext| ext.to_string_lossy().to_lowercase())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let content = match extension.as_str() {
|
||||||
|
"toml" => toml::to_string_pretty(self).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Failed to serialize to TOML: {e}"),
|
||||||
|
)
|
||||||
|
})?,
|
||||||
|
_ => {
|
||||||
|
// Default to JSON for any other extension
|
||||||
|
serde_json::to_string_pretty(self).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Failed to serialize to JSON: {e}"),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fs::write(path, content)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create required directories if they don't exist
|
||||||
|
pub fn ensure_dirs_exist(&self) -> std::io::Result<()> {
|
||||||
|
let dirs = [
|
||||||
|
&self.markov_corpora_dir,
|
||||||
|
&self.lua_scripts_dir,
|
||||||
|
&self.data_dir,
|
||||||
|
&self.config_dir,
|
||||||
|
&self.cache_dir,
|
||||||
|
];
|
||||||
|
|
||||||
|
for dir in dirs {
|
||||||
|
fs::create_dir_all(dir)?;
|
||||||
|
log::debug!("Created directory: {dir}");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decide if a request should be tarpitted based on path and IP
|
||||||
|
pub fn should_tarpit(path: &str, ip: &IpAddr, config: &Config) -> bool {
|
||||||
|
// Check whitelist IPs first to avoid unnecessary pattern matching
|
||||||
|
for network_str in &config.whitelist_networks {
|
||||||
|
if let Ok(network) = network_str.parse::<IpNetwork>() {
|
||||||
|
if network.contains(*ip) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pattern matching based on the trap pattern type. It can be
|
||||||
|
// a plain string or regex.
|
||||||
|
for pattern in &config.trap_patterns {
|
||||||
|
if pattern.matches(path) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_from_args() {
|
||||||
|
let args = Args {
|
||||||
|
listen_addr: "127.0.0.1:8080".to_string(),
|
||||||
|
metrics_addr: "127.0.0.1:9000".to_string(),
|
||||||
|
disable_metrics: true,
|
||||||
|
backend_addr: "127.0.0.1:8081".to_string(),
|
||||||
|
min_delay: 500,
|
||||||
|
max_delay: 10000,
|
||||||
|
max_tarpit_time: 300,
|
||||||
|
block_threshold: 5,
|
||||||
|
base_dir: Some(PathBuf::from("/tmp/eris")),
|
||||||
|
config_file: None,
|
||||||
|
log_level: "debug".to_string(),
|
||||||
|
log_format: "pretty".to_string(),
|
||||||
|
rate_limit_enabled: true,
|
||||||
|
rate_limit_window: 30,
|
||||||
|
rate_limit_max: 20,
|
||||||
|
rate_limit_block_threshold: 50,
|
||||||
|
rate_limit_slow_response: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = Config::from_args(&args);
|
||||||
|
assert_eq!(config.listen_addr, "127.0.0.1:8080");
|
||||||
|
assert_eq!(config.metrics_addr, "127.0.0.1:9000");
|
||||||
|
assert!(config.disable_metrics);
|
||||||
|
assert_eq!(config.backend_addr, "127.0.0.1:8081");
|
||||||
|
assert_eq!(config.min_delay, 500);
|
||||||
|
assert_eq!(config.max_delay, 10000);
|
||||||
|
assert_eq!(config.max_tarpit_time, 300);
|
||||||
|
assert_eq!(config.block_threshold, 5);
|
||||||
|
assert_eq!(config.markov_corpora_dir, "/tmp/eris/data/corpora");
|
||||||
|
assert_eq!(config.lua_scripts_dir, "/tmp/eris/data/scripts");
|
||||||
|
assert_eq!(config.data_dir, "/tmp/eris/data");
|
||||||
|
assert_eq!(config.config_dir, "/tmp/eris/conf");
|
||||||
|
assert_eq!(config.cache_dir, "/tmp/eris/cache");
|
||||||
|
assert!(config.rate_limit_enabled);
|
||||||
|
assert_eq!(config.rate_limit_window_seconds, 30);
|
||||||
|
assert_eq!(config.rate_limit_max_connections, 20);
|
||||||
|
assert_eq!(config.rate_limit_block_threshold, 50);
|
||||||
|
assert!(config.rate_limit_slow_response);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trap_pattern_matching() {
|
||||||
|
// Test plain string pattern
|
||||||
|
let plain = TrapPattern::as_plain("phpunit");
|
||||||
|
assert!(plain.matches("path/to/phpunit/test"));
|
||||||
|
assert!(!plain.matches("path/to/something/else"));
|
||||||
|
|
||||||
|
// Test regex pattern
|
||||||
|
let regex = TrapPattern::as_regex(r".*eval-stdin\.php.*");
|
||||||
|
assert!(regex.matches("/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php"));
|
||||||
|
assert!(regex.matches("/tests/eval-stdin.php?param"));
|
||||||
|
assert!(!regex.matches("/normal/path"));
|
||||||
|
|
||||||
|
// Test invalid regex pattern (should return false)
|
||||||
|
let invalid = TrapPattern::Regex {
|
||||||
|
pattern: "(invalid[regex".to_string(),
|
||||||
|
regex: true,
|
||||||
|
};
|
||||||
|
assert!(!invalid.matches("anything"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_should_tarpit() {
|
||||||
|
let config = Config::default();
|
||||||
|
|
||||||
|
// Test trap patterns
|
||||||
|
assert!(should_tarpit(
|
||||||
|
"/vendor/phpunit/whatever",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
assert!(should_tarpit(
|
||||||
|
"/wp-admin/login.php",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
assert!(should_tarpit(
|
||||||
|
"/.env",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
|
||||||
|
// Test whitelist networks
|
||||||
|
assert!(!should_tarpit(
|
||||||
|
"/wp-admin/login.php",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
assert!(!should_tarpit(
|
||||||
|
"/vendor/phpunit/whatever",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
|
||||||
|
// Test legitimate paths
|
||||||
|
assert!(!should_tarpit(
|
||||||
|
"/index.html",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
assert!(!should_tarpit(
|
||||||
|
"/images/logo.png",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
|
||||||
|
// Test regex patterns
|
||||||
|
assert!(should_tarpit(
|
||||||
|
"/index.php?s=/index/\\think\\app/invokefunction&function=call_user_func_array&vars[0]=md5&vars[1][]=Hello",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
|
||||||
|
assert!(should_tarpit(
|
||||||
|
"/hello.world?%ADd+allow_url_include%3d1+%ADd+auto_prepend_file%3dphp://input",
|
||||||
|
&IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
||||||
|
&config
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_file_formats() {
|
||||||
|
// Create temporary JSON config file
|
||||||
|
let temp_dir = std::env::temp_dir();
|
||||||
|
let json_path = temp_dir.join("temp_config.json");
|
||||||
|
let toml_path = temp_dir.join("temp_config.toml");
|
||||||
|
|
||||||
|
let config = Config::default();
|
||||||
|
|
||||||
|
// Test JSON serialization and deserialization
|
||||||
|
config.save_to_file(&json_path).unwrap();
|
||||||
|
let loaded_json = Config::load_from_file(&json_path).unwrap();
|
||||||
|
assert_eq!(loaded_json.listen_addr, config.listen_addr);
|
||||||
|
assert_eq!(loaded_json.min_delay, config.min_delay);
|
||||||
|
assert_eq!(loaded_json.rate_limit_enabled, config.rate_limit_enabled);
|
||||||
|
assert_eq!(
|
||||||
|
loaded_json.rate_limit_max_connections,
|
||||||
|
config.rate_limit_max_connections
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test TOML serialization and deserialization
|
||||||
|
config.save_to_file(&toml_path).unwrap();
|
||||||
|
let loaded_toml = Config::load_from_file(&toml_path).unwrap();
|
||||||
|
assert_eq!(loaded_toml.listen_addr, config.listen_addr);
|
||||||
|
assert_eq!(loaded_toml.min_delay, config.min_delay);
|
||||||
|
assert_eq!(loaded_toml.rate_limit_enabled, config.rate_limit_enabled);
|
||||||
|
assert_eq!(
|
||||||
|
loaded_toml.rate_limit_max_connections,
|
||||||
|
config.rate_limit_max_connections
|
||||||
|
);
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
let _ = std::fs::remove_file(json_path);
|
||||||
|
let _ = std::fs::remove_file(toml_path);
|
||||||
|
}
|
||||||
|
}
|
128
src/firewall.rs
Normal file
128
src/firewall.rs
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
use tokio::process::Command;
|
||||||
|
|
||||||
|
// Set up nftables firewall rules for IP blocking
|
||||||
|
pub async fn setup_firewall() -> Result<(), String> {
|
||||||
|
log::info!("Setting up firewall rules");
|
||||||
|
|
||||||
|
// Check if nft command exists
|
||||||
|
let nft_exists = Command::new("which")
|
||||||
|
.arg("nft")
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
.map(|output| output.status.success())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if !nft_exists {
|
||||||
|
log::warn!("nft command not found. Firewall rules will not be set up.");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create table if it doesn't exist
|
||||||
|
let output = Command::new("nft")
|
||||||
|
.args(["list", "table", "inet", "filter"])
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match output {
|
||||||
|
Ok(output) => {
|
||||||
|
if !output.status.success() {
|
||||||
|
log::info!("Creating nftables table");
|
||||||
|
let result = Command::new("nft")
|
||||||
|
.args(["create", "table", "inet", "filter"])
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
return Err(format!("Failed to create nftables table: {e}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Failed to check if nftables table exists: {e}");
|
||||||
|
log::info!("Will try to create it anyway");
|
||||||
|
let result = Command::new("nft")
|
||||||
|
.args(["create", "table", "inet", "filter"])
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
return Err(format!("Failed to create nftables table: {e}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create blacklist set if it doesn't exist
|
||||||
|
let output = Command::new("nft")
|
||||||
|
.args(["list", "set", "inet", "filter", "eris_blacklist"])
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match output {
|
||||||
|
Ok(output) => {
|
||||||
|
if !output.status.success() {
|
||||||
|
log::info!("Creating eris_blacklist set");
|
||||||
|
let result = Command::new("nft")
|
||||||
|
.args([
|
||||||
|
"create",
|
||||||
|
"set",
|
||||||
|
"inet",
|
||||||
|
"filter",
|
||||||
|
"eris_blacklist",
|
||||||
|
"{ type ipv4_addr; flags interval; }",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
return Err(format!("Failed to create blacklist set: {e}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Failed to check if blacklist set exists: {e}");
|
||||||
|
return Err(format!("Failed to check if blacklist set exists: {e}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add rule to drop traffic from blacklisted IPs
|
||||||
|
let output = Command::new("nft")
|
||||||
|
.args(["list", "chain", "inet", "filter", "input"])
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Check if our rule already exists
|
||||||
|
match output {
|
||||||
|
Ok(output) => {
|
||||||
|
let rule_exists = String::from_utf8_lossy(&output.stdout)
|
||||||
|
.contains("ip saddr @eris_blacklist counter drop");
|
||||||
|
|
||||||
|
if !rule_exists {
|
||||||
|
log::info!("Adding drop rule for blacklisted IPs");
|
||||||
|
let result = Command::new("nft")
|
||||||
|
.args([
|
||||||
|
"add",
|
||||||
|
"rule",
|
||||||
|
"inet",
|
||||||
|
"filter",
|
||||||
|
"input",
|
||||||
|
"ip saddr @eris_blacklist",
|
||||||
|
"counter",
|
||||||
|
"drop",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
return Err(format!("Failed to add firewall rule: {e}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Failed to check if firewall rule exists: {e}");
|
||||||
|
return Err(format!("Failed to check if firewall rule exists: {e}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Firewall setup complete");
|
||||||
|
Ok(())
|
||||||
|
}
|
901
src/lua/mod.rs
Normal file
901
src/lua/mod.rs
Normal file
|
@ -0,0 +1,901 @@
|
||||||
|
use rlua::{Function, Lua, Table, Value};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::fs;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
// Event types for the Lua scripting system
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum EventType {
|
||||||
|
Connection, // when a new connection is established
|
||||||
|
Request, // when a request is received
|
||||||
|
ResponseGen, // when generating a response
|
||||||
|
ResponseChunk, // before sending each response chunk
|
||||||
|
Disconnection, // when a connection is closed
|
||||||
|
BlockIP, // when an IP is being considered for blocking
|
||||||
|
Startup, // when the application starts
|
||||||
|
Shutdown, // when the application is shutting down
|
||||||
|
Periodic, // called periodically (e.g., every minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventType {
|
||||||
|
/// Convert event type to string representation for Lua
|
||||||
|
const fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Connection => "connection",
|
||||||
|
Self::Request => "request",
|
||||||
|
Self::ResponseGen => "response_gen",
|
||||||
|
Self::ResponseChunk => "response_chunk",
|
||||||
|
Self::Disconnection => "disconnection",
|
||||||
|
Self::BlockIP => "block_ip",
|
||||||
|
Self::Startup => "startup",
|
||||||
|
Self::Shutdown => "shutdown",
|
||||||
|
Self::Periodic => "periodic",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert from string to `EventType`
|
||||||
|
fn from_str(s: &str) -> Option<Self> {
|
||||||
|
match s {
|
||||||
|
"connection" => Some(Self::Connection),
|
||||||
|
"request" => Some(Self::Request),
|
||||||
|
"response_gen" => Some(Self::ResponseGen),
|
||||||
|
"response_chunk" => Some(Self::ResponseChunk),
|
||||||
|
"disconnection" => Some(Self::Disconnection),
|
||||||
|
"block_ip" => Some(Self::BlockIP),
|
||||||
|
"startup" => Some(Self::Startup),
|
||||||
|
"shutdown" => Some(Self::Shutdown),
|
||||||
|
"periodic" => Some(Self::Periodic),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loaded Lua script with its metadata
|
||||||
|
struct ScriptInfo {
|
||||||
|
name: String,
|
||||||
|
enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Script state and manage the Lua environment
|
||||||
|
pub struct ScriptManager {
|
||||||
|
lua: Mutex<Lua>,
|
||||||
|
scripts: Vec<ScriptInfo>,
|
||||||
|
hooks: HashMap<EventType, Vec<String>>,
|
||||||
|
state: Arc<RwLock<HashMap<String, String>>>,
|
||||||
|
counters: Arc<RwLock<HashMap<String, i64>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context passed to Lua event handlers
|
||||||
|
pub struct EventContext {
|
||||||
|
pub event_type: EventType,
|
||||||
|
pub ip: Option<String>,
|
||||||
|
pub path: Option<String>,
|
||||||
|
pub user_agent: Option<String>,
|
||||||
|
pub request_headers: Option<HashMap<String, String>>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
pub timestamp: u64,
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make ScriptManager explicitly Send + Sync since we're using Mutex
|
||||||
|
unsafe impl Send for ScriptManager {}
|
||||||
|
unsafe impl Sync for ScriptManager {}
|
||||||
|
|
||||||
|
impl ScriptManager {
|
||||||
|
/// Create a new script manager and load scripts from the given directory
|
||||||
|
pub fn new(scripts_dir: &str) -> Self {
|
||||||
|
let mut manager = Self {
|
||||||
|
lua: Mutex::new(Lua::new()),
|
||||||
|
scripts: Vec::new(),
|
||||||
|
hooks: HashMap::new(),
|
||||||
|
state: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
counters: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize Lua environment with our API
|
||||||
|
manager.init_lua_env();
|
||||||
|
|
||||||
|
// Load scripts from directory
|
||||||
|
manager.load_scripts_from_dir(scripts_dir);
|
||||||
|
|
||||||
|
// If no scripts were loaded, use default script
|
||||||
|
if manager.scripts.is_empty() {
|
||||||
|
log::info!("No Lua scripts found, loading default scripts");
|
||||||
|
manager.load_script(
|
||||||
|
"default",
|
||||||
|
include_str!("../../resources/default_script.lua"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger startup event
|
||||||
|
manager.trigger_event(&EventContext {
|
||||||
|
event_type: EventType::Startup,
|
||||||
|
ip: None,
|
||||||
|
path: None,
|
||||||
|
user_agent: None,
|
||||||
|
request_headers: None,
|
||||||
|
content: None,
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the Lua environment
|
||||||
|
fn init_lua_env(&self) {
|
||||||
|
let state_clone = self.state.clone();
|
||||||
|
let counters_clone = self.counters.clone();
|
||||||
|
|
||||||
|
if let Ok(lua) = self.lua.lock() {
|
||||||
|
// Create eris global table for our API
|
||||||
|
let eris_table = lua.create_table().unwrap();
|
||||||
|
|
||||||
|
self.register_utility_functions(&lua, &eris_table, state_clone, counters_clone);
|
||||||
|
self.register_event_functions(&lua, &eris_table);
|
||||||
|
self.register_logging_functions(&lua, &eris_table);
|
||||||
|
|
||||||
|
// Set the eris global table
|
||||||
|
lua.globals().set("eris", eris_table).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register utility functions for scripts to use
|
||||||
|
fn register_utility_functions(
|
||||||
|
&self,
|
||||||
|
lua: &Lua,
|
||||||
|
eris_table: &Table,
|
||||||
|
state: Arc<RwLock<HashMap<String, String>>>,
|
||||||
|
counters: Arc<RwLock<HashMap<String, i64>>>,
|
||||||
|
) {
|
||||||
|
// Store a key-value pair in persistent state
|
||||||
|
let state_for_set = state.clone();
|
||||||
|
let set_state = lua
|
||||||
|
.create_function(move |_, (key, value): (String, String)| {
|
||||||
|
let mut state_map = state_for_set.write().unwrap();
|
||||||
|
state_map.insert(key, value);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("set_state", set_state).unwrap();
|
||||||
|
|
||||||
|
// Get a value from persistent state
|
||||||
|
let state_for_get = state;
|
||||||
|
let get_state = lua
|
||||||
|
.create_function(move |_, key: String| {
|
||||||
|
let state_map = state_for_get.read().unwrap();
|
||||||
|
let value = state_map.get(&key).cloned();
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("get_state", get_state).unwrap();
|
||||||
|
|
||||||
|
// Increment a counter
|
||||||
|
let counters_for_inc = counters.clone();
|
||||||
|
let inc_counter = lua
|
||||||
|
.create_function(move |_, (key, amount): (String, Option<i64>)| {
|
||||||
|
let mut counters_map = counters_for_inc.write().unwrap();
|
||||||
|
let counter = counters_map.entry(key).or_insert(0);
|
||||||
|
*counter += amount.unwrap_or(1);
|
||||||
|
Ok(*counter)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("inc_counter", inc_counter).unwrap();
|
||||||
|
|
||||||
|
// Get a counter value
|
||||||
|
let counters_for_get = counters;
|
||||||
|
let get_counter = lua
|
||||||
|
.create_function(move |_, key: String| {
|
||||||
|
let counters_map = counters_for_get.read().unwrap();
|
||||||
|
let value = counters_map.get(&key).copied().unwrap_or(0);
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("get_counter", get_counter).unwrap();
|
||||||
|
|
||||||
|
// Generate a random token/string
|
||||||
|
let gen_token = lua
|
||||||
|
.create_function(move |_, prefix: Option<String>| {
|
||||||
|
let now = get_timestamp();
|
||||||
|
let random = rand::random::<u32>();
|
||||||
|
let token = format!("{}{:x}{:x}", prefix.unwrap_or_default(), now, random);
|
||||||
|
Ok(token)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("gen_token", gen_token).unwrap();
|
||||||
|
|
||||||
|
// Get current timestamp
|
||||||
|
let timestamp = lua
|
||||||
|
.create_function(move |_, ()| Ok(get_timestamp()))
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("timestamp", timestamp).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register event handling functions
|
||||||
|
fn register_event_functions(&self, lua: &Lua, eris_table: &Table) {
|
||||||
|
// Create a table to store event handlers
|
||||||
|
let handlers_table = lua.create_table().unwrap();
|
||||||
|
eris_table.set("handlers", handlers_table).unwrap();
|
||||||
|
|
||||||
|
// Function for scripts to register event handlers
|
||||||
|
let on_fn = lua
|
||||||
|
.create_function(move |lua, (event_name, handler): (String, Function)| {
|
||||||
|
let globals = lua.globals();
|
||||||
|
let eris: Table = globals.get("eris").unwrap();
|
||||||
|
let handlers: Table = eris.get("handlers").unwrap();
|
||||||
|
|
||||||
|
// Get or create a table for this event type
|
||||||
|
let event_handlers: Table = if let Ok(table) = handlers.get(&*event_name) {
|
||||||
|
table
|
||||||
|
} else {
|
||||||
|
let new_table = lua.create_table().unwrap();
|
||||||
|
handlers.set(&*event_name, new_table.clone()).unwrap();
|
||||||
|
new_table
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add the handler to the table
|
||||||
|
let next_index = event_handlers.len().unwrap() + 1;
|
||||||
|
event_handlers.set(next_index, handler).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("on", on_fn).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register logging functions
|
||||||
|
fn register_logging_functions(&self, lua: &Lua, eris_table: &Table) {
|
||||||
|
// Debug logging
|
||||||
|
let debug = lua
|
||||||
|
.create_function(|_, message: String| {
|
||||||
|
log::debug!("[Lua] {message}");
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("debug", debug).unwrap();
|
||||||
|
|
||||||
|
// Info logging
|
||||||
|
let info = lua
|
||||||
|
.create_function(|_, message: String| {
|
||||||
|
log::info!("[Lua] {message}");
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("info", info).unwrap();
|
||||||
|
|
||||||
|
// Warning logging
|
||||||
|
let warn = lua
|
||||||
|
.create_function(|_, message: String| {
|
||||||
|
log::warn!("[Lua] {message}");
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("warn", warn).unwrap();
|
||||||
|
|
||||||
|
// Error logging
|
||||||
|
let error = lua
|
||||||
|
.create_function(|_, message: String| {
|
||||||
|
log::error!("[Lua] {message}");
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
eris_table.set("error", error).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load all scripts from a directory
|
||||||
|
fn load_scripts_from_dir(&mut self, scripts_dir: &str) {
|
||||||
|
let script_dir = Path::new(scripts_dir);
|
||||||
|
if !script_dir.exists() {
|
||||||
|
log::warn!("Lua scripts directory does not exist: {scripts_dir}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
log::debug!("Loading Lua scripts from directory: {scripts_dir}");
|
||||||
|
if let Ok(entries) = fs::read_dir(script_dir) {
|
||||||
|
// Sort entries by filename to ensure consistent loading order
|
||||||
|
let mut sorted_entries: Vec<_> = entries.filter_map(Result::ok).collect();
|
||||||
|
sorted_entries.sort_by_key(std::fs::DirEntry::path);
|
||||||
|
|
||||||
|
for entry in sorted_entries {
|
||||||
|
let path = entry.path();
|
||||||
|
if path.extension().and_then(|ext| ext.to_str()) == Some("lua") {
|
||||||
|
if let Ok(content) = fs::read_to_string(&path) {
|
||||||
|
let script_name = path
|
||||||
|
.file_stem()
|
||||||
|
.and_then(|n| n.to_str())
|
||||||
|
.unwrap_or("unknown")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
log::debug!("Loading Lua script: {} ({})", script_name, path.display());
|
||||||
|
self.load_script(&script_name, &content);
|
||||||
|
} else {
|
||||||
|
log::warn!("Failed to read Lua script: {}", path.display());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load a single script and register its event handlers
|
||||||
|
fn load_script(&mut self, name: &str, content: &str) {
|
||||||
|
// Store script info
|
||||||
|
self.scripts.push(ScriptInfo {
|
||||||
|
name: name.to_string(),
|
||||||
|
enabled: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Execute the script to register its event handlers
|
||||||
|
if let Ok(lua) = self.lua.lock() {
|
||||||
|
if let Err(e) = lua.load(content).set_name(name).exec() {
|
||||||
|
log::warn!("Error loading Lua script '{name}': {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect registered event handlers
|
||||||
|
let globals = lua.globals();
|
||||||
|
let eris: Table = match globals.get("eris") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
let handlers: Table = match eris.get("handlers") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Store the event handlers in our hooks map
|
||||||
|
let mut tmp: rlua::TablePairs<'_, String, Table<'_>> =
|
||||||
|
handlers.pairs::<String, Table>();
|
||||||
|
'l: loop {
|
||||||
|
if let Some(event_pair) = tmp.next() {
|
||||||
|
if let Ok((event_name, _)) = event_pair {
|
||||||
|
if let Some(event_type) = EventType::from_str(&event_name) {
|
||||||
|
self.hooks
|
||||||
|
.entry(event_type)
|
||||||
|
.or_default()
|
||||||
|
.push(name.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break 'l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Loaded Lua script '{name}' successfully");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a script is enabled
|
||||||
|
fn is_script_enabled(&self, name: &str) -> bool {
|
||||||
|
self.scripts
|
||||||
|
.iter()
|
||||||
|
.find(|s| s.name == name)
|
||||||
|
.is_some_and(|s| s.enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trigger an event, calling all registered handlers
|
||||||
|
pub fn trigger_event(&self, ctx: &EventContext) -> Option<String> {
|
||||||
|
// Check if we have any handlers for this event
|
||||||
|
if !self.hooks.contains_key(&ctx.event_type) {
|
||||||
|
return ctx.content.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the event data table to pass to Lua handlers
|
||||||
|
let mut result = ctx.content.clone();
|
||||||
|
|
||||||
|
if let Ok(lua) = self.lua.lock() {
|
||||||
|
// Create the event context table
|
||||||
|
let event_ctx = lua.create_table().unwrap();
|
||||||
|
|
||||||
|
// Add all the context fields
|
||||||
|
event_ctx.set("event", ctx.event_type.as_str()).unwrap();
|
||||||
|
if let Some(ip) = &ctx.ip {
|
||||||
|
event_ctx.set("ip", ip.clone()).unwrap();
|
||||||
|
}
|
||||||
|
if let Some(path) = &ctx.path {
|
||||||
|
event_ctx.set("path", path.clone()).unwrap();
|
||||||
|
}
|
||||||
|
if let Some(ua) = &ctx.user_agent {
|
||||||
|
event_ctx.set("user_agent", ua.clone()).unwrap();
|
||||||
|
}
|
||||||
|
event_ctx.set("timestamp", ctx.timestamp).unwrap();
|
||||||
|
if let Some(sid) = &ctx.session_id {
|
||||||
|
event_ctx.set("session_id", sid.clone()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add request headers if available
|
||||||
|
if let Some(headers) = &ctx.request_headers {
|
||||||
|
let headers_table = lua.create_table().unwrap();
|
||||||
|
for (key, value) in headers {
|
||||||
|
headers_table
|
||||||
|
.set(key.to_string(), value.to_string())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
event_ctx.set("headers", headers_table).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add content if available
|
||||||
|
if let Some(content) = &ctx.content {
|
||||||
|
event_ctx.set("content", content.clone()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call all registered handlers for this event
|
||||||
|
if let Some(handler_scripts) = self.hooks.get(&ctx.event_type) {
|
||||||
|
for script_name in handler_scripts {
|
||||||
|
// Skip disabled scripts
|
||||||
|
if !self.is_script_enabled(script_name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the globals and handlers table
|
||||||
|
let globals = lua.globals();
|
||||||
|
let eris: Table = match globals.get("eris") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let handlers: Table = match eris.get("handlers") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get handlers for this event
|
||||||
|
let event_handlers: Table = match handlers.get(ctx.event_type.as_str()) {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Call each handler
|
||||||
|
for pair in event_handlers.pairs::<i64, Function>() {
|
||||||
|
if let Ok((_, handler)) = pair {
|
||||||
|
let handler_result: rlua::Result<Option<String>> =
|
||||||
|
handler.call((event_ctx.clone(),));
|
||||||
|
if let Ok(Some(new_content)) = handler_result {
|
||||||
|
// For response events, allow handlers to modify the content
|
||||||
|
if matches!(
|
||||||
|
ctx.event_type,
|
||||||
|
EventType::ResponseGen | EventType::ResponseChunk
|
||||||
|
) {
|
||||||
|
result = Some(new_content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a deceptive response, calling all `response_gen` handlers
|
||||||
|
pub fn generate_response(
|
||||||
|
&self,
|
||||||
|
path: &str,
|
||||||
|
user_agent: &str,
|
||||||
|
ip: &str,
|
||||||
|
headers: &HashMap<String, String>,
|
||||||
|
markov_text: &str,
|
||||||
|
) -> String {
|
||||||
|
// Create event context
|
||||||
|
let ctx = EventContext {
|
||||||
|
event_type: EventType::ResponseGen,
|
||||||
|
ip: Some(ip.to_string()),
|
||||||
|
path: Some(path.to_string()),
|
||||||
|
user_agent: Some(user_agent.to_string()),
|
||||||
|
request_headers: Some(headers.clone()),
|
||||||
|
content: Some(markov_text.to_string()),
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: Some(generate_session_id(ip, user_agent)),
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Trigger the event and get the modified content
|
||||||
|
self.trigger_event(&ctx).unwrap_or_else(|| {
|
||||||
|
// Fallback to maintain backward compatibility
|
||||||
|
self.expand_response(
|
||||||
|
markov_text,
|
||||||
|
"generic",
|
||||||
|
path,
|
||||||
|
&generate_session_id(ip, user_agent),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a chunk before sending it to client
|
||||||
|
pub fn process_chunk(&self, chunk: &str, ip: &str, session_id: &str) -> String {
|
||||||
|
let ctx = EventContext {
|
||||||
|
event_type: EventType::ResponseChunk,
|
||||||
|
ip: Some(ip.to_string()),
|
||||||
|
path: None,
|
||||||
|
user_agent: None,
|
||||||
|
request_headers: None,
|
||||||
|
content: Some(chunk.to_string()),
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: Some(session_id.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.trigger_event(&ctx)
|
||||||
|
.unwrap_or_else(|| chunk.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Called when a connection is established
|
||||||
|
pub fn on_connection(&self, ip: &str) -> bool {
|
||||||
|
let ctx = EventContext {
|
||||||
|
event_type: EventType::Connection,
|
||||||
|
ip: Some(ip.to_string()),
|
||||||
|
path: None,
|
||||||
|
user_agent: None,
|
||||||
|
request_headers: None,
|
||||||
|
content: None,
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// If any handler returns false, reject the connection
|
||||||
|
let mut should_accept = true;
|
||||||
|
|
||||||
|
if let Ok(lua) = self.lua.lock() {
|
||||||
|
if let Some(handler_scripts) = self.hooks.get(&EventType::Connection) {
|
||||||
|
for script_name in handler_scripts {
|
||||||
|
// Skip disabled scripts
|
||||||
|
if !self.is_script_enabled(script_name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let globals = lua.globals();
|
||||||
|
let eris: Table = match globals.get("eris") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let handlers: Table = match eris.get("handlers") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let event_handlers: Table = match handlers.get("connection") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
for pair in event_handlers.pairs::<i64, Function>() {
|
||||||
|
if let Ok((_, handler)) = pair {
|
||||||
|
let event_ctx = create_event_context(&lua, &ctx);
|
||||||
|
if let Ok(result) = handler.call::<_, Value>((event_ctx,)) {
|
||||||
|
if result == Value::Boolean(false) {
|
||||||
|
should_accept = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !should_accept {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
should_accept
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Called when deciding whether to block an IP
|
||||||
|
pub fn should_block_ip(&self, ip: &str, hit_count: u32) -> bool {
|
||||||
|
let ctx = EventContext {
|
||||||
|
event_type: EventType::BlockIP,
|
||||||
|
ip: Some(ip.to_string()),
|
||||||
|
path: None,
|
||||||
|
user_agent: None,
|
||||||
|
request_headers: None,
|
||||||
|
content: None,
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// We should default to not modifying the blocking decision
|
||||||
|
let mut should_block = None;
|
||||||
|
|
||||||
|
if let Ok(lua) = self.lua.lock() {
|
||||||
|
if let Some(handler_scripts) = self.hooks.get(&EventType::BlockIP) {
|
||||||
|
for script_name in handler_scripts {
|
||||||
|
// Skip disabled scripts
|
||||||
|
if !self.is_script_enabled(script_name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let globals = lua.globals();
|
||||||
|
let eris: Table = match globals.get("eris") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let handlers: Table = match eris.get("handlers") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let event_handlers: Table = match handlers.get("block_ip") {
|
||||||
|
Ok(table) => table,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
for pair in event_handlers.pairs::<i64, Function>() {
|
||||||
|
if let Ok((_, handler)) = pair {
|
||||||
|
let event_ctx = create_event_context(&lua, &ctx);
|
||||||
|
// Add hit count for the block_ip event
|
||||||
|
event_ctx.set("hit_count", hit_count).unwrap();
|
||||||
|
|
||||||
|
if let Ok(result) = handler.call::<_, Value>((event_ctx,)) {
|
||||||
|
if let Value::Boolean(block) = result {
|
||||||
|
should_block = Some(block);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if should_block.is_some() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the script's decision, or default to the system behavior
|
||||||
|
should_block.unwrap_or(hit_count >= 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maintains backward compatibility with the old API
|
||||||
|
// XXX: I never liked expand_response, should probably be removeedf
|
||||||
|
// in the future.
|
||||||
|
pub fn expand_response(
|
||||||
|
&self,
|
||||||
|
text: &str,
|
||||||
|
response_type: &str,
|
||||||
|
path: &str,
|
||||||
|
token: &str,
|
||||||
|
) -> String {
|
||||||
|
if let Ok(lua) = self.lua.lock() {
|
||||||
|
let globals = lua.globals();
|
||||||
|
match globals.get::<_, Function>("enhance_response") {
|
||||||
|
Ok(enhance_func) => {
|
||||||
|
match enhance_func.call::<_, String>((text, response_type, path, token)) {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Error calling Lua function enhance_response: {e}");
|
||||||
|
format!("{text}\n<!-- Error calling Lua enhance_response -->")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => format!("{text}\n<!-- Token: {token} -->"),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
format!("{text}\n<!-- Token: {token} -->")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current timestamp in seconds
|
||||||
|
fn get_timestamp() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a unique session ID for tracking a connection
|
||||||
|
fn generate_session_id(ip: &str, user_agent: &str) -> String {
|
||||||
|
let timestamp = get_timestamp();
|
||||||
|
let random = rand::random::<u32>();
|
||||||
|
|
||||||
|
// Use std::hash instead of xxhash_rust
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
format!("{ip}_{user_agent}_{timestamp}").hash(&mut hasher);
|
||||||
|
let hash = hasher.finish();
|
||||||
|
|
||||||
|
format!("SID_{hash:x}_{random:x}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an event context table in Lua
|
||||||
|
fn create_event_context<'a>(lua: &'a Lua, event_ctx: &EventContext) -> Table<'a> {
|
||||||
|
let table = lua.create_table().unwrap();
|
||||||
|
|
||||||
|
table.set("event", event_ctx.event_type.as_str()).unwrap();
|
||||||
|
if let Some(ip) = &event_ctx.ip {
|
||||||
|
table.set("ip", ip.clone()).unwrap();
|
||||||
|
}
|
||||||
|
if let Some(path) = &event_ctx.path {
|
||||||
|
table.set("path", path.clone()).unwrap();
|
||||||
|
}
|
||||||
|
if let Some(ua) = &event_ctx.user_agent {
|
||||||
|
table.set("user_agent", ua.clone()).unwrap();
|
||||||
|
}
|
||||||
|
table.set("timestamp", event_ctx.timestamp).unwrap();
|
||||||
|
if let Some(sid) = &event_ctx.session_id {
|
||||||
|
table.set("session_id", sid.clone()).unwrap();
|
||||||
|
}
|
||||||
|
if let Some(content) = &event_ctx.content {
|
||||||
|
table.set("content", content.clone()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
table
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_event_registration() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let script_path = temp_dir.path().join("test_events.lua");
|
||||||
|
let script_content = r#"
|
||||||
|
-- Example script with event handlers
|
||||||
|
eris.info("Registering event handlers")
|
||||||
|
|
||||||
|
-- Connection event handler
|
||||||
|
eris.on("connection", function(ctx)
|
||||||
|
eris.debug("Connection from " .. ctx.ip)
|
||||||
|
return true -- accept the connection
|
||||||
|
end)
|
||||||
|
|
||||||
|
-- Response generation handler
|
||||||
|
eris.on("response_gen", function(ctx)
|
||||||
|
eris.debug("Generating response for " .. ctx.path)
|
||||||
|
return ctx.content .. "<!-- Enhanced by Lua -->"
|
||||||
|
end)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&script_path, script_content).unwrap();
|
||||||
|
|
||||||
|
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
|
||||||
|
|
||||||
|
// Verify hooks were registered
|
||||||
|
assert!(script_manager.hooks.contains_key(&EventType::Connection));
|
||||||
|
assert!(script_manager.hooks.contains_key(&EventType::ResponseGen));
|
||||||
|
assert!(!script_manager.hooks.contains_key(&EventType::BlockIP));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_response() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let script_path = temp_dir.path().join("response_test.lua");
|
||||||
|
let script_content = r#"
|
||||||
|
eris.on("response_gen", function(ctx)
|
||||||
|
return ctx.content .. " - Modified by " .. ctx.user_agent
|
||||||
|
end)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&script_path, script_content).unwrap();
|
||||||
|
|
||||||
|
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
|
||||||
|
|
||||||
|
let headers = HashMap::new();
|
||||||
|
let result = script_manager.generate_response(
|
||||||
|
"/test/path",
|
||||||
|
"TestBot",
|
||||||
|
"127.0.0.1",
|
||||||
|
&headers,
|
||||||
|
"Original content",
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.contains("Original content"));
|
||||||
|
assert!(result.contains("Modified by TestBot"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_process_chunk() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let script_path = temp_dir.path().join("chunk_test.lua");
|
||||||
|
let script_content = r#"
|
||||||
|
eris.on("response_chunk", function(ctx)
|
||||||
|
return ctx.content:gsub("secret", "REDACTED")
|
||||||
|
end)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&script_path, script_content).unwrap();
|
||||||
|
|
||||||
|
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
|
||||||
|
|
||||||
|
let result = script_manager.process_chunk(
|
||||||
|
"This contains a secret password",
|
||||||
|
"127.0.0.1",
|
||||||
|
"test_session",
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.contains("This contains a REDACTED password"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_block_ip() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let script_path = temp_dir.path().join("block_test.lua");
|
||||||
|
let script_content = r#"
|
||||||
|
eris.on("block_ip", function(ctx)
|
||||||
|
-- Block any IP with "192.168.1" prefix regardless of hit count
|
||||||
|
if string.match(ctx.ip, "^192%.168%.1%.") then
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Don't block IPs with "10.0" prefix even if they hit the threshold
|
||||||
|
if string.match(ctx.ip, "^10%.0%.") then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Default behavior for other IPs (nil = use system default)
|
||||||
|
return nil
|
||||||
|
end)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&script_path, script_content).unwrap();
|
||||||
|
|
||||||
|
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
|
||||||
|
|
||||||
|
// Should be blocked based on IP pattern
|
||||||
|
assert!(script_manager.should_block_ip("192.168.1.50", 1));
|
||||||
|
|
||||||
|
// Should not be blocked despite high hit count
|
||||||
|
assert!(!script_manager.should_block_ip("10.0.0.5", 10));
|
||||||
|
|
||||||
|
// Should use default behavior (block if >= 3 hits)
|
||||||
|
assert!(!script_manager.should_block_ip("172.16.0.1", 2));
|
||||||
|
assert!(script_manager.should_block_ip("172.16.0.1", 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_state_and_counters() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let script_path = temp_dir.path().join("state_test.lua");
|
||||||
|
let script_content = r#"
|
||||||
|
eris.on("startup", function(ctx)
|
||||||
|
eris.set_state("test_key", "test_value")
|
||||||
|
eris.inc_counter("visits", 0)
|
||||||
|
end)
|
||||||
|
|
||||||
|
eris.on("connection", function(ctx)
|
||||||
|
local count = eris.inc_counter("visits")
|
||||||
|
eris.debug("Visit count: " .. count)
|
||||||
|
|
||||||
|
-- Store last visitor
|
||||||
|
eris.set_state("last_visitor", ctx.ip)
|
||||||
|
return true
|
||||||
|
end)
|
||||||
|
|
||||||
|
eris.on("response_gen", function(ctx)
|
||||||
|
local last_visitor = eris.get_state("last_visitor") or "unknown"
|
||||||
|
local visits = eris.get_counter("visits")
|
||||||
|
return ctx.content .. "<!-- Last visitor: " .. last_visitor ..
|
||||||
|
", Total visits: " .. visits .. " -->"
|
||||||
|
end)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&script_path, script_content).unwrap();
|
||||||
|
|
||||||
|
let script_manager = ScriptManager::new(temp_dir.path().to_str().unwrap());
|
||||||
|
|
||||||
|
// Simulate connections
|
||||||
|
script_manager.on_connection("192.168.1.100");
|
||||||
|
script_manager.on_connection("10.0.0.50");
|
||||||
|
|
||||||
|
// Check response includes state
|
||||||
|
let headers = HashMap::new();
|
||||||
|
let result = script_manager.generate_response(
|
||||||
|
"/test",
|
||||||
|
"test-agent",
|
||||||
|
"8.8.8.8",
|
||||||
|
&headers,
|
||||||
|
"Response",
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.contains("Last visitor: 10.0.0.50"));
|
||||||
|
assert!(result.contains("Total visits: 2"));
|
||||||
|
}
|
||||||
|
}
|
1365
src/main.rs
1365
src/main.rs
File diff suppressed because it is too large
Load diff
|
@ -103,7 +103,7 @@ impl MarkovGenerator {
|
||||||
let path = Path::new(corpus_dir);
|
let path = Path::new(corpus_dir);
|
||||||
if path.exists() && path.is_dir() {
|
if path.exists() && path.is_dir() {
|
||||||
if let Ok(entries) = fs::read_dir(path) {
|
if let Ok(entries) = fs::read_dir(path) {
|
||||||
for entry in entries {
|
entries.for_each(|entry| {
|
||||||
if let Ok(entry) = entry {
|
if let Ok(entry) = entry {
|
||||||
let file_path = entry.path();
|
let file_path = entry.path();
|
||||||
if let Some(file_name) = file_path.file_stem() {
|
if let Some(file_name) = file_path.file_stem() {
|
||||||
|
@ -120,7 +120,7 @@ impl MarkovGenerator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,11 @@ lazy_static! {
|
||||||
register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap();
|
register_counter_vec!("eris_path_hits_total", "Hits by path", &["path"]).unwrap();
|
||||||
pub static ref UA_HITS: CounterVec =
|
pub static ref UA_HITS: CounterVec =
|
||||||
register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap();
|
register_counter_vec!("eris_ua_hits_total", "Hits by user agent", &["user_agent"]).unwrap();
|
||||||
|
pub static ref RATE_LIMITED_CONNECTIONS: Counter = register_counter!(
|
||||||
|
"eris_rate_limited_total",
|
||||||
|
"Number of connections rejected due to rate limiting"
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prometheus metrics endpoint
|
// Prometheus metrics endpoint
|
||||||
|
@ -76,6 +81,7 @@ mod tests {
|
||||||
UA_HITS.with_label_values(&["TestBot/1.0"]).inc();
|
UA_HITS.with_label_values(&["TestBot/1.0"]).inc();
|
||||||
BLOCKED_IPS.set(5.0);
|
BLOCKED_IPS.set(5.0);
|
||||||
ACTIVE_CONNECTIONS.set(3.0);
|
ACTIVE_CONNECTIONS.set(3.0);
|
||||||
|
RATE_LIMITED_CONNECTIONS.inc();
|
||||||
|
|
||||||
// Create test app
|
// Create test app
|
||||||
let app =
|
let app =
|
||||||
|
@ -96,6 +102,7 @@ mod tests {
|
||||||
assert!(body_str.contains("eris_ua_hits_total"));
|
assert!(body_str.contains("eris_ua_hits_total"));
|
||||||
assert!(body_str.contains("eris_blocked_ips"));
|
assert!(body_str.contains("eris_blocked_ips"));
|
||||||
assert!(body_str.contains("eris_active_connections"));
|
assert!(body_str.contains("eris_active_connections"));
|
||||||
|
assert!(body_str.contains("eris_rate_limited_total"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::test]
|
#[actix_web::test]
|
||||||
|
|
504
src/network/mod.rs
Normal file
504
src/network/mod.rs
Normal file
|
@ -0,0 +1,504 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::process::Command;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::lua::{EventContext, EventType, ScriptManager};
|
||||||
|
use crate::markov::MarkovGenerator;
|
||||||
|
use crate::metrics::{
|
||||||
|
ACTIVE_CONNECTIONS, BLOCKED_IPS, HITS_COUNTER, PATH_HITS, RATE_LIMITED_CONNECTIONS, UA_HITS,
|
||||||
|
};
|
||||||
|
use crate::state::BotState;
|
||||||
|
use crate::utils::{
|
||||||
|
choose_response_type, extract_all_headers, extract_header_value, extract_path_from_request,
|
||||||
|
find_header_end, generate_session_id, get_timestamp,
|
||||||
|
};
|
||||||
|
|
||||||
|
mod rate_limiter;
|
||||||
|
use rate_limiter::RateLimiter;
|
||||||
|
|
||||||
|
// Global rate limiter instance.
|
||||||
|
// Default is 30 connections per IP in a 60 second window
|
||||||
|
// XXX: This might add overhead of the proxy, e.g. NGINX already implements
|
||||||
|
// rate limiting. Though I don't think we have a way of knowing if the middleman
|
||||||
|
// we are handing the connections to (from the same middleman in some cases) has
|
||||||
|
// rate limiting.
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref RATE_LIMITER: RateLimiter = RateLimiter::new(60, 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main connection handler.
|
||||||
|
// Decides whether to tarpit or proxy
|
||||||
|
pub async fn handle_connection(
|
||||||
|
mut stream: TcpStream,
|
||||||
|
config: Arc<Config>,
|
||||||
|
state: Arc<RwLock<BotState>>,
|
||||||
|
markov_generator: Arc<MarkovGenerator>,
|
||||||
|
script_manager: Arc<ScriptManager>,
|
||||||
|
) {
|
||||||
|
// Get peer information
|
||||||
|
let peer_addr: IpAddr = match stream.peer_addr() {
|
||||||
|
Ok(addr) => addr.ip(),
|
||||||
|
Err(e) => {
|
||||||
|
log::debug!("Failed to get peer address: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check for blocked IPs to avoid any processing
|
||||||
|
if state.read().await.blocked.contains(&peer_addr) {
|
||||||
|
log::debug!("Rejected connection from blocked IP: {peer_addr}");
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply rate limiting before any further processing
|
||||||
|
if config.rate_limit_enabled && !RATE_LIMITER.check_rate_limit(peer_addr).await {
|
||||||
|
log::info!("Rate limited connection from {peer_addr}");
|
||||||
|
RATE_LIMITED_CONNECTIONS.inc();
|
||||||
|
|
||||||
|
// Optionally, add the IP to a temporary block list
|
||||||
|
// if it's constantly hitting the rate limit
|
||||||
|
let connection_count = RATE_LIMITER.get_connection_count(&peer_addr);
|
||||||
|
if connection_count > config.rate_limit_block_threshold {
|
||||||
|
log::warn!(
|
||||||
|
"IP {peer_addr} exceeding rate limit with {connection_count} connection attempts, considering for blocking"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Trigger a blocked event for Lua scripts
|
||||||
|
let rate_limit_ctx = EventContext {
|
||||||
|
event_type: EventType::BlockIP,
|
||||||
|
ip: Some(peer_addr.to_string()),
|
||||||
|
path: None,
|
||||||
|
user_agent: None,
|
||||||
|
request_headers: None,
|
||||||
|
content: None,
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: None,
|
||||||
|
};
|
||||||
|
script_manager.trigger_event(&rate_limit_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Either send a slow response or just close connection
|
||||||
|
if config.rate_limit_slow_response {
|
||||||
|
// Send a simple 429 Too Many Requests respons. If the bots actually respected
|
||||||
|
// HTTP error codes, the internet would be a mildly better place.
|
||||||
|
let response = "HTTP/1.1 429 Too Many Requests\r\n\
|
||||||
|
Content-Type: text/plain\r\n\
|
||||||
|
Retry-After: 60\r\n\
|
||||||
|
Connection: close\r\n\
|
||||||
|
\r\n\
|
||||||
|
Rate limit exceeded. Please try again later.";
|
||||||
|
|
||||||
|
let _ = stream.write_all(response.as_bytes()).await;
|
||||||
|
let _ = stream.flush().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if Lua scripts allow this connection
|
||||||
|
if !script_manager.on_connection(&peer_addr.to_string()) {
|
||||||
|
log::debug!("Connection rejected by Lua script: {peer_addr}");
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-check for whitelisted IPs to bypass heavy processing
|
||||||
|
let mut whitelisted = false;
|
||||||
|
for network_str in &config.whitelist_networks {
|
||||||
|
if let Ok(network) = network_str.parse::<ipnetwork::IpNetwork>() {
|
||||||
|
if network.contains(peer_addr) {
|
||||||
|
whitelisted = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read buffer
|
||||||
|
let mut buffer = vec![0; 8192];
|
||||||
|
let mut request_data = Vec::with_capacity(8192);
|
||||||
|
let mut header_end_pos = 0;
|
||||||
|
|
||||||
|
// Read with timeout to prevent hanging resource load ops.
|
||||||
|
let read_fut = async {
|
||||||
|
loop {
|
||||||
|
match stream.read(&mut buffer).await {
|
||||||
|
Ok(0) => break,
|
||||||
|
Ok(n) => {
|
||||||
|
let new_data = &buffer[..n];
|
||||||
|
request_data.extend_from_slice(new_data);
|
||||||
|
|
||||||
|
// Look for end of headers
|
||||||
|
if header_end_pos == 0 {
|
||||||
|
if let Some(pos) = find_header_end(&request_data) {
|
||||||
|
header_end_pos = pos;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid excessive buffering
|
||||||
|
if request_data.len() > 32768 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::debug!("Error reading from stream: {e}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let timeout_fut = sleep(Duration::from_secs(3));
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
() = read_fut => {},
|
||||||
|
() = timeout_fut => {
|
||||||
|
log::debug!("Connection timeout from: {peer_addr}");
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fast path for whitelisted IPs. Skip full parsing and speed up "approved"
|
||||||
|
// connections automatically.
|
||||||
|
if whitelisted {
|
||||||
|
log::debug!("Whitelisted IP {peer_addr} - using fast proxy path");
|
||||||
|
proxy_fast_path(stream, request_data, &config.backend_addr).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse minimally to extract the path
|
||||||
|
let path = if let Some(p) = extract_path_from_request(&request_data) {
|
||||||
|
p
|
||||||
|
} else {
|
||||||
|
log::debug!("Invalid request from {peer_addr}");
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract request headers for Lua scripts
|
||||||
|
let headers = extract_all_headers(&request_data);
|
||||||
|
|
||||||
|
// Extract user agent for logging and decision making
|
||||||
|
let user_agent =
|
||||||
|
extract_header_value(&request_data, "user-agent").unwrap_or_else(|| "unknown".to_string());
|
||||||
|
|
||||||
|
// Trigger request event for Lua scripts
|
||||||
|
let request_ctx = EventContext {
|
||||||
|
event_type: EventType::Request,
|
||||||
|
ip: Some(peer_addr.to_string()),
|
||||||
|
path: Some(path.to_string()),
|
||||||
|
user_agent: Some(user_agent.clone()),
|
||||||
|
request_headers: Some(headers.clone()),
|
||||||
|
content: None,
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: Some(generate_session_id(&peer_addr.to_string(), &user_agent)),
|
||||||
|
};
|
||||||
|
script_manager.trigger_event(&request_ctx);
|
||||||
|
|
||||||
|
// Check if this request matches our tarpit patterns
|
||||||
|
let should_tarpit = crate::config::should_tarpit(path, &peer_addr, &config);
|
||||||
|
|
||||||
|
if should_tarpit {
|
||||||
|
log::info!("Tarpit triggered: {path} from {peer_addr} (UA: {user_agent})");
|
||||||
|
|
||||||
|
// Update metrics
|
||||||
|
HITS_COUNTER.inc();
|
||||||
|
PATH_HITS.with_label_values(&[path]).inc();
|
||||||
|
UA_HITS.with_label_values(&[&user_agent]).inc();
|
||||||
|
|
||||||
|
// Update state and check for blocking threshold
|
||||||
|
{
|
||||||
|
let mut state = state.write().await;
|
||||||
|
state.active_connections.insert(peer_addr);
|
||||||
|
ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64);
|
||||||
|
|
||||||
|
*state.hits.entry(peer_addr).or_insert(0) += 1;
|
||||||
|
let hit_count = state.hits[&peer_addr];
|
||||||
|
|
||||||
|
// Use Lua to decide whether to block this IP
|
||||||
|
let should_block = script_manager.should_block_ip(&peer_addr.to_string(), hit_count);
|
||||||
|
|
||||||
|
// Block IPs that hit tarpits too many times
|
||||||
|
if should_block && !state.blocked.contains(&peer_addr) {
|
||||||
|
log::info!("Blocking IP {peer_addr} after {hit_count} hits");
|
||||||
|
state.blocked.insert(peer_addr);
|
||||||
|
BLOCKED_IPS.set(state.blocked.len() as f64);
|
||||||
|
state.save_to_disk();
|
||||||
|
|
||||||
|
// Do firewall blocking in background
|
||||||
|
let peer_addr_str = peer_addr.to_string();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
log::debug!("Adding IP {peer_addr_str} to firewall blacklist");
|
||||||
|
match Command::new("nft")
|
||||||
|
.args([
|
||||||
|
"add",
|
||||||
|
"element",
|
||||||
|
"inet",
|
||||||
|
"filter",
|
||||||
|
"eris_blacklist",
|
||||||
|
"{",
|
||||||
|
&peer_addr_str,
|
||||||
|
"}",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(output) => {
|
||||||
|
if !output.status.success() {
|
||||||
|
log::warn!(
|
||||||
|
"Failed to add IP {} to firewall: {}",
|
||||||
|
peer_addr_str,
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Failed to execute nft command: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a deceptive response using Markov chains and Lua
|
||||||
|
let response = generate_deceptive_response(
|
||||||
|
path,
|
||||||
|
&user_agent,
|
||||||
|
&peer_addr,
|
||||||
|
&headers,
|
||||||
|
&markov_generator,
|
||||||
|
&script_manager,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Generate a session ID for tracking this tarpit session
|
||||||
|
let session_id = generate_session_id(&peer_addr.to_string(), &user_agent);
|
||||||
|
|
||||||
|
// Send the response with the tarpit delay strategy
|
||||||
|
{
|
||||||
|
let mut stream = stream;
|
||||||
|
let peer_addr = peer_addr;
|
||||||
|
let state = state.clone();
|
||||||
|
let min_delay = config.min_delay;
|
||||||
|
let max_delay = config.max_delay;
|
||||||
|
let max_tarpit_time = config.max_tarpit_time;
|
||||||
|
let script_manager = script_manager.clone();
|
||||||
|
async move {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let mut chars = response.chars().collect::<Vec<_>>();
|
||||||
|
for i in (0..chars.len()).rev() {
|
||||||
|
if i > 0 && rand::random::<f32>() < 0.1 {
|
||||||
|
chars.swap(i, i - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log::debug!(
|
||||||
|
"Starting tarpit for {} with {} chars, min_delay={}ms, max_delay={}ms",
|
||||||
|
peer_addr,
|
||||||
|
chars.len(),
|
||||||
|
min_delay,
|
||||||
|
max_delay
|
||||||
|
);
|
||||||
|
let mut position = 0;
|
||||||
|
let mut chunks_sent = 0;
|
||||||
|
let mut total_delay = 0;
|
||||||
|
while position < chars.len() {
|
||||||
|
// Check if we've exceeded maximum tarpit time
|
||||||
|
let elapsed_secs = start_time.elapsed().as_secs();
|
||||||
|
if elapsed_secs > max_tarpit_time {
|
||||||
|
log::info!(
|
||||||
|
"Tarpit maximum time ({max_tarpit_time} sec) reached for {peer_addr}"
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decide how many chars to send in this chunk (usually 1, sometimes more)
|
||||||
|
let chunk_size = if rand::random::<f32>() < 0.9 {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
(rand::random::<f32>() * 3.0).floor() as usize + 1
|
||||||
|
};
|
||||||
|
|
||||||
|
let end = (position + chunk_size).min(chars.len());
|
||||||
|
let chunk: String = chars[position..end].iter().collect();
|
||||||
|
|
||||||
|
// Process chunk through Lua before sending
|
||||||
|
let processed_chunk =
|
||||||
|
script_manager.process_chunk(&chunk, &peer_addr.to_string(), &session_id);
|
||||||
|
|
||||||
|
// Try to write processed chunk
|
||||||
|
if stream.write_all(processed_chunk.as_bytes()).await.is_err() {
|
||||||
|
log::debug!("Connection closed by client during tarpit: {peer_addr}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream.flush().await.is_err() {
|
||||||
|
log::debug!("Failed to flush stream during tarpit: {peer_addr}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
position = end;
|
||||||
|
chunks_sent += 1;
|
||||||
|
|
||||||
|
// Apply random delay between min and max configured values
|
||||||
|
let delay_ms =
|
||||||
|
(rand::random::<f32>() * (max_delay - min_delay) as f32) as u64 + min_delay;
|
||||||
|
total_delay += delay_ms;
|
||||||
|
sleep(Duration::from_millis(delay_ms)).await;
|
||||||
|
}
|
||||||
|
log::debug!(
|
||||||
|
"Tarpit stats for {}: sent {} chunks, {}% of data, total delay {}ms over {}s",
|
||||||
|
peer_addr,
|
||||||
|
chunks_sent,
|
||||||
|
position * 100 / chars.len(),
|
||||||
|
total_delay,
|
||||||
|
start_time.elapsed().as_secs()
|
||||||
|
);
|
||||||
|
let disconnection_ctx = EventContext {
|
||||||
|
event_type: EventType::Disconnection,
|
||||||
|
ip: Some(peer_addr.to_string()),
|
||||||
|
path: None,
|
||||||
|
user_agent: None,
|
||||||
|
request_headers: None,
|
||||||
|
content: None,
|
||||||
|
timestamp: get_timestamp(),
|
||||||
|
session_id: Some(session_id),
|
||||||
|
};
|
||||||
|
script_manager.trigger_event(&disconnection_ctx);
|
||||||
|
if let Ok(mut state) = state.try_write() {
|
||||||
|
state.active_connections.remove(&peer_addr);
|
||||||
|
ACTIVE_CONNECTIONS.set(state.active_connections.len() as f64);
|
||||||
|
}
|
||||||
|
let _ = stream.shutdown().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
} else {
|
||||||
|
log::debug!("Proxying request: {path} from {peer_addr}");
|
||||||
|
proxy_fast_path(stream, request_data, &config.backend_addr).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward a legitimate request to the real backend server
|
||||||
|
pub async fn proxy_fast_path(
|
||||||
|
mut client_stream: TcpStream,
|
||||||
|
request_data: Vec<u8>,
|
||||||
|
backend_addr: &str,
|
||||||
|
) {
|
||||||
|
// Connect to backend server
|
||||||
|
let server_stream = match TcpStream::connect(backend_addr).await {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Failed to connect to backend {backend_addr}: {e}");
|
||||||
|
let _ = client_stream.shutdown().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set TCP_NODELAY for both streams before splitting them
|
||||||
|
if let Err(e) = client_stream.set_nodelay(true) {
|
||||||
|
log::debug!("Failed to set TCP_NODELAY on client stream: {e}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut server_stream = server_stream;
|
||||||
|
if let Err(e) = server_stream.set_nodelay(true) {
|
||||||
|
log::debug!("Failed to set TCP_NODELAY on server stream: {e}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward the original request bytes directly without parsing
|
||||||
|
if server_stream.write_all(&request_data).await.is_err() {
|
||||||
|
log::debug!("Failed to write request to backend server");
|
||||||
|
let _ = client_stream.shutdown().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now split the streams for concurrent reading/writing
|
||||||
|
let (mut client_read, mut client_write) = client_stream.split();
|
||||||
|
let (mut server_read, mut server_write) = server_stream.split();
|
||||||
|
|
||||||
|
// 32KB buffer
|
||||||
|
let buf_size = 32768;
|
||||||
|
|
||||||
|
// Client -> Server
|
||||||
|
let client_to_server = async {
|
||||||
|
let mut buf = vec![0; buf_size];
|
||||||
|
let mut bytes_forwarded = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match client_read.read(&mut buf).await {
|
||||||
|
Ok(0) => break,
|
||||||
|
Ok(n) => {
|
||||||
|
bytes_forwarded += n;
|
||||||
|
if server_write.write_all(&buf[..n]).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure everything is sent
|
||||||
|
let _ = server_write.flush().await;
|
||||||
|
log::debug!("Client -> Server: forwarded {bytes_forwarded} bytes");
|
||||||
|
};
|
||||||
|
|
||||||
|
// Server -> Client
|
||||||
|
let server_to_client = async {
|
||||||
|
let mut buf = vec![0; buf_size];
|
||||||
|
let mut bytes_forwarded = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match server_read.read(&mut buf).await {
|
||||||
|
Ok(0) => break,
|
||||||
|
Ok(n) => {
|
||||||
|
bytes_forwarded += n;
|
||||||
|
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure everything is sent
|
||||||
|
let _ = client_write.flush().await;
|
||||||
|
log::debug!("Server -> Client: forwarded {bytes_forwarded} bytes");
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run both directions concurrently
|
||||||
|
tokio::join!(client_to_server, server_to_client);
|
||||||
|
log::debug!("Fast proxy connection completed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a deceptive HTTP response that appears legitimate
|
||||||
|
pub async fn generate_deceptive_response(
|
||||||
|
path: &str,
|
||||||
|
user_agent: &str,
|
||||||
|
peer_addr: &IpAddr,
|
||||||
|
headers: &HashMap<String, String>,
|
||||||
|
markov: &MarkovGenerator,
|
||||||
|
script_manager: &ScriptManager,
|
||||||
|
) -> String {
|
||||||
|
// Generate base response using Markov chain text generator
|
||||||
|
let response_type = choose_response_type(path);
|
||||||
|
let markov_text = markov.generate(response_type, 30);
|
||||||
|
|
||||||
|
// Use Lua scripts to enhance with honeytokens and other deceptive content
|
||||||
|
script_manager.generate_response(
|
||||||
|
path,
|
||||||
|
user_agent,
|
||||||
|
&peer_addr.to_string(),
|
||||||
|
headers,
|
||||||
|
&markov_text,
|
||||||
|
)
|
||||||
|
}
|
85
src/network/rate_limiter.rs
Normal file
85
src/network/rate_limiter.rs
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
pub struct RateLimiter {
|
||||||
|
connections: Arc<Mutex<HashMap<IpAddr, Vec<Instant>>>>,
|
||||||
|
window_seconds: u64,
|
||||||
|
max_connections: usize,
|
||||||
|
cleanup_interval: Duration,
|
||||||
|
last_cleanup: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new(window_seconds: u64, max_connections: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
connections: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
window_seconds,
|
||||||
|
max_connections,
|
||||||
|
cleanup_interval: Duration::from_secs(60),
|
||||||
|
last_cleanup: Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_rate_limit(&self, ip: IpAddr) -> bool {
|
||||||
|
let now = Instant::now();
|
||||||
|
let window = Duration::from_secs(self.window_seconds);
|
||||||
|
|
||||||
|
let mut connections = self.connections.lock().await;
|
||||||
|
|
||||||
|
// Periodically clean up old entries across all IPs
|
||||||
|
if now.duration_since(self.last_cleanup) > self.cleanup_interval {
|
||||||
|
self.cleanup_old_entries(&mut connections, now, window);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up old entries for this specific IP
|
||||||
|
if let Some(times) = connections.get_mut(&ip) {
|
||||||
|
times.retain(|time| now.duration_since(*time) < window);
|
||||||
|
|
||||||
|
// Check if rate limit exceeded
|
||||||
|
if times.len() >= self.max_connections {
|
||||||
|
log::debug!("Rate limit exceeded for IP: {}", ip);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new connection time
|
||||||
|
times.push(now);
|
||||||
|
} else {
|
||||||
|
connections.insert(ip, vec![now]);
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cleanup_old_entries(
|
||||||
|
&self,
|
||||||
|
connections: &mut HashMap<IpAddr, Vec<Instant>>,
|
||||||
|
now: Instant,
|
||||||
|
window: Duration,
|
||||||
|
) {
|
||||||
|
let mut empty_keys = Vec::new();
|
||||||
|
|
||||||
|
for (ip, times) in connections.iter_mut() {
|
||||||
|
times.retain(|time| now.duration_since(*time) < window);
|
||||||
|
if times.is_empty() {
|
||||||
|
empty_keys.push(*ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove empty entries
|
||||||
|
for ip in empty_keys {
|
||||||
|
connections.remove(&ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_connection_count(&self, ip: &IpAddr) -> usize {
|
||||||
|
if let Ok(connections) = self.connections.try_lock() {
|
||||||
|
if let Some(times) = connections.get(ip) {
|
||||||
|
return times.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
166
src/state.rs
Normal file
166
src/state.rs
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::fs;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
|
use crate::metrics::BLOCKED_IPS;
|
||||||
|
|
||||||
|
// State of bots/IPs hitting the honeypot
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct BotState {
|
||||||
|
pub hits: HashMap<IpAddr, u32>,
|
||||||
|
pub blocked: HashSet<IpAddr>,
|
||||||
|
pub active_connections: HashSet<IpAddr>,
|
||||||
|
pub data_dir: String,
|
||||||
|
pub cache_dir: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BotState {
|
||||||
|
pub fn new(data_dir: &str, cache_dir: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
hits: HashMap::new(),
|
||||||
|
blocked: HashSet::new(),
|
||||||
|
active_connections: HashSet::new(),
|
||||||
|
data_dir: data_dir.to_string(),
|
||||||
|
cache_dir: cache_dir.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load previous state from disk
|
||||||
|
pub fn load_from_disk(data_dir: &str, cache_dir: &str) -> Self {
|
||||||
|
let mut state = Self::new(data_dir, cache_dir);
|
||||||
|
let blocked_ips_file = format!("{data_dir}/blocked_ips.txt");
|
||||||
|
|
||||||
|
if let Ok(content) = fs::read_to_string(&blocked_ips_file) {
|
||||||
|
let mut loaded = 0;
|
||||||
|
for line in content.lines() {
|
||||||
|
if let Ok(ip) = line.parse::<IpAddr>() {
|
||||||
|
state.blocked.insert(ip);
|
||||||
|
loaded += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log::info!("Loaded {loaded} blocked IPs from {blocked_ips_file}");
|
||||||
|
} else {
|
||||||
|
log::info!("No blocked IPs file found at {blocked_ips_file}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for temporary hit counter cache
|
||||||
|
let hit_cache_file = format!("{cache_dir}/hit_counters.json");
|
||||||
|
if let Ok(content) = fs::read_to_string(&hit_cache_file) {
|
||||||
|
if let Ok(hit_map) = serde_json::from_str::<HashMap<String, u32>>(&content) {
|
||||||
|
for (ip_str, count) in hit_map {
|
||||||
|
if let Ok(ip) = ip_str.parse::<IpAddr>() {
|
||||||
|
state.hits.insert(ip, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log::info!("Loaded hit counters for {} IPs", state.hits.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BLOCKED_IPS.set(state.blocked.len() as f64);
|
||||||
|
state
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist state to disk for later reloading
|
||||||
|
pub fn save_to_disk(&self) {
|
||||||
|
// Save blocked IPs
|
||||||
|
if let Err(e) = fs::create_dir_all(&self.data_dir) {
|
||||||
|
log::error!("Failed to create data directory: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let blocked_ips_file = format!("{}/blocked_ips.txt", self.data_dir);
|
||||||
|
|
||||||
|
match fs::File::create(&blocked_ips_file) {
|
||||||
|
Ok(mut file) => {
|
||||||
|
let mut count = 0;
|
||||||
|
for ip in &self.blocked {
|
||||||
|
if writeln!(file, "{ip}").is_ok() {
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log::info!("Saved {count} blocked IPs to {blocked_ips_file}");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to create blocked IPs file: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save hit counters to cache
|
||||||
|
if let Err(e) = fs::create_dir_all(&self.cache_dir) {
|
||||||
|
log::error!("Failed to create cache directory: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let hit_cache_file = format!("{}/hit_counters.json", self.cache_dir);
|
||||||
|
let hit_map: HashMap<String, u32> = self
|
||||||
|
.hits
|
||||||
|
.iter()
|
||||||
|
.map(|(ip, count)| (ip.to_string(), *count))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
match fs::File::create(&hit_cache_file) {
|
||||||
|
Ok(file) => {
|
||||||
|
if let Err(e) = serde_json::to_writer(file, &hit_map) {
|
||||||
|
log::error!("Failed to write hit counters to cache: {e}");
|
||||||
|
} else {
|
||||||
|
log::debug!("Saved hit counters for {} IPs to cache", hit_map.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to create hit counter cache file: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_state() {
|
||||||
|
let state = BotState::new("/tmp/eris_test", "/tmp/eris_test_cache");
|
||||||
|
let ip1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
|
||||||
|
let ip2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8));
|
||||||
|
|
||||||
|
let state = Arc::new(RwLock::new(state));
|
||||||
|
|
||||||
|
// Test hit counter
|
||||||
|
{
|
||||||
|
let mut state = state.write().await;
|
||||||
|
*state.hits.entry(ip1).or_insert(0) += 1;
|
||||||
|
*state.hits.entry(ip1).or_insert(0) += 1;
|
||||||
|
*state.hits.entry(ip2).or_insert(0) += 1;
|
||||||
|
|
||||||
|
assert_eq!(*state.hits.get(&ip1).unwrap(), 2);
|
||||||
|
assert_eq!(*state.hits.get(&ip2).unwrap(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test blocking
|
||||||
|
{
|
||||||
|
let mut state = state.write().await;
|
||||||
|
state.blocked.insert(ip1);
|
||||||
|
assert!(state.blocked.contains(&ip1));
|
||||||
|
assert!(!state.blocked.contains(&ip2));
|
||||||
|
drop(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test active connections
|
||||||
|
{
|
||||||
|
let mut state = state.write().await;
|
||||||
|
state.active_connections.insert(ip1);
|
||||||
|
state.active_connections.insert(ip2);
|
||||||
|
assert_eq!(state.active_connections.len(), 2);
|
||||||
|
|
||||||
|
state.active_connections.remove(&ip1);
|
||||||
|
assert_eq!(state.active_connections.len(), 1);
|
||||||
|
assert!(!state.active_connections.contains(&ip1));
|
||||||
|
assert!(state.active_connections.contains(&ip2));
|
||||||
|
drop(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
168
src/utils.rs
Normal file
168
src/utils.rs
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::hash::Hasher;
|
||||||
|
|
||||||
|
// Find end of HTTP headers
|
||||||
|
pub fn find_header_end(data: &[u8]) -> Option<usize> {
|
||||||
|
data.windows(4)
|
||||||
|
.position(|window| window == b"\r\n\r\n")
|
||||||
|
.map(|pos| pos + 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract path from raw request data
|
||||||
|
pub fn extract_path_from_request(data: &[u8]) -> Option<&str> {
|
||||||
|
// Get first line from request
|
||||||
|
let first_line = data
|
||||||
|
.split(|&b| b == b'\r' || b == b'\n')
|
||||||
|
.next()
|
||||||
|
.filter(|line| !line.is_empty())?;
|
||||||
|
|
||||||
|
// Split by spaces and ensure we have at least 3 parts (METHOD PATH VERSION)
|
||||||
|
let parts: Vec<&[u8]> = first_line.split(|&b| b == b' ').collect();
|
||||||
|
if parts.len() < 3 || !parts[2].starts_with(b"HTTP/") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the path (second element)
|
||||||
|
std::str::from_utf8(parts[1]).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract header value from raw request data
|
||||||
|
pub fn extract_header_value(data: &[u8], header_name: &str) -> Option<String> {
|
||||||
|
let data_str = std::str::from_utf8(data).ok()?;
|
||||||
|
let header_prefix = format!("{header_name}: ").to_lowercase();
|
||||||
|
|
||||||
|
for line in data_str.lines() {
|
||||||
|
let line_lower = line.to_lowercase();
|
||||||
|
if line_lower.starts_with(&header_prefix) {
|
||||||
|
return Some(line[header_prefix.len()..].trim().to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract all headers from request data
|
||||||
|
pub fn extract_all_headers(data: &[u8]) -> HashMap<String, String> {
|
||||||
|
let mut headers = HashMap::new();
|
||||||
|
|
||||||
|
if let Ok(data_str) = std::str::from_utf8(data) {
|
||||||
|
let mut lines = data_str.lines();
|
||||||
|
|
||||||
|
// Skip the request line
|
||||||
|
let _ = lines.next();
|
||||||
|
|
||||||
|
// Parse headers until empty line
|
||||||
|
for line in lines {
|
||||||
|
if line.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(colon_pos) = line.find(':') {
|
||||||
|
let key = line[..colon_pos].trim().to_lowercase();
|
||||||
|
let value = line[colon_pos + 1..].trim().to_string();
|
||||||
|
headers.insert(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine response type based on request path
|
||||||
|
pub fn choose_response_type(path: &str) -> &'static str {
|
||||||
|
if path.contains("phpunit") || path.contains("eval") {
|
||||||
|
"php_exploit"
|
||||||
|
} else if path.contains("wp-") {
|
||||||
|
"wordpress"
|
||||||
|
} else if path.contains("api") {
|
||||||
|
"api"
|
||||||
|
} else {
|
||||||
|
"generic"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current timestamp in seconds
|
||||||
|
pub fn get_timestamp() -> u64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a unique session ID for tracking a connection
|
||||||
|
pub fn generate_session_id(ip: &str, user_agent: &str) -> String {
|
||||||
|
let timestamp = get_timestamp();
|
||||||
|
let random = rand::random::<u32>();
|
||||||
|
|
||||||
|
// XXX: Is this fast enough for our case? I don't think hashing is a huge
|
||||||
|
// bottleneck, but it's worth revisiting in the future to see if there is
|
||||||
|
// an objectively faster algorithm that we can try.
|
||||||
|
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||||
|
std::hash::Hash::hash(&format!("{ip}_{user_agent}_{timestamp}"), &mut hasher);
|
||||||
|
let hash = hasher.finish();
|
||||||
|
|
||||||
|
format!("SID_{hash:x}_{random:x}")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_header_end() {
|
||||||
|
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: test\r\n\r\nBody content";
|
||||||
|
assert_eq!(find_header_end(data), Some(55));
|
||||||
|
|
||||||
|
let incomplete = b"GET / HTTP/1.1\r\nHost: example.com\r\n";
|
||||||
|
assert_eq!(find_header_end(incomplete), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_path_from_request() {
|
||||||
|
let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n";
|
||||||
|
assert_eq!(extract_path_from_request(data), Some("/index.html"));
|
||||||
|
|
||||||
|
let bad_data = b"INVALID DATA";
|
||||||
|
assert_eq!(extract_path_from_request(bad_data), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_header_value() {
|
||||||
|
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\n\r\n";
|
||||||
|
assert_eq!(
|
||||||
|
extract_header_value(data, "user-agent"),
|
||||||
|
Some("TestBot/1.0".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
extract_header_value(data, "Host"),
|
||||||
|
Some("example.com".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(extract_header_value(data, "nonexistent"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_all_headers() {
|
||||||
|
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: TestBot/1.0\r\nAccept: */*\r\n\r\n";
|
||||||
|
let headers = extract_all_headers(data);
|
||||||
|
|
||||||
|
assert_eq!(headers.len(), 3);
|
||||||
|
assert_eq!(headers.get("host").unwrap(), "example.com");
|
||||||
|
assert_eq!(headers.get("user-agent").unwrap(), "TestBot/1.0");
|
||||||
|
assert_eq!(headers.get("accept").unwrap(), "*/*");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_choose_response_type() {
|
||||||
|
assert_eq!(
|
||||||
|
choose_response_type("/vendor/phpunit/whatever"),
|
||||||
|
"php_exploit"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
choose_response_type("/path/to/eval-stdin.php"),
|
||||||
|
"php_exploit"
|
||||||
|
);
|
||||||
|
assert_eq!(choose_response_type("/wp-admin/login.php"), "wordpress");
|
||||||
|
assert_eq!(choose_response_type("/wp-login.php"), "wordpress");
|
||||||
|
assert_eq!(choose_response_type("/api/v1/users"), "api");
|
||||||
|
assert_eq!(choose_response_type("/index.html"), "generic");
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue