nff/src/cst.rs
2025-05-24 23:41:06 +03:00

1228 lines
40 KiB
Rust

//! Concrete Syntax Tree implementation for nftables configuration files
//!
//! Lossless representation preserving whitespace, comments, and formatting.
use crate::lexer::{Token, TokenKind};
use cstree::{RawSyntaxKind, green::GreenNode, util::NodeOrToken};
use std::fmt;
use thiserror::Error;
/// nftables syntax node types
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u16)]
pub enum SyntaxKind {
// Root and containers
Root = 0,
Table,
Chain,
Rule,
Set,
Map,
Element,
// Expressions
Expression,
BinaryExpr,
UnaryExpr,
CallExpr,
SetExpr,
RangeExpr,
// Statements
Statement,
IncludeStmt,
DefineStmt,
FlushStmt,
AddStmt,
DeleteStmt,
// Literals and identifiers
Identifier,
StringLiteral,
NumberLiteral,
IpAddress,
Ipv6Address,
MacAddress,
// Keywords
TableKw,
ChainKw,
RuleKw,
SetKw,
MapKw,
ElementKw,
IncludeKw,
DefineKw,
FlushKw,
AddKw,
DeleteKw,
InsertKw,
ReplaceKw,
// Chain types and hooks
FilterKw,
NatKw,
RouteKw,
InputKw,
OutputKw,
ForwardKw,
PreroutingKw,
PostroutingKw,
// Protocols and families
IpKw,
Ip6Kw,
InetKw,
ArpKw,
BridgeKw,
NetdevKw,
TcpKw,
UdpKw,
IcmpKw,
Icmpv6Kw,
// Match keywords
SportKw,
DportKw,
SaddrKw,
DaddrKw,
ProtocolKw,
NexthdrKw,
TypeKw,
HookKw,
PriorityKw,
PolicyKw,
IifnameKw,
OifnameKw,
CtKw,
StateKw,
// Actions
AcceptKw,
DropKw,
RejectKw,
ReturnKw,
JumpKw,
GotoKw,
ContinueKw,
LogKw,
CommentKw,
// States
EstablishedKw,
RelatedKw,
NewKw,
InvalidKw,
// Operators
EqOp,
NeOp,
LeOp,
GeOp,
LtOp,
GtOp,
// Punctuation
LeftBrace,
RightBrace,
LeftParen,
RightParen,
LeftBracket,
RightBracket,
Comma,
Semicolon,
Colon,
Assign,
Dash,
Slash,
Dot,
// Trivia
Whitespace,
Newline,
Comment,
Shebang,
// Error recovery
Error,
}
impl From<TokenKind> for SyntaxKind {
fn from(kind: TokenKind) -> Self {
match kind {
TokenKind::Table => SyntaxKind::TableKw,
TokenKind::Chain => SyntaxKind::ChainKw,
TokenKind::Rule => SyntaxKind::RuleKw,
TokenKind::Set => SyntaxKind::SetKw,
TokenKind::Map => SyntaxKind::MapKw,
TokenKind::Element => SyntaxKind::ElementKw,
TokenKind::Include => SyntaxKind::IncludeKw,
TokenKind::Define => SyntaxKind::DefineKw,
TokenKind::Flush => SyntaxKind::FlushKw,
TokenKind::Add => SyntaxKind::AddKw,
TokenKind::Delete => SyntaxKind::DeleteKw,
TokenKind::Insert => SyntaxKind::InsertKw,
TokenKind::Replace => SyntaxKind::ReplaceKw,
TokenKind::Filter => SyntaxKind::FilterKw,
TokenKind::Nat => SyntaxKind::NatKw,
TokenKind::Route => SyntaxKind::RouteKw,
TokenKind::Input => SyntaxKind::InputKw,
TokenKind::Output => SyntaxKind::OutputKw,
TokenKind::Forward => SyntaxKind::ForwardKw,
TokenKind::Prerouting => SyntaxKind::PreroutingKw,
TokenKind::Postrouting => SyntaxKind::PostroutingKw,
TokenKind::Ip => SyntaxKind::IpKw,
TokenKind::Ip6 => SyntaxKind::Ip6Kw,
TokenKind::Inet => SyntaxKind::InetKw,
TokenKind::Arp => SyntaxKind::ArpKw,
TokenKind::Bridge => SyntaxKind::BridgeKw,
TokenKind::Netdev => SyntaxKind::NetdevKw,
TokenKind::Tcp => SyntaxKind::TcpKw,
TokenKind::Udp => SyntaxKind::UdpKw,
TokenKind::Icmp => SyntaxKind::IcmpKw,
TokenKind::Icmpv6 => SyntaxKind::Icmpv6Kw,
TokenKind::Sport => SyntaxKind::SportKw,
TokenKind::Dport => SyntaxKind::DportKw,
TokenKind::Saddr => SyntaxKind::SaddrKw,
TokenKind::Daddr => SyntaxKind::DaddrKw,
TokenKind::Protocol => SyntaxKind::ProtocolKw,
TokenKind::Nexthdr => SyntaxKind::NexthdrKw,
TokenKind::Type => SyntaxKind::TypeKw,
TokenKind::Hook => SyntaxKind::HookKw,
TokenKind::Priority => SyntaxKind::PriorityKw,
TokenKind::Policy => SyntaxKind::PolicyKw,
TokenKind::Iifname => SyntaxKind::IifnameKw,
TokenKind::Oifname => SyntaxKind::OifnameKw,
TokenKind::Ct => SyntaxKind::CtKw,
TokenKind::State => SyntaxKind::StateKw,
TokenKind::Accept => SyntaxKind::AcceptKw,
TokenKind::Drop => SyntaxKind::DropKw,
TokenKind::Reject => SyntaxKind::RejectKw,
TokenKind::Return => SyntaxKind::ReturnKw,
TokenKind::Jump => SyntaxKind::JumpKw,
TokenKind::Goto => SyntaxKind::GotoKw,
TokenKind::Continue => SyntaxKind::ContinueKw,
TokenKind::Log => SyntaxKind::LogKw,
TokenKind::Comment => SyntaxKind::CommentKw,
TokenKind::Established => SyntaxKind::EstablishedKw,
TokenKind::Related => SyntaxKind::RelatedKw,
TokenKind::New => SyntaxKind::NewKw,
TokenKind::Invalid => SyntaxKind::InvalidKw,
TokenKind::Eq => SyntaxKind::EqOp,
TokenKind::Ne => SyntaxKind::NeOp,
TokenKind::Le => SyntaxKind::LeOp,
TokenKind::Ge => SyntaxKind::GeOp,
TokenKind::Lt => SyntaxKind::LtOp,
TokenKind::Gt => SyntaxKind::GtOp,
TokenKind::LeftBrace => SyntaxKind::LeftBrace,
TokenKind::RightBrace => SyntaxKind::RightBrace,
TokenKind::LeftParen => SyntaxKind::LeftParen,
TokenKind::RightParen => SyntaxKind::RightParen,
TokenKind::LeftBracket => SyntaxKind::LeftBracket,
TokenKind::RightBracket => SyntaxKind::RightBracket,
TokenKind::Comma => SyntaxKind::Comma,
TokenKind::Semicolon => SyntaxKind::Semicolon,
TokenKind::Colon => SyntaxKind::Colon,
TokenKind::Assign => SyntaxKind::Assign,
TokenKind::Dash => SyntaxKind::Dash,
TokenKind::Slash => SyntaxKind::Slash,
TokenKind::Dot => SyntaxKind::Dot,
TokenKind::StringLiteral(_) => SyntaxKind::StringLiteral,
TokenKind::NumberLiteral(_) => SyntaxKind::NumberLiteral,
TokenKind::IpAddress(_) => SyntaxKind::IpAddress,
TokenKind::Ipv6Address(_) => SyntaxKind::Ipv6Address,
TokenKind::MacAddress(_) => SyntaxKind::MacAddress,
TokenKind::Identifier(_) => SyntaxKind::Identifier,
TokenKind::Newline => SyntaxKind::Newline,
TokenKind::CommentLine(_) => SyntaxKind::Comment,
TokenKind::Shebang(_) => SyntaxKind::Shebang,
TokenKind::Error => SyntaxKind::Error,
}
}
}
impl fmt::Display for SyntaxKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
SyntaxKind::Root => "ROOT",
SyntaxKind::Table => "TABLE",
SyntaxKind::Chain => "CHAIN",
SyntaxKind::Rule => "RULE",
SyntaxKind::Set => "SET",
SyntaxKind::Map => "MAP",
SyntaxKind::Element => "ELEMENT",
SyntaxKind::Expression => "EXPRESSION",
SyntaxKind::BinaryExpr => "BINARY_EXPR",
SyntaxKind::UnaryExpr => "UNARY_EXPR",
SyntaxKind::CallExpr => "CALL_EXPR",
SyntaxKind::SetExpr => "SET_EXPR",
SyntaxKind::RangeExpr => "RANGE_EXPR",
SyntaxKind::Statement => "STATEMENT",
SyntaxKind::IncludeStmt => "INCLUDE_STMT",
SyntaxKind::DefineStmt => "DEFINE_STMT",
SyntaxKind::FlushStmt => "FLUSH_STMT",
SyntaxKind::AddStmt => "ADD_STMT",
SyntaxKind::DeleteStmt => "DELETE_STMT",
SyntaxKind::Identifier => "IDENTIFIER",
SyntaxKind::StringLiteral => "STRING_LITERAL",
SyntaxKind::NumberLiteral => "NUMBER_LITERAL",
SyntaxKind::IpAddress => "IP_ADDRESS",
SyntaxKind::Ipv6Address => "IPV6_ADDRESS",
SyntaxKind::MacAddress => "MAC_ADDRESS",
SyntaxKind::Whitespace => "WHITESPACE",
SyntaxKind::Newline => "NEWLINE",
SyntaxKind::Comment => "COMMENT",
SyntaxKind::Shebang => "SHEBANG",
SyntaxKind::Error => "ERROR",
_ => return write!(f, "{:?}", self),
};
write!(f, "{}", name)
}
}
/// Language definition for nftables CST
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct NftablesLanguage;
impl SyntaxKind {
pub fn to_raw(self) -> RawSyntaxKind {
RawSyntaxKind(self as u32)
}
pub fn from_raw(raw: RawSyntaxKind) -> Self {
unsafe { std::mem::transmute(raw.0 as u16) }
}
}
/// CST construction errors
#[derive(Error, Debug, PartialEq)]
pub enum CstError {
#[error("Unexpected token: expected {expected}, found {found}")]
UnexpectedToken { expected: String, found: String },
#[error("Missing token: expected {expected}")]
MissingToken { expected: String },
#[error("Invalid syntax at position {position}")]
InvalidSyntax { position: usize },
}
/// Result type for CST operations
pub type CstResult<T> = Result<T, CstError>;
/// Basic CST builder
pub struct CstBuilder;
impl CstBuilder {
/// Build a proper CST from tokens that represents nftables syntax structure
pub fn build_tree(tokens: &[Token]) -> GreenNode {
let mut builder = CstTreeBuilder::new();
builder.parse_tokens(tokens);
builder.finish()
}
/// Validate CST tree structure according to nftables grammar rules
pub fn validate_tree(node: &GreenNode) -> CstResult<()> {
let validator = CstValidator::new();
validator.validate(node)
}
/// Parse tokens into a validated CST
pub fn parse_to_cst(tokens: &[Token]) -> CstResult<GreenNode> {
let tree = Self::build_tree(tokens);
Self::validate_tree(&tree)?;
Ok(tree)
}
}
/// Internal tree builder that constructs CST according to nftables grammar
struct CstTreeBuilder {
stack: Vec<(
SyntaxKind,
Vec<NodeOrToken<GreenNode, cstree::green::GreenToken>>,
)>,
}
impl CstTreeBuilder {
fn new() -> Self {
Self {
stack: vec![(SyntaxKind::Root, Vec::new())],
}
}
fn parse_tokens(&mut self, tokens: &[Token]) {
let mut i = 0;
while i < tokens.len() {
i = self.parse_top_level(&tokens, i);
}
}
fn parse_top_level(&mut self, tokens: &[Token], start: usize) -> usize {
if start >= tokens.len() {
return start;
}
match &tokens[start].kind {
TokenKind::Table => self.parse_table(tokens, start),
TokenKind::Chain => self.parse_chain(tokens, start),
TokenKind::Include => self.parse_include(tokens, start),
TokenKind::Define => self.parse_define(tokens, start),
TokenKind::Flush => self.parse_flush_stmt(tokens, start),
TokenKind::Add => self.parse_add_stmt(tokens, start),
TokenKind::Delete => self.parse_delete_stmt(tokens, start),
TokenKind::Element => self.parse_element(tokens, start),
TokenKind::CommentLine(_) => {
self.add_token(&tokens[start], SyntaxKind::Comment);
start + 1
}
TokenKind::Newline => {
self.add_token(&tokens[start], SyntaxKind::Newline);
start + 1
}
TokenKind::Shebang(_) => {
self.add_token(&tokens[start], SyntaxKind::Shebang);
start + 1
}
_ => {
// Handle other tokens as statements
self.parse_statement(tokens, start)
}
}
}
fn parse_table(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Table);
// Add 'table' keyword
self.add_token(&tokens[start], SyntaxKind::TableKw);
let mut pos = start + 1;
// Parse family (inet, ip, ip6, etc.)
if pos < tokens.len() {
match &tokens[pos].kind {
TokenKind::Inet => self.add_token(&tokens[pos], SyntaxKind::InetKw),
TokenKind::Ip => self.add_token(&tokens[pos], SyntaxKind::IpKw),
TokenKind::Ip6 => self.add_token(&tokens[pos], SyntaxKind::Ip6Kw),
TokenKind::Arp => self.add_token(&tokens[pos], SyntaxKind::ArpKw),
TokenKind::Bridge => self.add_token(&tokens[pos], SyntaxKind::BridgeKw),
TokenKind::Netdev => self.add_token(&tokens[pos], SyntaxKind::NetdevKw),
_ => self.add_token(&tokens[pos], SyntaxKind::Identifier),
}
pos += 1;
}
// Parse table name
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::Identifier);
pos += 1;
}
// Parse table body
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) {
self.add_token(&tokens[pos], SyntaxKind::LeftBrace);
pos += 1;
// Parse table contents
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) {
match &tokens[pos].kind {
TokenKind::Chain => {
pos = self.parse_chain(tokens, pos);
}
TokenKind::Set => {
pos = self.parse_set(tokens, pos);
}
TokenKind::Map => {
pos = self.parse_map(tokens, pos);
}
TokenKind::Newline | TokenKind::CommentLine(_) => {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
_ => pos += 1, // Skip unknown tokens
}
}
// Add closing brace
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::RightBrace);
pos += 1;
}
}
self.finish_node();
pos
}
fn parse_chain(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Chain);
self.add_token(&tokens[start], SyntaxKind::ChainKw);
let mut pos = start + 1;
// Parse chain name
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::Identifier);
pos += 1;
}
// Parse chain body
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) {
self.add_token(&tokens[pos], SyntaxKind::LeftBrace);
pos += 1;
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) {
match &tokens[pos].kind {
TokenKind::Type => {
pos = self.parse_chain_properties(tokens, pos);
}
TokenKind::Newline | TokenKind::CommentLine(_) => {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
_ => {
// Parse as rule
pos = self.parse_rule(tokens, pos);
}
}
}
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::RightBrace);
pos += 1;
}
}
self.finish_node();
pos
}
fn parse_chain_properties(&mut self, tokens: &[Token], start: usize) -> usize {
let mut pos = start;
// Parse 'type'
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Type) {
self.add_token(&tokens[pos], SyntaxKind::TypeKw);
pos += 1;
}
// Parse chain type (filter, nat, route)
if pos < tokens.len() {
match &tokens[pos].kind {
TokenKind::Filter => self.add_token(&tokens[pos], SyntaxKind::FilterKw),
TokenKind::Nat => self.add_token(&tokens[pos], SyntaxKind::NatKw),
TokenKind::Route => self.add_token(&tokens[pos], SyntaxKind::RouteKw),
_ => self.add_token(&tokens[pos], SyntaxKind::Identifier),
}
pos += 1;
}
// Parse 'hook'
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Hook) {
self.add_token(&tokens[pos], SyntaxKind::HookKw);
pos += 1;
// Parse hook type
if pos < tokens.len() {
match &tokens[pos].kind {
TokenKind::Input => self.add_token(&tokens[pos], SyntaxKind::InputKw),
TokenKind::Output => self.add_token(&tokens[pos], SyntaxKind::OutputKw),
TokenKind::Forward => self.add_token(&tokens[pos], SyntaxKind::ForwardKw),
TokenKind::Prerouting => self.add_token(&tokens[pos], SyntaxKind::PreroutingKw),
TokenKind::Postrouting => {
self.add_token(&tokens[pos], SyntaxKind::PostroutingKw)
}
_ => self.add_token(&tokens[pos], SyntaxKind::Identifier),
}
pos += 1;
}
}
// Parse 'priority'
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Priority) {
self.add_token(&tokens[pos], SyntaxKind::PriorityKw);
pos += 1;
if pos < tokens.len() {
if let TokenKind::NumberLiteral(_) = &tokens[pos].kind {
self.add_token(&tokens[pos], SyntaxKind::NumberLiteral);
pos += 1;
}
}
}
// Parse semicolon
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Semicolon) {
self.add_token(&tokens[pos], SyntaxKind::Semicolon);
pos += 1;
}
pos
}
fn parse_rule(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Rule);
let mut pos = start;
let rule_start = pos;
// Parse until we hit a newline or end of input (simple rule parsing)
while pos < tokens.len() {
match &tokens[pos].kind {
TokenKind::Newline => {
self.add_token(&tokens[pos], SyntaxKind::Newline);
pos += 1;
break;
}
TokenKind::Accept => self.add_token(&tokens[pos], SyntaxKind::AcceptKw),
TokenKind::Drop => self.add_token(&tokens[pos], SyntaxKind::DropKw),
TokenKind::Reject => self.add_token(&tokens[pos], SyntaxKind::RejectKw),
TokenKind::Return => self.add_token(&tokens[pos], SyntaxKind::ReturnKw),
TokenKind::Log => self.add_token(&tokens[pos], SyntaxKind::LogKw),
TokenKind::Comment => self.add_token(&tokens[pos], SyntaxKind::CommentKw),
TokenKind::Eq => self.add_token(&tokens[pos], SyntaxKind::EqOp),
TokenKind::Ne => self.add_token(&tokens[pos], SyntaxKind::NeOp),
TokenKind::Lt => self.add_token(&tokens[pos], SyntaxKind::LtOp),
TokenKind::Le => self.add_token(&tokens[pos], SyntaxKind::LeOp),
TokenKind::Gt => self.add_token(&tokens[pos], SyntaxKind::GtOp),
TokenKind::Ge => self.add_token(&tokens[pos], SyntaxKind::GeOp),
TokenKind::LeftParen => {
pos = self.parse_expression(tokens, pos);
continue;
}
TokenKind::Identifier(_) => self.add_token(&tokens[pos], SyntaxKind::Identifier),
TokenKind::StringLiteral(_) => {
self.add_token(&tokens[pos], SyntaxKind::StringLiteral)
}
TokenKind::NumberLiteral(_) => {
self.add_token(&tokens[pos], SyntaxKind::NumberLiteral)
}
TokenKind::IpAddress(_) => self.add_token(&tokens[pos], SyntaxKind::IpAddress),
_ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())),
}
pos += 1;
}
// If no tokens were added to the rule, add at least one statement node
if pos == rule_start {
self.start_node(SyntaxKind::Statement);
self.finish_node();
}
self.finish_node();
pos
}
fn parse_expression(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Expression);
let mut pos = start;
let mut paren_depth = 0;
while pos < tokens.len() {
match &tokens[pos].kind {
TokenKind::LeftParen => {
paren_depth += 1;
self.add_token(&tokens[pos], SyntaxKind::LeftParen);
}
TokenKind::RightParen => {
self.add_token(&tokens[pos], SyntaxKind::RightParen);
paren_depth -= 1;
if paren_depth == 0 {
pos += 1;
break;
}
}
TokenKind::LeftBracket => {
pos = self.parse_set_expression(tokens, pos);
continue;
}
TokenKind::Identifier(_) => {
// Check if this is a function call
if pos + 1 < tokens.len()
&& matches!(tokens[pos + 1].kind, TokenKind::LeftParen)
{
pos = self.parse_call_expression(tokens, pos);
continue;
} else {
self.add_token(&tokens[pos], SyntaxKind::Identifier);
}
}
TokenKind::Dash => {
// Could be binary or unary expression
if self.is_binary_context() {
self.parse_binary_expression(tokens, pos);
} else {
self.parse_unary_expression(tokens, pos);
}
}
_ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())),
}
pos += 1;
}
self.finish_node();
pos
}
fn parse_set_expression(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::SetExpr);
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::LeftBracket);
pos += 1;
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBracket) {
if matches!(tokens[pos].kind, TokenKind::Colon) {
// Range expression
self.start_node(SyntaxKind::RangeExpr);
self.add_token(&tokens[pos], SyntaxKind::Colon);
self.finish_node();
} else {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
}
pos += 1;
}
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::RightBracket);
pos += 1;
}
self.finish_node();
pos
}
fn parse_binary_expression(&mut self, tokens: &[Token], pos: usize) {
self.start_node(SyntaxKind::BinaryExpr);
self.add_token(&tokens[pos], SyntaxKind::Dash);
self.finish_node();
}
fn parse_unary_expression(&mut self, tokens: &[Token], pos: usize) {
self.start_node(SyntaxKind::UnaryExpr);
self.add_token(&tokens[pos], SyntaxKind::Dash);
self.finish_node();
}
fn parse_set(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Set);
// Basic set parsing implementation
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::SetKw);
pos += 1;
// Skip to end of set for now
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) {
pos += 1;
}
self.finish_node();
pos
}
fn parse_map(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Map);
// Basic map parsing implementation
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::MapKw);
pos += 1;
// Skip to end of map for now
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) {
pos += 1;
}
self.finish_node();
pos
}
fn parse_include(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::IncludeStmt);
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::IncludeKw);
pos += 1;
if pos < tokens.len() {
if let TokenKind::StringLiteral(_) = &tokens[pos].kind {
self.add_token(&tokens[pos], SyntaxKind::StringLiteral);
pos += 1;
}
}
self.finish_node();
pos
}
fn parse_define(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::DefineStmt);
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::DefineKw);
pos += 1;
// Parse variable name
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::Identifier);
pos += 1;
}
// Parse assignment
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Assign) {
self.add_token(&tokens[pos], SyntaxKind::Assign);
pos += 1;
}
// Parse value - could be complex expression
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
self.finish_node();
pos
}
fn parse_flush_stmt(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::FlushStmt);
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::FlushKw);
pos += 1;
// Parse what to flush (table, chain, etc.)
while pos < tokens.len()
&& !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::Semicolon)
{
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Semicolon) {
self.add_token(&tokens[pos], SyntaxKind::Semicolon);
pos += 1;
}
self.finish_node();
pos
}
fn parse_add_stmt(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::AddStmt);
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::AddKw);
pos += 1;
// Parse what to add
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
self.finish_node();
pos
}
fn parse_delete_stmt(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::DeleteStmt);
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::DeleteKw);
pos += 1;
// Parse what to delete
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
self.finish_node();
pos
}
fn parse_element(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Element);
let mut pos = start;
self.add_token(&tokens[pos], SyntaxKind::ElementKw);
pos += 1;
// Parse element specification
while pos < tokens.len()
&& !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::LeftBrace)
{
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
// Parse element body if present
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) {
self.add_token(&tokens[pos], SyntaxKind::LeftBrace);
pos += 1;
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
pos += 1;
}
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::RightBrace);
pos += 1;
}
}
self.finish_node();
pos
}
fn parse_call_expression(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::CallExpr);
let mut pos = start;
// Add function name
self.add_token(&tokens[pos], SyntaxKind::Identifier);
pos += 1;
// Add opening paren
if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftParen) {
self.add_token(&tokens[pos], SyntaxKind::LeftParen);
pos += 1;
// Parse arguments
while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightParen) {
if matches!(tokens[pos].kind, TokenKind::Comma) {
self.add_token(&tokens[pos], SyntaxKind::Comma);
} else {
self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone()));
}
pos += 1;
}
// Add closing paren
if pos < tokens.len() {
self.add_token(&tokens[pos], SyntaxKind::RightParen);
pos += 1;
}
}
self.finish_node();
pos
}
fn parse_statement(&mut self, tokens: &[Token], start: usize) -> usize {
self.start_node(SyntaxKind::Statement);
self.add_token(&tokens[start], SyntaxKind::from(tokens[start].kind.clone()));
self.finish_node();
start + 1
}
fn is_binary_context(&self) -> bool {
// Simple heuristic: assume binary if we have recent tokens in current scope
self.stack
.last()
.map_or(false, |(_, children)| children.len() > 1)
}
fn start_node(&mut self, kind: SyntaxKind) {
self.stack.push((kind, Vec::new()));
}
fn add_token(&mut self, token: &Token, kind: SyntaxKind) {
// Handle whitespace specially if it's part of the token
if let TokenKind::Identifier(text) = &token.kind {
if text.chars().all(|c| c.is_whitespace()) {
let whitespace_token =
GreenNode::new(SyntaxKind::Whitespace.to_raw(), std::iter::empty());
if let Some((_, children)) = self.stack.last_mut() {
children.push(NodeOrToken::Node(whitespace_token));
}
return;
}
}
let token_node = GreenNode::new(kind.to_raw(), std::iter::empty());
if let Some((_, children)) = self.stack.last_mut() {
children.push(NodeOrToken::Node(token_node));
}
}
fn finish_node(&mut self) {
if let Some((kind, children)) = self.stack.pop() {
let node = GreenNode::new(kind.to_raw(), children);
if let Some((_, parent_children)) = self.stack.last_mut() {
parent_children.push(NodeOrToken::Node(node));
} else {
// This is the root node
self.stack.push((kind, vec![NodeOrToken::Node(node)]));
}
}
}
fn finish(self) -> GreenNode {
if let Some((kind, children)) = self.stack.into_iter().next() {
GreenNode::new(kind.to_raw(), children)
} else {
GreenNode::new(SyntaxKind::Root.to_raw(), std::iter::empty())
}
}
}
/// Validator for CST structure according to nftables grammar
struct CstValidator;
impl CstValidator {
fn new() -> Self {
Self
}
fn validate(&self, node: &GreenNode) -> CstResult<()> {
let kind = SyntaxKind::from_raw(node.kind());
match kind {
SyntaxKind::Root => self.validate_root(node),
SyntaxKind::Table => self.validate_table(node),
SyntaxKind::Chain => self.validate_chain(node),
SyntaxKind::Rule => self.validate_rule(node),
SyntaxKind::Set | SyntaxKind::Map => self.validate_set_or_map(node),
SyntaxKind::Element => self.validate_element(node),
SyntaxKind::Expression
| SyntaxKind::BinaryExpr
| SyntaxKind::UnaryExpr
| SyntaxKind::CallExpr => self.validate_expression(node),
SyntaxKind::IncludeStmt => self.validate_include(node),
SyntaxKind::DefineStmt => self.validate_define(node),
SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt => {
self.validate_statement(node)
}
SyntaxKind::Whitespace => Ok(()), // Whitespace is always valid
_ => Ok(()), // Other nodes are generally valid
}
}
fn validate_root(&self, node: &GreenNode) -> CstResult<()> {
let children: Vec<_> = node.children().collect();
if children.is_empty() {
return Err(CstError::MissingToken {
expected: "at least one declaration or statement".to_string(),
});
}
// Validate that root contains valid top-level constructs
for child in children {
if let Some(child_node) = child.as_node() {
let child_kind = SyntaxKind::from_raw(child_node.kind());
match child_kind {
SyntaxKind::Table
| SyntaxKind::IncludeStmt
| SyntaxKind::DefineStmt
| SyntaxKind::FlushStmt
| SyntaxKind::AddStmt
| SyntaxKind::DeleteStmt
| SyntaxKind::Element
| SyntaxKind::Comment
| SyntaxKind::Newline
| SyntaxKind::Shebang
| SyntaxKind::Whitespace => {
self.validate(child_node)?;
}
_ => {
return Err(CstError::UnexpectedToken {
expected:
"table, include, define, flush, add, delete, element, or comment"
.to_string(),
found: format!("{:?}", child_kind),
});
}
}
}
}
Ok(())
}
fn validate_table(&self, node: &GreenNode) -> CstResult<()> {
let children: Vec<_> = node.children().collect();
if children.len() < 3 {
return Err(CstError::MissingToken {
expected: "table keyword, family, and name".to_string(),
});
}
// Validate table structure: table <family> <name> { ... }
let mut child_iter = children.iter();
// First should be 'table' keyword
if let Some(first) = child_iter.next() {
if let Some(first_node) = first.as_node() {
let first_kind = SyntaxKind::from_raw(first_node.kind());
if first_kind != SyntaxKind::TableKw {
return Err(CstError::UnexpectedToken {
expected: "table keyword".to_string(),
found: format!("{:?}", first_kind),
});
}
}
}
// Second should be family
if let Some(second) = child_iter.next() {
if let Some(second_node) = second.as_node() {
let second_kind = SyntaxKind::from_raw(second_node.kind());
match second_kind {
SyntaxKind::InetKw
| SyntaxKind::IpKw
| SyntaxKind::Ip6Kw
| SyntaxKind::ArpKw
| SyntaxKind::BridgeKw
| SyntaxKind::NetdevKw => {}
_ => {
return Err(CstError::UnexpectedToken {
expected: "table family (inet, ip, ip6, arp, bridge, netdev)"
.to_string(),
found: format!("{:?}", second_kind),
});
}
}
}
}
// Recursively validate children
for child in children {
if let Some(child_node) = child.as_node() {
self.validate(child_node)?;
}
}
Ok(())
}
fn validate_chain(&self, node: &GreenNode) -> CstResult<()> {
// Basic chain validation
for child in node.children() {
if let Some(child_node) = child.as_node() {
self.validate(child_node)?;
}
}
Ok(())
}
fn validate_rule(&self, node: &GreenNode) -> CstResult<()> {
// Rules should have at least some content
let children: Vec<_> = node.children().collect();
if children.is_empty() {
return Err(CstError::InvalidSyntax { position: 0 });
}
// Validate rule children
for child in children {
if let Some(child_node) = child.as_node() {
self.validate(child_node)?;
}
}
Ok(())
}
fn validate_set_or_map(&self, _node: &GreenNode) -> CstResult<()> {
// Basic validation for sets and maps
Ok(())
}
fn validate_expression(&self, node: &GreenNode) -> CstResult<()> {
// Validate expression structure
for child in node.children() {
if let Some(child_node) = child.as_node() {
self.validate(child_node)?;
}
}
Ok(())
}
fn validate_include(&self, node: &GreenNode) -> CstResult<()> {
let children: Vec<_> = node.children().collect();
if children.len() < 2 {
return Err(CstError::MissingToken {
expected: "include keyword and file path".to_string(),
});
}
Ok(())
}
fn validate_define(&self, node: &GreenNode) -> CstResult<()> {
let children: Vec<_> = node.children().collect();
if children.len() < 3 {
return Err(CstError::MissingToken {
expected: "define keyword, variable name, and value".to_string(),
});
}
Ok(())
}
fn validate_element(&self, _node: &GreenNode) -> CstResult<()> {
// Basic validation for element statements
Ok(())
}
fn validate_statement(&self, node: &GreenNode) -> CstResult<()> {
// Basic validation for statements like flush, add, delete
let children: Vec<_> = node.children().collect();
if children.is_empty() {
return Err(CstError::MissingToken {
expected: "statement keyword".to_string(),
});
}
Ok(())
}
}
impl Default for CstBuilder {
fn default() -> Self {
Self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexer::NftablesLexer;
#[test]
fn test_cst_construction() {
let source = "table inet filter { }";
let mut lexer = NftablesLexer::new(source);
let tokens = lexer.tokenize().expect("Tokenization should succeed");
// CST is now implemented - test that it works
let green_tree = CstBuilder::build_tree(&tokens);
// Verify the tree was created successfully
assert_eq!(green_tree.kind(), SyntaxKind::Root.to_raw());
// Test validation works
let validation_result = CstBuilder::validate_tree(&green_tree);
if let Err(ref e) = validation_result {
println!("Validation error: {}", e);
}
assert!(validation_result.is_ok());
// Test parse_to_cst works
let cst_result = CstBuilder::parse_to_cst(&tokens);
assert!(cst_result.is_ok());
}
}