//! Concrete Syntax Tree implementation for nftables configuration files //! //! Lossless representation preserving whitespace, comments, and formatting. use crate::lexer::{Token, TokenKind}; use cstree::{RawSyntaxKind, green::GreenNode, util::NodeOrToken}; use std::fmt; use thiserror::Error; /// nftables syntax node types #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { // Root and containers Root = 0, Table, Chain, Rule, Set, Map, Element, // Expressions Expression, BinaryExpr, UnaryExpr, CallExpr, SetExpr, RangeExpr, // Statements Statement, IncludeStmt, DefineStmt, FlushStmt, AddStmt, DeleteStmt, // Literals and identifiers Identifier, StringLiteral, NumberLiteral, IpAddress, Ipv6Address, MacAddress, // Keywords TableKw, ChainKw, RuleKw, SetKw, MapKw, ElementKw, IncludeKw, DefineKw, FlushKw, AddKw, DeleteKw, InsertKw, ReplaceKw, // Chain types and hooks FilterKw, NatKw, RouteKw, InputKw, OutputKw, ForwardKw, PreroutingKw, PostroutingKw, // Protocols and families IpKw, Ip6Kw, InetKw, ArpKw, BridgeKw, NetdevKw, TcpKw, UdpKw, IcmpKw, Icmpv6Kw, // Match keywords SportKw, DportKw, SaddrKw, DaddrKw, ProtocolKw, NexthdrKw, TypeKw, HookKw, PriorityKw, PolicyKw, IifnameKw, OifnameKw, CtKw, StateKw, // Actions AcceptKw, DropKw, RejectKw, ReturnKw, JumpKw, GotoKw, ContinueKw, LogKw, CommentKw, // States EstablishedKw, RelatedKw, NewKw, InvalidKw, // Operators EqOp, NeOp, LeOp, GeOp, LtOp, GtOp, // Punctuation LeftBrace, RightBrace, LeftParen, RightParen, LeftBracket, RightBracket, Comma, Semicolon, Colon, Assign, Dash, Slash, Dot, // Trivia Whitespace, Newline, Comment, Shebang, // Error recovery Error, } impl From for SyntaxKind { fn from(kind: TokenKind) -> Self { match kind { TokenKind::Table => SyntaxKind::TableKw, TokenKind::Chain => SyntaxKind::ChainKw, TokenKind::Rule => SyntaxKind::RuleKw, TokenKind::Set => SyntaxKind::SetKw, TokenKind::Map => SyntaxKind::MapKw, TokenKind::Element => SyntaxKind::ElementKw, TokenKind::Include => SyntaxKind::IncludeKw, TokenKind::Define => SyntaxKind::DefineKw, TokenKind::Flush => SyntaxKind::FlushKw, TokenKind::Add => SyntaxKind::AddKw, TokenKind::Delete => SyntaxKind::DeleteKw, TokenKind::Insert => SyntaxKind::InsertKw, TokenKind::Replace => SyntaxKind::ReplaceKw, TokenKind::Filter => SyntaxKind::FilterKw, TokenKind::Nat => SyntaxKind::NatKw, TokenKind::Route => SyntaxKind::RouteKw, TokenKind::Input => SyntaxKind::InputKw, TokenKind::Output => SyntaxKind::OutputKw, TokenKind::Forward => SyntaxKind::ForwardKw, TokenKind::Prerouting => SyntaxKind::PreroutingKw, TokenKind::Postrouting => SyntaxKind::PostroutingKw, TokenKind::Ip => SyntaxKind::IpKw, TokenKind::Ip6 => SyntaxKind::Ip6Kw, TokenKind::Inet => SyntaxKind::InetKw, TokenKind::Arp => SyntaxKind::ArpKw, TokenKind::Bridge => SyntaxKind::BridgeKw, TokenKind::Netdev => SyntaxKind::NetdevKw, TokenKind::Tcp => SyntaxKind::TcpKw, TokenKind::Udp => SyntaxKind::UdpKw, TokenKind::Icmp => SyntaxKind::IcmpKw, TokenKind::Icmpv6 => SyntaxKind::Icmpv6Kw, TokenKind::Sport => SyntaxKind::SportKw, TokenKind::Dport => SyntaxKind::DportKw, TokenKind::Saddr => SyntaxKind::SaddrKw, TokenKind::Daddr => SyntaxKind::DaddrKw, TokenKind::Protocol => SyntaxKind::ProtocolKw, TokenKind::Nexthdr => SyntaxKind::NexthdrKw, TokenKind::Type => SyntaxKind::TypeKw, TokenKind::Hook => SyntaxKind::HookKw, TokenKind::Priority => SyntaxKind::PriorityKw, TokenKind::Policy => SyntaxKind::PolicyKw, TokenKind::Iifname => SyntaxKind::IifnameKw, TokenKind::Oifname => SyntaxKind::OifnameKw, TokenKind::Ct => SyntaxKind::CtKw, TokenKind::State => SyntaxKind::StateKw, TokenKind::Accept => SyntaxKind::AcceptKw, TokenKind::Drop => SyntaxKind::DropKw, TokenKind::Reject => SyntaxKind::RejectKw, TokenKind::Return => SyntaxKind::ReturnKw, TokenKind::Jump => SyntaxKind::JumpKw, TokenKind::Goto => SyntaxKind::GotoKw, TokenKind::Continue => SyntaxKind::ContinueKw, TokenKind::Log => SyntaxKind::LogKw, TokenKind::Comment => SyntaxKind::CommentKw, TokenKind::Established => SyntaxKind::EstablishedKw, TokenKind::Related => SyntaxKind::RelatedKw, TokenKind::New => SyntaxKind::NewKw, TokenKind::Invalid => SyntaxKind::InvalidKw, TokenKind::Eq => SyntaxKind::EqOp, TokenKind::Ne => SyntaxKind::NeOp, TokenKind::Le => SyntaxKind::LeOp, TokenKind::Ge => SyntaxKind::GeOp, TokenKind::Lt => SyntaxKind::LtOp, TokenKind::Gt => SyntaxKind::GtOp, TokenKind::LeftBrace => SyntaxKind::LeftBrace, TokenKind::RightBrace => SyntaxKind::RightBrace, TokenKind::LeftParen => SyntaxKind::LeftParen, TokenKind::RightParen => SyntaxKind::RightParen, TokenKind::LeftBracket => SyntaxKind::LeftBracket, TokenKind::RightBracket => SyntaxKind::RightBracket, TokenKind::Comma => SyntaxKind::Comma, TokenKind::Semicolon => SyntaxKind::Semicolon, TokenKind::Colon => SyntaxKind::Colon, TokenKind::Assign => SyntaxKind::Assign, TokenKind::Dash => SyntaxKind::Dash, TokenKind::Slash => SyntaxKind::Slash, TokenKind::Dot => SyntaxKind::Dot, TokenKind::StringLiteral(_) => SyntaxKind::StringLiteral, TokenKind::NumberLiteral(_) => SyntaxKind::NumberLiteral, TokenKind::IpAddress(_) => SyntaxKind::IpAddress, TokenKind::Ipv6Address(_) => SyntaxKind::Ipv6Address, TokenKind::MacAddress(_) => SyntaxKind::MacAddress, TokenKind::Identifier(_) => SyntaxKind::Identifier, TokenKind::Newline => SyntaxKind::Newline, TokenKind::CommentLine(_) => SyntaxKind::Comment, TokenKind::Shebang(_) => SyntaxKind::Shebang, TokenKind::Error => SyntaxKind::Error, } } } impl fmt::Display for SyntaxKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let name = match self { SyntaxKind::Root => "ROOT", SyntaxKind::Table => "TABLE", SyntaxKind::Chain => "CHAIN", SyntaxKind::Rule => "RULE", SyntaxKind::Set => "SET", SyntaxKind::Map => "MAP", SyntaxKind::Element => "ELEMENT", SyntaxKind::Expression => "EXPRESSION", SyntaxKind::BinaryExpr => "BINARY_EXPR", SyntaxKind::UnaryExpr => "UNARY_EXPR", SyntaxKind::CallExpr => "CALL_EXPR", SyntaxKind::SetExpr => "SET_EXPR", SyntaxKind::RangeExpr => "RANGE_EXPR", SyntaxKind::Statement => "STATEMENT", SyntaxKind::IncludeStmt => "INCLUDE_STMT", SyntaxKind::DefineStmt => "DEFINE_STMT", SyntaxKind::FlushStmt => "FLUSH_STMT", SyntaxKind::AddStmt => "ADD_STMT", SyntaxKind::DeleteStmt => "DELETE_STMT", SyntaxKind::Identifier => "IDENTIFIER", SyntaxKind::StringLiteral => "STRING_LITERAL", SyntaxKind::NumberLiteral => "NUMBER_LITERAL", SyntaxKind::IpAddress => "IP_ADDRESS", SyntaxKind::Ipv6Address => "IPV6_ADDRESS", SyntaxKind::MacAddress => "MAC_ADDRESS", SyntaxKind::Whitespace => "WHITESPACE", SyntaxKind::Newline => "NEWLINE", SyntaxKind::Comment => "COMMENT", SyntaxKind::Shebang => "SHEBANG", SyntaxKind::Error => "ERROR", _ => return write!(f, "{:?}", self), }; write!(f, "{}", name) } } /// Language definition for nftables CST #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct NftablesLanguage; impl SyntaxKind { pub fn to_raw(self) -> RawSyntaxKind { RawSyntaxKind(self as u32) } pub fn from_raw(raw: RawSyntaxKind) -> Self { unsafe { std::mem::transmute(raw.0 as u16) } } } /// CST construction errors #[derive(Error, Debug, PartialEq)] pub enum CstError { #[error("Unexpected token: expected {expected}, found {found}")] UnexpectedToken { expected: String, found: String }, #[error("Missing token: expected {expected}")] MissingToken { expected: String }, #[error("Invalid syntax at position {position}")] InvalidSyntax { position: usize }, } /// Result type for CST operations pub type CstResult = Result; /// Basic CST builder pub struct CstBuilder; impl CstBuilder { /// Build a proper CST from tokens that represents nftables syntax structure pub fn build_tree(tokens: &[Token]) -> GreenNode { let mut builder = CstTreeBuilder::new(); builder.parse_tokens(tokens); builder.finish() } /// Validate CST tree structure according to nftables grammar rules pub fn validate_tree(node: &GreenNode) -> CstResult<()> { let validator = CstValidator::new(); validator.validate(node) } /// Parse tokens into a validated CST pub fn parse_to_cst(tokens: &[Token]) -> CstResult { let tree = Self::build_tree(tokens); Self::validate_tree(&tree)?; Ok(tree) } } /// Internal tree builder that constructs CST according to nftables grammar struct CstTreeBuilder { stack: Vec<( SyntaxKind, Vec>, )>, } impl CstTreeBuilder { fn new() -> Self { Self { stack: vec![(SyntaxKind::Root, Vec::new())], } } fn parse_tokens(&mut self, tokens: &[Token]) { let mut i = 0; while i < tokens.len() { i = self.parse_top_level(&tokens, i); } } fn parse_top_level(&mut self, tokens: &[Token], start: usize) -> usize { if start >= tokens.len() { return start; } match &tokens[start].kind { TokenKind::Table => self.parse_table(tokens, start), TokenKind::Chain => self.parse_chain(tokens, start), TokenKind::Include => self.parse_include(tokens, start), TokenKind::Define => self.parse_define(tokens, start), TokenKind::Flush => self.parse_flush_stmt(tokens, start), TokenKind::Add => self.parse_add_stmt(tokens, start), TokenKind::Delete => self.parse_delete_stmt(tokens, start), TokenKind::Element => self.parse_element(tokens, start), TokenKind::CommentLine(_) => { self.add_token(&tokens[start], SyntaxKind::Comment); start + 1 } TokenKind::Newline => { self.add_token(&tokens[start], SyntaxKind::Newline); start + 1 } TokenKind::Shebang(_) => { self.add_token(&tokens[start], SyntaxKind::Shebang); start + 1 } _ => { // Handle other tokens as statements self.parse_statement(tokens, start) } } } fn parse_table(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Table); // Add 'table' keyword self.add_token(&tokens[start], SyntaxKind::TableKw); let mut pos = start + 1; // Parse family (inet, ip, ip6, etc.) if pos < tokens.len() { match &tokens[pos].kind { TokenKind::Inet => self.add_token(&tokens[pos], SyntaxKind::InetKw), TokenKind::Ip => self.add_token(&tokens[pos], SyntaxKind::IpKw), TokenKind::Ip6 => self.add_token(&tokens[pos], SyntaxKind::Ip6Kw), TokenKind::Arp => self.add_token(&tokens[pos], SyntaxKind::ArpKw), TokenKind::Bridge => self.add_token(&tokens[pos], SyntaxKind::BridgeKw), TokenKind::Netdev => self.add_token(&tokens[pos], SyntaxKind::NetdevKw), _ => self.add_token(&tokens[pos], SyntaxKind::Identifier), } pos += 1; } // Parse table name if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::Identifier); pos += 1; } // Parse table body if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) { self.add_token(&tokens[pos], SyntaxKind::LeftBrace); pos += 1; // Parse table contents while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { match &tokens[pos].kind { TokenKind::Chain => { pos = self.parse_chain(tokens, pos); } TokenKind::Set => { pos = self.parse_set(tokens, pos); } TokenKind::Map => { pos = self.parse_map(tokens, pos); } TokenKind::Newline | TokenKind::CommentLine(_) => { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } _ => pos += 1, // Skip unknown tokens } } // Add closing brace if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::RightBrace); pos += 1; } } self.finish_node(); pos } fn parse_chain(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Chain); self.add_token(&tokens[start], SyntaxKind::ChainKw); let mut pos = start + 1; // Parse chain name if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::Identifier); pos += 1; } // Parse chain body if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) { self.add_token(&tokens[pos], SyntaxKind::LeftBrace); pos += 1; while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { match &tokens[pos].kind { TokenKind::Type => { pos = self.parse_chain_properties(tokens, pos); } TokenKind::Newline | TokenKind::CommentLine(_) => { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } _ => { // Parse as rule pos = self.parse_rule(tokens, pos); } } } if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::RightBrace); pos += 1; } } self.finish_node(); pos } fn parse_chain_properties(&mut self, tokens: &[Token], start: usize) -> usize { let mut pos = start; // Parse 'type' if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Type) { self.add_token(&tokens[pos], SyntaxKind::TypeKw); pos += 1; } // Parse chain type (filter, nat, route) if pos < tokens.len() { match &tokens[pos].kind { TokenKind::Filter => self.add_token(&tokens[pos], SyntaxKind::FilterKw), TokenKind::Nat => self.add_token(&tokens[pos], SyntaxKind::NatKw), TokenKind::Route => self.add_token(&tokens[pos], SyntaxKind::RouteKw), _ => self.add_token(&tokens[pos], SyntaxKind::Identifier), } pos += 1; } // Parse 'hook' if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Hook) { self.add_token(&tokens[pos], SyntaxKind::HookKw); pos += 1; // Parse hook type if pos < tokens.len() { match &tokens[pos].kind { TokenKind::Input => self.add_token(&tokens[pos], SyntaxKind::InputKw), TokenKind::Output => self.add_token(&tokens[pos], SyntaxKind::OutputKw), TokenKind::Forward => self.add_token(&tokens[pos], SyntaxKind::ForwardKw), TokenKind::Prerouting => self.add_token(&tokens[pos], SyntaxKind::PreroutingKw), TokenKind::Postrouting => { self.add_token(&tokens[pos], SyntaxKind::PostroutingKw) } _ => self.add_token(&tokens[pos], SyntaxKind::Identifier), } pos += 1; } } // Parse 'priority' if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Priority) { self.add_token(&tokens[pos], SyntaxKind::PriorityKw); pos += 1; if pos < tokens.len() { if let TokenKind::NumberLiteral(_) = &tokens[pos].kind { self.add_token(&tokens[pos], SyntaxKind::NumberLiteral); pos += 1; } } } // Parse semicolon if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Semicolon) { self.add_token(&tokens[pos], SyntaxKind::Semicolon); pos += 1; } pos } fn parse_rule(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Rule); let mut pos = start; let rule_start = pos; // Parse until we hit a newline or end of input (simple rule parsing) while pos < tokens.len() { match &tokens[pos].kind { TokenKind::Newline => { self.add_token(&tokens[pos], SyntaxKind::Newline); pos += 1; break; } TokenKind::Accept => self.add_token(&tokens[pos], SyntaxKind::AcceptKw), TokenKind::Drop => self.add_token(&tokens[pos], SyntaxKind::DropKw), TokenKind::Reject => self.add_token(&tokens[pos], SyntaxKind::RejectKw), TokenKind::Return => self.add_token(&tokens[pos], SyntaxKind::ReturnKw), TokenKind::Log => self.add_token(&tokens[pos], SyntaxKind::LogKw), TokenKind::Comment => self.add_token(&tokens[pos], SyntaxKind::CommentKw), TokenKind::Eq => self.add_token(&tokens[pos], SyntaxKind::EqOp), TokenKind::Ne => self.add_token(&tokens[pos], SyntaxKind::NeOp), TokenKind::Lt => self.add_token(&tokens[pos], SyntaxKind::LtOp), TokenKind::Le => self.add_token(&tokens[pos], SyntaxKind::LeOp), TokenKind::Gt => self.add_token(&tokens[pos], SyntaxKind::GtOp), TokenKind::Ge => self.add_token(&tokens[pos], SyntaxKind::GeOp), TokenKind::LeftParen => { pos = self.parse_expression(tokens, pos); continue; } TokenKind::Identifier(_) => self.add_token(&tokens[pos], SyntaxKind::Identifier), TokenKind::StringLiteral(_) => { self.add_token(&tokens[pos], SyntaxKind::StringLiteral) } TokenKind::NumberLiteral(_) => { self.add_token(&tokens[pos], SyntaxKind::NumberLiteral) } TokenKind::IpAddress(_) => self.add_token(&tokens[pos], SyntaxKind::IpAddress), _ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())), } pos += 1; } // If no tokens were added to the rule, add at least one statement node if pos == rule_start { self.start_node(SyntaxKind::Statement); self.finish_node(); } self.finish_node(); pos } fn parse_expression(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Expression); let mut pos = start; let mut paren_depth = 0; while pos < tokens.len() { match &tokens[pos].kind { TokenKind::LeftParen => { paren_depth += 1; self.add_token(&tokens[pos], SyntaxKind::LeftParen); } TokenKind::RightParen => { self.add_token(&tokens[pos], SyntaxKind::RightParen); paren_depth -= 1; if paren_depth == 0 { pos += 1; break; } } TokenKind::LeftBracket => { pos = self.parse_set_expression(tokens, pos); continue; } TokenKind::Identifier(_) => { // Check if this is a function call if pos + 1 < tokens.len() && matches!(tokens[pos + 1].kind, TokenKind::LeftParen) { pos = self.parse_call_expression(tokens, pos); continue; } else { self.add_token(&tokens[pos], SyntaxKind::Identifier); } } TokenKind::Dash => { // Could be binary or unary expression if self.is_binary_context() { self.parse_binary_expression(tokens, pos); } else { self.parse_unary_expression(tokens, pos); } } _ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())), } pos += 1; } self.finish_node(); pos } fn parse_set_expression(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::SetExpr); let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::LeftBracket); pos += 1; while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBracket) { if matches!(tokens[pos].kind, TokenKind::Colon) { // Range expression self.start_node(SyntaxKind::RangeExpr); self.add_token(&tokens[pos], SyntaxKind::Colon); self.finish_node(); } else { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); } pos += 1; } if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::RightBracket); pos += 1; } self.finish_node(); pos } fn parse_binary_expression(&mut self, tokens: &[Token], pos: usize) { self.start_node(SyntaxKind::BinaryExpr); self.add_token(&tokens[pos], SyntaxKind::Dash); self.finish_node(); } fn parse_unary_expression(&mut self, tokens: &[Token], pos: usize) { self.start_node(SyntaxKind::UnaryExpr); self.add_token(&tokens[pos], SyntaxKind::Dash); self.finish_node(); } fn parse_set(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Set); // Basic set parsing implementation let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::SetKw); pos += 1; // Skip to end of set for now while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { pos += 1; } self.finish_node(); pos } fn parse_map(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Map); // Basic map parsing implementation let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::MapKw); pos += 1; // Skip to end of map for now while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { pos += 1; } self.finish_node(); pos } fn parse_include(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::IncludeStmt); let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::IncludeKw); pos += 1; if pos < tokens.len() { if let TokenKind::StringLiteral(_) = &tokens[pos].kind { self.add_token(&tokens[pos], SyntaxKind::StringLiteral); pos += 1; } } self.finish_node(); pos } fn parse_define(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::DefineStmt); let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::DefineKw); pos += 1; // Parse variable name if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::Identifier); pos += 1; } // Parse assignment if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Assign) { self.add_token(&tokens[pos], SyntaxKind::Assign); pos += 1; } // Parse value - could be complex expression while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } self.finish_node(); pos } fn parse_flush_stmt(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::FlushStmt); let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::FlushKw); pos += 1; // Parse what to flush (table, chain, etc.) while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::Semicolon) { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Semicolon) { self.add_token(&tokens[pos], SyntaxKind::Semicolon); pos += 1; } self.finish_node(); pos } fn parse_add_stmt(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::AddStmt); let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::AddKw); pos += 1; // Parse what to add while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } self.finish_node(); pos } fn parse_delete_stmt(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::DeleteStmt); let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::DeleteKw); pos += 1; // Parse what to delete while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } self.finish_node(); pos } fn parse_element(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Element); let mut pos = start; self.add_token(&tokens[pos], SyntaxKind::ElementKw); pos += 1; // Parse element specification while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::LeftBrace) { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } // Parse element body if present if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) { self.add_token(&tokens[pos], SyntaxKind::LeftBrace); pos += 1; while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); pos += 1; } if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::RightBrace); pos += 1; } } self.finish_node(); pos } fn parse_call_expression(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::CallExpr); let mut pos = start; // Add function name self.add_token(&tokens[pos], SyntaxKind::Identifier); pos += 1; // Add opening paren if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftParen) { self.add_token(&tokens[pos], SyntaxKind::LeftParen); pos += 1; // Parse arguments while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightParen) { if matches!(tokens[pos].kind, TokenKind::Comma) { self.add_token(&tokens[pos], SyntaxKind::Comma); } else { self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); } pos += 1; } // Add closing paren if pos < tokens.len() { self.add_token(&tokens[pos], SyntaxKind::RightParen); pos += 1; } } self.finish_node(); pos } fn parse_statement(&mut self, tokens: &[Token], start: usize) -> usize { self.start_node(SyntaxKind::Statement); self.add_token(&tokens[start], SyntaxKind::from(tokens[start].kind.clone())); self.finish_node(); start + 1 } fn is_binary_context(&self) -> bool { // Simple heuristic: assume binary if we have recent tokens in current scope self.stack .last() .map_or(false, |(_, children)| children.len() > 1) } fn start_node(&mut self, kind: SyntaxKind) { self.stack.push((kind, Vec::new())); } fn add_token(&mut self, token: &Token, kind: SyntaxKind) { // Handle whitespace specially if it's part of the token if let TokenKind::Identifier(text) = &token.kind { if text.chars().all(|c| c.is_whitespace()) { let whitespace_token = GreenNode::new(SyntaxKind::Whitespace.to_raw(), std::iter::empty()); if let Some((_, children)) = self.stack.last_mut() { children.push(NodeOrToken::Node(whitespace_token)); } return; } } let token_node = GreenNode::new(kind.to_raw(), std::iter::empty()); if let Some((_, children)) = self.stack.last_mut() { children.push(NodeOrToken::Node(token_node)); } } fn finish_node(&mut self) { if let Some((kind, children)) = self.stack.pop() { let node = GreenNode::new(kind.to_raw(), children); if let Some((_, parent_children)) = self.stack.last_mut() { parent_children.push(NodeOrToken::Node(node)); } else { // This is the root node self.stack.push((kind, vec![NodeOrToken::Node(node)])); } } } fn finish(self) -> GreenNode { if let Some((kind, children)) = self.stack.into_iter().next() { GreenNode::new(kind.to_raw(), children) } else { GreenNode::new(SyntaxKind::Root.to_raw(), std::iter::empty()) } } } /// Validator for CST structure according to nftables grammar struct CstValidator; impl CstValidator { fn new() -> Self { Self } fn validate(&self, node: &GreenNode) -> CstResult<()> { let kind = SyntaxKind::from_raw(node.kind()); match kind { SyntaxKind::Root => self.validate_root(node), SyntaxKind::Table => self.validate_table(node), SyntaxKind::Chain => self.validate_chain(node), SyntaxKind::Rule => self.validate_rule(node), SyntaxKind::Set | SyntaxKind::Map => self.validate_set_or_map(node), SyntaxKind::Element => self.validate_element(node), SyntaxKind::Expression | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::CallExpr => self.validate_expression(node), SyntaxKind::IncludeStmt => self.validate_include(node), SyntaxKind::DefineStmt => self.validate_define(node), SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt => { self.validate_statement(node) } SyntaxKind::Whitespace => Ok(()), // Whitespace is always valid _ => Ok(()), // Other nodes are generally valid } } fn validate_root(&self, node: &GreenNode) -> CstResult<()> { let children: Vec<_> = node.children().collect(); if children.is_empty() { return Err(CstError::MissingToken { expected: "at least one declaration or statement".to_string(), }); } // Validate that root contains valid top-level constructs for child in children { if let Some(child_node) = child.as_node() { let child_kind = SyntaxKind::from_raw(child_node.kind()); match child_kind { SyntaxKind::Table | SyntaxKind::IncludeStmt | SyntaxKind::DefineStmt | SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt | SyntaxKind::Element | SyntaxKind::Comment | SyntaxKind::Newline | SyntaxKind::Shebang | SyntaxKind::Whitespace => { self.validate(child_node)?; } _ => { return Err(CstError::UnexpectedToken { expected: "table, include, define, flush, add, delete, element, or comment" .to_string(), found: format!("{:?}", child_kind), }); } } } } Ok(()) } fn validate_table(&self, node: &GreenNode) -> CstResult<()> { let children: Vec<_> = node.children().collect(); if children.len() < 3 { return Err(CstError::MissingToken { expected: "table keyword, family, and name".to_string(), }); } // Validate table structure: table { ... } let mut child_iter = children.iter(); // First should be 'table' keyword if let Some(first) = child_iter.next() { if let Some(first_node) = first.as_node() { let first_kind = SyntaxKind::from_raw(first_node.kind()); if first_kind != SyntaxKind::TableKw { return Err(CstError::UnexpectedToken { expected: "table keyword".to_string(), found: format!("{:?}", first_kind), }); } } } // Second should be family if let Some(second) = child_iter.next() { if let Some(second_node) = second.as_node() { let second_kind = SyntaxKind::from_raw(second_node.kind()); match second_kind { SyntaxKind::InetKw | SyntaxKind::IpKw | SyntaxKind::Ip6Kw | SyntaxKind::ArpKw | SyntaxKind::BridgeKw | SyntaxKind::NetdevKw => {} _ => { return Err(CstError::UnexpectedToken { expected: "table family (inet, ip, ip6, arp, bridge, netdev)" .to_string(), found: format!("{:?}", second_kind), }); } } } } // Recursively validate children for child in children { if let Some(child_node) = child.as_node() { self.validate(child_node)?; } } Ok(()) } fn validate_chain(&self, node: &GreenNode) -> CstResult<()> { // Basic chain validation for child in node.children() { if let Some(child_node) = child.as_node() { self.validate(child_node)?; } } Ok(()) } fn validate_rule(&self, node: &GreenNode) -> CstResult<()> { // Rules should have at least some content let children: Vec<_> = node.children().collect(); if children.is_empty() { return Err(CstError::InvalidSyntax { position: 0 }); } // Validate rule children for child in children { if let Some(child_node) = child.as_node() { self.validate(child_node)?; } } Ok(()) } fn validate_set_or_map(&self, _node: &GreenNode) -> CstResult<()> { // Basic validation for sets and maps Ok(()) } fn validate_expression(&self, node: &GreenNode) -> CstResult<()> { // Validate expression structure for child in node.children() { if let Some(child_node) = child.as_node() { self.validate(child_node)?; } } Ok(()) } fn validate_include(&self, node: &GreenNode) -> CstResult<()> { let children: Vec<_> = node.children().collect(); if children.len() < 2 { return Err(CstError::MissingToken { expected: "include keyword and file path".to_string(), }); } Ok(()) } fn validate_define(&self, node: &GreenNode) -> CstResult<()> { let children: Vec<_> = node.children().collect(); if children.len() < 3 { return Err(CstError::MissingToken { expected: "define keyword, variable name, and value".to_string(), }); } Ok(()) } fn validate_element(&self, _node: &GreenNode) -> CstResult<()> { // Basic validation for element statements Ok(()) } fn validate_statement(&self, node: &GreenNode) -> CstResult<()> { // Basic validation for statements like flush, add, delete let children: Vec<_> = node.children().collect(); if children.is_empty() { return Err(CstError::MissingToken { expected: "statement keyword".to_string(), }); } Ok(()) } } impl Default for CstBuilder { fn default() -> Self { Self } } #[cfg(test)] mod tests { use super::*; use crate::lexer::NftablesLexer; #[test] fn test_cst_construction() { let source = "table inet filter { }"; let mut lexer = NftablesLexer::new(source); let tokens = lexer.tokenize().expect("Tokenization should succeed"); // CST is now implemented - test that it works let green_tree = CstBuilder::build_tree(&tokens); // Verify the tree was created successfully assert_eq!(green_tree.kind(), SyntaxKind::Root.to_raw()); // Test validation works let validation_result = CstBuilder::validate_tree(&green_tree); if let Err(ref e) = validation_result { println!("Validation error: {}", e); } assert!(validation_result.is_ok()); // Test parse_to_cst works let cst_result = CstBuilder::parse_to_cst(&tokens); assert!(cst_result.is_ok()); } }