361 lines
11 KiB
Rust
361 lines
11 KiB
Rust
use crate::ast::*;
|
|
use std::fmt::Write;
|
|
|
|
/// Configuration for formatting output
|
|
#[derive(Debug, Clone)]
|
|
pub struct FormatConfig {
|
|
pub indent_style: IndentStyle,
|
|
pub spaces_per_level: usize,
|
|
pub optimize: bool,
|
|
pub max_empty_lines: usize,
|
|
}
|
|
|
|
impl Default for FormatConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
indent_style: IndentStyle::Tabs,
|
|
spaces_per_level: 2,
|
|
optimize: false,
|
|
max_empty_lines: 1,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum IndentStyle {
|
|
Tabs,
|
|
Spaces,
|
|
}
|
|
|
|
impl IndentStyle {
|
|
pub fn format(&self, level: usize, spaces_per_level: usize) -> String {
|
|
match self {
|
|
IndentStyle::Tabs => "\t".repeat(level),
|
|
IndentStyle::Spaces => " ".repeat(spaces_per_level * level),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Formatter for nftables AST
|
|
pub struct NftablesFormatter {
|
|
config: FormatConfig,
|
|
}
|
|
|
|
impl NftablesFormatter {
|
|
pub fn new(config: FormatConfig) -> Self {
|
|
Self { config }
|
|
}
|
|
|
|
/// Add appropriate number of empty lines based on configuration
|
|
fn add_separator(&self, output: &mut String) {
|
|
if self.config.optimize {
|
|
output.push('\n');
|
|
} else {
|
|
// Add newlines based on max_empty_lines setting
|
|
for _ in 0..=self.config.max_empty_lines {
|
|
output.push('\n');
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn format_ruleset(&self, ruleset: &Ruleset) -> String {
|
|
let mut output = String::new();
|
|
|
|
// Format shebang
|
|
if let Some(shebang) = &ruleset.shebang {
|
|
writeln!(output, "#!{}", shebang).unwrap();
|
|
}
|
|
|
|
// Format includes
|
|
for include in &ruleset.includes {
|
|
self.format_include(&mut output, include, 0);
|
|
}
|
|
|
|
// Add separator if we have includes
|
|
if !ruleset.includes.is_empty() {
|
|
self.add_separator(&mut output);
|
|
}
|
|
|
|
// Format defines
|
|
for define in &ruleset.defines {
|
|
self.format_define(&mut output, define, 0);
|
|
}
|
|
|
|
// Add separator if we have defines
|
|
if !ruleset.defines.is_empty() {
|
|
self.add_separator(&mut output);
|
|
}
|
|
|
|
// Format tables
|
|
let mut table_iter = ruleset.tables.values().peekable();
|
|
while let Some(table) = table_iter.next() {
|
|
self.format_table(&mut output, table, 0); // Add separator between tables
|
|
if table_iter.peek().is_some() {
|
|
self.add_separator(&mut output);
|
|
}
|
|
}
|
|
|
|
// Ensure file ends with newline
|
|
if !output.ends_with('\n') {
|
|
output.push('\n');
|
|
}
|
|
|
|
output
|
|
}
|
|
|
|
fn format_include(&self, output: &mut String, include: &Include, level: usize) {
|
|
let indent = self
|
|
.config
|
|
.indent_style
|
|
.format(level, self.config.spaces_per_level);
|
|
writeln!(output, "{}include \"{}\"", indent, include.path).unwrap();
|
|
}
|
|
|
|
fn format_define(&self, output: &mut String, define: &Define, level: usize) {
|
|
let indent = self
|
|
.config
|
|
.indent_style
|
|
.format(level, self.config.spaces_per_level);
|
|
write!(output, "{}define {} = ", indent, define.name).unwrap();
|
|
self.format_expression(output, &define.value);
|
|
output.push('\n');
|
|
}
|
|
|
|
fn format_table(&self, output: &mut String, table: &Table, level: usize) {
|
|
let indent = self
|
|
.config
|
|
.indent_style
|
|
.format(level, self.config.spaces_per_level);
|
|
|
|
writeln!(output, "{}table {} {} {{", indent, table.family, table.name).unwrap();
|
|
|
|
// Format chains
|
|
let mut chain_iter = table.chains.values().peekable();
|
|
while let Some(chain) = chain_iter.next() {
|
|
self.format_chain(output, chain, level + 1); // Add separator between chains
|
|
if chain_iter.peek().is_some() {
|
|
self.add_separator(output);
|
|
}
|
|
}
|
|
|
|
writeln!(output, "{}}}", indent).unwrap();
|
|
}
|
|
|
|
fn format_chain(&self, output: &mut String, chain: &Chain, level: usize) {
|
|
let indent = self
|
|
.config
|
|
.indent_style
|
|
.format(level, self.config.spaces_per_level);
|
|
|
|
writeln!(output, "{}chain {} {{", indent, chain.name).unwrap();
|
|
|
|
// Format chain properties
|
|
if let Some(chain_type) = &chain.chain_type {
|
|
write!(
|
|
output,
|
|
"{}type {}",
|
|
self.config
|
|
.indent_style
|
|
.format(level + 1, self.config.spaces_per_level),
|
|
chain_type
|
|
)
|
|
.unwrap();
|
|
|
|
if let Some(hook) = &chain.hook {
|
|
write!(output, " hook {}", hook).unwrap();
|
|
}
|
|
|
|
if let Some(priority) = chain.priority {
|
|
write!(output, " priority {}", priority).unwrap();
|
|
}
|
|
|
|
// Add semicolon after type/hook/priority
|
|
output.push_str(";");
|
|
|
|
// Add policy on the same line if present
|
|
if let Some(policy) = &chain.policy {
|
|
write!(output, " policy {}", policy).unwrap();
|
|
output.push_str(";\n");
|
|
} else {
|
|
output.push_str("\n");
|
|
}
|
|
|
|
if !chain.rules.is_empty() && !self.config.optimize {
|
|
output.push('\n');
|
|
}
|
|
}
|
|
|
|
// Format rules
|
|
for (i, rule) in chain.rules.iter().enumerate() {
|
|
// Add spacing between rules (but not before first rule)
|
|
if i > 0 && !self.config.optimize && self.config.max_empty_lines > 0 {
|
|
output.push('\n');
|
|
}
|
|
self.format_rule(output, rule, level + 1);
|
|
}
|
|
|
|
writeln!(output, "{}}}", indent).unwrap();
|
|
}
|
|
|
|
fn format_rule(&self, output: &mut String, rule: &Rule, level: usize) {
|
|
let indent = self
|
|
.config
|
|
.indent_style
|
|
.format(level, self.config.spaces_per_level);
|
|
|
|
write!(output, "{}", indent).unwrap();
|
|
|
|
// Format expressions
|
|
for (i, expr) in rule.expressions.iter().enumerate() {
|
|
if i > 0 {
|
|
output.push(' ');
|
|
}
|
|
self.format_expression(output, expr);
|
|
}
|
|
|
|
// Add action
|
|
if !rule.expressions.is_empty() {
|
|
output.push(' ');
|
|
}
|
|
write!(output, "{}", rule.action).unwrap();
|
|
|
|
output.push('\n');
|
|
|
|
// Only add extra newline between rules, not after the last rule
|
|
// We'll handle this in the chain formatting instead
|
|
}
|
|
|
|
fn format_expression(&self, output: &mut String, expr: &Expression) {
|
|
match expr {
|
|
Expression::Identifier(name) => write!(output, "{}", name).unwrap(),
|
|
Expression::String(s) => write!(output, "\"{}\"", s).unwrap(),
|
|
Expression::Number(n) => write!(output, "{}", n).unwrap(),
|
|
Expression::IpAddress(addr) => write!(output, "{}", addr).unwrap(),
|
|
Expression::Ipv6Address(addr) => write!(output, "{}", addr).unwrap(),
|
|
Expression::MacAddress(addr) => write!(output, "{}", addr).unwrap(),
|
|
|
|
Expression::Binary {
|
|
left,
|
|
operator,
|
|
right,
|
|
} => {
|
|
self.format_expression(output, left);
|
|
write!(output, " {} ", operator).unwrap();
|
|
self.format_expression(output, right);
|
|
}
|
|
|
|
Expression::Protocol(proto) => write!(output, "protocol {}", proto).unwrap(),
|
|
|
|
Expression::Port { direction, value } => {
|
|
match direction {
|
|
PortDirection::Source => write!(output, "sport ").unwrap(),
|
|
PortDirection::Destination => write!(output, "dport ").unwrap(),
|
|
}
|
|
self.format_expression(output, value);
|
|
}
|
|
|
|
Expression::Address { direction, value } => {
|
|
// Include the protocol family when formatting addresses
|
|
write!(output, "ip ").unwrap();
|
|
match direction {
|
|
AddressDirection::Source => write!(output, "saddr ").unwrap(),
|
|
AddressDirection::Destination => write!(output, "daddr ").unwrap(),
|
|
}
|
|
self.format_expression(output, value);
|
|
}
|
|
|
|
Expression::Interface { direction, name } => {
|
|
match direction {
|
|
InterfaceDirection::Input => write!(output, "iifname ").unwrap(),
|
|
InterfaceDirection::Output => write!(output, "oifname ").unwrap(),
|
|
}
|
|
write!(output, "{}", name).unwrap();
|
|
}
|
|
|
|
Expression::ConnTrack { field, value } => {
|
|
write!(output, "ct {} ", field).unwrap();
|
|
self.format_expression(output, value);
|
|
}
|
|
|
|
Expression::Set(elements) => {
|
|
output.push_str("{ ");
|
|
for (i, element) in elements.iter().enumerate() {
|
|
if i > 0 {
|
|
output.push_str(", ");
|
|
}
|
|
self.format_expression(output, element);
|
|
}
|
|
output.push_str(" }");
|
|
}
|
|
|
|
Expression::Range { start, end } => {
|
|
self.format_expression(output, start);
|
|
output.push('-');
|
|
self.format_expression(output, end);
|
|
}
|
|
|
|
Expression::Vmap { expr, map } => {
|
|
if let Some(expr) = expr {
|
|
self.format_expression(output, expr);
|
|
output.push(' ');
|
|
}
|
|
output.push_str("vmap { ");
|
|
for (i, (key, value)) in map.iter().enumerate() {
|
|
if i > 0 {
|
|
output.push_str(", ");
|
|
}
|
|
self.format_expression(output, key);
|
|
output.push_str(" : ");
|
|
self.format_expression(output, value);
|
|
}
|
|
output.push_str(" }");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Convert from string-based IndentStyle to our enum
|
|
impl std::str::FromStr for IndentStyle {
|
|
type Err = String;
|
|
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
match s.to_lowercase().as_str() {
|
|
"tabs" | "tab" => Ok(IndentStyle::Tabs),
|
|
"spaces" | "space" => Ok(IndentStyle::Spaces),
|
|
_ => Err(format!(
|
|
"Invalid indent style: {}. Use 'tabs' or 'spaces'",
|
|
s
|
|
)),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_format_simple_table() {
|
|
let table = Table::new(Family::Inet, "test".to_string()).add_chain(
|
|
Chain::new("input".to_string())
|
|
.with_type(ChainType::Filter)
|
|
.with_hook(Hook::Input)
|
|
.with_priority(0)
|
|
.with_policy(Policy::Accept)
|
|
.add_rule(Rule::new(
|
|
vec![Expression::Interface {
|
|
direction: InterfaceDirection::Input,
|
|
name: "lo".to_string(),
|
|
}],
|
|
Action::Accept,
|
|
)),
|
|
);
|
|
|
|
let formatter = NftablesFormatter::new(FormatConfig::default());
|
|
let mut output = String::new();
|
|
formatter.format_table(&mut output, &table, 0);
|
|
|
|
// Just verify it doesn't panic and produces some output
|
|
assert!(!output.is_empty());
|
|
}
|
|
}
|