treewide: fmt

This commit is contained in:
raf 2025-05-24 23:41:06 +03:00
commit 10a525b2e7
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
5 changed files with 326 additions and 206 deletions

View file

@ -3,14 +3,10 @@
//! Lossless representation preserving whitespace, comments, and formatting.
use crate::lexer::{Token, TokenKind};
use cstree::{
green::GreenNode,
RawSyntaxKind, util::NodeOrToken,
};
use cstree::{RawSyntaxKind, green::GreenNode, util::NodeOrToken};
use std::fmt;
use thiserror::Error;
/// nftables syntax node types
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u16)]
@ -350,7 +346,10 @@ impl CstBuilder {
/// Internal tree builder that constructs CST according to nftables grammar
struct CstTreeBuilder {
stack: Vec<(SyntaxKind, Vec<NodeOrToken<GreenNode, cstree::green::GreenToken>>)>,
stack: Vec<(
SyntaxKind,
Vec<NodeOrToken<GreenNode, cstree::green::GreenToken>>,
)>,
}
impl CstTreeBuilder {
@ -384,15 +383,15 @@ impl CstTreeBuilder {
TokenKind::CommentLine(_) => {
self.add_token(&tokens[start], SyntaxKind::Comment);
start + 1
},
}
TokenKind::Newline => {
self.add_token(&tokens[start], SyntaxKind::Newline);
start + 1
},
}
TokenKind::Shebang(_) => {
self.add_token(&tokens[start], SyntaxKind::Shebang);
start + 1
},
}
_ => {
// Handle other tokens as statements
self.parse_statement(tokens, start)
@ -437,17 +436,17 @@ impl CstTreeBuilder {
match &tokens[pos].kind {
TokenKind::Chain => {
pos = self.parse_chain(tokens, pos);
},
}
TokenKind::Set => {
pos = self.parse_set(tokens, pos);
},
}
TokenKind::Map => {
pos = self.parse_map(tokens, pos);
},
}
TokenKind::Newline | TokenKind::CommentLine(_) => {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
},
}
_ => pos += 1, // Skip unknown tokens
}
}
@ -484,11 +483,11 @@ impl CstTreeBuilder {
match &tokens[pos].kind {
TokenKind::Type => {
pos = self.parse_chain_properties(tokens, pos);
},
}
TokenKind::Newline | TokenKind::CommentLine(_) => {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
},
}
_ => {
// Parse as rule
pos = self.parse_rule(tokens, pos);
@ -538,7 +537,9 @@ impl CstTreeBuilder {
TokenKind::Output => self.add_token(&tokens[pos], SyntaxKind::OutputKw),
TokenKind::Forward => self.add_token(&tokens[pos], SyntaxKind::ForwardKw),
TokenKind::Prerouting => self.add_token(&tokens[pos], SyntaxKind::PreroutingKw),
TokenKind::Postrouting => self.add_token(&tokens[pos], SyntaxKind::PostroutingKw),
TokenKind::Postrouting => {
self.add_token(&tokens[pos], SyntaxKind::PostroutingKw)
}
_ => self.add_token(&tokens[pos], SyntaxKind::Identifier),
}
pos += 1;
@ -580,7 +581,7 @@ impl CstTreeBuilder {
self.add_token(&tokens[pos], SyntaxKind::Newline);
pos += 1;
break;
},
}
TokenKind::Accept => self.add_token(&tokens[pos], SyntaxKind::AcceptKw),
TokenKind::Drop => self.add_token(&tokens[pos], SyntaxKind::DropKw),
TokenKind::Reject => self.add_token(&tokens[pos], SyntaxKind::RejectKw),
@ -596,10 +597,14 @@ impl CstTreeBuilder {
TokenKind::LeftParen => {
pos = self.parse_expression(tokens, pos);
continue;
},
}
TokenKind::Identifier(_) => self.add_token(&tokens[pos], SyntaxKind::Identifier),
TokenKind::StringLiteral(_) => self.add_token(&tokens[pos], SyntaxKind::StringLiteral),
TokenKind::NumberLiteral(_) => self.add_token(&tokens[pos], SyntaxKind::NumberLiteral),
TokenKind::StringLiteral(_) => {
self.add_token(&tokens[pos], SyntaxKind::StringLiteral)
}
TokenKind::NumberLiteral(_) => {
self.add_token(&tokens[pos], SyntaxKind::NumberLiteral)
}
TokenKind::IpAddress(_) => self.add_token(&tokens[pos], SyntaxKind::IpAddress),
_ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())),
}
@ -627,7 +632,7 @@ impl CstTreeBuilder {
TokenKind::LeftParen => {
paren_depth += 1;
self.add_token(&tokens[pos], SyntaxKind::LeftParen);
},
}
TokenKind::RightParen => {
self.add_token(&tokens[pos], SyntaxKind::RightParen);
paren_depth -= 1;
@ -635,20 +640,22 @@ impl CstTreeBuilder {
pos += 1;
break;
}
},
}
TokenKind::LeftBracket => {
pos = self.parse_set_expression(tokens, pos);
continue;
},
}
TokenKind::Identifier(_) => {
// Check if this is a function call
if pos + 1 < tokens.len() && matches!(tokens[pos + 1].kind, TokenKind::LeftParen) {
if pos + 1 < tokens.len()
&& matches!(tokens[pos + 1].kind, TokenKind::LeftParen)
{
pos = self.parse_call_expression(tokens, pos);
continue;
} else {
self.add_token(&tokens[pos], SyntaxKind::Identifier);
}
},
}
TokenKind::Dash => {
// Could be binary or unary expression
if self.is_binary_context() {
@ -656,7 +663,7 @@ impl CstTreeBuilder {
} else {
self.parse_unary_expression(tokens, pos);
}
},
}
_ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())),
}
pos += 1;
@ -789,7 +796,9 @@ impl CstTreeBuilder {
pos += 1;
// Parse what to flush (table, chain, etc.)
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::Semicolon) {
while pos < tokens.len()
&& !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::Semicolon)
{
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
@ -845,7 +854,9 @@ impl CstTreeBuilder {
pos += 1;
// Parse element specification
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::LeftBrace) {
while pos < tokens.len()
&& !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::LeftBrace)
{
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
@ -914,7 +925,9 @@ impl CstTreeBuilder {
fn is_binary_context(&self) -> bool {
// Simple heuristic: assume binary if we have recent tokens in current scope
self.stack.last().map_or(false, |(_, children)| children.len() > 1)
self.stack
.last()
.map_or(false, |(_, children)| children.len() > 1)
}
fn start_node(&mut self, kind: SyntaxKind) {
@ -925,7 +938,8 @@ impl CstTreeBuilder {
// Handle whitespace specially if it's part of the token
if let TokenKind::Identifier(text) = &token.kind {
if text.chars().all(|c| c.is_whitespace()) {
let whitespace_token = GreenNode::new(SyntaxKind::Whitespace.to_raw(), std::iter::empty());
let whitespace_token =
GreenNode::new(SyntaxKind::Whitespace.to_raw(), std::iter::empty());
if let Some((_, children)) = self.stack.last_mut() {
children.push(NodeOrToken::Node(whitespace_token));
}
@ -978,14 +992,17 @@ impl CstValidator {
SyntaxKind::Rule => self.validate_rule(node),
SyntaxKind::Set | SyntaxKind::Map => self.validate_set_or_map(node),
SyntaxKind::Element => self.validate_element(node),
SyntaxKind::Expression | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::CallExpr => {
self.validate_expression(node)
},
SyntaxKind::Expression
| SyntaxKind::BinaryExpr
| SyntaxKind::UnaryExpr
| SyntaxKind::CallExpr => self.validate_expression(node),
SyntaxKind::IncludeStmt => self.validate_include(node),
SyntaxKind::DefineStmt => self.validate_define(node),
SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt => self.validate_statement(node),
SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt => {
self.validate_statement(node)
}
SyntaxKind::Whitespace => Ok(()), // Whitespace is always valid
_ => Ok(()), // Other nodes are generally valid
_ => Ok(()), // Other nodes are generally valid
}
}
@ -1003,15 +1020,24 @@ impl CstValidator {
if let Some(child_node) = child.as_node() {
let child_kind = SyntaxKind::from_raw(child_node.kind());
match child_kind {
SyntaxKind::Table | SyntaxKind::IncludeStmt | SyntaxKind::DefineStmt |
SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt |
SyntaxKind::Element | SyntaxKind::Comment | SyntaxKind::Newline |
SyntaxKind::Shebang | SyntaxKind::Whitespace => {
SyntaxKind::Table
| SyntaxKind::IncludeStmt
| SyntaxKind::DefineStmt
| SyntaxKind::FlushStmt
| SyntaxKind::AddStmt
| SyntaxKind::DeleteStmt
| SyntaxKind::Element
| SyntaxKind::Comment
| SyntaxKind::Newline
| SyntaxKind::Shebang
| SyntaxKind::Whitespace => {
self.validate(child_node)?;
},
}
_ => {
return Err(CstError::UnexpectedToken {
expected: "table, include, define, flush, add, delete, element, or comment".to_string(),
expected:
"table, include, define, flush, add, delete, element, or comment"
.to_string(),
found: format!("{:?}", child_kind),
});
}
@ -1052,11 +1078,16 @@ impl CstValidator {
if let Some(second_node) = second.as_node() {
let second_kind = SyntaxKind::from_raw(second_node.kind());
match second_kind {
SyntaxKind::InetKw | SyntaxKind::IpKw | SyntaxKind::Ip6Kw |
SyntaxKind::ArpKw | SyntaxKind::BridgeKw | SyntaxKind::NetdevKw => {},
SyntaxKind::InetKw
| SyntaxKind::IpKw
| SyntaxKind::Ip6Kw
| SyntaxKind::ArpKw
| SyntaxKind::BridgeKw
| SyntaxKind::NetdevKw => {}
_ => {
return Err(CstError::UnexpectedToken {
expected: "table family (inet, ip, ip6, arp, bridge, netdev)".to_string(),
expected: "table family (inet, ip, ip6, arp, bridge, netdev)"
.to_string(),
found: format!("{:?}", second_kind),
});
}

View file

@ -19,7 +19,7 @@ pub type LexResult<T> = Result<T, LexError>;
/// Token kinds for nftables configuration files
#[derive(Logos, Debug, Clone, PartialEq, Eq, Hash)]
#[logos(skip r"[ \t\f]+")] // Skip whitespace but not newlines
#[logos(skip r"[ \t\f]+")] // Skip whitespace but not newlines
pub enum TokenKind {
// Keywords
#[token("table")]
@ -198,7 +198,11 @@ pub enum TokenKind {
NumberLiteral(u64),
#[regex(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+", |lex| lex.slice().to_owned())]
IpAddress(String),
#[regex(r"(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:)*::[0-9a-fA-F:]*", ipv6_address, priority = 5)]
#[regex(
r"(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:)*::[0-9a-fA-F:]*",
ipv6_address,
priority = 5
)]
Ipv6Address(String),
#[regex(r"[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}", |lex| lex.slice().to_owned())]
MacAddress(String),
@ -245,7 +249,9 @@ impl fmt::Display for TokenKind {
fn string_literal(lex: &mut Lexer<TokenKind>) -> String {
let slice = lex.slice();
// Remove surrounding quotes and process escape sequences
slice[1..slice.len()-1].replace("\\\"", "\"").replace("\\\\", "\\")
slice[1..slice.len() - 1]
.replace("\\\"", "\"")
.replace("\\\\", "\\")
}
fn number_literal(lex: &mut Lexer<TokenKind>) -> Option<u64> {
@ -309,7 +315,7 @@ impl<'a> NftablesLexer<'a> {
let text = &self.source[span.clone()];
let range = TextRange::new(
TextSize::from(span.start as u32),
TextSize::from(span.end as u32)
TextSize::from(span.end as u32),
);
match result {
@ -322,8 +328,11 @@ impl<'a> NftablesLexer<'a> {
return Err(LexError::UnterminatedString {
position: span.start,
});
} else if text.chars().any(|c| c.is_ascii_digit()) &&
text.chars().any(|c| !c.is_ascii_digit() && c != '.' && c != 'x' && c != 'X') {
} else if text.chars().any(|c| c.is_ascii_digit())
&& text
.chars()
.any(|c| !c.is_ascii_digit() && c != '.' && c != 'x' && c != 'X')
{
return Err(LexError::InvalidNumber {
text: text.to_owned(),
});
@ -349,7 +358,7 @@ impl<'a> NftablesLexer<'a> {
let text = &self.source[span.clone()];
let range = TextRange::new(
TextSize::from(span.start as u32),
TextSize::from(span.end as u32)
TextSize::from(span.end as u32),
);
let kind = result.unwrap_or(TokenKind::Error);
@ -395,7 +404,11 @@ mod tests {
let mut lexer = NftablesLexer::new(source);
let tokens = lexer.tokenize().expect("Tokenization should succeed");
assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::CommentLine(_))));
assert!(
tokens
.iter()
.any(|t| matches!(t.kind, TokenKind::CommentLine(_)))
);
assert!(tokens.iter().any(|t| t.kind == TokenKind::Table));
}

View file

@ -4,17 +4,17 @@ mod lexer;
mod parser;
mod syntax;
use anyhow::{Context, Result};
use clap::Parser;
use std::fs;
use std::io::{self, Write};
use std::path::Path;
use clap::Parser;
use anyhow::{Context, Result};
use thiserror::Error;
use crate::cst::CstBuilder;
use crate::lexer::NftablesLexer;
use crate::parser::Parser as NftablesParser;
use crate::syntax::{FormatConfig, IndentStyle, NftablesFormatter};
use crate::cst::CstBuilder;
#[derive(Error, Debug)]
enum FormatterError {
@ -85,14 +85,18 @@ fn process_nftables_config(args: Args) -> Result<()> {
// Use error-recovery tokenization for debug mode
lexer.tokenize_with_errors()
} else {
lexer.tokenize()
lexer
.tokenize()
.map_err(|e| FormatterError::ParseError(e.to_string()))?
};
if args.debug {
eprintln!("=== TOKENS ===");
for (i, token) in tokens.iter().enumerate() {
eprintln!("{:3}: {:?} @ {:?} = '{}'", i, token.kind, token.range, token.text);
eprintln!(
"{:3}: {:?} @ {:?} = '{}'",
i, token.kind, token.range, token.text
);
}
eprintln!();
@ -126,7 +130,8 @@ fn process_nftables_config(args: Args) -> Result<()> {
parsed_ruleset.unwrap_or_else(|| crate::ast::Ruleset::new())
} else {
let mut parser = NftablesParser::new(tokens.clone());
parser.parse()
parser
.parse()
.map_err(|e| FormatterError::ParseError(e.to_string()))?
};
@ -160,7 +165,8 @@ fn process_nftables_config(args: Args) -> Result<()> {
println!("Formatted output written to: {}", output_file);
}
None => {
io::stdout().write_all(formatted_output.as_bytes())
io::stdout()
.write_all(formatted_output.as_bytes())
.with_context(|| "Failed to write to stdout")?;
}
}

View file

@ -1,12 +1,14 @@
use crate::ast::*;
use crate::lexer::{LexError, Token, TokenKind};
use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow};
use thiserror::Error;
/// Parse errors for nftables configuration
#[derive(Error, Debug)]
pub enum ParseError {
#[error("Unexpected token at line {line}, column {column}: expected {expected}, found '{found}'")]
#[error(
"Unexpected token at line {line}, column {column}: expected {expected}, found '{found}'"
)]
UnexpectedToken {
line: usize,
column: usize,
@ -123,7 +125,10 @@ impl Parser {
line: 1,
column: 1,
expected: "include, define, table, or comment".to_string(),
found: self.peek().map(|t| t.text.clone()).unwrap_or("EOF".to_string()),
found: self
.peek()
.map(|t| t.text.clone())
.unwrap_or("EOF".to_string()),
});
}
}
@ -154,21 +159,23 @@ impl Parser {
self.consume(TokenKind::Include, "Expected 'include'")?;
let path = match self.advance() {
Some(token) if matches!(token.kind, TokenKind::StringLiteral(_)) => {
match &token.kind {
TokenKind::StringLiteral(s) => s.clone(),
_ => unreachable!(),
}
Some(token) if matches!(token.kind, TokenKind::StringLiteral(_)) => match &token.kind {
TokenKind::StringLiteral(s) => s.clone(),
_ => unreachable!(),
},
None => {
return Err(ParseError::MissingToken {
expected: "string literal after 'include'".to_string(),
});
}
Some(token) => {
return Err(ParseError::UnexpectedToken {
line: 1,
column: 1,
expected: "string literal".to_string(),
found: token.text.clone(),
});
}
None => return Err(ParseError::MissingToken {
expected: "string literal after 'include'".to_string()
}),
Some(token) => return Err(ParseError::UnexpectedToken {
line: 1,
column: 1,
expected: "string literal".to_string(),
found: token.text.clone(),
}),
};
Ok(Include { path })
@ -179,21 +186,23 @@ impl Parser {
self.consume(TokenKind::Define, "Expected 'define'")?;
let name = match self.advance() {
Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => {
match &token.kind {
TokenKind::Identifier(s) => s.clone(),
_ => unreachable!(),
}
Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => match &token.kind {
TokenKind::Identifier(s) => s.clone(),
_ => unreachable!(),
},
None => {
return Err(ParseError::MissingToken {
expected: "identifier after 'define'".to_string(),
});
}
Some(token) => {
return Err(ParseError::UnexpectedToken {
line: 1,
column: 1,
expected: "identifier".to_string(),
found: token.text.clone(),
});
}
None => return Err(ParseError::MissingToken {
expected: "identifier after 'define'".to_string()
}),
Some(token) => return Err(ParseError::UnexpectedToken {
line: 1,
column: 1,
expected: "identifier".to_string(),
found: token.text.clone(),
}),
};
self.consume(TokenKind::Assign, "Expected '=' after define name")?;
@ -232,10 +241,14 @@ impl Parser {
Some(TokenKind::Newline) => {
self.advance();
}
_ => { return Err(ParseError::SemanticError {
message: format!("Unexpected token in table: {}",
self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")),
}.into());
_ => {
return Err(ParseError::SemanticError {
message: format!(
"Unexpected token in table: {}",
self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")
),
}
.into());
}
}
self.skip_whitespace();
@ -337,8 +350,8 @@ impl Parser {
// Parse expressions and action
while !self.current_token_is(&TokenKind::Newline)
&& !self.current_token_is(&TokenKind::RightBrace)
&& !self.is_at_end() {
&& !self.is_at_end()
{
// Check for actions first
match self.peek().map(|t| &t.kind) {
Some(TokenKind::Accept) => {
@ -415,7 +428,8 @@ impl Parser {
} else {
return Err(ParseError::InvalidStatement {
message: "Expected string literal after 'comment'".to_string(),
}.into());
}
.into());
}
}
_ => {
@ -451,16 +465,16 @@ impl Parser {
fn parse_comparison_expression(&mut self) -> Result<Expression> {
let mut expr = self.parse_range_expression()?;
while let Some(token) = self.peek() {
let operator = match &token.kind {
TokenKind::Eq => BinaryOperator::Eq,
TokenKind::Ne => BinaryOperator::Ne,
TokenKind::Lt => BinaryOperator::Lt,
TokenKind::Le => BinaryOperator::Le,
TokenKind::Gt => BinaryOperator::Gt,
TokenKind::Ge => BinaryOperator::Ge,
_ => break,
};
while let Some(token) = self.peek() {
let operator = match &token.kind {
TokenKind::Eq => BinaryOperator::Eq,
TokenKind::Ne => BinaryOperator::Ne,
TokenKind::Lt => BinaryOperator::Lt,
TokenKind::Le => BinaryOperator::Le,
TokenKind::Gt => BinaryOperator::Gt,
TokenKind::Ge => BinaryOperator::Ge,
_ => break,
};
self.advance(); // consume operator
let right = self.parse_range_expression()?;
@ -664,19 +678,22 @@ impl Parser {
// Check for rate expression (e.g., 10/minute)
if self.current_token_is(&TokenKind::Slash) {
self.advance(); // consume '/'
if let Some(token) = self.peek() {
if matches!(token.kind, TokenKind::Identifier(_)) {
let unit = self.advance().unwrap().text.clone();
Ok(Expression::String(format!("{}/{}", num, unit)))
} else {
Err(ParseError::InvalidExpression {
message: "Expected identifier after '/' in rate expression".to_string(),
}.into())
}
if let Some(token) = self.peek() {
if matches!(token.kind, TokenKind::Identifier(_)) {
let unit = self.advance().unwrap().text.clone();
Ok(Expression::String(format!("{}/{}", num, unit)))
} else {
Err(ParseError::InvalidExpression {
message: "Expected identifier after '/' in rate expression"
.to_string(),
}
.into())
}
} else {
Err(ParseError::InvalidExpression {
message: "Expected identifier after '/' in rate expression".to_string(),
}.into())
}
.into())
}
} else {
Ok(Expression::Number(num))
@ -694,12 +711,14 @@ impl Parser {
} else {
Err(ParseError::InvalidExpression {
message: "Expected number after '/' in CIDR notation".to_string(),
}.into())
}
.into())
}
} else {
Err(ParseError::InvalidExpression {
message: "Expected number after '/' in CIDR notation".to_string(),
}.into())
}
.into())
}
// Check for port specification (e.g., 192.168.1.100:80)
} else if self.current_token_is(&TokenKind::Colon) {
@ -710,13 +729,17 @@ impl Parser {
Ok(Expression::String(format!("{}:{}", addr, port)))
} else {
Err(ParseError::InvalidExpression {
message: "Expected number after ':' in address:port specification".to_string(),
}.into())
message: "Expected number after ':' in address:port specification"
.to_string(),
}
.into())
}
} else {
Err(ParseError::InvalidExpression {
message: "Expected number after ':' in address:port specification".to_string(),
}.into())
message: "Expected number after ':' in address:port specification"
.to_string(),
}
.into())
}
} else {
Ok(Expression::IpAddress(addr))
@ -751,12 +774,13 @@ impl Parser {
self.consume(TokenKind::RightBrace, "Expected '}' to close set")?;
Ok(Expression::Set(elements))
}
_ => {
Err(ParseError::InvalidExpression {
message: format!("Unexpected token in expression: {}",
self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")),
}.into())
_ => Err(ParseError::InvalidExpression {
message: format!(
"Unexpected token in expression: {}",
self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")
),
}
.into()),
}
}
@ -822,12 +846,12 @@ impl Parser {
let token_clone = token.clone();
Err(self.unexpected_token_error_with_token(
"family (ip, ip6, inet, arp, bridge, netdev)".to_string(),
&token_clone
&token_clone,
))
}
},
None => Err(ParseError::MissingToken {
expected: "family".to_string()
expected: "family".to_string(),
}),
}
}
@ -842,12 +866,12 @@ impl Parser {
let token_clone = token.clone();
Err(self.unexpected_token_error_with_token(
"chain type (filter, nat, route)".to_string(),
&token_clone
&token_clone,
))
}
},
None => Err(ParseError::MissingToken {
expected: "chain type".to_string()
expected: "chain type".to_string(),
}),
}
}
@ -864,12 +888,12 @@ impl Parser {
let token_clone = token.clone();
Err(self.unexpected_token_error_with_token(
"hook (input, output, forward, prerouting, postrouting)".to_string(),
&token_clone
&token_clone,
))
}
},
None => Err(ParseError::MissingToken {
expected: "hook".to_string()
expected: "hook".to_string(),
}),
}
}
@ -879,7 +903,10 @@ impl Parser {
Some(token) => match token.kind {
TokenKind::Accept => Ok(Policy::Accept),
TokenKind::Drop => Ok(Policy::Drop),
_ => Err(anyhow!("Expected policy (accept, drop), got: {}", token.text)),
_ => Err(anyhow!(
"Expected policy (accept, drop), got: {}",
token.text
)),
},
None => Err(anyhow!("Expected policy")),
}
@ -887,12 +914,10 @@ impl Parser {
fn parse_identifier(&mut self) -> Result<String> {
match self.advance() {
Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => {
match &token.kind {
TokenKind::Identifier(s) => Ok(s.clone()),
_ => unreachable!(),
}
}
Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => match &token.kind {
TokenKind::Identifier(s) => Ok(s.clone()),
_ => unreachable!(),
},
Some(token) => Err(anyhow!("Expected identifier, got: {}", token.text)),
None => Err(anyhow!("Expected identifier")),
}
@ -905,17 +930,33 @@ impl Parser {
match &token.kind {
TokenKind::Identifier(s) => Ok(s.clone()),
// Allow keywords to be used as identifiers in certain contexts
TokenKind::Filter | TokenKind::Nat | TokenKind::Route |
TokenKind::Input | TokenKind::Output | TokenKind::Forward |
TokenKind::Prerouting | TokenKind::Postrouting |
TokenKind::Accept | TokenKind::Drop | TokenKind::Reject |
TokenKind::State | TokenKind::Ct | TokenKind::Type | TokenKind::Hook |
TokenKind::Priority | TokenKind::Policy |
TokenKind::Tcp | TokenKind::Udp | TokenKind::Icmp | TokenKind::Icmpv6 |
TokenKind::Ip | TokenKind::Ip6 => {
Ok(token.text.clone())
}
_ => Err(anyhow!("Expected identifier or keyword, got: {}", token.text)),
TokenKind::Filter
| TokenKind::Nat
| TokenKind::Route
| TokenKind::Input
| TokenKind::Output
| TokenKind::Forward
| TokenKind::Prerouting
| TokenKind::Postrouting
| TokenKind::Accept
| TokenKind::Drop
| TokenKind::Reject
| TokenKind::State
| TokenKind::Ct
| TokenKind::Type
| TokenKind::Hook
| TokenKind::Priority
| TokenKind::Policy
| TokenKind::Tcp
| TokenKind::Udp
| TokenKind::Icmp
| TokenKind::Icmpv6
| TokenKind::Ip
| TokenKind::Ip6 => Ok(token.text.clone()),
_ => Err(anyhow!(
"Expected identifier or keyword, got: {}",
token.text
)),
}
}
None => Err(anyhow!("Expected identifier or keyword")),
@ -924,26 +965,32 @@ impl Parser {
fn parse_string_or_identifier(&mut self) -> Result<String> {
match self.advance() {
Some(token) if matches!(token.kind, TokenKind::Identifier(_) | TokenKind::StringLiteral(_)) => {
Some(token)
if matches!(
token.kind,
TokenKind::Identifier(_) | TokenKind::StringLiteral(_)
) =>
{
match &token.kind {
TokenKind::Identifier(s) => Ok(s.clone()),
TokenKind::StringLiteral(s) => Ok(s.clone()),
_ => unreachable!(),
}
}
Some(token) => Err(anyhow!("Expected string or identifier, got: {}", token.text)),
Some(token) => Err(anyhow!(
"Expected string or identifier, got: {}",
token.text
)),
None => Err(anyhow!("Expected string or identifier")),
}
}
fn parse_number(&mut self) -> Result<u64> {
match self.advance() {
Some(token) if matches!(token.kind, TokenKind::NumberLiteral(_)) => {
match &token.kind {
TokenKind::NumberLiteral(n) => Ok(*n),
_ => unreachable!(),
}
}
Some(token) if matches!(token.kind, TokenKind::NumberLiteral(_)) => match &token.kind {
TokenKind::NumberLiteral(n) => Ok(*n),
_ => unreachable!(),
},
Some(token) => Err(anyhow!("Expected number, got: {}", token.text)),
None => Err(anyhow!("Expected number")),
}
@ -1029,11 +1076,7 @@ impl Parser {
}
}
fn unexpected_token_error_with_token(
&self,
expected: String,
token: &Token,
) -> ParseError {
fn unexpected_token_error_with_token(&self, expected: String, token: &Token) -> ParseError {
ParseError::UnexpectedToken {
line: 1, // TODO: Calculate line from range
column: token.range.start().into(),

View file

@ -71,28 +71,28 @@ impl NftablesFormatter {
self.format_include(&mut output, include, 0);
}
// Add separator if we have includes
if !ruleset.includes.is_empty() {
self.add_separator(&mut output);
}
// Add separator if we have includes
if !ruleset.includes.is_empty() {
self.add_separator(&mut output);
}
// Format defines
for define in &ruleset.defines {
self.format_define(&mut output, define, 0);
}
// Add separator if we have defines
if !ruleset.defines.is_empty() {
self.add_separator(&mut output);
}
// Add separator if we have defines
if !ruleset.defines.is_empty() {
self.add_separator(&mut output);
}
// Format tables
let mut table_iter = ruleset.tables.values().peekable();
while let Some(table) = table_iter.next() {
self.format_table(&mut output, table, 0); // Add separator between tables
if table_iter.peek().is_some() {
self.add_separator(&mut output);
}
self.format_table(&mut output, table, 0); // Add separator between tables
if table_iter.peek().is_some() {
self.add_separator(&mut output);
}
}
// Ensure file ends with newline
@ -104,44 +104,62 @@ impl NftablesFormatter {
}
fn format_include(&self, output: &mut String, include: &Include, level: usize) {
let indent = self.config.indent_style.format(level, self.config.spaces_per_level);
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
writeln!(output, "{}include \"{}\"", indent, include.path).unwrap();
}
fn format_define(&self, output: &mut String, define: &Define, level: usize) {
let indent = self.config.indent_style.format(level, self.config.spaces_per_level);
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
write!(output, "{}define {} = ", indent, define.name).unwrap();
self.format_expression(output, &define.value);
output.push('\n');
}
fn format_table(&self, output: &mut String, table: &Table, level: usize) {
let indent = self.config.indent_style.format(level, self.config.spaces_per_level);
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
writeln!(output, "{}table {} {} {{", indent, table.family, table.name).unwrap();
// Format chains
let mut chain_iter = table.chains.values().peekable();
while let Some(chain) = chain_iter.next() {
self.format_chain(output, chain, level + 1); // Add separator between chains
if chain_iter.peek().is_some() {
self.add_separator(output);
}
self.format_chain(output, chain, level + 1); // Add separator between chains
if chain_iter.peek().is_some() {
self.add_separator(output);
}
}
writeln!(output, "{}}}", indent).unwrap();
}
fn format_chain(&self, output: &mut String, chain: &Chain, level: usize) {
let indent = self.config.indent_style.format(level, self.config.spaces_per_level);
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
writeln!(output, "{}chain {} {{", indent, chain.name).unwrap();
// Format chain properties
if let Some(chain_type) = &chain.chain_type {
write!(output, "{}type {}",
self.config.indent_style.format(level + 1, self.config.spaces_per_level),
chain_type).unwrap();
write!(
output,
"{}type {}",
self.config
.indent_style
.format(level + 1, self.config.spaces_per_level),
chain_type
)
.unwrap();
if let Some(hook) = &chain.hook {
write!(output, " hook {}", hook).unwrap();
@ -179,7 +197,10 @@ impl NftablesFormatter {
}
fn format_rule(&self, output: &mut String, rule: &Rule, level: usize) {
let indent = self.config.indent_style.format(level, self.config.spaces_per_level);
let indent = self
.config
.indent_style
.format(level, self.config.spaces_per_level);
write!(output, "{}", indent).unwrap();
@ -212,7 +233,11 @@ impl NftablesFormatter {
Expression::Ipv6Address(addr) => write!(output, "{}", addr).unwrap(),
Expression::MacAddress(addr) => write!(output, "{}", addr).unwrap(),
Expression::Binary { left, operator, right } => {
Expression::Binary {
left,
operator,
right,
} => {
self.format_expression(output, left);
write!(output, " {} ", operator).unwrap();
self.format_expression(output, right);
@ -279,7 +304,10 @@ impl std::str::FromStr for IndentStyle {
match s.to_lowercase().as_str() {
"tabs" | "tab" => Ok(IndentStyle::Tabs),
"spaces" | "space" => Ok(IndentStyle::Spaces),
_ => Err(format!("Invalid indent style: {}. Use 'tabs' or 'spaces'", s)),
_ => Err(format!(
"Invalid indent style: {}. Use 'tabs' or 'spaces'",
s
)),
}
}
}
@ -290,21 +318,20 @@ mod tests {
#[test]
fn test_format_simple_table() {
let table = Table::new(Family::Inet, "test".to_string())
.add_chain(
Chain::new("input".to_string())
.with_type(ChainType::Filter)
.with_hook(Hook::Input)
.with_priority(0)
.with_policy(Policy::Accept)
.add_rule(Rule::new(
vec![Expression::Interface {
direction: InterfaceDirection::Input,
name: "lo".to_string(),
}],
Action::Accept,
))
);
let table = Table::new(Family::Inet, "test".to_string()).add_chain(
Chain::new("input".to_string())
.with_type(ChainType::Filter)
.with_hook(Hook::Input)
.with_priority(0)
.with_policy(Policy::Accept)
.add_rule(Rule::new(
vec![Expression::Interface {
direction: InterfaceDirection::Input,
name: "lo".to_string(),
}],
Action::Accept,
)),
);
let formatter = NftablesFormatter::new(FormatConfig::default());
let mut output = String::new();