nff/src/syntax.rs

361 lines
11 KiB
Rust

use crate::ast::*;
use std::fmt::Write;
/// Configuration for formatting output
#[derive(Debug, Clone)]
pub struct FormatConfig {
pub indent_style: IndentStyle,
pub spaces_per_level: usize,
pub optimize: bool,
pub max_empty_lines: usize,
}
impl Default for FormatConfig {
fn default() -> Self {
Self {
indent_style: IndentStyle::Tabs,
spaces_per_level: 2,
optimize: false,
max_empty_lines: 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndentStyle {
Tabs,
Spaces,
}
impl IndentStyle {
pub fn format(&self, level: usize, spaces_per_level: usize) -> String {
match self {
IndentStyle::Tabs => "\t".repeat(level),
IndentStyle::Spaces => " ".repeat(spaces_per_level * level),
}
}
}
/// Formatter for nftables AST
pub struct NftablesFormatter {
config: FormatConfig,
}
impl NftablesFormatter {
pub fn new(config: FormatConfig) -> Self {
Self { config }
}
/// Add appropriate number of empty lines based on configuration
fn add_separator(&self, output: &mut String) {
if self.config.optimize {
output.push('\n');
} else {
// Add newlines based on max_empty_lines setting
for _ in 0..=self.config.max_empty_lines {
output.push('\n');
}
}
}
pub fn format_ruleset(&self, ruleset: &Ruleset) -> String {
let mut output = String::new();
// Format shebang
if let Some(shebang) = &ruleset.shebang {
writeln!(output, "#!{}", shebang).unwrap();
}
// Format includes
for include in &ruleset.includes {
self.format_include(&mut output, include, 0);
}
// Add separator if we have includes
if !ruleset.includes.is_empty() {
self.add_separator(&mut output);
}
// Format defines
for define in &ruleset.defines {
self.format_define(&mut output, define, 0);
}
// Add separator if we have defines
if !ruleset.defines.is_empty() {
self.add_separator(&mut output);
}
// Format tables
let mut table_iter = ruleset.tables.values().peekable();
while let Some(table) = table_iter.next() {
self.format_table(&mut output, table, 0); // Add separator between tables
if table_iter.peek().is_some() {
self.add_separator(&mut output);
}
}
// Ensure file ends with newline
if !output.ends_with('\n') {
output.push('\n');
}
output
}
fn format_include(&self, output: &mut String, include: &Include, level: usize) {
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
writeln!(output, "{}include \"{}\"", indent, include.path).unwrap();
}
fn format_define(&self, output: &mut String, define: &Define, level: usize) {
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
write!(output, "{}define {} = ", indent, define.name).unwrap();
self.format_expression(output, &define.value);
output.push('\n');
}
fn format_table(&self, output: &mut String, table: &Table, level: usize) {
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
writeln!(output, "{}table {} {} {{", indent, table.family, table.name).unwrap();
// Format chains
let mut chain_iter = table.chains.values().peekable();
while let Some(chain) = chain_iter.next() {
self.format_chain(output, chain, level + 1); // Add separator between chains
if chain_iter.peek().is_some() {
self.add_separator(output);
}
}
writeln!(output, "{}}}", indent).unwrap();
}
fn format_chain(&self, output: &mut String, chain: &Chain, level: usize) {
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
writeln!(output, "{}chain {} {{", indent, chain.name).unwrap();
// Format chain properties
if let Some(chain_type) = &chain.chain_type {
write!(
output,
"{}type {}",
self.config
.indent_style
.format(level + 1, self.config.spaces_per_level),
chain_type
)
.unwrap();
if let Some(hook) = &chain.hook {
write!(output, " hook {}", hook).unwrap();
}
if let Some(priority) = chain.priority {
write!(output, " priority {}", priority).unwrap();
}
// Add semicolon after type/hook/priority
output.push_str(";");
// Add policy on the same line if present
if let Some(policy) = &chain.policy {
write!(output, " policy {}", policy).unwrap();
output.push_str(";\n");
} else {
output.push_str("\n");
}
if !chain.rules.is_empty() && !self.config.optimize {
output.push('\n');
}
}
// Format rules
for (i, rule) in chain.rules.iter().enumerate() {
// Add spacing between rules (but not before first rule)
if i > 0 && !self.config.optimize && self.config.max_empty_lines > 0 {
output.push('\n');
}
self.format_rule(output, rule, level + 1);
}
writeln!(output, "{}}}", indent).unwrap();
}
fn format_rule(&self, output: &mut String, rule: &Rule, level: usize) {
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
write!(output, "{}", indent).unwrap();
// Format expressions
for (i, expr) in rule.expressions.iter().enumerate() {
if i > 0 {
output.push(' ');
}
self.format_expression(output, expr);
}
// Add action
if !rule.expressions.is_empty() {
output.push(' ');
}
write!(output, "{}", rule.action).unwrap();
output.push('\n');
// Only add extra newline between rules, not after the last rule
// We'll handle this in the chain formatting instead
}
fn format_expression(&self, output: &mut String, expr: &Expression) {
match expr {
Expression::Identifier(name) => write!(output, "{}", name).unwrap(),
Expression::String(s) => write!(output, "\"{}\"", s).unwrap(),
Expression::Number(n) => write!(output, "{}", n).unwrap(),
Expression::IpAddress(addr) => write!(output, "{}", addr).unwrap(),
Expression::Ipv6Address(addr) => write!(output, "{}", addr).unwrap(),
Expression::MacAddress(addr) => write!(output, "{}", addr).unwrap(),
Expression::Binary {
left,
operator,
right,
} => {
self.format_expression(output, left);
write!(output, " {} ", operator).unwrap();
self.format_expression(output, right);
}
Expression::Protocol(proto) => write!(output, "protocol {}", proto).unwrap(),
Expression::Port { direction, value } => {
match direction {
PortDirection::Source => write!(output, "sport ").unwrap(),
PortDirection::Destination => write!(output, "dport ").unwrap(),
}
self.format_expression(output, value);
}
Expression::Address { direction, value } => {
// Include the protocol family when formatting addresses
write!(output, "ip ").unwrap();
match direction {
AddressDirection::Source => write!(output, "saddr ").unwrap(),
AddressDirection::Destination => write!(output, "daddr ").unwrap(),
}
self.format_expression(output, value);
}
Expression::Interface { direction, name } => {
match direction {
InterfaceDirection::Input => write!(output, "iifname ").unwrap(),
InterfaceDirection::Output => write!(output, "oifname ").unwrap(),
}
write!(output, "{}", name).unwrap();
}
Expression::ConnTrack { field, value } => {
write!(output, "ct {} ", field).unwrap();
self.format_expression(output, value);
}
Expression::Set(elements) => {
output.push_str("{ ");
for (i, element) in elements.iter().enumerate() {
if i > 0 {
output.push_str(", ");
}
self.format_expression(output, element);
}
output.push_str(" }");
}
Expression::Range { start, end } => {
self.format_expression(output, start);
output.push('-');
self.format_expression(output, end);
}
Expression::Vmap { expr, map } => {
if let Some(expr) = expr {
self.format_expression(output, expr);
output.push(' ');
}
output.push_str("vmap { ");
for (i, (key, value)) in map.iter().enumerate() {
if i > 0 {
output.push_str(", ");
}
self.format_expression(output, key);
output.push_str(" : ");
self.format_expression(output, value);
}
output.push_str(" }");
}
}
}
}
/// Convert from string-based IndentStyle to our enum
impl std::str::FromStr for IndentStyle {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"tabs" | "tab" => Ok(IndentStyle::Tabs),
"spaces" | "space" => Ok(IndentStyle::Spaces),
_ => Err(format!(
"Invalid indent style: {}. Use 'tabs' or 'spaces'",
s
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_simple_table() {
let table = Table::new(Family::Inet, "test".to_string()).add_chain(
Chain::new("input".to_string())
.with_type(ChainType::Filter)
.with_hook(Hook::Input)
.with_priority(0)
.with_policy(Policy::Accept)
.add_rule(Rule::new(
vec![Expression::Interface {
direction: InterfaceDirection::Input,
name: "lo".to_string(),
}],
Action::Accept,
)),
);
let formatter = NftablesFormatter::new(FormatConfig::default());
let mut output = String::new();
formatter.format_table(&mut output, &table, 0);
// Just verify it doesn't panic and produces some output
assert!(!output.is_empty());
}
}