From 10a525b2e7e07ccbfccca8f9fcdbaacc07fe3475 Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Sat, 24 May 2025 23:41:06 +0300 Subject: [PATCH] treewide: fmt --- src/cst.rs | 119 ++++++++++++++--------- src/lexer.rs | 29 ++++-- src/main.rs | 20 ++-- src/parser.rs | 255 +++++++++++++++++++++++++++++--------------------- src/syntax.rs | 109 +++++++++++++-------- 5 files changed, 326 insertions(+), 206 deletions(-) diff --git a/src/cst.rs b/src/cst.rs index e0e093c..9dcd962 100644 --- a/src/cst.rs +++ b/src/cst.rs @@ -3,14 +3,10 @@ //! Lossless representation preserving whitespace, comments, and formatting. use crate::lexer::{Token, TokenKind}; -use cstree::{ - green::GreenNode, - RawSyntaxKind, util::NodeOrToken, -}; +use cstree::{RawSyntaxKind, green::GreenNode, util::NodeOrToken}; use std::fmt; use thiserror::Error; - /// nftables syntax node types #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] @@ -350,7 +346,10 @@ impl CstBuilder { /// Internal tree builder that constructs CST according to nftables grammar struct CstTreeBuilder { - stack: Vec<(SyntaxKind, Vec>)>, + stack: Vec<( + SyntaxKind, + Vec>, + )>, } impl CstTreeBuilder { @@ -384,15 +383,15 @@ impl CstTreeBuilder { TokenKind::CommentLine(_) => { self.add_token(&tokens[start], SyntaxKind::Comment); start + 1 - }, + } TokenKind::Newline => { self.add_token(&tokens[start], SyntaxKind::Newline); start + 1 - }, + } TokenKind::Shebang(_) => { self.add_token(&tokens[start], SyntaxKind::Shebang); start + 1 - }, + } _ => { // Handle other tokens as statements self.parse_statement(tokens, start) @@ -437,17 +436,17 @@ impl CstTreeBuilder { match &tokens[pos].kind { TokenKind::Chain => { pos = self.parse_chain(tokens, pos); - }, + } TokenKind::Set => { pos = self.parse_set(tokens, pos); - }, + } TokenKind::Map => { pos = self.parse_map(tokens, pos); - }, + } TokenKind::Newline | TokenKind::CommentLine(_) => { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; - }, + } _ => pos += 1, // Skip unknown tokens } } @@ -484,11 +483,11 @@ impl CstTreeBuilder { match &tokens[pos].kind { TokenKind::Type => { pos = self.parse_chain_properties(tokens, pos); - }, + } TokenKind::Newline | TokenKind::CommentLine(_) => { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; - }, + } _ => { // Parse as rule pos = self.parse_rule(tokens, pos); @@ -538,7 +537,9 @@ impl CstTreeBuilder { TokenKind::Output => self.add_token(&tokens[pos], SyntaxKind::OutputKw), TokenKind::Forward => self.add_token(&tokens[pos], SyntaxKind::ForwardKw), TokenKind::Prerouting => self.add_token(&tokens[pos], SyntaxKind::PreroutingKw), - TokenKind::Postrouting => self.add_token(&tokens[pos], SyntaxKind::PostroutingKw), + TokenKind::Postrouting => { + self.add_token(&tokens[pos], SyntaxKind::PostroutingKw) + } _ => self.add_token(&tokens[pos], SyntaxKind::Identifier), } pos += 1; @@ -580,7 +581,7 @@ impl CstTreeBuilder { self.add_token(&tokens[pos], SyntaxKind::Newline); pos += 1; break; - }, + } TokenKind::Accept => self.add_token(&tokens[pos], SyntaxKind::AcceptKw), TokenKind::Drop => self.add_token(&tokens[pos], SyntaxKind::DropKw), TokenKind::Reject => self.add_token(&tokens[pos], SyntaxKind::RejectKw), @@ -596,10 +597,14 @@ impl CstTreeBuilder { TokenKind::LeftParen => { pos = self.parse_expression(tokens, pos); continue; - }, + } TokenKind::Identifier(_) => self.add_token(&tokens[pos], SyntaxKind::Identifier), - TokenKind::StringLiteral(_) => self.add_token(&tokens[pos], SyntaxKind::StringLiteral), - TokenKind::NumberLiteral(_) => self.add_token(&tokens[pos], SyntaxKind::NumberLiteral), + TokenKind::StringLiteral(_) => { + self.add_token(&tokens[pos], SyntaxKind::StringLiteral) + } + TokenKind::NumberLiteral(_) => { + self.add_token(&tokens[pos], SyntaxKind::NumberLiteral) + } TokenKind::IpAddress(_) => self.add_token(&tokens[pos], SyntaxKind::IpAddress), _ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())), } @@ -627,7 +632,7 @@ impl CstTreeBuilder { TokenKind::LeftParen => { paren_depth += 1; self.add_token(&tokens[pos], SyntaxKind::LeftParen); - }, + } TokenKind::RightParen => { self.add_token(&tokens[pos], SyntaxKind::RightParen); paren_depth -= 1; @@ -635,20 +640,22 @@ impl CstTreeBuilder { pos += 1; break; } - }, + } TokenKind::LeftBracket => { pos = self.parse_set_expression(tokens, pos); continue; - }, + } TokenKind::Identifier(_) => { // Check if this is a function call - if pos + 1 < tokens.len() && matches!(tokens[pos + 1].kind, TokenKind::LeftParen) { + if pos + 1 < tokens.len() + && matches!(tokens[pos + 1].kind, TokenKind::LeftParen) + { pos = self.parse_call_expression(tokens, pos); continue; } else { self.add_token(&tokens[pos], SyntaxKind::Identifier); } - }, + } TokenKind::Dash => { // Could be binary or unary expression if self.is_binary_context() { @@ -656,7 +663,7 @@ impl CstTreeBuilder { } else { self.parse_unary_expression(tokens, pos); } - }, + } _ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())), } pos += 1; @@ -789,7 +796,9 @@ impl CstTreeBuilder { pos += 1; // Parse what to flush (table, chain, etc.) - while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::Semicolon) { + while pos < tokens.len() + && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::Semicolon) + { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } @@ -845,7 +854,9 @@ impl CstTreeBuilder { pos += 1; // Parse element specification - while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::LeftBrace) { + while pos < tokens.len() + && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::LeftBrace) + { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } @@ -914,7 +925,9 @@ impl CstTreeBuilder { fn is_binary_context(&self) -> bool { // Simple heuristic: assume binary if we have recent tokens in current scope - self.stack.last().map_or(false, |(_, children)| children.len() > 1) + self.stack + .last() + .map_or(false, |(_, children)| children.len() > 1) } fn start_node(&mut self, kind: SyntaxKind) { @@ -925,7 +938,8 @@ impl CstTreeBuilder { // Handle whitespace specially if it's part of the token if let TokenKind::Identifier(text) = &token.kind { if text.chars().all(|c| c.is_whitespace()) { - let whitespace_token = GreenNode::new(SyntaxKind::Whitespace.to_raw(), std::iter::empty()); + let whitespace_token = + GreenNode::new(SyntaxKind::Whitespace.to_raw(), std::iter::empty()); if let Some((_, children)) = self.stack.last_mut() { children.push(NodeOrToken::Node(whitespace_token)); } @@ -978,14 +992,17 @@ impl CstValidator { SyntaxKind::Rule => self.validate_rule(node), SyntaxKind::Set | SyntaxKind::Map => self.validate_set_or_map(node), SyntaxKind::Element => self.validate_element(node), - SyntaxKind::Expression | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::CallExpr => { - self.validate_expression(node) - }, + SyntaxKind::Expression + | SyntaxKind::BinaryExpr + | SyntaxKind::UnaryExpr + | SyntaxKind::CallExpr => self.validate_expression(node), SyntaxKind::IncludeStmt => self.validate_include(node), SyntaxKind::DefineStmt => self.validate_define(node), - SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt => self.validate_statement(node), + SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt => { + self.validate_statement(node) + } SyntaxKind::Whitespace => Ok(()), // Whitespace is always valid - _ => Ok(()), // Other nodes are generally valid + _ => Ok(()), // Other nodes are generally valid } } @@ -1003,15 +1020,24 @@ impl CstValidator { if let Some(child_node) = child.as_node() { let child_kind = SyntaxKind::from_raw(child_node.kind()); match child_kind { - SyntaxKind::Table | SyntaxKind::IncludeStmt | SyntaxKind::DefineStmt | - SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt | - SyntaxKind::Element | SyntaxKind::Comment | SyntaxKind::Newline | - SyntaxKind::Shebang | SyntaxKind::Whitespace => { + SyntaxKind::Table + | SyntaxKind::IncludeStmt + | SyntaxKind::DefineStmt + | SyntaxKind::FlushStmt + | SyntaxKind::AddStmt + | SyntaxKind::DeleteStmt + | SyntaxKind::Element + | SyntaxKind::Comment + | SyntaxKind::Newline + | SyntaxKind::Shebang + | SyntaxKind::Whitespace => { self.validate(child_node)?; - }, + } _ => { return Err(CstError::UnexpectedToken { - expected: "table, include, define, flush, add, delete, element, or comment".to_string(), + expected: + "table, include, define, flush, add, delete, element, or comment" + .to_string(), found: format!("{:?}", child_kind), }); } @@ -1052,11 +1078,16 @@ impl CstValidator { if let Some(second_node) = second.as_node() { let second_kind = SyntaxKind::from_raw(second_node.kind()); match second_kind { - SyntaxKind::InetKw | SyntaxKind::IpKw | SyntaxKind::Ip6Kw | - SyntaxKind::ArpKw | SyntaxKind::BridgeKw | SyntaxKind::NetdevKw => {}, + SyntaxKind::InetKw + | SyntaxKind::IpKw + | SyntaxKind::Ip6Kw + | SyntaxKind::ArpKw + | SyntaxKind::BridgeKw + | SyntaxKind::NetdevKw => {} _ => { return Err(CstError::UnexpectedToken { - expected: "table family (inet, ip, ip6, arp, bridge, netdev)".to_string(), + expected: "table family (inet, ip, ip6, arp, bridge, netdev)" + .to_string(), found: format!("{:?}", second_kind), }); } diff --git a/src/lexer.rs b/src/lexer.rs index 6a9e087..e28cdae 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -19,7 +19,7 @@ pub type LexResult = Result; /// Token kinds for nftables configuration files #[derive(Logos, Debug, Clone, PartialEq, Eq, Hash)] -#[logos(skip r"[ \t\f]+")] // Skip whitespace but not newlines +#[logos(skip r"[ \t\f]+")] // Skip whitespace but not newlines pub enum TokenKind { // Keywords #[token("table")] @@ -198,7 +198,11 @@ pub enum TokenKind { NumberLiteral(u64), #[regex(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+", |lex| lex.slice().to_owned())] IpAddress(String), - #[regex(r"(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:)*::[0-9a-fA-F:]*", ipv6_address, priority = 5)] + #[regex( + r"(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:)*::[0-9a-fA-F:]*", + ipv6_address, + priority = 5 + )] Ipv6Address(String), #[regex(r"[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}", |lex| lex.slice().to_owned())] MacAddress(String), @@ -245,7 +249,9 @@ impl fmt::Display for TokenKind { fn string_literal(lex: &mut Lexer) -> String { let slice = lex.slice(); // Remove surrounding quotes and process escape sequences - slice[1..slice.len()-1].replace("\\\"", "\"").replace("\\\\", "\\") + slice[1..slice.len() - 1] + .replace("\\\"", "\"") + .replace("\\\\", "\\") } fn number_literal(lex: &mut Lexer) -> Option { @@ -309,7 +315,7 @@ impl<'a> NftablesLexer<'a> { let text = &self.source[span.clone()]; let range = TextRange::new( TextSize::from(span.start as u32), - TextSize::from(span.end as u32) + TextSize::from(span.end as u32), ); match result { @@ -322,8 +328,11 @@ impl<'a> NftablesLexer<'a> { return Err(LexError::UnterminatedString { position: span.start, }); - } else if text.chars().any(|c| c.is_ascii_digit()) && - text.chars().any(|c| !c.is_ascii_digit() && c != '.' && c != 'x' && c != 'X') { + } else if text.chars().any(|c| c.is_ascii_digit()) + && text + .chars() + .any(|c| !c.is_ascii_digit() && c != '.' && c != 'x' && c != 'X') + { return Err(LexError::InvalidNumber { text: text.to_owned(), }); @@ -349,7 +358,7 @@ impl<'a> NftablesLexer<'a> { let text = &self.source[span.clone()]; let range = TextRange::new( TextSize::from(span.start as u32), - TextSize::from(span.end as u32) + TextSize::from(span.end as u32), ); let kind = result.unwrap_or(TokenKind::Error); @@ -395,7 +404,11 @@ mod tests { let mut lexer = NftablesLexer::new(source); let tokens = lexer.tokenize().expect("Tokenization should succeed"); - assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::CommentLine(_)))); + assert!( + tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::CommentLine(_))) + ); assert!(tokens.iter().any(|t| t.kind == TokenKind::Table)); } diff --git a/src/main.rs b/src/main.rs index 5985c74..5cb40c9 100755 --- a/src/main.rs +++ b/src/main.rs @@ -4,17 +4,17 @@ mod lexer; mod parser; mod syntax; +use anyhow::{Context, Result}; +use clap::Parser; use std::fs; use std::io::{self, Write}; use std::path::Path; -use clap::Parser; -use anyhow::{Context, Result}; use thiserror::Error; +use crate::cst::CstBuilder; use crate::lexer::NftablesLexer; use crate::parser::Parser as NftablesParser; use crate::syntax::{FormatConfig, IndentStyle, NftablesFormatter}; -use crate::cst::CstBuilder; #[derive(Error, Debug)] enum FormatterError { @@ -85,14 +85,18 @@ fn process_nftables_config(args: Args) -> Result<()> { // Use error-recovery tokenization for debug mode lexer.tokenize_with_errors() } else { - lexer.tokenize() + lexer + .tokenize() .map_err(|e| FormatterError::ParseError(e.to_string()))? }; if args.debug { eprintln!("=== TOKENS ==="); for (i, token) in tokens.iter().enumerate() { - eprintln!("{:3}: {:?} @ {:?} = '{}'", i, token.kind, token.range, token.text); + eprintln!( + "{:3}: {:?} @ {:?} = '{}'", + i, token.kind, token.range, token.text + ); } eprintln!(); @@ -126,7 +130,8 @@ fn process_nftables_config(args: Args) -> Result<()> { parsed_ruleset.unwrap_or_else(|| crate::ast::Ruleset::new()) } else { let mut parser = NftablesParser::new(tokens.clone()); - parser.parse() + parser + .parse() .map_err(|e| FormatterError::ParseError(e.to_string()))? }; @@ -160,7 +165,8 @@ fn process_nftables_config(args: Args) -> Result<()> { println!("Formatted output written to: {}", output_file); } None => { - io::stdout().write_all(formatted_output.as_bytes()) + io::stdout() + .write_all(formatted_output.as_bytes()) .with_context(|| "Failed to write to stdout")?; } } diff --git a/src/parser.rs b/src/parser.rs index aed23e8..8db1593 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,12 +1,14 @@ use crate::ast::*; use crate::lexer::{LexError, Token, TokenKind}; -use anyhow::{anyhow, Result}; +use anyhow::{Result, anyhow}; use thiserror::Error; /// Parse errors for nftables configuration #[derive(Error, Debug)] pub enum ParseError { - #[error("Unexpected token at line {line}, column {column}: expected {expected}, found '{found}'")] + #[error( + "Unexpected token at line {line}, column {column}: expected {expected}, found '{found}'" + )] UnexpectedToken { line: usize, column: usize, @@ -123,7 +125,10 @@ impl Parser { line: 1, column: 1, expected: "include, define, table, or comment".to_string(), - found: self.peek().map(|t| t.text.clone()).unwrap_or("EOF".to_string()), + found: self + .peek() + .map(|t| t.text.clone()) + .unwrap_or("EOF".to_string()), }); } } @@ -154,21 +159,23 @@ impl Parser { self.consume(TokenKind::Include, "Expected 'include'")?; let path = match self.advance() { - Some(token) if matches!(token.kind, TokenKind::StringLiteral(_)) => { - match &token.kind { - TokenKind::StringLiteral(s) => s.clone(), - _ => unreachable!(), - } + Some(token) if matches!(token.kind, TokenKind::StringLiteral(_)) => match &token.kind { + TokenKind::StringLiteral(s) => s.clone(), + _ => unreachable!(), + }, + None => { + return Err(ParseError::MissingToken { + expected: "string literal after 'include'".to_string(), + }); + } + Some(token) => { + return Err(ParseError::UnexpectedToken { + line: 1, + column: 1, + expected: "string literal".to_string(), + found: token.text.clone(), + }); } - None => return Err(ParseError::MissingToken { - expected: "string literal after 'include'".to_string() - }), - Some(token) => return Err(ParseError::UnexpectedToken { - line: 1, - column: 1, - expected: "string literal".to_string(), - found: token.text.clone(), - }), }; Ok(Include { path }) @@ -179,21 +186,23 @@ impl Parser { self.consume(TokenKind::Define, "Expected 'define'")?; let name = match self.advance() { - Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => { - match &token.kind { - TokenKind::Identifier(s) => s.clone(), - _ => unreachable!(), - } + Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => match &token.kind { + TokenKind::Identifier(s) => s.clone(), + _ => unreachable!(), + }, + None => { + return Err(ParseError::MissingToken { + expected: "identifier after 'define'".to_string(), + }); + } + Some(token) => { + return Err(ParseError::UnexpectedToken { + line: 1, + column: 1, + expected: "identifier".to_string(), + found: token.text.clone(), + }); } - None => return Err(ParseError::MissingToken { - expected: "identifier after 'define'".to_string() - }), - Some(token) => return Err(ParseError::UnexpectedToken { - line: 1, - column: 1, - expected: "identifier".to_string(), - found: token.text.clone(), - }), }; self.consume(TokenKind::Assign, "Expected '=' after define name")?; @@ -232,10 +241,14 @@ impl Parser { Some(TokenKind::Newline) => { self.advance(); } - _ => { return Err(ParseError::SemanticError { - message: format!("Unexpected token in table: {}", - self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")), - }.into()); + _ => { + return Err(ParseError::SemanticError { + message: format!( + "Unexpected token in table: {}", + self.peek().map(|t| t.text.as_str()).unwrap_or("EOF") + ), + } + .into()); } } self.skip_whitespace(); @@ -337,8 +350,8 @@ impl Parser { // Parse expressions and action while !self.current_token_is(&TokenKind::Newline) && !self.current_token_is(&TokenKind::RightBrace) - && !self.is_at_end() { - + && !self.is_at_end() + { // Check for actions first match self.peek().map(|t| &t.kind) { Some(TokenKind::Accept) => { @@ -415,7 +428,8 @@ impl Parser { } else { return Err(ParseError::InvalidStatement { message: "Expected string literal after 'comment'".to_string(), - }.into()); + } + .into()); } } _ => { @@ -451,16 +465,16 @@ impl Parser { fn parse_comparison_expression(&mut self) -> Result { let mut expr = self.parse_range_expression()?; - while let Some(token) = self.peek() { - let operator = match &token.kind { - TokenKind::Eq => BinaryOperator::Eq, - TokenKind::Ne => BinaryOperator::Ne, - TokenKind::Lt => BinaryOperator::Lt, - TokenKind::Le => BinaryOperator::Le, - TokenKind::Gt => BinaryOperator::Gt, - TokenKind::Ge => BinaryOperator::Ge, - _ => break, - }; + while let Some(token) = self.peek() { + let operator = match &token.kind { + TokenKind::Eq => BinaryOperator::Eq, + TokenKind::Ne => BinaryOperator::Ne, + TokenKind::Lt => BinaryOperator::Lt, + TokenKind::Le => BinaryOperator::Le, + TokenKind::Gt => BinaryOperator::Gt, + TokenKind::Ge => BinaryOperator::Ge, + _ => break, + }; self.advance(); // consume operator let right = self.parse_range_expression()?; @@ -664,19 +678,22 @@ impl Parser { // Check for rate expression (e.g., 10/minute) if self.current_token_is(&TokenKind::Slash) { self.advance(); // consume '/' - if let Some(token) = self.peek() { - if matches!(token.kind, TokenKind::Identifier(_)) { - let unit = self.advance().unwrap().text.clone(); - Ok(Expression::String(format!("{}/{}", num, unit))) - } else { - Err(ParseError::InvalidExpression { - message: "Expected identifier after '/' in rate expression".to_string(), - }.into()) - } + if let Some(token) = self.peek() { + if matches!(token.kind, TokenKind::Identifier(_)) { + let unit = self.advance().unwrap().text.clone(); + Ok(Expression::String(format!("{}/{}", num, unit))) + } else { + Err(ParseError::InvalidExpression { + message: "Expected identifier after '/' in rate expression" + .to_string(), + } + .into()) + } } else { Err(ParseError::InvalidExpression { message: "Expected identifier after '/' in rate expression".to_string(), - }.into()) + } + .into()) } } else { Ok(Expression::Number(num)) @@ -694,12 +711,14 @@ impl Parser { } else { Err(ParseError::InvalidExpression { message: "Expected number after '/' in CIDR notation".to_string(), - }.into()) + } + .into()) } } else { Err(ParseError::InvalidExpression { message: "Expected number after '/' in CIDR notation".to_string(), - }.into()) + } + .into()) } // Check for port specification (e.g., 192.168.1.100:80) } else if self.current_token_is(&TokenKind::Colon) { @@ -710,13 +729,17 @@ impl Parser { Ok(Expression::String(format!("{}:{}", addr, port))) } else { Err(ParseError::InvalidExpression { - message: "Expected number after ':' in address:port specification".to_string(), - }.into()) + message: "Expected number after ':' in address:port specification" + .to_string(), + } + .into()) } } else { Err(ParseError::InvalidExpression { - message: "Expected number after ':' in address:port specification".to_string(), - }.into()) + message: "Expected number after ':' in address:port specification" + .to_string(), + } + .into()) } } else { Ok(Expression::IpAddress(addr)) @@ -751,12 +774,13 @@ impl Parser { self.consume(TokenKind::RightBrace, "Expected '}' to close set")?; Ok(Expression::Set(elements)) } - _ => { - Err(ParseError::InvalidExpression { - message: format!("Unexpected token in expression: {}", - self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")), - }.into()) + _ => Err(ParseError::InvalidExpression { + message: format!( + "Unexpected token in expression: {}", + self.peek().map(|t| t.text.as_str()).unwrap_or("EOF") + ), } + .into()), } } @@ -822,12 +846,12 @@ impl Parser { let token_clone = token.clone(); Err(self.unexpected_token_error_with_token( "family (ip, ip6, inet, arp, bridge, netdev)".to_string(), - &token_clone + &token_clone, )) } }, None => Err(ParseError::MissingToken { - expected: "family".to_string() + expected: "family".to_string(), }), } } @@ -842,12 +866,12 @@ impl Parser { let token_clone = token.clone(); Err(self.unexpected_token_error_with_token( "chain type (filter, nat, route)".to_string(), - &token_clone + &token_clone, )) } }, None => Err(ParseError::MissingToken { - expected: "chain type".to_string() + expected: "chain type".to_string(), }), } } @@ -864,12 +888,12 @@ impl Parser { let token_clone = token.clone(); Err(self.unexpected_token_error_with_token( "hook (input, output, forward, prerouting, postrouting)".to_string(), - &token_clone + &token_clone, )) } }, None => Err(ParseError::MissingToken { - expected: "hook".to_string() + expected: "hook".to_string(), }), } } @@ -879,7 +903,10 @@ impl Parser { Some(token) => match token.kind { TokenKind::Accept => Ok(Policy::Accept), TokenKind::Drop => Ok(Policy::Drop), - _ => Err(anyhow!("Expected policy (accept, drop), got: {}", token.text)), + _ => Err(anyhow!( + "Expected policy (accept, drop), got: {}", + token.text + )), }, None => Err(anyhow!("Expected policy")), } @@ -887,12 +914,10 @@ impl Parser { fn parse_identifier(&mut self) -> Result { match self.advance() { - Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => { - match &token.kind { - TokenKind::Identifier(s) => Ok(s.clone()), - _ => unreachable!(), - } - } + Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => match &token.kind { + TokenKind::Identifier(s) => Ok(s.clone()), + _ => unreachable!(), + }, Some(token) => Err(anyhow!("Expected identifier, got: {}", token.text)), None => Err(anyhow!("Expected identifier")), } @@ -905,17 +930,33 @@ impl Parser { match &token.kind { TokenKind::Identifier(s) => Ok(s.clone()), // Allow keywords to be used as identifiers in certain contexts - TokenKind::Filter | TokenKind::Nat | TokenKind::Route | - TokenKind::Input | TokenKind::Output | TokenKind::Forward | - TokenKind::Prerouting | TokenKind::Postrouting | - TokenKind::Accept | TokenKind::Drop | TokenKind::Reject | - TokenKind::State | TokenKind::Ct | TokenKind::Type | TokenKind::Hook | - TokenKind::Priority | TokenKind::Policy | - TokenKind::Tcp | TokenKind::Udp | TokenKind::Icmp | TokenKind::Icmpv6 | - TokenKind::Ip | TokenKind::Ip6 => { - Ok(token.text.clone()) - } - _ => Err(anyhow!("Expected identifier or keyword, got: {}", token.text)), + TokenKind::Filter + | TokenKind::Nat + | TokenKind::Route + | TokenKind::Input + | TokenKind::Output + | TokenKind::Forward + | TokenKind::Prerouting + | TokenKind::Postrouting + | TokenKind::Accept + | TokenKind::Drop + | TokenKind::Reject + | TokenKind::State + | TokenKind::Ct + | TokenKind::Type + | TokenKind::Hook + | TokenKind::Priority + | TokenKind::Policy + | TokenKind::Tcp + | TokenKind::Udp + | TokenKind::Icmp + | TokenKind::Icmpv6 + | TokenKind::Ip + | TokenKind::Ip6 => Ok(token.text.clone()), + _ => Err(anyhow!( + "Expected identifier or keyword, got: {}", + token.text + )), } } None => Err(anyhow!("Expected identifier or keyword")), @@ -924,26 +965,32 @@ impl Parser { fn parse_string_or_identifier(&mut self) -> Result { match self.advance() { - Some(token) if matches!(token.kind, TokenKind::Identifier(_) | TokenKind::StringLiteral(_)) => { + Some(token) + if matches!( + token.kind, + TokenKind::Identifier(_) | TokenKind::StringLiteral(_) + ) => + { match &token.kind { TokenKind::Identifier(s) => Ok(s.clone()), TokenKind::StringLiteral(s) => Ok(s.clone()), _ => unreachable!(), } } - Some(token) => Err(anyhow!("Expected string or identifier, got: {}", token.text)), + Some(token) => Err(anyhow!( + "Expected string or identifier, got: {}", + token.text + )), None => Err(anyhow!("Expected string or identifier")), } } fn parse_number(&mut self) -> Result { match self.advance() { - Some(token) if matches!(token.kind, TokenKind::NumberLiteral(_)) => { - match &token.kind { - TokenKind::NumberLiteral(n) => Ok(*n), - _ => unreachable!(), - } - } + Some(token) if matches!(token.kind, TokenKind::NumberLiteral(_)) => match &token.kind { + TokenKind::NumberLiteral(n) => Ok(*n), + _ => unreachable!(), + }, Some(token) => Err(anyhow!("Expected number, got: {}", token.text)), None => Err(anyhow!("Expected number")), } @@ -1029,11 +1076,7 @@ impl Parser { } } - fn unexpected_token_error_with_token( - &self, - expected: String, - token: &Token, - ) -> ParseError { + fn unexpected_token_error_with_token(&self, expected: String, token: &Token) -> ParseError { ParseError::UnexpectedToken { line: 1, // TODO: Calculate line from range column: token.range.start().into(), diff --git a/src/syntax.rs b/src/syntax.rs index edcf135..2c498d0 100644 --- a/src/syntax.rs +++ b/src/syntax.rs @@ -71,28 +71,28 @@ impl NftablesFormatter { self.format_include(&mut output, include, 0); } - // Add separator if we have includes - if !ruleset.includes.is_empty() { - self.add_separator(&mut output); - } + // Add separator if we have includes + if !ruleset.includes.is_empty() { + self.add_separator(&mut output); + } // Format defines for define in &ruleset.defines { self.format_define(&mut output, define, 0); } - // Add separator if we have defines - if !ruleset.defines.is_empty() { - self.add_separator(&mut output); - } + // Add separator if we have defines + if !ruleset.defines.is_empty() { + self.add_separator(&mut output); + } // Format tables let mut table_iter = ruleset.tables.values().peekable(); while let Some(table) = table_iter.next() { - self.format_table(&mut output, table, 0); // Add separator between tables - if table_iter.peek().is_some() { - self.add_separator(&mut output); - } + self.format_table(&mut output, table, 0); // Add separator between tables + if table_iter.peek().is_some() { + self.add_separator(&mut output); + } } // Ensure file ends with newline @@ -104,44 +104,62 @@ impl NftablesFormatter { } fn format_include(&self, output: &mut String, include: &Include, level: usize) { - let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + let indent = self + .config + .indent_style + .format(level, self.config.spaces_per_level); writeln!(output, "{}include \"{}\"", indent, include.path).unwrap(); } fn format_define(&self, output: &mut String, define: &Define, level: usize) { - let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + let indent = self + .config + .indent_style + .format(level, self.config.spaces_per_level); write!(output, "{}define {} = ", indent, define.name).unwrap(); self.format_expression(output, &define.value); output.push('\n'); } fn format_table(&self, output: &mut String, table: &Table, level: usize) { - let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + let indent = self + .config + .indent_style + .format(level, self.config.spaces_per_level); writeln!(output, "{}table {} {} {{", indent, table.family, table.name).unwrap(); // Format chains let mut chain_iter = table.chains.values().peekable(); while let Some(chain) = chain_iter.next() { - self.format_chain(output, chain, level + 1); // Add separator between chains - if chain_iter.peek().is_some() { - self.add_separator(output); - } + self.format_chain(output, chain, level + 1); // Add separator between chains + if chain_iter.peek().is_some() { + self.add_separator(output); + } } writeln!(output, "{}}}", indent).unwrap(); } fn format_chain(&self, output: &mut String, chain: &Chain, level: usize) { - let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + let indent = self + .config + .indent_style + .format(level, self.config.spaces_per_level); writeln!(output, "{}chain {} {{", indent, chain.name).unwrap(); // Format chain properties if let Some(chain_type) = &chain.chain_type { - write!(output, "{}type {}", - self.config.indent_style.format(level + 1, self.config.spaces_per_level), - chain_type).unwrap(); + write!( + output, + "{}type {}", + self.config + .indent_style + .format(level + 1, self.config.spaces_per_level), + chain_type + ) + .unwrap(); if let Some(hook) = &chain.hook { write!(output, " hook {}", hook).unwrap(); @@ -179,7 +197,10 @@ impl NftablesFormatter { } fn format_rule(&self, output: &mut String, rule: &Rule, level: usize) { - let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + let indent = self + .config + .indent_style + .format(level, self.config.spaces_per_level); write!(output, "{}", indent).unwrap(); @@ -212,7 +233,11 @@ impl NftablesFormatter { Expression::Ipv6Address(addr) => write!(output, "{}", addr).unwrap(), Expression::MacAddress(addr) => write!(output, "{}", addr).unwrap(), - Expression::Binary { left, operator, right } => { + Expression::Binary { + left, + operator, + right, + } => { self.format_expression(output, left); write!(output, " {} ", operator).unwrap(); self.format_expression(output, right); @@ -279,7 +304,10 @@ impl std::str::FromStr for IndentStyle { match s.to_lowercase().as_str() { "tabs" | "tab" => Ok(IndentStyle::Tabs), "spaces" | "space" => Ok(IndentStyle::Spaces), - _ => Err(format!("Invalid indent style: {}. Use 'tabs' or 'spaces'", s)), + _ => Err(format!( + "Invalid indent style: {}. Use 'tabs' or 'spaces'", + s + )), } } } @@ -290,21 +318,20 @@ mod tests { #[test] fn test_format_simple_table() { - let table = Table::new(Family::Inet, "test".to_string()) - .add_chain( - Chain::new("input".to_string()) - .with_type(ChainType::Filter) - .with_hook(Hook::Input) - .with_priority(0) - .with_policy(Policy::Accept) - .add_rule(Rule::new( - vec![Expression::Interface { - direction: InterfaceDirection::Input, - name: "lo".to_string(), - }], - Action::Accept, - )) - ); + let table = Table::new(Family::Inet, "test".to_string()).add_chain( + Chain::new("input".to_string()) + .with_type(ChainType::Filter) + .with_hook(Hook::Input) + .with_priority(0) + .with_policy(Policy::Accept) + .add_rule(Rule::new( + vec![Expression::Interface { + direction: InterfaceDirection::Input, + name: "lo".to_string(), + }], + Action::Accept, + )), + ); let formatter = NftablesFormatter::new(FormatConfig::default()); let mut output = String::new();