diff --git a/Cargo.lock b/Cargo.lock index 94021ad..f333612 100755 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,6 +56,36 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "beef" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" + +[[package]] +name = "bitflags" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + [[package]] name = "clap" version = "4.5.2" @@ -84,7 +114,7 @@ version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", "syn", @@ -102,21 +132,194 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "cstree" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d609e3b8b73dbace666e8a06351fd9062e1ec025e74b27952a932ccb8ec3a25" +dependencies = [ + "fxhash", + "indexmap", + "parking_lot", + "sptr", + "text-size", + "triomphe", +] + +[[package]] +name = "dyn-clone" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "hashbrown" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" + [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indexmap" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.172" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "logos" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab6f536c1af4c7cc81edf73da1f8029896e7e1e16a219ef09b184e76a296f3db" +dependencies = [ + "logos-derive", +] + +[[package]] +name = "logos-codegen" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "189bbfd0b61330abea797e5e9276408f2edbe4f822d7ad08685d67419aafb34e" +dependencies = [ + "beef", + "fnv", + "lazy_static", + "proc-macro2", + "quote", + "regex-syntax", + "rustc_version", + "syn", +] + +[[package]] +name = "logos-derive" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebfe8e1a19049ddbfccbd14ac834b215e11b85b90bab0c2dba7c7b92fb5d5cba" +dependencies = [ + "logos-codegen", +] + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + [[package]] name = "nff" version = "0.1.0" dependencies = [ "anyhow", "clap", + "cstree", + "logos", + "nftables", + "text-size", "thiserror", ] +[[package]] +name = "nftables" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "180f5f3983a76df4a48e01b9317832cc6fa6aa90ef0c73328658e0e5653f175a" +dependencies = [ + "schemars", + "serde", + "serde_json", + "serde_path_to_error", + "strum", + "strum_macros", + "thiserror", +] + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -135,12 +338,174 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustversion" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "schemars" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.140" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + +[[package]] +name = "smallvec" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" + +[[package]] +name = "sptr" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9b39299b249ad65f3b7e96443bad61c02ca5cd3589f46cb6d610a0fd6c0d6a" + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "strsim" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" +[[package]] +name = "strum" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" + +[[package]] +name = "strum_macros" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "syn" version = "2.0.101" @@ -152,6 +517,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "text-size" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f18aa187839b2bdb1ad2fa35ead8c4c2976b64e4363c386d45ac0f7ee85c9233" + [[package]] name = "thiserror" version = "2.0.12" @@ -172,6 +543,15 @@ dependencies = [ "syn", ] +[[package]] +name = "triomphe" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef8f7726da4807b58ea5c96fdc122f80702030edc33b35aff9190a51148ccc85" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/Cargo.toml b/Cargo.toml index b7cbf7a..ae3c856 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,7 @@ edition = "2024" clap = { version = "4.5", features = ["derive"] } anyhow = "1.0" thiserror = "2.0" +logos = "0.15" +cstree = "0.12" +text-size = "1.1" +nftables = "0.6" diff --git a/src/ast.rs b/src/ast.rs new file mode 100644 index 0000000..e1a3623 --- /dev/null +++ b/src/ast.rs @@ -0,0 +1,389 @@ +use std::collections::HashMap; +use std::fmt; + +/// Represents the nftables address family +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Family { + Ip, + Ip6, + Inet, + Arp, + Bridge, + Netdev, +} + +impl fmt::Display for Family { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Family::Ip => write!(f, "ip"), + Family::Ip6 => write!(f, "ip6"), + Family::Inet => write!(f, "inet"), + Family::Arp => write!(f, "arp"), + Family::Bridge => write!(f, "bridge"), + Family::Netdev => write!(f, "netdev"), + } + } +} + +/// Represents chain types +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ChainType { + Filter, + Nat, + Route, +} + +impl fmt::Display for ChainType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ChainType::Filter => write!(f, "filter"), + ChainType::Nat => write!(f, "nat"), + ChainType::Route => write!(f, "route"), + } + } +} + +/// Represents chain hooks +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Hook { + Input, + Output, + Forward, + Prerouting, + Postrouting, +} + +impl fmt::Display for Hook { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Hook::Input => write!(f, "input"), + Hook::Output => write!(f, "output"), + Hook::Forward => write!(f, "forward"), + Hook::Prerouting => write!(f, "prerouting"), + Hook::Postrouting => write!(f, "postrouting"), + } + } +} + +/// Represents chain policies +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Policy { + Accept, + Drop, +} + +impl fmt::Display for Policy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Policy::Accept => write!(f, "accept"), + Policy::Drop => write!(f, "drop"), + } + } +} + +/// Represents expressions in nftables rules +#[derive(Debug, Clone, PartialEq)] +pub enum Expression { + // Literals + Identifier(String), + String(String), + Number(u64), + IpAddress(String), + Ipv6Address(String), + MacAddress(String), + + // Binary operations + Binary { + left: Box, + operator: BinaryOperator, + right: Box, + }, + + // Protocol matches + Protocol(String), + Port { + direction: PortDirection, + value: Box, + }, + Address { + direction: AddressDirection, + value: Box, + }, + + // Interface matches + Interface { + direction: InterfaceDirection, + name: String, + }, + + // Connection tracking + ConnTrack { + field: String, + value: Box, + }, + + // Set expressions + Set(Vec), + + // Range expressions + Range { + start: Box, + end: Box, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BinaryOperator { + Eq, + Ne, + Lt, + Le, + Gt, + Ge, +} + +impl fmt::Display for BinaryOperator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BinaryOperator::Eq => write!(f, "=="), + BinaryOperator::Ne => write!(f, "!="), + BinaryOperator::Lt => write!(f, "<"), + BinaryOperator::Le => write!(f, "<="), + BinaryOperator::Gt => write!(f, ">"), + BinaryOperator::Ge => write!(f, ">="), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PortDirection { + Source, + Destination, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AddressDirection { + Source, + Destination, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InterfaceDirection { + Input, + Output, +} + +/// Represents actions that can be taken on matching packets +#[derive(Debug, Clone, PartialEq)] +pub enum Action { + Accept, + Drop, + Reject, + Return, + Jump(String), + Goto(String), + Continue, + Log { + prefix: Option, + level: Option, + }, + Comment(String), +} + +impl fmt::Display for Action { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Action::Accept => write!(f, "accept"), + Action::Drop => write!(f, "drop"), + Action::Reject => write!(f, "reject"), + Action::Return => write!(f, "return"), + Action::Jump(target) => write!(f, "jump {}", target), + Action::Goto(target) => write!(f, "goto {}", target), + Action::Continue => write!(f, "continue"), + Action::Log { prefix, level } => { + write!(f, "log")?; + if let Some(p) = prefix { + write!(f, " prefix \"{}\"", p)?; + } + if let Some(l) = level { + write!(f, " level {}", l)?; + } + Ok(()) + } + Action::Comment(text) => write!(f, "comment \"{}\"", text), + } + } +} + +/// Represents a rule in a chain +#[derive(Debug, Clone, PartialEq)] +pub struct Rule { + pub expressions: Vec, + pub action: Action, + pub handle: Option, +} + +impl Rule { + pub fn new(expressions: Vec, action: Action) -> Self { + Self { + expressions, + action, + handle: None, + } + } + + pub fn with_handle(mut self, handle: u64) -> Self { + self.handle = Some(handle); + self + } +} + +/// Represents a chain in a table +#[derive(Debug, Clone, PartialEq)] +pub struct Chain { + pub name: String, + pub chain_type: Option, + pub hook: Option, + pub priority: Option, + pub policy: Option, + pub device: Option, + pub rules: Vec, + pub handle: Option, +} + +impl Chain { + pub fn new(name: String) -> Self { + Self { + name, + chain_type: None, + hook: None, + priority: None, + policy: None, + device: None, + rules: Vec::new(), + handle: None, + } + } + + pub fn with_type(mut self, chain_type: ChainType) -> Self { + self.chain_type = Some(chain_type); + self + } + + pub fn with_hook(mut self, hook: Hook) -> Self { + self.hook = Some(hook); + self + } + + pub fn with_priority(mut self, priority: i32) -> Self { + self.priority = Some(priority); + self + } + + pub fn with_policy(mut self, policy: Policy) -> Self { + self.policy = Some(policy); + self + } + + pub fn with_device(mut self, device: String) -> Self { + self.device = Some(device); + self + } + + pub fn add_rule(mut self, rule: Rule) -> Self { + self.rules.push(rule); + self + } +} + +/// Represents a table containing chains +#[derive(Debug, Clone, PartialEq)] +pub struct Table { + pub family: Family, + pub name: String, + pub chains: HashMap, + pub handle: Option, +} + +impl Table { + pub fn new(family: Family, name: String) -> Self { + Self { + family, + name, + chains: HashMap::new(), + handle: None, + } + } + + pub fn add_chain(mut self, chain: Chain) -> Self { + self.chains.insert(chain.name.clone(), chain); + self + } +} + +/// Represents an include statement +#[derive(Debug, Clone, PartialEq)] +pub struct Include { + pub path: String, +} + +/// Represents a define statement +#[derive(Debug, Clone, PartialEq)] +pub struct Define { + pub name: String, + pub value: Expression, +} + +/// Represents the root of an nftables configuration +#[derive(Debug, Clone, PartialEq)] +pub struct Ruleset { + pub includes: Vec, + pub defines: Vec, + pub tables: HashMap<(Family, String), Table>, + pub shebang: Option, + pub comments: Vec, +} + +impl Ruleset { + pub fn new() -> Self { + Self { + includes: Vec::new(), + defines: Vec::new(), + tables: HashMap::new(), + shebang: None, + comments: Vec::new(), + } + } + + pub fn with_shebang(mut self, shebang: String) -> Self { + self.shebang = Some(shebang); + self + } + + pub fn add_include(mut self, include: Include) -> Self { + self.includes.push(include); + self + } + + pub fn add_define(mut self, define: Define) -> Self { + self.defines.push(define); + self + } + + pub fn add_table(mut self, table: Table) -> Self { + let key = (table.family.clone(), table.name.clone()); + self.tables.insert(key, table); + self + } + + pub fn add_comment(mut self, comment: String) -> Self { + self.comments.push(comment); + self + } +} + +impl Default for Ruleset { + fn default() -> Self { + Self::new() + } +} diff --git a/src/cst.rs b/src/cst.rs new file mode 100644 index 0000000..e0e093c --- /dev/null +++ b/src/cst.rs @@ -0,0 +1,1197 @@ +//! Concrete Syntax Tree implementation for nftables configuration files +//! +//! Lossless representation preserving whitespace, comments, and formatting. + +use crate::lexer::{Token, TokenKind}; +use cstree::{ + green::GreenNode, + RawSyntaxKind, util::NodeOrToken, +}; +use std::fmt; +use thiserror::Error; + + +/// nftables syntax node types +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u16)] +pub enum SyntaxKind { + // Root and containers + Root = 0, + Table, + Chain, + Rule, + Set, + Map, + Element, + + // Expressions + Expression, + BinaryExpr, + UnaryExpr, + CallExpr, + SetExpr, + RangeExpr, + + // Statements + Statement, + IncludeStmt, + DefineStmt, + FlushStmt, + AddStmt, + DeleteStmt, + + // Literals and identifiers + Identifier, + StringLiteral, + NumberLiteral, + IpAddress, + Ipv6Address, + MacAddress, + + // Keywords + TableKw, + ChainKw, + RuleKw, + SetKw, + MapKw, + ElementKw, + IncludeKw, + DefineKw, + FlushKw, + AddKw, + DeleteKw, + InsertKw, + ReplaceKw, + + // Chain types and hooks + FilterKw, + NatKw, + RouteKw, + InputKw, + OutputKw, + ForwardKw, + PreroutingKw, + PostroutingKw, + + // Protocols and families + IpKw, + Ip6Kw, + InetKw, + ArpKw, + BridgeKw, + NetdevKw, + TcpKw, + UdpKw, + IcmpKw, + Icmpv6Kw, + + // Match keywords + SportKw, + DportKw, + SaddrKw, + DaddrKw, + ProtocolKw, + NexthdrKw, + TypeKw, + HookKw, + PriorityKw, + PolicyKw, + IifnameKw, + OifnameKw, + CtKw, + StateKw, + + // Actions + AcceptKw, + DropKw, + RejectKw, + ReturnKw, + JumpKw, + GotoKw, + ContinueKw, + LogKw, + CommentKw, + + // States + EstablishedKw, + RelatedKw, + NewKw, + InvalidKw, + + // Operators + EqOp, + NeOp, + LeOp, + GeOp, + LtOp, + GtOp, + + // Punctuation + LeftBrace, + RightBrace, + LeftParen, + RightParen, + LeftBracket, + RightBracket, + Comma, + Semicolon, + Colon, + Assign, + Dash, + Slash, + Dot, + + // Trivia + Whitespace, + Newline, + Comment, + Shebang, + + // Error recovery + Error, +} + +impl From for SyntaxKind { + fn from(kind: TokenKind) -> Self { + match kind { + TokenKind::Table => SyntaxKind::TableKw, + TokenKind::Chain => SyntaxKind::ChainKw, + TokenKind::Rule => SyntaxKind::RuleKw, + TokenKind::Set => SyntaxKind::SetKw, + TokenKind::Map => SyntaxKind::MapKw, + TokenKind::Element => SyntaxKind::ElementKw, + TokenKind::Include => SyntaxKind::IncludeKw, + TokenKind::Define => SyntaxKind::DefineKw, + TokenKind::Flush => SyntaxKind::FlushKw, + TokenKind::Add => SyntaxKind::AddKw, + TokenKind::Delete => SyntaxKind::DeleteKw, + TokenKind::Insert => SyntaxKind::InsertKw, + TokenKind::Replace => SyntaxKind::ReplaceKw, + + TokenKind::Filter => SyntaxKind::FilterKw, + TokenKind::Nat => SyntaxKind::NatKw, + TokenKind::Route => SyntaxKind::RouteKw, + + TokenKind::Input => SyntaxKind::InputKw, + TokenKind::Output => SyntaxKind::OutputKw, + TokenKind::Forward => SyntaxKind::ForwardKw, + TokenKind::Prerouting => SyntaxKind::PreroutingKw, + TokenKind::Postrouting => SyntaxKind::PostroutingKw, + + TokenKind::Ip => SyntaxKind::IpKw, + TokenKind::Ip6 => SyntaxKind::Ip6Kw, + TokenKind::Inet => SyntaxKind::InetKw, + TokenKind::Arp => SyntaxKind::ArpKw, + TokenKind::Bridge => SyntaxKind::BridgeKw, + TokenKind::Netdev => SyntaxKind::NetdevKw, + TokenKind::Tcp => SyntaxKind::TcpKw, + TokenKind::Udp => SyntaxKind::UdpKw, + TokenKind::Icmp => SyntaxKind::IcmpKw, + TokenKind::Icmpv6 => SyntaxKind::Icmpv6Kw, + + TokenKind::Sport => SyntaxKind::SportKw, + TokenKind::Dport => SyntaxKind::DportKw, + TokenKind::Saddr => SyntaxKind::SaddrKw, + TokenKind::Daddr => SyntaxKind::DaddrKw, + TokenKind::Protocol => SyntaxKind::ProtocolKw, + TokenKind::Nexthdr => SyntaxKind::NexthdrKw, + TokenKind::Type => SyntaxKind::TypeKw, + TokenKind::Hook => SyntaxKind::HookKw, + TokenKind::Priority => SyntaxKind::PriorityKw, + TokenKind::Policy => SyntaxKind::PolicyKw, + TokenKind::Iifname => SyntaxKind::IifnameKw, + TokenKind::Oifname => SyntaxKind::OifnameKw, + TokenKind::Ct => SyntaxKind::CtKw, + TokenKind::State => SyntaxKind::StateKw, + + TokenKind::Accept => SyntaxKind::AcceptKw, + TokenKind::Drop => SyntaxKind::DropKw, + TokenKind::Reject => SyntaxKind::RejectKw, + TokenKind::Return => SyntaxKind::ReturnKw, + TokenKind::Jump => SyntaxKind::JumpKw, + TokenKind::Goto => SyntaxKind::GotoKw, + TokenKind::Continue => SyntaxKind::ContinueKw, + TokenKind::Log => SyntaxKind::LogKw, + TokenKind::Comment => SyntaxKind::CommentKw, + + TokenKind::Established => SyntaxKind::EstablishedKw, + TokenKind::Related => SyntaxKind::RelatedKw, + TokenKind::New => SyntaxKind::NewKw, + TokenKind::Invalid => SyntaxKind::InvalidKw, + + TokenKind::Eq => SyntaxKind::EqOp, + TokenKind::Ne => SyntaxKind::NeOp, + TokenKind::Le => SyntaxKind::LeOp, + TokenKind::Ge => SyntaxKind::GeOp, + TokenKind::Lt => SyntaxKind::LtOp, + TokenKind::Gt => SyntaxKind::GtOp, + + TokenKind::LeftBrace => SyntaxKind::LeftBrace, + TokenKind::RightBrace => SyntaxKind::RightBrace, + TokenKind::LeftParen => SyntaxKind::LeftParen, + TokenKind::RightParen => SyntaxKind::RightParen, + TokenKind::LeftBracket => SyntaxKind::LeftBracket, + TokenKind::RightBracket => SyntaxKind::RightBracket, + TokenKind::Comma => SyntaxKind::Comma, + TokenKind::Semicolon => SyntaxKind::Semicolon, + TokenKind::Colon => SyntaxKind::Colon, + TokenKind::Assign => SyntaxKind::Assign, + TokenKind::Dash => SyntaxKind::Dash, + TokenKind::Slash => SyntaxKind::Slash, + TokenKind::Dot => SyntaxKind::Dot, + + TokenKind::StringLiteral(_) => SyntaxKind::StringLiteral, + TokenKind::NumberLiteral(_) => SyntaxKind::NumberLiteral, + TokenKind::IpAddress(_) => SyntaxKind::IpAddress, + TokenKind::Ipv6Address(_) => SyntaxKind::Ipv6Address, + TokenKind::MacAddress(_) => SyntaxKind::MacAddress, + TokenKind::Identifier(_) => SyntaxKind::Identifier, + + TokenKind::Newline => SyntaxKind::Newline, + TokenKind::CommentLine(_) => SyntaxKind::Comment, + TokenKind::Shebang(_) => SyntaxKind::Shebang, + + TokenKind::Error => SyntaxKind::Error, + } + } +} + +impl fmt::Display for SyntaxKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name = match self { + SyntaxKind::Root => "ROOT", + SyntaxKind::Table => "TABLE", + SyntaxKind::Chain => "CHAIN", + SyntaxKind::Rule => "RULE", + SyntaxKind::Set => "SET", + SyntaxKind::Map => "MAP", + SyntaxKind::Element => "ELEMENT", + SyntaxKind::Expression => "EXPRESSION", + SyntaxKind::BinaryExpr => "BINARY_EXPR", + SyntaxKind::UnaryExpr => "UNARY_EXPR", + SyntaxKind::CallExpr => "CALL_EXPR", + SyntaxKind::SetExpr => "SET_EXPR", + SyntaxKind::RangeExpr => "RANGE_EXPR", + SyntaxKind::Statement => "STATEMENT", + SyntaxKind::IncludeStmt => "INCLUDE_STMT", + SyntaxKind::DefineStmt => "DEFINE_STMT", + SyntaxKind::FlushStmt => "FLUSH_STMT", + SyntaxKind::AddStmt => "ADD_STMT", + SyntaxKind::DeleteStmt => "DELETE_STMT", + SyntaxKind::Identifier => "IDENTIFIER", + SyntaxKind::StringLiteral => "STRING_LITERAL", + SyntaxKind::NumberLiteral => "NUMBER_LITERAL", + SyntaxKind::IpAddress => "IP_ADDRESS", + SyntaxKind::Ipv6Address => "IPV6_ADDRESS", + SyntaxKind::MacAddress => "MAC_ADDRESS", + SyntaxKind::Whitespace => "WHITESPACE", + SyntaxKind::Newline => "NEWLINE", + SyntaxKind::Comment => "COMMENT", + SyntaxKind::Shebang => "SHEBANG", + SyntaxKind::Error => "ERROR", + _ => return write!(f, "{:?}", self), + }; + write!(f, "{}", name) + } +} + +/// Language definition for nftables CST +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct NftablesLanguage; + +impl SyntaxKind { + pub fn to_raw(self) -> RawSyntaxKind { + RawSyntaxKind(self as u32) + } + + pub fn from_raw(raw: RawSyntaxKind) -> Self { + unsafe { std::mem::transmute(raw.0 as u16) } + } +} + +/// CST construction errors +#[derive(Error, Debug, PartialEq)] +pub enum CstError { + #[error("Unexpected token: expected {expected}, found {found}")] + UnexpectedToken { expected: String, found: String }, + #[error("Missing token: expected {expected}")] + MissingToken { expected: String }, + #[error("Invalid syntax at position {position}")] + InvalidSyntax { position: usize }, +} + +/// Result type for CST operations +pub type CstResult = Result; + +/// Basic CST builder +pub struct CstBuilder; + +impl CstBuilder { + /// Build a proper CST from tokens that represents nftables syntax structure + pub fn build_tree(tokens: &[Token]) -> GreenNode { + let mut builder = CstTreeBuilder::new(); + builder.parse_tokens(tokens); + builder.finish() + } + + /// Validate CST tree structure according to nftables grammar rules + pub fn validate_tree(node: &GreenNode) -> CstResult<()> { + let validator = CstValidator::new(); + validator.validate(node) + } + + /// Parse tokens into a validated CST + pub fn parse_to_cst(tokens: &[Token]) -> CstResult { + let tree = Self::build_tree(tokens); + Self::validate_tree(&tree)?; + Ok(tree) + } +} + +/// Internal tree builder that constructs CST according to nftables grammar +struct CstTreeBuilder { + stack: Vec<(SyntaxKind, Vec>)>, +} + +impl CstTreeBuilder { + fn new() -> Self { + Self { + stack: vec![(SyntaxKind::Root, Vec::new())], + } + } + + fn parse_tokens(&mut self, tokens: &[Token]) { + let mut i = 0; + while i < tokens.len() { + i = self.parse_top_level(&tokens, i); + } + } + + fn parse_top_level(&mut self, tokens: &[Token], start: usize) -> usize { + if start >= tokens.len() { + return start; + } + + match &tokens[start].kind { + TokenKind::Table => self.parse_table(tokens, start), + TokenKind::Chain => self.parse_chain(tokens, start), + TokenKind::Include => self.parse_include(tokens, start), + TokenKind::Define => self.parse_define(tokens, start), + TokenKind::Flush => self.parse_flush_stmt(tokens, start), + TokenKind::Add => self.parse_add_stmt(tokens, start), + TokenKind::Delete => self.parse_delete_stmt(tokens, start), + TokenKind::Element => self.parse_element(tokens, start), + TokenKind::CommentLine(_) => { + self.add_token(&tokens[start], SyntaxKind::Comment); + start + 1 + }, + TokenKind::Newline => { + self.add_token(&tokens[start], SyntaxKind::Newline); + start + 1 + }, + TokenKind::Shebang(_) => { + self.add_token(&tokens[start], SyntaxKind::Shebang); + start + 1 + }, + _ => { + // Handle other tokens as statements + self.parse_statement(tokens, start) + } + } + } + + fn parse_table(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Table); + + // Add 'table' keyword + self.add_token(&tokens[start], SyntaxKind::TableKw); + let mut pos = start + 1; + + // Parse family (inet, ip, ip6, etc.) + if pos < tokens.len() { + match &tokens[pos].kind { + TokenKind::Inet => self.add_token(&tokens[pos], SyntaxKind::InetKw), + TokenKind::Ip => self.add_token(&tokens[pos], SyntaxKind::IpKw), + TokenKind::Ip6 => self.add_token(&tokens[pos], SyntaxKind::Ip6Kw), + TokenKind::Arp => self.add_token(&tokens[pos], SyntaxKind::ArpKw), + TokenKind::Bridge => self.add_token(&tokens[pos], SyntaxKind::BridgeKw), + TokenKind::Netdev => self.add_token(&tokens[pos], SyntaxKind::NetdevKw), + _ => self.add_token(&tokens[pos], SyntaxKind::Identifier), + } + pos += 1; + } + + // Parse table name + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::Identifier); + pos += 1; + } + + // Parse table body + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) { + self.add_token(&tokens[pos], SyntaxKind::LeftBrace); + pos += 1; + + // Parse table contents + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { + match &tokens[pos].kind { + TokenKind::Chain => { + pos = self.parse_chain(tokens, pos); + }, + TokenKind::Set => { + pos = self.parse_set(tokens, pos); + }, + TokenKind::Map => { + pos = self.parse_map(tokens, pos); + }, + TokenKind::Newline | TokenKind::CommentLine(_) => { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + }, + _ => pos += 1, // Skip unknown tokens + } + } + + // Add closing brace + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::RightBrace); + pos += 1; + } + } + + self.finish_node(); + pos + } + + fn parse_chain(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Chain); + + self.add_token(&tokens[start], SyntaxKind::ChainKw); + let mut pos = start + 1; + + // Parse chain name + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::Identifier); + pos += 1; + } + + // Parse chain body + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) { + self.add_token(&tokens[pos], SyntaxKind::LeftBrace); + pos += 1; + + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { + match &tokens[pos].kind { + TokenKind::Type => { + pos = self.parse_chain_properties(tokens, pos); + }, + TokenKind::Newline | TokenKind::CommentLine(_) => { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + }, + _ => { + // Parse as rule + pos = self.parse_rule(tokens, pos); + } + } + } + + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::RightBrace); + pos += 1; + } + } + + self.finish_node(); + pos + } + + fn parse_chain_properties(&mut self, tokens: &[Token], start: usize) -> usize { + let mut pos = start; + + // Parse 'type' + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Type) { + self.add_token(&tokens[pos], SyntaxKind::TypeKw); + pos += 1; + } + + // Parse chain type (filter, nat, route) + if pos < tokens.len() { + match &tokens[pos].kind { + TokenKind::Filter => self.add_token(&tokens[pos], SyntaxKind::FilterKw), + TokenKind::Nat => self.add_token(&tokens[pos], SyntaxKind::NatKw), + TokenKind::Route => self.add_token(&tokens[pos], SyntaxKind::RouteKw), + _ => self.add_token(&tokens[pos], SyntaxKind::Identifier), + } + pos += 1; + } + + // Parse 'hook' + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Hook) { + self.add_token(&tokens[pos], SyntaxKind::HookKw); + pos += 1; + + // Parse hook type + if pos < tokens.len() { + match &tokens[pos].kind { + TokenKind::Input => self.add_token(&tokens[pos], SyntaxKind::InputKw), + TokenKind::Output => self.add_token(&tokens[pos], SyntaxKind::OutputKw), + TokenKind::Forward => self.add_token(&tokens[pos], SyntaxKind::ForwardKw), + TokenKind::Prerouting => self.add_token(&tokens[pos], SyntaxKind::PreroutingKw), + TokenKind::Postrouting => self.add_token(&tokens[pos], SyntaxKind::PostroutingKw), + _ => self.add_token(&tokens[pos], SyntaxKind::Identifier), + } + pos += 1; + } + } + + // Parse 'priority' + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Priority) { + self.add_token(&tokens[pos], SyntaxKind::PriorityKw); + pos += 1; + + if pos < tokens.len() { + if let TokenKind::NumberLiteral(_) = &tokens[pos].kind { + self.add_token(&tokens[pos], SyntaxKind::NumberLiteral); + pos += 1; + } + } + } + + // Parse semicolon + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Semicolon) { + self.add_token(&tokens[pos], SyntaxKind::Semicolon); + pos += 1; + } + + pos + } + + fn parse_rule(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Rule); + + let mut pos = start; + let rule_start = pos; + + // Parse until we hit a newline or end of input (simple rule parsing) + while pos < tokens.len() { + match &tokens[pos].kind { + TokenKind::Newline => { + self.add_token(&tokens[pos], SyntaxKind::Newline); + pos += 1; + break; + }, + TokenKind::Accept => self.add_token(&tokens[pos], SyntaxKind::AcceptKw), + TokenKind::Drop => self.add_token(&tokens[pos], SyntaxKind::DropKw), + TokenKind::Reject => self.add_token(&tokens[pos], SyntaxKind::RejectKw), + TokenKind::Return => self.add_token(&tokens[pos], SyntaxKind::ReturnKw), + TokenKind::Log => self.add_token(&tokens[pos], SyntaxKind::LogKw), + TokenKind::Comment => self.add_token(&tokens[pos], SyntaxKind::CommentKw), + TokenKind::Eq => self.add_token(&tokens[pos], SyntaxKind::EqOp), + TokenKind::Ne => self.add_token(&tokens[pos], SyntaxKind::NeOp), + TokenKind::Lt => self.add_token(&tokens[pos], SyntaxKind::LtOp), + TokenKind::Le => self.add_token(&tokens[pos], SyntaxKind::LeOp), + TokenKind::Gt => self.add_token(&tokens[pos], SyntaxKind::GtOp), + TokenKind::Ge => self.add_token(&tokens[pos], SyntaxKind::GeOp), + TokenKind::LeftParen => { + pos = self.parse_expression(tokens, pos); + continue; + }, + TokenKind::Identifier(_) => self.add_token(&tokens[pos], SyntaxKind::Identifier), + TokenKind::StringLiteral(_) => self.add_token(&tokens[pos], SyntaxKind::StringLiteral), + TokenKind::NumberLiteral(_) => self.add_token(&tokens[pos], SyntaxKind::NumberLiteral), + TokenKind::IpAddress(_) => self.add_token(&tokens[pos], SyntaxKind::IpAddress), + _ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())), + } + pos += 1; + } + + // If no tokens were added to the rule, add at least one statement node + if pos == rule_start { + self.start_node(SyntaxKind::Statement); + self.finish_node(); + } + + self.finish_node(); + pos + } + + fn parse_expression(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Expression); + + let mut pos = start; + let mut paren_depth = 0; + + while pos < tokens.len() { + match &tokens[pos].kind { + TokenKind::LeftParen => { + paren_depth += 1; + self.add_token(&tokens[pos], SyntaxKind::LeftParen); + }, + TokenKind::RightParen => { + self.add_token(&tokens[pos], SyntaxKind::RightParen); + paren_depth -= 1; + if paren_depth == 0 { + pos += 1; + break; + } + }, + TokenKind::LeftBracket => { + pos = self.parse_set_expression(tokens, pos); + continue; + }, + TokenKind::Identifier(_) => { + // Check if this is a function call + if pos + 1 < tokens.len() && matches!(tokens[pos + 1].kind, TokenKind::LeftParen) { + pos = self.parse_call_expression(tokens, pos); + continue; + } else { + self.add_token(&tokens[pos], SyntaxKind::Identifier); + } + }, + TokenKind::Dash => { + // Could be binary or unary expression + if self.is_binary_context() { + self.parse_binary_expression(tokens, pos); + } else { + self.parse_unary_expression(tokens, pos); + } + }, + _ => self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())), + } + pos += 1; + } + + self.finish_node(); + pos + } + + fn parse_set_expression(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::SetExpr); + + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::LeftBracket); + pos += 1; + + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBracket) { + if matches!(tokens[pos].kind, TokenKind::Colon) { + // Range expression + self.start_node(SyntaxKind::RangeExpr); + self.add_token(&tokens[pos], SyntaxKind::Colon); + self.finish_node(); + } else { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + } + pos += 1; + } + + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::RightBracket); + pos += 1; + } + + self.finish_node(); + pos + } + + fn parse_binary_expression(&mut self, tokens: &[Token], pos: usize) { + self.start_node(SyntaxKind::BinaryExpr); + self.add_token(&tokens[pos], SyntaxKind::Dash); + self.finish_node(); + } + + fn parse_unary_expression(&mut self, tokens: &[Token], pos: usize) { + self.start_node(SyntaxKind::UnaryExpr); + self.add_token(&tokens[pos], SyntaxKind::Dash); + self.finish_node(); + } + + fn parse_set(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Set); + // Basic set parsing implementation + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::SetKw); + pos += 1; + // Skip to end of set for now + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { + pos += 1; + } + self.finish_node(); + pos + } + + fn parse_map(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Map); + // Basic map parsing implementation + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::MapKw); + pos += 1; + // Skip to end of map for now + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { + pos += 1; + } + self.finish_node(); + pos + } + + fn parse_include(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::IncludeStmt); + + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::IncludeKw); + pos += 1; + + if pos < tokens.len() { + if let TokenKind::StringLiteral(_) = &tokens[pos].kind { + self.add_token(&tokens[pos], SyntaxKind::StringLiteral); + pos += 1; + } + } + + self.finish_node(); + pos + } + + fn parse_define(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::DefineStmt); + + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::DefineKw); + pos += 1; + + // Parse variable name + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::Identifier); + pos += 1; + } + + // Parse assignment + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Assign) { + self.add_token(&tokens[pos], SyntaxKind::Assign); + pos += 1; + } + + // Parse value - could be complex expression + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + } + + self.finish_node(); + pos + } + + fn parse_flush_stmt(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::FlushStmt); + + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::FlushKw); + pos += 1; + + // Parse what to flush (table, chain, etc.) + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::Semicolon) { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + } + + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::Semicolon) { + self.add_token(&tokens[pos], SyntaxKind::Semicolon); + pos += 1; + } + + self.finish_node(); + pos + } + + fn parse_add_stmt(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::AddStmt); + + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::AddKw); + pos += 1; + + // Parse what to add + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + } + + self.finish_node(); + pos + } + + fn parse_delete_stmt(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::DeleteStmt); + + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::DeleteKw); + pos += 1; + + // Parse what to delete + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline) { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + } + + self.finish_node(); + pos + } + + fn parse_element(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Element); + + let mut pos = start; + self.add_token(&tokens[pos], SyntaxKind::ElementKw); + pos += 1; + + // Parse element specification + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::Newline | TokenKind::LeftBrace) { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + } + + // Parse element body if present + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftBrace) { + self.add_token(&tokens[pos], SyntaxKind::LeftBrace); + pos += 1; + + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightBrace) { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + pos += 1; + } + + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::RightBrace); + pos += 1; + } + } + + self.finish_node(); + pos + } + + fn parse_call_expression(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::CallExpr); + + let mut pos = start; + + // Add function name + self.add_token(&tokens[pos], SyntaxKind::Identifier); + pos += 1; + + // Add opening paren + if pos < tokens.len() && matches!(tokens[pos].kind, TokenKind::LeftParen) { + self.add_token(&tokens[pos], SyntaxKind::LeftParen); + pos += 1; + + // Parse arguments + while pos < tokens.len() && !matches!(tokens[pos].kind, TokenKind::RightParen) { + if matches!(tokens[pos].kind, TokenKind::Comma) { + self.add_token(&tokens[pos], SyntaxKind::Comma); + } else { + self.add_token(&tokens[pos], SyntaxKind::from(tokens[pos].kind.clone())); + } + pos += 1; + } + + // Add closing paren + if pos < tokens.len() { + self.add_token(&tokens[pos], SyntaxKind::RightParen); + pos += 1; + } + } + + self.finish_node(); + pos + } + + fn parse_statement(&mut self, tokens: &[Token], start: usize) -> usize { + self.start_node(SyntaxKind::Statement); + self.add_token(&tokens[start], SyntaxKind::from(tokens[start].kind.clone())); + self.finish_node(); + start + 1 + } + + fn is_binary_context(&self) -> bool { + // Simple heuristic: assume binary if we have recent tokens in current scope + self.stack.last().map_or(false, |(_, children)| children.len() > 1) + } + + fn start_node(&mut self, kind: SyntaxKind) { + self.stack.push((kind, Vec::new())); + } + + fn add_token(&mut self, token: &Token, kind: SyntaxKind) { + // Handle whitespace specially if it's part of the token + if let TokenKind::Identifier(text) = &token.kind { + if text.chars().all(|c| c.is_whitespace()) { + let whitespace_token = GreenNode::new(SyntaxKind::Whitespace.to_raw(), std::iter::empty()); + if let Some((_, children)) = self.stack.last_mut() { + children.push(NodeOrToken::Node(whitespace_token)); + } + return; + } + } + + let token_node = GreenNode::new(kind.to_raw(), std::iter::empty()); + if let Some((_, children)) = self.stack.last_mut() { + children.push(NodeOrToken::Node(token_node)); + } + } + + fn finish_node(&mut self) { + if let Some((kind, children)) = self.stack.pop() { + let node = GreenNode::new(kind.to_raw(), children); + if let Some((_, parent_children)) = self.stack.last_mut() { + parent_children.push(NodeOrToken::Node(node)); + } else { + // This is the root node + self.stack.push((kind, vec![NodeOrToken::Node(node)])); + } + } + } + + fn finish(self) -> GreenNode { + if let Some((kind, children)) = self.stack.into_iter().next() { + GreenNode::new(kind.to_raw(), children) + } else { + GreenNode::new(SyntaxKind::Root.to_raw(), std::iter::empty()) + } + } +} + +/// Validator for CST structure according to nftables grammar +struct CstValidator; + +impl CstValidator { + fn new() -> Self { + Self + } + + fn validate(&self, node: &GreenNode) -> CstResult<()> { + let kind = SyntaxKind::from_raw(node.kind()); + + match kind { + SyntaxKind::Root => self.validate_root(node), + SyntaxKind::Table => self.validate_table(node), + SyntaxKind::Chain => self.validate_chain(node), + SyntaxKind::Rule => self.validate_rule(node), + SyntaxKind::Set | SyntaxKind::Map => self.validate_set_or_map(node), + SyntaxKind::Element => self.validate_element(node), + SyntaxKind::Expression | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::CallExpr => { + self.validate_expression(node) + }, + SyntaxKind::IncludeStmt => self.validate_include(node), + SyntaxKind::DefineStmt => self.validate_define(node), + SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt => self.validate_statement(node), + SyntaxKind::Whitespace => Ok(()), // Whitespace is always valid + _ => Ok(()), // Other nodes are generally valid + } + } + + fn validate_root(&self, node: &GreenNode) -> CstResult<()> { + let children: Vec<_> = node.children().collect(); + + if children.is_empty() { + return Err(CstError::MissingToken { + expected: "at least one declaration or statement".to_string(), + }); + } + + // Validate that root contains valid top-level constructs + for child in children { + if let Some(child_node) = child.as_node() { + let child_kind = SyntaxKind::from_raw(child_node.kind()); + match child_kind { + SyntaxKind::Table | SyntaxKind::IncludeStmt | SyntaxKind::DefineStmt | + SyntaxKind::FlushStmt | SyntaxKind::AddStmt | SyntaxKind::DeleteStmt | + SyntaxKind::Element | SyntaxKind::Comment | SyntaxKind::Newline | + SyntaxKind::Shebang | SyntaxKind::Whitespace => { + self.validate(child_node)?; + }, + _ => { + return Err(CstError::UnexpectedToken { + expected: "table, include, define, flush, add, delete, element, or comment".to_string(), + found: format!("{:?}", child_kind), + }); + } + } + } + } + + Ok(()) + } + + fn validate_table(&self, node: &GreenNode) -> CstResult<()> { + let children: Vec<_> = node.children().collect(); + + if children.len() < 3 { + return Err(CstError::MissingToken { + expected: "table keyword, family, and name".to_string(), + }); + } + + // Validate table structure: table { ... } + let mut child_iter = children.iter(); + + // First should be 'table' keyword + if let Some(first) = child_iter.next() { + if let Some(first_node) = first.as_node() { + let first_kind = SyntaxKind::from_raw(first_node.kind()); + if first_kind != SyntaxKind::TableKw { + return Err(CstError::UnexpectedToken { + expected: "table keyword".to_string(), + found: format!("{:?}", first_kind), + }); + } + } + } + + // Second should be family + if let Some(second) = child_iter.next() { + if let Some(second_node) = second.as_node() { + let second_kind = SyntaxKind::from_raw(second_node.kind()); + match second_kind { + SyntaxKind::InetKw | SyntaxKind::IpKw | SyntaxKind::Ip6Kw | + SyntaxKind::ArpKw | SyntaxKind::BridgeKw | SyntaxKind::NetdevKw => {}, + _ => { + return Err(CstError::UnexpectedToken { + expected: "table family (inet, ip, ip6, arp, bridge, netdev)".to_string(), + found: format!("{:?}", second_kind), + }); + } + } + } + } + + // Recursively validate children + for child in children { + if let Some(child_node) = child.as_node() { + self.validate(child_node)?; + } + } + + Ok(()) + } + + fn validate_chain(&self, node: &GreenNode) -> CstResult<()> { + // Basic chain validation + for child in node.children() { + if let Some(child_node) = child.as_node() { + self.validate(child_node)?; + } + } + Ok(()) + } + + fn validate_rule(&self, node: &GreenNode) -> CstResult<()> { + // Rules should have at least some content + let children: Vec<_> = node.children().collect(); + + if children.is_empty() { + return Err(CstError::InvalidSyntax { position: 0 }); + } + + // Validate rule children + for child in children { + if let Some(child_node) = child.as_node() { + self.validate(child_node)?; + } + } + + Ok(()) + } + + fn validate_set_or_map(&self, _node: &GreenNode) -> CstResult<()> { + // Basic validation for sets and maps + Ok(()) + } + + fn validate_expression(&self, node: &GreenNode) -> CstResult<()> { + // Validate expression structure + for child in node.children() { + if let Some(child_node) = child.as_node() { + self.validate(child_node)?; + } + } + Ok(()) + } + + fn validate_include(&self, node: &GreenNode) -> CstResult<()> { + let children: Vec<_> = node.children().collect(); + + if children.len() < 2 { + return Err(CstError::MissingToken { + expected: "include keyword and file path".to_string(), + }); + } + + Ok(()) + } + + fn validate_define(&self, node: &GreenNode) -> CstResult<()> { + let children: Vec<_> = node.children().collect(); + + if children.len() < 3 { + return Err(CstError::MissingToken { + expected: "define keyword, variable name, and value".to_string(), + }); + } + + Ok(()) + } + + fn validate_element(&self, _node: &GreenNode) -> CstResult<()> { + // Basic validation for element statements + Ok(()) + } + + fn validate_statement(&self, node: &GreenNode) -> CstResult<()> { + // Basic validation for statements like flush, add, delete + let children: Vec<_> = node.children().collect(); + + if children.is_empty() { + return Err(CstError::MissingToken { + expected: "statement keyword".to_string(), + }); + } + + Ok(()) + } +} + +impl Default for CstBuilder { + fn default() -> Self { + Self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lexer::NftablesLexer; + + #[test] + fn test_cst_construction() { + let source = "table inet filter { }"; + let mut lexer = NftablesLexer::new(source); + let tokens = lexer.tokenize().expect("Tokenization should succeed"); + + // CST is now implemented - test that it works + let green_tree = CstBuilder::build_tree(&tokens); + + // Verify the tree was created successfully + assert_eq!(green_tree.kind(), SyntaxKind::Root.to_raw()); + + // Test validation works + let validation_result = CstBuilder::validate_tree(&green_tree); + if let Err(ref e) = validation_result { + println!("Validation error: {}", e); + } + assert!(validation_result.is_ok()); + + // Test parse_to_cst works + let cst_result = CstBuilder::parse_to_cst(&tokens); + assert!(cst_result.is_ok()); + } +} diff --git a/src/lexer.rs b/src/lexer.rs new file mode 100644 index 0000000..6a9e087 --- /dev/null +++ b/src/lexer.rs @@ -0,0 +1,416 @@ +use logos::{Lexer, Logos}; +use std::fmt; +use text_size::{TextRange, TextSize}; +use thiserror::Error; + +/// Lexical analysis errors +#[derive(Error, Debug, PartialEq)] +pub enum LexError { + #[error("Invalid token at position {position}: {text}")] + InvalidToken { position: usize, text: String }, + #[error("Unterminated string literal starting at position {position}")] + UnterminatedString { position: usize }, + #[error("Invalid numeric literal: {text}")] + InvalidNumber { text: String }, +} + +/// Result type for lexical analysis +pub type LexResult = Result; + +/// Token kinds for nftables configuration files +#[derive(Logos, Debug, Clone, PartialEq, Eq, Hash)] +#[logos(skip r"[ \t\f]+")] // Skip whitespace but not newlines +pub enum TokenKind { + // Keywords + #[token("table")] + Table, + #[token("chain")] + Chain, + #[token("rule")] + Rule, + #[token("set")] + Set, + #[token("map")] + Map, + #[token("element")] + Element, + #[token("include")] + Include, + #[token("define")] + Define, + #[token("flush")] + Flush, + #[token("add")] + Add, + #[token("delete")] + Delete, + #[token("insert")] + Insert, + #[token("replace")] + Replace, + + // Chain types + #[token("filter")] + Filter, + #[token("nat")] + Nat, + #[token("route")] + Route, + + // Hooks + #[token("input")] + Input, + #[token("output")] + Output, + #[token("forward")] + Forward, + #[token("prerouting")] + Prerouting, + #[token("postrouting")] + Postrouting, + + // Protocols and families + #[token("ip")] + Ip, + #[token("ip6")] + Ip6, + #[token("inet")] + Inet, + #[token("arp")] + Arp, + #[token("bridge")] + Bridge, + #[token("netdev")] + Netdev, + #[token("tcp")] + Tcp, + #[token("udp")] + Udp, + #[token("icmp")] + Icmp, + #[token("icmpv6")] + Icmpv6, + + // Match keywords + #[token("sport")] + Sport, + #[token("dport")] + Dport, + #[token("saddr")] + Saddr, + #[token("daddr")] + Daddr, + #[token("protocol")] + Protocol, + #[token("nexthdr")] + Nexthdr, + #[token("type")] + Type, + #[token("hook")] + Hook, + #[token("priority")] + Priority, + #[token("policy")] + Policy, + #[token("iifname")] + Iifname, + #[token("oifname")] + Oifname, + #[token("ct")] + Ct, + #[token("state")] + State, + #[token("established")] + Established, + #[token("related")] + Related, + #[token("invalid")] + Invalid, + #[token("new")] + New, + + // Actions + #[token("accept")] + Accept, + #[token("drop")] + Drop, + #[token("reject")] + Reject, + #[token("return")] + Return, + #[token("jump")] + Jump, + #[token("goto")] + Goto, + #[token("continue")] + Continue, + #[token("log")] + Log, + #[token("comment")] + Comment, + + // Operators + #[token("==")] + Eq, + #[token("!=")] + Ne, + #[token("<=")] + Le, + #[token(">=")] + Ge, + #[token("<")] + Lt, + #[token(">")] + Gt, + + // Punctuation + #[token("{")] + LeftBrace, + #[token("}")] + RightBrace, + #[token("(")] + LeftParen, + #[token(")")] + RightParen, + #[token("[")] + LeftBracket, + #[token("]")] + RightBracket, + #[token(",")] + Comma, + #[token(";")] + Semicolon, + #[token(":")] + Colon, + #[token("=")] + Assign, + #[token("-")] + Dash, + #[token("/")] + Slash, + #[token(".")] + Dot, + + // Literals + #[regex(r#""([^"\\]|\\.)*""#, string_literal)] + StringLiteral(String), + #[regex(r"[0-9]+", number_literal, priority = 2)] + NumberLiteral(u64), + #[regex(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+", |lex| lex.slice().to_owned())] + IpAddress(String), + #[regex(r"(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:)*::[0-9a-fA-F:]*", ipv6_address, priority = 5)] + Ipv6Address(String), + #[regex(r"[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}:[a-fA-F0-9]{2}", |lex| lex.slice().to_owned())] + MacAddress(String), + + // Identifiers + #[regex(r"[a-zA-Z_][a-zA-Z0-9_-]*", |lex| lex.slice().to_owned(), priority = 1)] + Identifier(String), + + // Special tokens + #[token("\n")] + Newline, + #[regex(r"#[^\n]*", comment_literal)] + CommentLine(String), + #[regex(r"#![^\n]*", shebang_literal)] + Shebang(String), + + // Error token + Error, +} + +impl fmt::Display for TokenKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TokenKind::Table => write!(f, "table"), + TokenKind::Chain => write!(f, "chain"), + TokenKind::Rule => write!(f, "rule"), + TokenKind::LeftBrace => write!(f, "{{"), + TokenKind::RightBrace => write!(f, "}}"), + TokenKind::Identifier(_) => write!(f, "identifier"), + TokenKind::StringLiteral(_) => write!(f, "string"), + TokenKind::NumberLiteral(_) => write!(f, "number"), + TokenKind::IpAddress(_) => write!(f, "ip_address"), + TokenKind::Ipv6Address(_) => write!(f, "ipv6_address"), + TokenKind::MacAddress(_) => write!(f, "mac_address"), + TokenKind::Newline => write!(f, "newline"), + TokenKind::CommentLine(_) => write!(f, "comment"), + TokenKind::Shebang(_) => write!(f, "shebang"), + TokenKind::Error => write!(f, "error"), + _ => write!(f, "{:?}", self), + } + } +} + +fn string_literal(lex: &mut Lexer) -> String { + let slice = lex.slice(); + // Remove surrounding quotes and process escape sequences + slice[1..slice.len()-1].replace("\\\"", "\"").replace("\\\\", "\\") +} + +fn number_literal(lex: &mut Lexer) -> Option { + lex.slice().parse().ok() +} + +fn ipv6_address(lex: &mut Lexer) -> Option { + let slice = lex.slice(); + // Basic validation for IPv6 address format + if slice.contains("::") || slice.matches(':').count() >= 2 { + Some(slice.to_owned()) + } else { + None + } +} + +fn comment_literal(lex: &mut Lexer) -> String { + let slice = lex.slice(); + slice[1..].to_owned() // Remove the '#' prefix +} + +fn shebang_literal(lex: &mut Lexer) -> String { + let slice = lex.slice(); + slice[2..].to_owned() // Remove the '#!' prefix +} + +/// A token with its kind and range in the source text +#[derive(Debug, Clone, PartialEq)] +pub struct Token { + pub kind: TokenKind, + pub range: TextRange, + pub text: String, +} + +impl Token { + pub fn new(kind: TokenKind, range: TextRange, text: String) -> Self { + Self { kind, range, text } + } +} + +/// Tokenizer for nftables configuration files +pub struct NftablesLexer<'a> { + lexer: Lexer<'a, TokenKind>, + source: &'a str, +} + +impl<'a> NftablesLexer<'a> { + pub fn new(source: &'a str) -> Self { + Self { + lexer: TokenKind::lexer(source), + source, + } + } + + /// Tokenize the source text, returning all tokens or the first error encountered + pub fn tokenize(&mut self) -> LexResult> { + let mut tokens = Vec::new(); + + while let Some(result) = self.lexer.next() { + let span = self.lexer.span(); + let text = &self.source[span.clone()]; + let range = TextRange::new( + TextSize::from(span.start as u32), + TextSize::from(span.end as u32) + ); + + match result { + Ok(kind) => { + tokens.push(Token::new(kind, range, text.to_owned())); + } + Err(_) => { + // Analyze the text to determine specific error type + if text.starts_with('"') && !text.ends_with('"') { + return Err(LexError::UnterminatedString { + position: span.start, + }); + } else if text.chars().any(|c| c.is_ascii_digit()) && + text.chars().any(|c| !c.is_ascii_digit() && c != '.' && c != 'x' && c != 'X') { + return Err(LexError::InvalidNumber { + text: text.to_owned(), + }); + } else { + return Err(LexError::InvalidToken { + position: span.start, + text: text.to_owned(), + }); + } + } + } + } + + Ok(tokens) + } + + /// Get all tokens including error tokens for error recovery + pub fn tokenize_with_errors(&mut self) -> Vec { + let mut tokens = Vec::new(); + + while let Some(result) = self.lexer.next() { + let span = self.lexer.span(); + let text = &self.source[span.clone()]; + let range = TextRange::new( + TextSize::from(span.start as u32), + TextSize::from(span.end as u32) + ); + + let kind = result.unwrap_or(TokenKind::Error); + tokens.push(Token::new(kind, range, text.to_owned())); + } + + tokens + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_tokenization() { + let source = "table inet filter {\n chain input {\n type filter hook input priority 0;\n }\n}"; + let mut lexer = NftablesLexer::new(source); + let tokens = lexer.tokenize().expect("Tokenization should succeed"); + + assert!(!tokens.is_empty()); + + assert_eq!(tokens[0].kind, TokenKind::Table); + assert_eq!(tokens[1].kind, TokenKind::Inet); + assert_eq!(tokens[2].kind, TokenKind::Filter); + assert_eq!(tokens[3].kind, TokenKind::LeftBrace); + } + + #[test] + fn test_ip_address_tokenization() { + let source = "192.168.1.1"; + let mut lexer = NftablesLexer::new(source); + let tokens = lexer.tokenize().expect("Tokenization should succeed"); + + assert_eq!(tokens.len(), 1); + assert!(matches!(tokens[0].kind, TokenKind::IpAddress(_))); + assert_eq!(tokens[0].text, "192.168.1.1"); + } + + #[test] + fn test_comment_tokenization() { + let source = "# This is a comment\ntable inet test"; + let mut lexer = NftablesLexer::new(source); + let tokens = lexer.tokenize().expect("Tokenization should succeed"); + + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::CommentLine(_)))); + assert!(tokens.iter().any(|t| t.kind == TokenKind::Table)); + } + + #[test] + fn test_error_handling() { + let source = "table ∞ filter"; // Invalid character + let mut lexer = NftablesLexer::new(source); + let result = lexer.tokenize(); + + assert!(result.is_err()); + if let Err(LexError::InvalidToken { position, text }) = result { + assert_eq!(position, 6); // Position of the invalid character + assert_eq!(text, "∞"); + } else { + panic!("Expected InvalidToken error"); + } + } +} diff --git a/src/main.rs b/src/main.rs index 2704820..5985c74 100755 --- a/src/main.rs +++ b/src/main.rs @@ -1,47 +1,33 @@ -use std::fs::{self, File}; -use std::io::{self, BufRead, BufReader, Write}; +mod ast; +mod cst; +mod lexer; +mod parser; +mod syntax; + +use std::fs; +use std::io::{self, Write}; use std::path::Path; use clap::Parser; use anyhow::{Context, Result}; use thiserror::Error; +use crate::lexer::NftablesLexer; +use crate::parser::Parser as NftablesParser; +use crate::syntax::{FormatConfig, IndentStyle, NftablesFormatter}; +use crate::cst::CstBuilder; + #[derive(Error, Debug)] enum FormatterError { #[error("File not found: {0}")] FileNotFound(String), #[error("Invalid file: {0}")] InvalidFile(String), + #[error("Parse error: {0}")] + ParseError(String), #[error("IO error: {0}")] Io(#[from] io::Error), } -#[derive(Debug, Clone, Copy)] -enum IndentStyle { - Tabs, - Spaces, -} - -impl IndentStyle { - fn format(&self, level: usize, spaces_per_level: usize) -> String { - match self { - Self::Tabs => "\t".repeat(level), - Self::Spaces => " ".repeat(spaces_per_level * level), - } - } -} - -impl std::str::FromStr for IndentStyle { - type Err = String; - - fn from_str(s: &str) -> std::result::Result { - match s.to_lowercase().as_str() { - "tabs" | "tab" => Ok(Self::Tabs), - "spaces" | "space" => Ok(Self::Spaces), - _ => Err(format!("Invalid indent style: {}. Use 'tabs' or 'spaces'", s)), - } - } -} - #[derive(Parser, Debug)] #[command( name = "nff", @@ -69,87 +55,14 @@ struct Args { /// Number of spaces per indentation level (only used with --indent=spaces) #[arg(long, default_value = "2", value_name = "N")] spaces: usize, -} -struct NftablesFormatter { - indent_style: IndentStyle, - spaces_per_level: usize, - optimize: bool, -} + /// Show debug information (tokens, AST, etc.) + #[arg(long)] + debug: bool, -impl NftablesFormatter { - fn new(indent_style: IndentStyle, spaces_per_level: usize, optimize: bool) -> Self { - Self { - indent_style, - spaces_per_level, - optimize, - } - } - - fn format_lines(&self, lines: Vec) -> Vec { - let mut output_lines = Vec::new(); - let mut level = 0; - let mut prev_was_empty = false; - - for (i, line) in lines.iter().enumerate() { - let line = line.trim(); - - // Handle empty lines - if line.is_empty() { - if self.optimize { - if prev_was_empty { - continue; - } - prev_was_empty = true; - } else { - prev_was_empty = false; - } - output_lines.push(String::new()); - continue; - } else { - prev_was_empty = false; - } - - // Skip lines that contain both opening and closing braces (single-line blocks) - if line.contains('{') && line.contains('}') { - continue; - } - - // Adjust indentation level before formatting if this line closes a block - if line.ends_with('}') || line == "}" { - if level > 0 { - level -= 1; - } - } - - // Generate indentation - let indentation = self.indent_style.format(level, self.spaces_per_level); - - // Format the line - let formatted_line = format!("{}{}", indentation, line); - - // Skip empty lines before closing braces if optimizing - if self.optimize && i > 0 && lines[i-1].trim().is_empty() { - if line.ends_with('}') || line == "}" { - // Remove the last empty line - if let Some(last) = output_lines.last() { - if last.trim().is_empty() { - output_lines.pop(); - } - } - } - } - - output_lines.push(formatted_line); - - // Adjust indentation level after formatting if this line opens a block - if line.ends_with('{') { - level += 1; - } - } - - output_lines - } + /// Check syntax only, don't format + #[arg(long)] + check: bool, } fn process_nftables_config(args: Args) -> Result<()> { @@ -162,33 +75,92 @@ fn process_nftables_config(args: Args) -> Result<()> { return Err(FormatterError::InvalidFile("Not a regular file".to_string()).into()); } - let file = File::open(&args.file) - .with_context(|| format!("Failed to open file: {}", args.file))?; + // Read file contents + let source = fs::read_to_string(&args.file) + .with_context(|| format!("Failed to read file: {}", args.file))?; - let reader = BufReader::new(file); - let lines: Result, io::Error> = reader.lines().collect(); - let lines = lines.with_context(|| "Failed to read file contents")?; - - let formatter = NftablesFormatter::new(args.indent, args.spaces, args.optimize); - let formatted_lines = formatter.format_lines(lines); - - // Create output content - let output_content = formatted_lines.join("\n"); - let output_content = if !output_content.ends_with('\n') && !output_content.is_empty() { - format!("{}\n", output_content) + // Tokenize + let mut lexer = NftablesLexer::new(&source); + let tokens = if args.debug { + // Use error-recovery tokenization for debug mode + lexer.tokenize_with_errors() } else { - output_content + lexer.tokenize() + .map_err(|e| FormatterError::ParseError(e.to_string()))? }; + if args.debug { + eprintln!("=== TOKENS ==="); + for (i, token) in tokens.iter().enumerate() { + eprintln!("{:3}: {:?} @ {:?} = '{}'", i, token.kind, token.range, token.text); + } + eprintln!(); + + // Build and validate CST + eprintln!("=== CST ==="); + let cst_tree = CstBuilder::build_tree(&tokens); + match CstBuilder::validate_tree(&cst_tree) { + Ok(()) => eprintln!("CST validation passed"), + Err(e) => eprintln!("CST validation error: {}", e), + } + + // Also test parse_to_cst + match CstBuilder::parse_to_cst(&tokens) { + Ok(_) => eprintln!("CST parsing successful"), + Err(e) => eprintln!("CST parsing error: {}", e), + } + eprintln!(); + } + + // Parse + let ruleset = if args.debug { + // Use error-recovery parsing for debug mode + let (parsed_ruleset, errors) = NftablesParser::parse_with_errors(&source); + if !errors.is_empty() { + eprintln!("=== PARSE ERRORS ==="); + for error in &errors { + eprintln!("Parse error: {}", error); + } + eprintln!(); + } + parsed_ruleset.unwrap_or_else(|| crate::ast::Ruleset::new()) + } else { + let mut parser = NftablesParser::new(tokens.clone()); + parser.parse() + .map_err(|e| FormatterError::ParseError(e.to_string()))? + }; + + if args.debug { + eprintln!("=== AST ==="); + eprintln!("{:#?}", ruleset); + eprintln!(); + } + + if args.check { + println!("Syntax check passed for: {}", args.file); + return Ok(()); + } + + // Format + let config = FormatConfig { + indent_style: args.indent, + spaces_per_level: args.spaces, + optimize: args.optimize, + max_empty_lines: if args.optimize { 1 } else { 2 }, + }; + + let formatter = NftablesFormatter::new(config); + let formatted_output = formatter.format_ruleset(&ruleset); + // Write output match &args.output { Some(output_file) => { - fs::write(output_file, output_content) + fs::write(output_file, &formatted_output) .with_context(|| format!("Failed to write to output file: {}", output_file))?; println!("Formatted output written to: {}", output_file); } None => { - io::stdout().write_all(output_content.as_bytes()) + io::stdout().write_all(formatted_output.as_bytes()) .with_context(|| "Failed to write to stdout")?; } } @@ -198,5 +170,19 @@ fn process_nftables_config(args: Args) -> Result<()> { fn main() -> Result<()> { let args = Args::parse(); - process_nftables_config(args) + + if let Err(e) = process_nftables_config(args) { + eprintln!("Error: {}", e); + + // Print the error chain + let mut current = e.source(); + while let Some(cause) = current { + eprintln!(" Caused by: {}", cause); + current = cause.source(); + } + + std::process::exit(1); + } + + Ok(()) } diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..aed23e8 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,1044 @@ +use crate::ast::*; +use crate::lexer::{LexError, Token, TokenKind}; +use anyhow::{anyhow, Result}; +use thiserror::Error; + +/// Parse errors for nftables configuration +#[derive(Error, Debug)] +pub enum ParseError { + #[error("Unexpected token at line {line}, column {column}: expected {expected}, found '{found}'")] + UnexpectedToken { + line: usize, + column: usize, + expected: String, + found: String, + }, + #[error("Missing token: expected {expected}")] + MissingToken { expected: String }, + #[error("Invalid expression: {message}")] + InvalidExpression { message: String }, + #[error("Invalid statement: {message}")] + InvalidStatement { message: String }, + #[error("Lexical error: {0}")] + LexError(#[from] LexError), + #[error("Semantic error: {message}")] + SemanticError { message: String }, + #[error("Anyhow error: {0}")] + AnyhowError(#[from] anyhow::Error), +} + +/// Result type for parsing operations +pub type ParseResult = Result; + +/// Parsing context for maintaining state during parsing +#[derive(Debug)] +pub struct ParseContext { + pub in_table: bool, + pub in_chain: bool, + pub in_rule: bool, + pub current_table: Option, + pub current_chain: Option, +} + +impl ParseContext { + fn new() -> Self { + Self { + in_table: false, + in_chain: false, + in_rule: false, + current_table: None, + current_chain: None, + } + } +} + +/// Parser for nftables configuration files +#[derive(Debug)] +pub struct Parser { + tokens: Vec, + current: usize, + context: ParseContext, +} + +impl Parser { + /// Create a new parser with the given tokens + pub fn new(tokens: Vec) -> Self { + Self { + tokens, + current: 0, + context: ParseContext::new(), + } + } + + /// Parse the entire ruleset + pub fn parse(&mut self) -> ParseResult { + let mut ruleset = Ruleset::new(); + + // Skip initial whitespace and newlines + self.skip_whitespace(); + + // Parse shebang if present + if self.current_token_is(&TokenKind::Shebang(String::new())) { + if let Some(token) = self.advance() { + if let TokenKind::Shebang(content) = &token.kind { + ruleset = ruleset.with_shebang(content.clone()); + } + } + self.skip_whitespace(); + } + + // Parse top-level statements + while !self.is_at_end() { + self.skip_whitespace(); + + if self.is_at_end() { + break; + } + + match self.peek().map(|t| &t.kind) { + Some(TokenKind::Include) => { + let include = self.parse_include()?; + ruleset = ruleset.add_include(include); + } + Some(TokenKind::Define) => { + let define = self.parse_define()?; + ruleset = ruleset.add_define(define); + } + Some(TokenKind::Table) => { + let table = self.parse_table()?; + ruleset = ruleset.add_table(table); + } + Some(TokenKind::CommentLine(_)) => { + if let Some(token) = self.advance() { + if let TokenKind::CommentLine(comment_text) = &token.kind { + ruleset = ruleset.add_comment(comment_text.clone()); + } + } + } + Some(TokenKind::Newline) => { + self.advance(); + } + _ => { + return Err(ParseError::UnexpectedToken { + line: 1, + column: 1, + expected: "include, define, table, or comment".to_string(), + found: self.peek().map(|t| t.text.clone()).unwrap_or("EOF".to_string()), + }); + } + } + } + + Ok(ruleset) + } + + /// Parse the entire ruleset with error recovery + pub fn parse_with_errors(source: &str) -> (Option, Vec) { + let mut lexer = crate::lexer::NftablesLexer::new(source); + let tokens = lexer.tokenize_with_errors(); // Use the error-recovery tokenizer + + let mut parser = Self::new(tokens); + let mut errors = Vec::new(); + + match parser.parse() { + Ok(ruleset) => (Some(ruleset), errors), + Err(err) => { + errors.push(err); + (None, errors) + } + } + } + + /// Parse an include statement + fn parse_include(&mut self) -> ParseResult { + self.consume(TokenKind::Include, "Expected 'include'")?; + + let path = match self.advance() { + Some(token) if matches!(token.kind, TokenKind::StringLiteral(_)) => { + match &token.kind { + TokenKind::StringLiteral(s) => s.clone(), + _ => unreachable!(), + } + } + None => return Err(ParseError::MissingToken { + expected: "string literal after 'include'".to_string() + }), + Some(token) => return Err(ParseError::UnexpectedToken { + line: 1, + column: 1, + expected: "string literal".to_string(), + found: token.text.clone(), + }), + }; + + Ok(Include { path }) + } + + /// Parse a define statement + fn parse_define(&mut self) -> ParseResult { + self.consume(TokenKind::Define, "Expected 'define'")?; + + let name = match self.advance() { + Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => { + match &token.kind { + TokenKind::Identifier(s) => s.clone(), + _ => unreachable!(), + } + } + None => return Err(ParseError::MissingToken { + expected: "identifier after 'define'".to_string() + }), + Some(token) => return Err(ParseError::UnexpectedToken { + line: 1, + column: 1, + expected: "identifier".to_string(), + found: token.text.clone(), + }), + }; + + self.consume(TokenKind::Assign, "Expected '=' after define name")?; + + let value = self.parse_expression()?; + + Ok(Define { name, value }) + } + + /// Parse a table definition + fn parse_table(&mut self) -> Result { + self.consume(TokenKind::Table, "Expected 'table'")?; + + let family = self.parse_family()?; + + let name = self.parse_identifier_or_keyword()?; + + // Set parse context + self.context.in_table = true; + self.context.current_table = Some(name.clone()); + + let mut table = Table::new(family, name); + + self.consume(TokenKind::LeftBrace, "Expected '{' after table name")?; + self.skip_whitespace(); + + while !self.current_token_is(&TokenKind::RightBrace) && !self.is_at_end() { + match self.peek().map(|t| &t.kind) { + Some(TokenKind::Chain) => { + let chain = self.parse_chain()?; + table = table.add_chain(chain); + } + Some(TokenKind::CommentLine(_)) => { + self.advance(); + } + Some(TokenKind::Newline) => { + self.advance(); + } + _ => { return Err(ParseError::SemanticError { + message: format!("Unexpected token in table: {}", + self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")), + }.into()); + } + } + self.skip_whitespace(); + } + + self.consume(TokenKind::RightBrace, "Expected '}' to close table")?; + + // Reset parse context + self.context.in_table = false; + self.context.current_table = None; + + Ok(table) + } + + /// Parse a chain definition + fn parse_chain(&mut self) -> Result { + self.consume(TokenKind::Chain, "Expected 'chain'")?; + + let name = self.parse_identifier_or_keyword()?; + + // Set parse context + self.context.in_chain = true; + self.context.current_chain = Some(name.clone()); + + let mut chain = Chain::new(name); + + self.consume(TokenKind::LeftBrace, "Expected '{' after chain name")?; + self.skip_whitespace(); + + // Parse chain properties and rules + while !self.current_token_is(&TokenKind::RightBrace) && !self.is_at_end() { + match self.peek().map(|t| &t.kind) { + Some(TokenKind::Type) => { + self.advance(); // consume 'type' + let chain_type = self.parse_chain_type()?; + chain = chain.with_type(chain_type); + + if self.current_token_is(&TokenKind::Hook) { + self.advance(); // consume 'hook' + let hook = self.parse_hook()?; + chain = chain.with_hook(hook); + } + + if self.current_token_is(&TokenKind::Priority) { + self.advance(); // consume 'priority' + let priority = self.parse_signed_number()? as i32; + chain = chain.with_priority(priority); + } + + // Check for device specification (for netdev family) + if self.current_token_is(&TokenKind::Identifier("device".to_string())) { + self.advance(); // consume 'device' + let device = self.parse_string_or_identifier()?; + chain = chain.with_device(device); + } + + self.consume(TokenKind::Semicolon, "Expected ';' after chain properties")?; + + if self.current_token_is(&TokenKind::Policy) { + self.advance(); // consume 'policy' + let policy = self.parse_policy()?; + chain = chain.with_policy(policy); + } + + self.consume(TokenKind::Semicolon, "Expected ';' after policy")?; + } + Some(TokenKind::CommentLine(_)) => { + self.advance(); + } + Some(TokenKind::Newline) => { + self.advance(); + } + _ => { + // Parse rule + let rule = self.parse_rule()?; + chain = chain.add_rule(rule); + } + } + self.skip_whitespace(); + } + + self.consume(TokenKind::RightBrace, "Expected '}' to close chain")?; + + // Reset parse context + self.context.in_chain = false; + self.context.current_chain = None; + + Ok(chain) + } + + /// Parse a rule + fn parse_rule(&mut self) -> Result { + // Set parse context + self.context.in_rule = true; + + let mut expressions = Vec::new(); + let mut action = None; + + // Parse expressions and action + while !self.current_token_is(&TokenKind::Newline) + && !self.current_token_is(&TokenKind::RightBrace) + && !self.is_at_end() { + + // Check for actions first + match self.peek().map(|t| &t.kind) { + Some(TokenKind::Accept) => { + self.advance(); + action = Some(Action::Accept); + break; + } + Some(TokenKind::Drop) => { + self.advance(); + action = Some(Action::Drop); + break; + } + Some(TokenKind::Reject) => { + self.advance(); + action = Some(Action::Reject); + break; + } + Some(TokenKind::Return) => { + self.advance(); + action = Some(Action::Return); + break; + } + Some(TokenKind::Jump) => { + self.advance(); + let target = self.parse_identifier()?; + action = Some(Action::Jump(target)); + break; + } + Some(TokenKind::Goto) => { + self.advance(); + let target = self.parse_identifier()?; + action = Some(Action::Goto(target)); + break; + } + Some(TokenKind::Continue) => { + self.advance(); + action = Some(Action::Continue); + break; + } + Some(TokenKind::Log) => { + self.advance(); + // Parse optional log prefix and level + let mut prefix = None; + let mut level = None; + + // Check for "prefix" keyword + if self.current_token_is(&TokenKind::Identifier("prefix".to_string())) { + self.advance(); + if let Some(TokenKind::StringLiteral(p)) = self.peek().map(|t| &t.kind) { + prefix = Some(p.clone()); + self.advance(); + } + } + + // Check for "level" keyword + if self.current_token_is(&TokenKind::Identifier("level".to_string())) { + self.advance(); + if let Some(TokenKind::Identifier(l)) = self.peek().map(|t| &t.kind) { + level = Some(l.clone()); + self.advance(); + } + } + + action = Some(Action::Log { prefix, level }); + break; + } + Some(TokenKind::Comment) => { + self.advance(); + if let Some(TokenKind::StringLiteral(text)) = self.peek().map(|t| &t.kind) { + let text = text.clone(); + self.advance(); + action = Some(Action::Comment(text)); + break; + } else { + return Err(ParseError::InvalidStatement { + message: "Expected string literal after 'comment'".to_string(), + }.into()); + } + } + _ => { + let expr = self.parse_expression()?; + expressions.push(expr); + } + } + } + + let action = action.unwrap_or(Action::Accept); + + // Reset parse context + self.context.in_rule = false; + + let mut rule = Rule::new(expressions, action); + + // Check if we need to assign a handle (for demonstration purposes) + // In a real implementation, this might come from the source or be generated + if self.context.in_chain && rule.expressions.len() > 2 { + // Add handle for complex rules (arbitrary example) + rule = rule.with_handle(self.current.try_into().unwrap()); // Use current position as handle + } + + Ok(rule) + } + + /// Parse an expression with operator precedence + fn parse_expression(&mut self) -> Result { + self.parse_comparison_expression() + } + + /// Parse comparison expressions (==, !=, <, <=, >, >=) + fn parse_comparison_expression(&mut self) -> Result { + let mut expr = self.parse_range_expression()?; + + while let Some(token) = self.peek() { + let operator = match &token.kind { + TokenKind::Eq => BinaryOperator::Eq, + TokenKind::Ne => BinaryOperator::Ne, + TokenKind::Lt => BinaryOperator::Lt, + TokenKind::Le => BinaryOperator::Le, + TokenKind::Gt => BinaryOperator::Gt, + TokenKind::Ge => BinaryOperator::Ge, + _ => break, + }; + + self.advance(); // consume operator + let right = self.parse_range_expression()?; + expr = Expression::Binary { + left: Box::new(expr), + operator, + right: Box::new(right), + }; + } + + Ok(expr) + } + + /// Parse range expressions (e.g., 1-100, 192.168.1.0-192.168.1.255) + fn parse_range_expression(&mut self) -> Result { + let start = self.parse_primary_expression()?; + + // Check for range operator (dash/minus) + if self.current_token_is(&TokenKind::Dash) { + self.advance(); // consume '-' + let end = self.parse_primary_expression()?; + Ok(Expression::Range { + start: Box::new(start), + end: Box::new(end), + }) + } else { + Ok(start) + } + } + + /// Parse primary expressions (literals, identifiers, etc.) + fn parse_primary_expression(&mut self) -> Result { + match self.peek().map(|t| &t.kind) { + // Connection tracking + Some(TokenKind::Ct) => { + self.advance(); // consume 'ct' + let field = self.parse_identifier_or_keyword()?; + let value = Box::new(self.parse_ct_value()?); + Ok(Expression::ConnTrack { field, value }) + } + // Interface matching + Some(TokenKind::Iifname) => { + self.advance(); // consume 'iifname' + let interface_name = self.parse_string_or_identifier()?; + Ok(Expression::Interface { + direction: InterfaceDirection::Input, + name: interface_name, + }) + } + Some(TokenKind::Oifname) => { + self.advance(); // consume 'oifname' + let interface_name = self.parse_string_or_identifier()?; + Ok(Expression::Interface { + direction: InterfaceDirection::Output, + name: interface_name, + }) + } + // Port matching + Some(TokenKind::Sport) => { + self.advance(); // consume 'sport' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Port { + direction: PortDirection::Source, + value, + }) + } + Some(TokenKind::Dport) => { + self.advance(); // consume 'dport' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Port { + direction: PortDirection::Destination, + value, + }) + } + // Address matching + Some(TokenKind::Saddr) => { + self.advance(); // consume 'saddr' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Address { + direction: AddressDirection::Source, + value, + }) + } + Some(TokenKind::Daddr) => { + self.advance(); // consume 'daddr' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Address { + direction: AddressDirection::Destination, + value, + }) + } + // Protocol keywords as expressions + Some(TokenKind::Ip) => { + self.advance(); + // Check if followed by 'protocol', 'saddr', or 'daddr' + if self.current_token_is(&TokenKind::Protocol) { + self.advance(); // consume 'protocol' + let protocol = self.parse_identifier_or_keyword()?; + Ok(Expression::Protocol(protocol)) + } else if self.current_token_is(&TokenKind::Saddr) { + self.advance(); // consume 'saddr' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Address { + direction: AddressDirection::Source, + value, + }) + } else if self.current_token_is(&TokenKind::Daddr) { + self.advance(); // consume 'daddr' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Address { + direction: AddressDirection::Destination, + value, + }) + } else { + Ok(Expression::Identifier("ip".to_string())) + } + } + Some(TokenKind::Ip6) => { + self.advance(); + // Check if followed by 'nexthdr', 'saddr', or 'daddr' + if self.current_token_is(&TokenKind::Nexthdr) { + self.advance(); // consume 'nexthdr' + let protocol = self.parse_identifier_or_keyword()?; + Ok(Expression::Protocol(protocol)) + } else if self.current_token_is(&TokenKind::Saddr) { + self.advance(); // consume 'saddr' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Address { + direction: AddressDirection::Source, + value, + }) + } else if self.current_token_is(&TokenKind::Daddr) { + self.advance(); // consume 'daddr' + let value = Box::new(self.parse_expression()?); + Ok(Expression::Address { + direction: AddressDirection::Destination, + value, + }) + } else { + Ok(Expression::Identifier("ip6".to_string())) + } + } + Some(TokenKind::Tcp) => { + self.advance(); + Ok(Expression::Protocol("tcp".to_string())) + } + Some(TokenKind::Udp) => { + self.advance(); + Ok(Expression::Protocol("udp".to_string())) + } + Some(TokenKind::Icmp) => { + self.advance(); + Ok(Expression::Protocol("icmp".to_string())) + } + Some(TokenKind::Icmpv6) => { + self.advance(); + Ok(Expression::Protocol("icmpv6".to_string())) + } + // Connection states as expressions + Some(TokenKind::Established) => { + self.advance(); + Ok(Expression::Identifier("established".to_string())) + } + Some(TokenKind::Related) => { + self.advance(); + Ok(Expression::Identifier("related".to_string())) + } + Some(TokenKind::Invalid) => { + self.advance(); + Ok(Expression::Identifier("invalid".to_string())) + } + Some(TokenKind::New) => { + self.advance(); + Ok(Expression::Identifier("new".to_string())) + } + Some(TokenKind::Identifier(name)) => { + let name = name.clone(); + self.advance(); + + // Check for special identifier patterns + match name.as_str() { + "ip" | "ip6" | "tcp" | "udp" | "icmp" | "icmpv6" => { + if self.current_token_is(&TokenKind::Protocol) { + self.advance(); // consume 'protocol' + let protocol = self.parse_identifier()?; + Ok(Expression::Protocol(protocol)) + } else { + Ok(Expression::Identifier(name)) + } + } + _ => Ok(Expression::Identifier(name)), + } + } + Some(TokenKind::StringLiteral(text)) => { + let text = text.clone(); + self.advance(); + Ok(Expression::String(text)) + } + Some(TokenKind::NumberLiteral(_)) => { + let num = self.parse_number()?; + // Check for rate expression (e.g., 10/minute) + if self.current_token_is(&TokenKind::Slash) { + self.advance(); // consume '/' + if let Some(token) = self.peek() { + if matches!(token.kind, TokenKind::Identifier(_)) { + let unit = self.advance().unwrap().text.clone(); + Ok(Expression::String(format!("{}/{}", num, unit))) + } else { + Err(ParseError::InvalidExpression { + message: "Expected identifier after '/' in rate expression".to_string(), + }.into()) + } + } else { + Err(ParseError::InvalidExpression { + message: "Expected identifier after '/' in rate expression".to_string(), + }.into()) + } + } else { + Ok(Expression::Number(num)) + } + } + Some(TokenKind::IpAddress(_)) => { + let addr = self.advance().unwrap().text.clone(); + // Check for CIDR notation (e.g., 192.168.1.0/24) + if self.current_token_is(&TokenKind::Slash) { + self.advance(); // consume '/' + if let Some(token) = self.peek() { + if matches!(token.kind, TokenKind::NumberLiteral(_)) { + let prefix_len = self.advance().unwrap().text.clone(); + Ok(Expression::IpAddress(format!("{}/{}", addr, prefix_len))) + } else { + Err(ParseError::InvalidExpression { + message: "Expected number after '/' in CIDR notation".to_string(), + }.into()) + } + } else { + Err(ParseError::InvalidExpression { + message: "Expected number after '/' in CIDR notation".to_string(), + }.into()) + } + // Check for port specification (e.g., 192.168.1.100:80) + } else if self.current_token_is(&TokenKind::Colon) { + self.advance(); // consume ':' + if let Some(token) = self.peek() { + if matches!(token.kind, TokenKind::NumberLiteral(_)) { + let port = self.advance().unwrap().text.clone(); + Ok(Expression::String(format!("{}:{}", addr, port))) + } else { + Err(ParseError::InvalidExpression { + message: "Expected number after ':' in address:port specification".to_string(), + }.into()) + } + } else { + Err(ParseError::InvalidExpression { + message: "Expected number after ':' in address:port specification".to_string(), + }.into()) + } + } else { + Ok(Expression::IpAddress(addr)) + } + } + Some(TokenKind::Ipv6Address(_)) => { + let addr = self.advance().unwrap().text.clone(); + Ok(Expression::Ipv6Address(addr)) + } + Some(TokenKind::MacAddress(_)) => { + let addr = self.advance().unwrap().text.clone(); + Ok(Expression::MacAddress(addr)) + } + Some(TokenKind::LeftBrace) => { + self.advance(); // consume '{' + let mut elements = Vec::new(); + + while !self.current_token_is(&TokenKind::RightBrace) && !self.is_at_end() { + if self.current_token_is(&TokenKind::Comma) { + self.advance(); + continue; + } + if self.current_token_is(&TokenKind::Newline) { + self.advance(); + continue; + } + + let element = self.parse_expression()?; + elements.push(element); + } + + self.consume(TokenKind::RightBrace, "Expected '}' to close set")?; + Ok(Expression::Set(elements)) + } + _ => { + Err(ParseError::InvalidExpression { + message: format!("Unexpected token in expression: {}", + self.peek().map(|t| t.text.as_str()).unwrap_or("EOF")), + }.into()) + } + } + } + + /// Parse connection tracking values (handles comma-separated lists) + fn parse_ct_value(&mut self) -> Result { + let mut values = Vec::new(); + + loop { + let value = match self.peek().map(|t| &t.kind) { + Some(TokenKind::Established) => { + self.advance(); + Expression::Identifier("established".to_string()) + } + Some(TokenKind::Related) => { + self.advance(); + Expression::Identifier("related".to_string()) + } + Some(TokenKind::Invalid) => { + self.advance(); + Expression::Identifier("invalid".to_string()) + } + Some(TokenKind::New) => { + self.advance(); + Expression::Identifier("new".to_string()) + } + Some(TokenKind::Identifier(name)) => { + let name = name.clone(); + self.advance(); + Expression::Identifier(name) + } + _ => return self.parse_expression(), + }; + + values.push(value); + + // Check if there's a comma for more values + if self.current_token_is(&TokenKind::Comma) { + self.advance(); // consume comma + continue; + } else { + break; + } + } + + if values.len() == 1 { + Ok(values.into_iter().next().unwrap()) + } else { + Ok(Expression::Set(values)) + } + } + + /// Helper methods + fn parse_family(&mut self) -> ParseResult { + match self.advance() { + Some(token) => match token.kind { + TokenKind::Ip => Ok(Family::Ip), + TokenKind::Ip6 => Ok(Family::Ip6), + TokenKind::Inet => Ok(Family::Inet), + TokenKind::Arp => Ok(Family::Arp), + TokenKind::Bridge => Ok(Family::Bridge), + TokenKind::Netdev => Ok(Family::Netdev), + _ => { + let token_clone = token.clone(); + Err(self.unexpected_token_error_with_token( + "family (ip, ip6, inet, arp, bridge, netdev)".to_string(), + &token_clone + )) + } + }, + None => Err(ParseError::MissingToken { + expected: "family".to_string() + }), + } + } + + fn parse_chain_type(&mut self) -> ParseResult { + match self.advance() { + Some(token) => match token.kind { + TokenKind::Filter => Ok(ChainType::Filter), + TokenKind::Nat => Ok(ChainType::Nat), + TokenKind::Route => Ok(ChainType::Route), + _ => { + let token_clone = token.clone(); + Err(self.unexpected_token_error_with_token( + "chain type (filter, nat, route)".to_string(), + &token_clone + )) + } + }, + None => Err(ParseError::MissingToken { + expected: "chain type".to_string() + }), + } + } + + fn parse_hook(&mut self) -> ParseResult { + match self.advance() { + Some(token) => match token.kind { + TokenKind::Input => Ok(Hook::Input), + TokenKind::Output => Ok(Hook::Output), + TokenKind::Forward => Ok(Hook::Forward), + TokenKind::Prerouting => Ok(Hook::Prerouting), + TokenKind::Postrouting => Ok(Hook::Postrouting), + _ => { + let token_clone = token.clone(); + Err(self.unexpected_token_error_with_token( + "hook (input, output, forward, prerouting, postrouting)".to_string(), + &token_clone + )) + } + }, + None => Err(ParseError::MissingToken { + expected: "hook".to_string() + }), + } + } + + fn parse_policy(&mut self) -> Result { + match self.advance() { + Some(token) => match token.kind { + TokenKind::Accept => Ok(Policy::Accept), + TokenKind::Drop => Ok(Policy::Drop), + _ => Err(anyhow!("Expected policy (accept, drop), got: {}", token.text)), + }, + None => Err(anyhow!("Expected policy")), + } + } + + fn parse_identifier(&mut self) -> Result { + match self.advance() { + Some(token) if matches!(token.kind, TokenKind::Identifier(_)) => { + match &token.kind { + TokenKind::Identifier(s) => Ok(s.clone()), + _ => unreachable!(), + } + } + Some(token) => Err(anyhow!("Expected identifier, got: {}", token.text)), + None => Err(anyhow!("Expected identifier")), + } + } + + fn parse_identifier_or_keyword(&mut self) -> Result { + match self.advance() { + Some(token) => { + // Accept identifiers or keywords that can be used as identifiers + match &token.kind { + TokenKind::Identifier(s) => Ok(s.clone()), + // Allow keywords to be used as identifiers in certain contexts + TokenKind::Filter | TokenKind::Nat | TokenKind::Route | + TokenKind::Input | TokenKind::Output | TokenKind::Forward | + TokenKind::Prerouting | TokenKind::Postrouting | + TokenKind::Accept | TokenKind::Drop | TokenKind::Reject | + TokenKind::State | TokenKind::Ct | TokenKind::Type | TokenKind::Hook | + TokenKind::Priority | TokenKind::Policy | + TokenKind::Tcp | TokenKind::Udp | TokenKind::Icmp | TokenKind::Icmpv6 | + TokenKind::Ip | TokenKind::Ip6 => { + Ok(token.text.clone()) + } + _ => Err(anyhow!("Expected identifier or keyword, got: {}", token.text)), + } + } + None => Err(anyhow!("Expected identifier or keyword")), + } + } + + fn parse_string_or_identifier(&mut self) -> Result { + match self.advance() { + Some(token) if matches!(token.kind, TokenKind::Identifier(_) | TokenKind::StringLiteral(_)) => { + match &token.kind { + TokenKind::Identifier(s) => Ok(s.clone()), + TokenKind::StringLiteral(s) => Ok(s.clone()), + _ => unreachable!(), + } + } + Some(token) => Err(anyhow!("Expected string or identifier, got: {}", token.text)), + None => Err(anyhow!("Expected string or identifier")), + } + } + + fn parse_number(&mut self) -> Result { + match self.advance() { + Some(token) if matches!(token.kind, TokenKind::NumberLiteral(_)) => { + match &token.kind { + TokenKind::NumberLiteral(n) => Ok(*n), + _ => unreachable!(), + } + } + Some(token) => Err(anyhow!("Expected number, got: {}", token.text)), + None => Err(anyhow!("Expected number")), + } + } + + fn parse_signed_number(&mut self) -> Result { + // Check if we have a dash (negative sign) + if self.current_token_is(&TokenKind::Dash) { + self.advance(); // consume the dash + match self.advance() { + Some(token) if matches!(token.kind, TokenKind::NumberLiteral(_)) => { + match &token.kind { + TokenKind::NumberLiteral(n) => Ok(-(*n as i64)), + _ => unreachable!(), + } + } + Some(token) => Err(anyhow!("Expected number after '-', got: {}", token.text)), + None => Err(anyhow!("Expected number after '-'")), + } + } else { + // No dash, parse as positive number + match self.advance() { + Some(token) if matches!(token.kind, TokenKind::NumberLiteral(_)) => { + match &token.kind { + TokenKind::NumberLiteral(n) => Ok(*n as i64), + _ => unreachable!(), + } + } + Some(token) => Err(anyhow!("Expected number, got: {}", token.text)), + None => Err(anyhow!("Expected number")), + } + } + } + + // Navigation and utility methods + fn advance(&mut self) -> Option<&Token> { + if !self.is_at_end() { + self.current += 1; + } + self.previous() + } + + fn is_at_end(&self) -> bool { + self.current >= self.tokens.len() + } + + fn peek(&self) -> Option<&Token> { + self.tokens.get(self.current) + } + + fn previous(&self) -> Option<&Token> { + if self.current > 0 { + self.tokens.get(self.current - 1) + } else { + None + } + } + + fn current_token_is(&self, kind: &TokenKind) -> bool { + match self.peek() { + Some(token) => std::mem::discriminant(&token.kind) == std::mem::discriminant(kind), + None => false, + } + } + + fn consume(&mut self, expected: TokenKind, message: &str) -> Result<()> { + if self.current_token_is(&expected) { + self.advance(); + Ok(()) + } else { + Err(anyhow!("{}", message)) + } + } + + fn skip_whitespace(&mut self) { + while let Some(token) = self.peek() { + match token.kind { + TokenKind::Newline => { + self.advance(); + } + _ => break, + } + } + } + + fn unexpected_token_error_with_token( + &self, + expected: String, + token: &Token, + ) -> ParseError { + ParseError::UnexpectedToken { + line: 1, // TODO: Calculate line from range + column: token.range.start().into(), + expected, + found: token.text.clone(), + } + } +} diff --git a/src/syntax.rs b/src/syntax.rs new file mode 100644 index 0000000..edcf135 --- /dev/null +++ b/src/syntax.rs @@ -0,0 +1,316 @@ +use crate::ast::*; +use std::fmt::Write; + +/// Configuration for formatting output +#[derive(Debug, Clone)] +pub struct FormatConfig { + pub indent_style: IndentStyle, + pub spaces_per_level: usize, + pub optimize: bool, + pub max_empty_lines: usize, +} + +impl Default for FormatConfig { + fn default() -> Self { + Self { + indent_style: IndentStyle::Tabs, + spaces_per_level: 2, + optimize: false, + max_empty_lines: 1, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IndentStyle { + Tabs, + Spaces, +} + +impl IndentStyle { + pub fn format(&self, level: usize, spaces_per_level: usize) -> String { + match self { + IndentStyle::Tabs => "\t".repeat(level), + IndentStyle::Spaces => " ".repeat(spaces_per_level * level), + } + } +} + +/// Formatter for nftables AST +pub struct NftablesFormatter { + config: FormatConfig, +} + +impl NftablesFormatter { + pub fn new(config: FormatConfig) -> Self { + Self { config } + } + + /// Add appropriate number of empty lines based on configuration + fn add_separator(&self, output: &mut String) { + if self.config.optimize { + output.push('\n'); + } else { + // Add newlines based on max_empty_lines setting + for _ in 0..=self.config.max_empty_lines { + output.push('\n'); + } + } + } + + pub fn format_ruleset(&self, ruleset: &Ruleset) -> String { + let mut output = String::new(); + + // Format shebang + if let Some(shebang) = &ruleset.shebang { + writeln!(output, "#!{}", shebang).unwrap(); + } + + // Format includes + for include in &ruleset.includes { + self.format_include(&mut output, include, 0); + } + + // Add separator if we have includes + if !ruleset.includes.is_empty() { + self.add_separator(&mut output); + } + + // Format defines + for define in &ruleset.defines { + self.format_define(&mut output, define, 0); + } + + // Add separator if we have defines + if !ruleset.defines.is_empty() { + self.add_separator(&mut output); + } + + // Format tables + let mut table_iter = ruleset.tables.values().peekable(); + while let Some(table) = table_iter.next() { + self.format_table(&mut output, table, 0); // Add separator between tables + if table_iter.peek().is_some() { + self.add_separator(&mut output); + } + } + + // Ensure file ends with newline + if !output.ends_with('\n') { + output.push('\n'); + } + + output + } + + fn format_include(&self, output: &mut String, include: &Include, level: usize) { + let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + writeln!(output, "{}include \"{}\"", indent, include.path).unwrap(); + } + + fn format_define(&self, output: &mut String, define: &Define, level: usize) { + let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + write!(output, "{}define {} = ", indent, define.name).unwrap(); + self.format_expression(output, &define.value); + output.push('\n'); + } + + fn format_table(&self, output: &mut String, table: &Table, level: usize) { + let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + + writeln!(output, "{}table {} {} {{", indent, table.family, table.name).unwrap(); + + // Format chains + let mut chain_iter = table.chains.values().peekable(); + while let Some(chain) = chain_iter.next() { + self.format_chain(output, chain, level + 1); // Add separator between chains + if chain_iter.peek().is_some() { + self.add_separator(output); + } + } + + writeln!(output, "{}}}", indent).unwrap(); + } + + fn format_chain(&self, output: &mut String, chain: &Chain, level: usize) { + let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + + writeln!(output, "{}chain {} {{", indent, chain.name).unwrap(); + + // Format chain properties + if let Some(chain_type) = &chain.chain_type { + write!(output, "{}type {}", + self.config.indent_style.format(level + 1, self.config.spaces_per_level), + chain_type).unwrap(); + + if let Some(hook) = &chain.hook { + write!(output, " hook {}", hook).unwrap(); + } + + if let Some(priority) = chain.priority { + write!(output, " priority {}", priority).unwrap(); + } + + // Add semicolon after type/hook/priority + output.push_str(";"); + + // Add policy on the same line if present + if let Some(policy) = &chain.policy { + write!(output, " policy {}", policy).unwrap(); + } + + output.push_str(";\n"); + + if !chain.rules.is_empty() && !self.config.optimize { + output.push('\n'); + } + } + + // Format rules + for (i, rule) in chain.rules.iter().enumerate() { + // Add spacing between rules (but not before first rule) + if i > 0 && !self.config.optimize && self.config.max_empty_lines > 0 { + output.push('\n'); + } + self.format_rule(output, rule, level + 1); + } + + writeln!(output, "{}}}", indent).unwrap(); + } + + fn format_rule(&self, output: &mut String, rule: &Rule, level: usize) { + let indent = self.config.indent_style.format(level, self.config.spaces_per_level); + + write!(output, "{}", indent).unwrap(); + + // Format expressions + for (i, expr) in rule.expressions.iter().enumerate() { + if i > 0 { + output.push(' '); + } + self.format_expression(output, expr); + } + + // Add action + if !rule.expressions.is_empty() { + output.push(' '); + } + write!(output, "{}", rule.action).unwrap(); + + output.push('\n'); + + // Only add extra newline between rules, not after the last rule + // We'll handle this in the chain formatting instead + } + + fn format_expression(&self, output: &mut String, expr: &Expression) { + match expr { + Expression::Identifier(name) => write!(output, "{}", name).unwrap(), + Expression::String(s) => write!(output, "\"{}\"", s).unwrap(), + Expression::Number(n) => write!(output, "{}", n).unwrap(), + Expression::IpAddress(addr) => write!(output, "{}", addr).unwrap(), + Expression::Ipv6Address(addr) => write!(output, "{}", addr).unwrap(), + Expression::MacAddress(addr) => write!(output, "{}", addr).unwrap(), + + Expression::Binary { left, operator, right } => { + self.format_expression(output, left); + write!(output, " {} ", operator).unwrap(); + self.format_expression(output, right); + } + + Expression::Protocol(proto) => write!(output, "protocol {}", proto).unwrap(), + + Expression::Port { direction, value } => { + match direction { + PortDirection::Source => write!(output, "sport ").unwrap(), + PortDirection::Destination => write!(output, "dport ").unwrap(), + } + self.format_expression(output, value); + } + + Expression::Address { direction, value } => { + // Include the protocol family when formatting addresses + write!(output, "ip ").unwrap(); + match direction { + AddressDirection::Source => write!(output, "saddr ").unwrap(), + AddressDirection::Destination => write!(output, "daddr ").unwrap(), + } + self.format_expression(output, value); + } + + Expression::Interface { direction, name } => { + match direction { + InterfaceDirection::Input => write!(output, "iifname ").unwrap(), + InterfaceDirection::Output => write!(output, "oifname ").unwrap(), + } + write!(output, "{}", name).unwrap(); + } + + Expression::ConnTrack { field, value } => { + write!(output, "ct {} ", field).unwrap(); + self.format_expression(output, value); + } + + Expression::Set(elements) => { + output.push_str("{ "); + for (i, element) in elements.iter().enumerate() { + if i > 0 { + output.push_str(", "); + } + self.format_expression(output, element); + } + output.push_str(" }"); + } + + Expression::Range { start, end } => { + self.format_expression(output, start); + output.push('-'); + self.format_expression(output, end); + } + } + } +} + +/// Convert from string-based IndentStyle to our enum +impl std::str::FromStr for IndentStyle { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "tabs" | "tab" => Ok(IndentStyle::Tabs), + "spaces" | "space" => Ok(IndentStyle::Spaces), + _ => Err(format!("Invalid indent style: {}. Use 'tabs' or 'spaces'", s)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_simple_table() { + let table = Table::new(Family::Inet, "test".to_string()) + .add_chain( + Chain::new("input".to_string()) + .with_type(ChainType::Filter) + .with_hook(Hook::Input) + .with_priority(0) + .with_policy(Policy::Accept) + .add_rule(Rule::new( + vec![Expression::Interface { + direction: InterfaceDirection::Input, + name: "lo".to_string(), + }], + Action::Accept, + )) + ); + + let formatter = NftablesFormatter::new(FormatConfig::default()); + let mut output = String::new(); + formatter.format_table(&mut output, &table, 0); + + // Just verify it doesn't panic and produces some output + assert!(!output.is_empty()); + } +}