use crate::ast::*; use std::fmt::Write; /// Configuration for formatting output #[derive(Debug, Clone)] pub struct FormatConfig { pub indent_style: IndentStyle, pub spaces_per_level: usize, pub optimize: bool, pub max_empty_lines: usize, } impl Default for FormatConfig { fn default() -> Self { Self { indent_style: IndentStyle::Tabs, spaces_per_level: 2, optimize: false, max_empty_lines: 1, } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum IndentStyle { Tabs, Spaces, } impl IndentStyle { pub fn format(&self, level: usize, spaces_per_level: usize) -> String { match self { IndentStyle::Tabs => "\t".repeat(level), IndentStyle::Spaces => " ".repeat(spaces_per_level * level), } } } /// Formatter for nftables AST pub struct NftablesFormatter { config: FormatConfig, } impl NftablesFormatter { pub fn new(config: FormatConfig) -> Self { Self { config } } /// Add appropriate number of empty lines based on configuration fn add_separator(&self, output: &mut String) { if self.config.optimize { output.push('\n'); } else { // Add newlines based on max_empty_lines setting for _ in 0..=self.config.max_empty_lines { output.push('\n'); } } } pub fn format_ruleset(&self, ruleset: &Ruleset) -> String { let mut output = String::new(); // Format shebang if let Some(shebang) = &ruleset.shebang { writeln!(output, "#!{}", shebang).unwrap(); } // Format includes for include in &ruleset.includes { self.format_include(&mut output, include, 0); } // Add separator if we have includes if !ruleset.includes.is_empty() { self.add_separator(&mut output); } // Format defines for define in &ruleset.defines { self.format_define(&mut output, define, 0); } // Add separator if we have defines if !ruleset.defines.is_empty() { self.add_separator(&mut output); } // Format tables let mut table_iter = ruleset.tables.values().peekable(); while let Some(table) = table_iter.next() { self.format_table(&mut output, table, 0); // Add separator between tables if table_iter.peek().is_some() { self.add_separator(&mut output); } } // Ensure file ends with newline if !output.ends_with('\n') { output.push('\n'); } output } fn format_include(&self, output: &mut String, include: &Include, level: usize) { let indent = self .config .indent_style .format(level, self.config.spaces_per_level); writeln!(output, "{}include \"{}\"", indent, include.path).unwrap(); } fn format_define(&self, output: &mut String, define: &Define, level: usize) { let indent = self .config .indent_style .format(level, self.config.spaces_per_level); write!(output, "{}define {} = ", indent, define.name).unwrap(); self.format_expression(output, &define.value); output.push('\n'); } fn format_table(&self, output: &mut String, table: &Table, level: usize) { let indent = self .config .indent_style .format(level, self.config.spaces_per_level); writeln!(output, "{}table {} {} {{", indent, table.family, table.name).unwrap(); // Format chains let mut chain_iter = table.chains.values().peekable(); while let Some(chain) = chain_iter.next() { self.format_chain(output, chain, level + 1); // Add separator between chains if chain_iter.peek().is_some() { self.add_separator(output); } } writeln!(output, "{}}}", indent).unwrap(); } fn format_chain(&self, output: &mut String, chain: &Chain, level: usize) { let indent = self .config .indent_style .format(level, self.config.spaces_per_level); writeln!(output, "{}chain {} {{", indent, chain.name).unwrap(); // Format chain properties if let Some(chain_type) = &chain.chain_type { write!( output, "{}type {}", self.config .indent_style .format(level + 1, self.config.spaces_per_level), chain_type ) .unwrap(); if let Some(hook) = &chain.hook { write!(output, " hook {}", hook).unwrap(); } if let Some(priority) = chain.priority { write!(output, " priority {}", priority).unwrap(); } // Add semicolon after type/hook/priority output.push_str(";"); // Add policy on the same line if present if let Some(policy) = &chain.policy { write!(output, " policy {}", policy).unwrap(); output.push_str(";\n"); } else { output.push_str("\n"); } if !chain.rules.is_empty() && !self.config.optimize { output.push('\n'); } } // Format rules for (i, rule) in chain.rules.iter().enumerate() { // Add spacing between rules (but not before first rule) if i > 0 && !self.config.optimize && self.config.max_empty_lines > 0 { output.push('\n'); } self.format_rule(output, rule, level + 1); } writeln!(output, "{}}}", indent).unwrap(); } fn format_rule(&self, output: &mut String, rule: &Rule, level: usize) { let indent = self .config .indent_style .format(level, self.config.spaces_per_level); write!(output, "{}", indent).unwrap(); // Format expressions for (i, expr) in rule.expressions.iter().enumerate() { if i > 0 { output.push(' '); } self.format_expression(output, expr); } // Add action if !rule.expressions.is_empty() { output.push(' '); } write!(output, "{}", rule.action).unwrap(); output.push('\n'); // Only add extra newline between rules, not after the last rule // We'll handle this in the chain formatting instead } fn format_expression(&self, output: &mut String, expr: &Expression) { match expr { Expression::Identifier(name) => write!(output, "{}", name).unwrap(), Expression::String(s) => write!(output, "\"{}\"", s).unwrap(), Expression::Number(n) => write!(output, "{}", n).unwrap(), Expression::IpAddress(addr) => write!(output, "{}", addr).unwrap(), Expression::Ipv6Address(addr) => write!(output, "{}", addr).unwrap(), Expression::MacAddress(addr) => write!(output, "{}", addr).unwrap(), Expression::Binary { left, operator, right, } => { self.format_expression(output, left); write!(output, " {} ", operator).unwrap(); self.format_expression(output, right); } Expression::Protocol(proto) => write!(output, "protocol {}", proto).unwrap(), Expression::Port { direction, value } => { match direction { PortDirection::Source => write!(output, "sport ").unwrap(), PortDirection::Destination => write!(output, "dport ").unwrap(), } self.format_expression(output, value); } Expression::Address { direction, value } => { // Include the protocol family when formatting addresses write!(output, "ip ").unwrap(); match direction { AddressDirection::Source => write!(output, "saddr ").unwrap(), AddressDirection::Destination => write!(output, "daddr ").unwrap(), } self.format_expression(output, value); } Expression::Interface { direction, name } => { match direction { InterfaceDirection::Input => write!(output, "iifname ").unwrap(), InterfaceDirection::Output => write!(output, "oifname ").unwrap(), } write!(output, "{}", name).unwrap(); } Expression::ConnTrack { field, value } => { write!(output, "ct {} ", field).unwrap(); self.format_expression(output, value); } Expression::Set(elements) => { output.push_str("{ "); for (i, element) in elements.iter().enumerate() { if i > 0 { output.push_str(", "); } self.format_expression(output, element); } output.push_str(" }"); } Expression::Range { start, end } => { self.format_expression(output, start); output.push('-'); self.format_expression(output, end); } Expression::Vmap { expr, map } => { if let Some(expr) = expr { self.format_expression(output, expr); output.push(' '); } output.push_str("vmap { "); for (i, (key, value)) in map.iter().enumerate() { if i > 0 { output.push_str(", "); } self.format_expression(output, key); output.push_str(" : "); self.format_expression(output, value); } output.push_str(" }"); } } } } /// Convert from string-based IndentStyle to our enum impl std::str::FromStr for IndentStyle { type Err = String; fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "tabs" | "tab" => Ok(IndentStyle::Tabs), "spaces" | "space" => Ok(IndentStyle::Spaces), _ => Err(format!( "Invalid indent style: {}. Use 'tabs' or 'spaces'", s )), } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_format_simple_table() { let table = Table::new(Family::Inet, "test".to_string()).add_chain( Chain::new("input".to_string()) .with_type(ChainType::Filter) .with_hook(Hook::Input) .with_priority(0) .with_policy(Policy::Accept) .add_rule(Rule::new( vec![Expression::Interface { direction: InterfaceDirection::Input, name: "lo".to_string(), }], Action::Accept, )), ); let formatter = NftablesFormatter::new(FormatConfig::default()); let mut output = String::new(); formatter.format_table(&mut output, &table, 0); // Just verify it doesn't panic and produces some output assert!(!output.is_empty()); } }