From f7f77a7e19b73905b27e19b4e5efb5406cf91e7d Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Mon, 2 Jun 2025 12:32:33 +0300 Subject: [PATCH 1/5] cst: expand SyntaxKind enum to include additional nftables syntax variants --- src/cst.rs | 115 +++++++++++++- src/main.rs | 448 ++++++++++++++++++++++------------------------------ 2 files changed, 304 insertions(+), 259 deletions(-) diff --git a/src/cst.rs b/src/cst.rs index 6feaa60..412b379 100644 --- a/src/cst.rs +++ b/src/cst.rs @@ -327,7 +327,120 @@ impl SyntaxKind { match raw.0 { 0 => SyntaxKind::Root, 1 => SyntaxKind::Table, - // ... other variants ... + 2 => SyntaxKind::Chain, + 3 => SyntaxKind::Rule, + 4 => SyntaxKind::Set, + 5 => SyntaxKind::Map, + 6 => SyntaxKind::Element, + 7 => SyntaxKind::Expression, + 8 => SyntaxKind::BinaryExpr, + 9 => SyntaxKind::UnaryExpr, + 10 => SyntaxKind::CallExpr, + 11 => SyntaxKind::SetExpr, + 12 => SyntaxKind::RangeExpr, + 13 => SyntaxKind::Statement, + 14 => SyntaxKind::IncludeStmt, + 15 => SyntaxKind::DefineStmt, + 16 => SyntaxKind::FlushStmt, + 17 => SyntaxKind::AddStmt, + 18 => SyntaxKind::DeleteStmt, + 19 => SyntaxKind::Identifier, + 20 => SyntaxKind::StringLiteral, + 21 => SyntaxKind::NumberLiteral, + 22 => SyntaxKind::IpAddress, + 23 => SyntaxKind::Ipv6Address, + 24 => SyntaxKind::MacAddress, + 25 => SyntaxKind::TableKw, + 26 => SyntaxKind::ChainKw, + 27 => SyntaxKind::RuleKw, + 28 => SyntaxKind::SetKw, + 29 => SyntaxKind::MapKw, + 30 => SyntaxKind::ElementKw, + 31 => SyntaxKind::IncludeKw, + 32 => SyntaxKind::DefineKw, + 33 => SyntaxKind::FlushKw, + 34 => SyntaxKind::AddKw, + 35 => SyntaxKind::DeleteKw, + 36 => SyntaxKind::InsertKw, + 37 => SyntaxKind::ReplaceKw, + 38 => SyntaxKind::FilterKw, + 39 => SyntaxKind::NatKw, + 40 => SyntaxKind::RouteKw, + 41 => SyntaxKind::InputKw, + 42 => SyntaxKind::OutputKw, + 43 => SyntaxKind::ForwardKw, + 44 => SyntaxKind::PreroutingKw, + 45 => SyntaxKind::PostroutingKw, + 46 => SyntaxKind::IpKw, + 47 => SyntaxKind::Ip6Kw, + 48 => SyntaxKind::InetKw, + 49 => SyntaxKind::ArpKw, + 50 => SyntaxKind::BridgeKw, + 51 => SyntaxKind::NetdevKw, + 52 => SyntaxKind::TcpKw, + 53 => SyntaxKind::UdpKw, + 54 => SyntaxKind::IcmpKw, + 55 => SyntaxKind::Icmpv6Kw, + 56 => SyntaxKind::SportKw, + 57 => SyntaxKind::DportKw, + 58 => SyntaxKind::SaddrKw, + 59 => SyntaxKind::DaddrKw, + 60 => SyntaxKind::ProtocolKw, + 61 => SyntaxKind::NexthdrKw, + 62 => SyntaxKind::TypeKw, + 63 => SyntaxKind::HookKw, + 64 => SyntaxKind::PriorityKw, + 65 => SyntaxKind::PolicyKw, + 66 => SyntaxKind::IifnameKw, + 67 => SyntaxKind::OifnameKw, + 68 => SyntaxKind::CtKw, + 69 => SyntaxKind::StateKw, + 70 => SyntaxKind::AcceptKw, + 71 => SyntaxKind::DropKw, + 72 => SyntaxKind::RejectKw, + 73 => SyntaxKind::ReturnKw, + 74 => SyntaxKind::JumpKw, + 75 => SyntaxKind::GotoKw, + 76 => SyntaxKind::ContinueKw, + 77 => SyntaxKind::LogKw, + 78 => SyntaxKind::CommentKw, + 79 => SyntaxKind::EstablishedKw, + 80 => SyntaxKind::RelatedKw, + 81 => SyntaxKind::NewKw, + 82 => SyntaxKind::InvalidKw, + 83 => SyntaxKind::EqOp, + 84 => SyntaxKind::NeOp, + 85 => SyntaxKind::LeOp, + 86 => SyntaxKind::GeOp, + 87 => SyntaxKind::LtOp, + 88 => SyntaxKind::GtOp, + 89 => SyntaxKind::LeftBrace, + 90 => SyntaxKind::RightBrace, + 91 => SyntaxKind::LeftParen, + 92 => SyntaxKind::RightParen, + 93 => SyntaxKind::LeftBracket, + 94 => SyntaxKind::RightBracket, + 95 => SyntaxKind::Comma, + 96 => SyntaxKind::Semicolon, + 97 => SyntaxKind::Colon, + 98 => SyntaxKind::Assign, + 99 => SyntaxKind::Dash, + 100 => SyntaxKind::Slash, + 101 => SyntaxKind::Dot, + 102 => SyntaxKind::Whitespace, + 103 => SyntaxKind::Newline, + 104 => SyntaxKind::Comment, + 105 => SyntaxKind::Shebang, + 106 => SyntaxKind::Error, + 107 => SyntaxKind::VmapKw, + 108 => SyntaxKind::NdRouterAdvertKw, + 109 => SyntaxKind::NdNeighborSolicitKw, + 110 => SyntaxKind::NdNeighborAdvertKw, + 111 => SyntaxKind::EchoRequestKw, + 112 => SyntaxKind::DestUnreachableKw, + 113 => SyntaxKind::RouterAdvertisementKw, + 114 => SyntaxKind::TimeExceededKw, + 115 => SyntaxKind::ParameterProblemKw, 116 => SyntaxKind::PacketTooBigKw, _ => SyntaxKind::Error, // Fallback to Error for invalid values } diff --git a/src/main.rs b/src/main.rs index 268c3d8..87a35f8 100755 --- a/src/main.rs +++ b/src/main.rs @@ -34,20 +34,6 @@ enum FormatterError { message: String, suggestion: Option, }, - #[error("Unsupported nftables syntax at line {line}, column {column}: {feature}")] - UnsupportedSyntax { - line: usize, - column: usize, - feature: String, - suggestion: Option, - }, - #[error("Invalid nftables syntax at line {line}, column {column}: {message}")] - InvalidSyntax { - line: usize, - column: usize, - message: String, - suggestion: Option, - }, #[error("IO error: {0}")] Io(#[from] io::Error), } @@ -289,7 +275,7 @@ fn process_single_file_format( let mut parser = NftablesParser::new(tokens.clone()); parser .parse() - .map_err(|e| analyze_parse_error(&source, &tokens, &e.to_string()))? + .map_err(|e| convert_parse_error_to_formatter_error(&e, &source, &tokens))? }; if debug { @@ -475,77 +461,203 @@ fn process_single_file_lint( Ok(()) } -/// Intelligent error analysis to categorize parse errors and provide location information -fn analyze_parse_error(source: &str, tokens: &[Token], error: &str) -> FormatterError { - // Convert line/column position from token ranges - let lines: Vec<&str> = source.lines().collect(); +/// Convert parser errors to formatter errors with proper location information +fn convert_parse_error_to_formatter_error( + error: &crate::parser::ParseError, + source: &str, + tokens: &[Token], +) -> FormatterError { + use crate::parser::ParseError; - // Look for common error patterns and provide specific messages - if error.contains("unexpected token") || error.contains("expected") { - // Try to find the problematic token - if let Some(error_token) = find_error_token(tokens) { - let (line, column) = position_from_range(&error_token.range, source); - - // Analyze the specific token to categorize the error - match categorize_syntax_error(&error_token, source, &lines) { - ErrorCategory::UnsupportedSyntax { - feature, - suggestion, - } => FormatterError::UnsupportedSyntax { - line, - column, - feature, - suggestion, - }, - ErrorCategory::InvalidSyntax { - message, - suggestion, - } => FormatterError::InvalidSyntax { - line, - column, - message, - suggestion, - }, - ErrorCategory::SyntaxError { - message, - suggestion, - } => FormatterError::SyntaxError { - line, - column, - message, - suggestion, - }, + match error { + ParseError::UnexpectedToken { + line, + column, + expected, + found, + } => FormatterError::SyntaxError { + line: *line, + column: *column, + message: format!("Expected {}, found '{}'", expected, found), + suggestion: None, + }, + ParseError::MissingToken { expected } => { + // Try to find current position from last token + let (line, column) = if let Some(last_token) = tokens.last() { + position_from_range(&last_token.range, source) + } else { + (1, 1) + }; + FormatterError::SyntaxError { + line, + column, + message: format!("Missing token: expected {}", expected), + suggestion: None, + } + } + ParseError::InvalidExpression { message } => { + // Try to find the current token position + let (line, column) = find_current_parse_position(tokens, source); + FormatterError::SyntaxError { + line, + column, + message: format!("Invalid expression: {}", message), + suggestion: None, + } + } + ParseError::InvalidStatement { message } => { + let (line, column) = find_current_parse_position(tokens, source); + FormatterError::SyntaxError { + line, + column, + message: format!("Invalid statement: {}", message), + suggestion: None, + } + } + ParseError::SemanticError { message } => { + let (line, column) = find_current_parse_position(tokens, source); + FormatterError::SyntaxError { + line, + column, + message: format!("Semantic error: {}", message), + suggestion: None, + } + } + ParseError::LexError(lex_error) => { + // Convert lexical errors to formatter errors with location + convert_lex_error_to_formatter_error(lex_error, source) + } + ParseError::AnyhowError(anyhow_error) => { + // For anyhow errors, try to extract location from error message and context + let error_msg = anyhow_error.to_string(); + let (line, column) = find_error_location_from_context(&error_msg, tokens, source); + let suggestion = generate_suggestion_for_error(&error_msg); + + FormatterError::SyntaxError { + line, + column, + message: error_msg, + suggestion, } - } else { - // Fallback to generic parse error - FormatterError::ParseError(error.to_string()) } - } else { - FormatterError::ParseError(error.to_string()) } } -#[derive(Debug)] -enum ErrorCategory { - UnsupportedSyntax { - feature: String, - suggestion: Option, - }, - InvalidSyntax { - message: String, - suggestion: Option, - }, - SyntaxError { - message: String, - suggestion: Option, - }, +/// Find the current parsing position from tokens +fn find_current_parse_position(tokens: &[Token], source: &str) -> (usize, usize) { + // Look for the last non-whitespace, non-comment token + for token in tokens.iter().rev() { + match token.kind { + TokenKind::Newline | TokenKind::CommentLine(_) => continue, + _ => return position_from_range(&token.range, source), + } + } + (1, 1) // fallback } -/// Find the first error token in the token stream -fn find_error_token(tokens: &[Token]) -> Option<&Token> { - tokens - .iter() - .find(|token| matches!(token.kind, TokenKind::Error)) +/// Convert lexical errors to formatter errors +fn convert_lex_error_to_formatter_error( + lex_error: &crate::lexer::LexError, + source: &str, +) -> FormatterError { + use crate::lexer::LexError; + + match lex_error { + LexError::InvalidToken { position, text } => { + let (line, column) = offset_to_line_column(*position, source); + FormatterError::SyntaxError { + line, + column, + message: format!("Invalid token: '{}'", text), + suggestion: None, + } + } + LexError::UnterminatedString { position } => { + let (line, column) = offset_to_line_column(*position, source); + FormatterError::SyntaxError { + line, + column, + message: "Unterminated string literal".to_string(), + suggestion: Some("Add closing quote".to_string()), + } + } + LexError::InvalidNumber { position, text } => { + let (line, column) = offset_to_line_column(*position, source); + FormatterError::SyntaxError { + line, + column, + message: format!("Invalid number: '{}'", text), + suggestion: Some("Check number format".to_string()), + } + } + } +} + +/// Convert byte offset to line/column position +fn offset_to_line_column(offset: usize, source: &str) -> (usize, usize) { + let mut line = 1; + let mut column = 1; + + for (i, ch) in source.char_indices() { + if i >= offset { + break; + } + if ch == '\n' { + line += 1; + column = 1; + } else { + column += 1; + } + } + + (line, column) +} + +/// Find error location from context clues in the error message +fn find_error_location_from_context( + error_msg: &str, + tokens: &[Token], + source: &str, +) -> (usize, usize) { + // Look for context clues in the error message + if error_msg.contains("Expected string or identifier, got:") { + // Find the problematic token mentioned in the error + if let Some(bad_token_text) = extract_token_from_error_message(error_msg) { + // Find this token in the token stream + for token in tokens { + if token.text == bad_token_text { + return position_from_range(&token.range, source); + } + } + } + } + + // Fallback to finding last meaningful token + find_current_parse_position(tokens, source) +} + +/// Extract the problematic token from error message +fn extract_token_from_error_message(error_msg: &str) -> Option { + // Parse messages like "Expected string or identifier, got: {" + if let Some(got_part) = error_msg.split("got: ").nth(1) { + Some(got_part.trim().to_string()) + } else { + None + } +} + +/// Generate helpful suggestions based on error message +fn generate_suggestion_for_error(error_msg: &str) -> Option { + if error_msg.contains("Expected string or identifier") { + Some( + "Check if you're missing quotes around a string value or have an unexpected character" + .to_string(), + ) + } else if error_msg.contains("Expected") && error_msg.contains("got:") { + Some("Check syntax and ensure proper nftables structure".to_string()) + } else { + None + } } /// Convert TextRange to line/column position @@ -566,186 +678,6 @@ fn position_from_range(range: &text_size::TextRange, source: &str) -> (usize, us (1, 1) // fallback } -/// Categorize syntax errors based on token content and context -fn categorize_syntax_error(token: &Token, source: &str, lines: &[&str]) -> ErrorCategory { - let token_text = &token.text; - let (line_num, _) = position_from_range(&token.range, source); - let line_content = lines.get(line_num.saturating_sub(1)).unwrap_or(&""); - - // Check for unsupported nftables features - if is_unsupported_feature(token_text, line_content) { - let (feature, suggestion) = classify_unsupported_feature(token_text, line_content); - return ErrorCategory::UnsupportedSyntax { - feature, - suggestion, - }; - } - - // Check for invalid but supported syntax - if is_invalid_syntax(token_text, line_content) { - let (message, suggestion) = classify_invalid_syntax(token_text, line_content); - return ErrorCategory::InvalidSyntax { - message, - suggestion, - }; - } - - // Default to syntax error - ErrorCategory::SyntaxError { - message: format!("Unexpected token '{}'", token_text), - suggestion: suggest_correction(token_text, line_content), - } -} - -/// Check if the token represents an unsupported nftables feature -fn is_unsupported_feature(token_text: &str, line_content: &str) -> bool { - // List of advanced nftables features that might not be fully supported yet - let unsupported_keywords = [ - "quota", "limit", "counter", "meter", "socket", "fib", "rt", "ipsec", "tunnel", "comp", - "dccp", "sctp", "gre", "esp", "ah", "vlan", "arp", "rateest", "osf", "netdev", "meta", - "exthdr", "payload", "lookup", "dynset", "flow", "hash", "jhash", "symhash", "crc32", - ]; - - unsupported_keywords - .iter() - .any(|&keyword| token_text.contains(keyword) || line_content.contains(keyword)) -} - -/// Check if the syntax is invalid (malformed but within supported features) -fn is_invalid_syntax(token_text: &str, line_content: &str) -> bool { - // Check for common syntax mistakes - if token_text.contains("..") || token_text.contains("::") { - return true; // Double operators usually indicate mistakes - } - - // Check for malformed addresses or ranges - if token_text.contains("/") && !is_valid_cidr(token_text) { - return true; - } - - // Check for malformed brackets/braces - let open_braces = line_content.matches('{').count(); - let close_braces = line_content.matches('}').count(); - if open_braces != close_braces { - return true; - } - - false -} - -/// Classify unsupported feature and provide suggestion -fn classify_unsupported_feature(token_text: &str, line_content: &str) -> (String, Option) { - let feature = if token_text.contains("quota") { - ( - "quota management".to_string(), - Some("Use explicit rule counting instead".to_string()), - ) - } else if token_text.contains("limit") { - ( - "rate limiting".to_string(), - Some("Consider using simpler rule-based rate limiting".to_string()), - ) - } else if token_text.contains("counter") { - ( - "packet counters".to_string(), - Some("Use rule-level statistics instead".to_string()), - ) - } else if line_content.contains("meta") { - ( - "meta expressions".to_string(), - Some("Use explicit protocol matching instead".to_string()), - ) - } else { - (format!("advanced feature '{}'", token_text), None) - }; - - feature -} - -/// Classify invalid syntax and provide suggestion -fn classify_invalid_syntax(token_text: &str, line_content: &str) -> (String, Option) { - if token_text.contains("/") && !is_valid_cidr(token_text) { - return ( - "Invalid CIDR notation".to_string(), - Some("Use format like '192.168.1.0/24' or '::1/128'".to_string()), - ); - } - - if token_text.contains("..") { - return ( - "Invalid range operator".to_string(), - Some("Use '-' for ranges like '1000-2000'".to_string()), - ); - } - - if line_content.contains('{') && !line_content.contains('}') { - return ( - "Unmatched opening brace".to_string(), - Some("Ensure all '{' have matching '}'".to_string()), - ); - } - - ( - format!("Malformed token '{}'", token_text), - Some("Check nftables syntax documentation".to_string()), - ) -} - -/// Suggest correction for common typos -fn suggest_correction(token_text: &str, line_content: &str) -> Option { - // Common typos and their corrections - let corrections = [ - ("tabel", "table"), - ("cahin", "chain"), - ("accpet", "accept"), - ("rejct", "reject"), - ("prtocol", "protocol"), - ("addres", "address"), - ("pririty", "priority"), - ("poicy", "policy"), - ]; - - for (typo, correction) in &corrections { - if token_text.contains(typo) { - return Some(format!("Did you mean '{}'?", correction)); - } - } - - // Context-based suggestions - if line_content.contains("type") && line_content.contains("hook") { - if !line_content.contains("filter") - && !line_content.contains("nat") - && !line_content.contains("route") - { - return Some("Chain type should be 'filter', 'nat', or 'route'".to_string()); - } - } - - None -} - -/// Validate CIDR notation -fn is_valid_cidr(text: &str) -> bool { - if let Some(slash_pos) = text.find('/') { - let (addr, prefix) = text.split_at(slash_pos); - let prefix = &prefix[1..]; // Remove the '/' - - // Check if prefix is a valid number - if let Ok(prefix_len) = prefix.parse::() { - // Basic validation - IPv4 should be <= 32, IPv6 <= 128 - if addr.contains(':') { - prefix_len <= 128 // IPv6 - } else { - prefix_len <= 32 // IPv4 - } - } else { - false - } - } else { - false - } -} - fn main() -> Result<()> { let args = Args::parse(); From 0eaa21a3ff6e2a8e045acdf92f198168dd35cfd3 Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Mon, 2 Jun 2025 13:46:57 +0300 Subject: [PATCH 2/5] diagnostic: update configuration options to use clap's `ArgAction` --- src/diagnostic.rs | 129 ++++++++++++++++++++++++++++++++++++++++++---- src/main.rs | 8 +-- 2 files changed, 122 insertions(+), 15 deletions(-) diff --git a/src/diagnostic.rs b/src/diagnostic.rs index c12b927..a3a46b0 100644 --- a/src/diagnostic.rs +++ b/src/diagnostic.rs @@ -607,6 +607,11 @@ impl AnalyzerModule for StyleAnalyzer { fn analyze(&self, source: &str, config: &DiagnosticConfig) -> Vec { let mut diagnostics = Vec::new(); + // Only perform style analysis if enabled + if !config.enable_style_warnings { + return diagnostics; + } + // Check for missing shebang if !source.starts_with("#!") { let range = Range::new(Position::new(0, 0), Position::new(0, 0)); @@ -802,14 +807,29 @@ impl StyleAnalyzer { pub struct SemanticAnalyzer; impl AnalyzerModule for SemanticAnalyzer { - fn analyze(&self, source: &str, _config: &DiagnosticConfig) -> Vec { + fn analyze(&self, source: &str, config: &DiagnosticConfig) -> Vec { let mut diagnostics = Vec::new(); - // Parse and validate nftables-specific constructs + // Always run semantic validation (syntax/semantic errors) diagnostics.extend(self.validate_table_declarations(source)); - diagnostics.extend(self.validate_chain_declarations(source)); + diagnostics.extend(self.validate_chain_declarations_semantic(source)); diagnostics.extend(self.validate_cidr_notation(source)); - diagnostics.extend(self.check_for_redundant_rules(source)); + + // Best practices checks (only if enabled) + if config.enable_best_practices { + diagnostics.extend(self.validate_chain_best_practices(source)); + diagnostics.extend(self.check_for_redundant_rules(source)); + } + + // Performance hints (only if enabled) + if config.enable_performance_hints { + diagnostics.extend(self.check_performance_hints(source)); + } + + // Security warnings (only if enabled) + if config.enable_security_warnings { + diagnostics.extend(self.check_security_warnings(source)); + } diagnostics } @@ -880,7 +900,7 @@ impl SemanticAnalyzer { diagnostics } - fn validate_chain_declarations(&self, source: &str) -> Vec { + fn validate_chain_declarations_semantic(&self, source: &str) -> Vec { let mut diagnostics = Vec::new(); for (line_idx, line) in source.lines().enumerate() { @@ -888,7 +908,7 @@ impl SemanticAnalyzer { let trimmed = line.trim(); if trimmed.starts_with("type ") && trimmed.contains("hook") { - // Validate chain type and hook + // Validate chain type and hook (semantic validation) if let Some(hook_pos) = trimmed.find("hook") { let hook_part = &trimmed[hook_pos..]; let hook_words: Vec<&str> = hook_part.split_whitespace().collect(); @@ -916,22 +936,109 @@ impl SemanticAnalyzer { } } } + } + } - // Check for missing policy in filter chains - if trimmed.contains("type filter") && !trimmed.contains("policy") { + diagnostics + } + + fn validate_chain_best_practices(&self, source: &str) -> Vec { + let mut diagnostics = Vec::new(); + + for (line_idx, line) in source.lines().enumerate() { + let line_num = line_idx as u32; + let trimmed = line.trim(); + + // Check for missing policy in filter chains (best practice) + if trimmed.contains("type filter") && !trimmed.contains("policy") { + let range = Range::new( + Position::new(line_num, 0), + Position::new(line_num, line.len() as u32), + ); + let diagnostic = Diagnostic::new( + range, + DiagnosticSeverity::Warning, + DiagnosticCode::ChainWithoutPolicy, + "Filter chain should have an explicit policy".to_string(), + ); + diagnostics.push(diagnostic); + } + } + + diagnostics + } + + fn check_performance_hints(&self, source: &str) -> Vec { + let mut diagnostics = Vec::new(); + + for (line_idx, line) in source.lines().enumerate() { + let line_num = line_idx as u32; + let trimmed = line.trim(); + + // Check for large sets without timeout (performance hint) + if trimmed.contains("set ") && trimmed.contains("{") && !trimmed.contains("timeout") { + // Simple heuristic: if set definition is long, suggest timeout + if trimmed.len() > 100 { let range = Range::new( Position::new(line_num, 0), Position::new(line_num, line.len() as u32), ); let diagnostic = Diagnostic::new( range, - DiagnosticSeverity::Warning, - DiagnosticCode::ChainWithoutPolicy, - "Filter chain should have an explicit policy".to_string(), + DiagnosticSeverity::Hint, + DiagnosticCode::LargeSetWithoutTimeout, + "Consider adding a timeout to large sets for better performance" + .to_string(), ); diagnostics.push(diagnostic); } } + + // Check for missing counters (performance hint) + if (trimmed.contains(" accept") || trimmed.contains(" drop")) + && !trimmed.contains("counter") + { + let range = Range::new( + Position::new(line_num, 0), + Position::new(line_num, line.len() as u32), + ); + let diagnostic = Diagnostic::new( + range, + DiagnosticSeverity::Hint, + DiagnosticCode::MissingCounters, + "Consider adding counters to rules for monitoring and debugging".to_string(), + ); + diagnostics.push(diagnostic); + } + } + + diagnostics + } + + fn check_security_warnings(&self, source: &str) -> Vec { + let mut diagnostics = Vec::new(); + + for (line_idx, line) in source.lines().enumerate() { + let line_num = line_idx as u32; + let trimmed = line.trim(); + + // Check for overly permissive rules (security warning) + if trimmed.contains(" accept") + && (trimmed.contains("0.0.0.0/0") || trimmed.contains("::/0")) + { + let range = Range::new( + Position::new(line_num, 0), + Position::new(line_num, line.len() as u32), + ); + let diagnostic = Diagnostic::new( + range, + DiagnosticSeverity::Warning, + DiagnosticCode::SecurityRisk, + "Rule accepts traffic from anywhere - consider restricting source addresses" + .to_string(), + ); + diagnostics.push(diagnostic); + } } diagnostics diff --git a/src/main.rs b/src/main.rs index 87a35f8..e7fa847 100755 --- a/src/main.rs +++ b/src/main.rs @@ -93,19 +93,19 @@ enum Commands { json: bool, /// Include style warnings in diagnostics - #[arg(long, default_value = "true")] + #[arg(long, action = clap::ArgAction::Set, default_value = "true")] style_warnings: bool, /// Include best practice recommendations in diagnostics - #[arg(long, default_value = "true")] + #[arg(long, action = clap::ArgAction::Set, default_value = "true")] best_practices: bool, /// Include performance hints in diagnostics - #[arg(long, default_value = "true")] + #[arg(long, action = clap::ArgAction::Set, default_value = "true")] performance_hints: bool, /// Include security warnings in diagnostics - #[arg(long, default_value = "true")] + #[arg(long, action = clap::ArgAction::Set, default_value = "true")] security_warnings: bool, /// Diagnostic modules to run (comma-separated: lexical,syntax,style,semantic) From 8d7fcd6ef446966025db1b857351f78d7a1381ff Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Mon, 2 Jun 2025 14:00:48 +0300 Subject: [PATCH 3/5] lint: introduce `LintError`; count diagnostic errors --- src/main.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/main.rs b/src/main.rs index e7fa847..4d32351 100755 --- a/src/main.rs +++ b/src/main.rs @@ -38,6 +38,12 @@ enum FormatterError { Io(#[from] io::Error), } +#[derive(Error, Debug)] +enum LintError { + #[error("Lint errors found in {file_count} file(s)")] + DiagnosticErrors { file_count: usize }, +} + #[derive(Parser, Debug, Clone)] #[command( name = "nff", @@ -337,7 +343,7 @@ fn process_lint_command( }; let is_multiple_files = files.len() > 1; - let mut has_errors = false; + let mut error_file_count = 0; for file_path in files { if let Err(e) = process_single_file_lint( @@ -352,16 +358,18 @@ fn process_lint_command( is_multiple_files, ) { eprintln!("Error processing {}: {}", file_path, e); - has_errors = true; + error_file_count += 1; if !is_multiple_files { return Err(e); } } } - // Exit with non-zero code if any file had errors - if has_errors { - std::process::exit(1); + if error_file_count > 0 { + return Err(LintError::DiagnosticErrors { + file_count: error_file_count, + } + .into()); } Ok(()) From d1cad19fead02a13cae3666d89a2be8c931a7d89 Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Mon, 2 Jun 2025 14:11:23 +0300 Subject: [PATCH 4/5] lint: better error handling; count "faultly" files --- src/main.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main.rs b/src/main.rs index 4d32351..3a34967 100755 --- a/src/main.rs +++ b/src/main.rs @@ -42,6 +42,8 @@ enum FormatterError { enum LintError { #[error("Lint errors found in {file_count} file(s)")] DiagnosticErrors { file_count: usize }, + #[error("File discovery error: {0}")] + FileDiscovery(#[from] anyhow::Error), } #[derive(Parser, Debug, Clone)] From b9d8cb6d5d5469db5a221b98e713232010e87cc4 Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Mon, 2 Jun 2025 14:32:36 +0300 Subject: [PATCH 5/5] cst: eliminate manual synchronization between enum variants and numeric values --- Cargo.lock | 57 ++++++++ Cargo.toml | 3 +- src/cst.rs | 421 +++++++++++++++++++++++++---------------------------- 3 files changed, 260 insertions(+), 221 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1b11d7f..cd5a4c2 100755 --- a/Cargo.lock +++ b/Cargo.lock @@ -281,6 +281,7 @@ dependencies = [ "cstree", "glob", "logos", + "num_enum", "regex", "serde", "serde_json", @@ -288,6 +289,27 @@ dependencies = [ "thiserror", ] +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -311,6 +333,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "proc-macro-crate" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -487,6 +518,23 @@ dependencies = [ "syn", ] +[[package]] +name = "toml_datetime" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" + +[[package]] +name = "toml_edit" +version = "0.22.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", +] + [[package]] name = "triomphe" version = "0.1.14" @@ -573,3 +621,12 @@ name = "windows_x86_64_msvc" version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" + +[[package]] +name = "winnow" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06928c8748d81b05c9be96aad92e1b6ff01833332f281e8cfca3be4b35fc9ec" +dependencies = [ + "memchr", +] diff --git a/Cargo.toml b/Cargo.toml index 4bb2743..afd902e 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,5 +14,6 @@ cstree = "0.12" text-size = "1.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -regex = "1.11.1" +regex = "1.11" glob = "0.3" +num_enum = "0.7" diff --git a/src/cst.rs b/src/cst.rs index 412b379..b36cad0 100644 --- a/src/cst.rs +++ b/src/cst.rs @@ -4,11 +4,15 @@ use crate::lexer::{Token, TokenKind}; use cstree::{RawSyntaxKind, green::GreenNode, util::NodeOrToken}; +use num_enum::{IntoPrimitive, TryFromPrimitive}; use std::fmt; use thiserror::Error; /// nftables syntax node types -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +/// Uses `TryFromPrimitive` for safe conversion from raw values with fallback to `Error`. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, TryFromPrimitive, IntoPrimitive, +)] #[repr(u16)] pub enum SyntaxKind { // Root and containers @@ -161,116 +165,128 @@ pub enum SyntaxKind { impl From for SyntaxKind { fn from(kind: TokenKind) -> Self { + use TokenKind::*; match kind { - TokenKind::Table => SyntaxKind::TableKw, - TokenKind::Chain => SyntaxKind::ChainKw, - TokenKind::Rule => SyntaxKind::RuleKw, - TokenKind::Set => SyntaxKind::SetKw, - TokenKind::Map => SyntaxKind::MapKw, - TokenKind::Element => SyntaxKind::ElementKw, - TokenKind::Include => SyntaxKind::IncludeKw, - TokenKind::Define => SyntaxKind::DefineKw, - TokenKind::Flush => SyntaxKind::FlushKw, - TokenKind::Add => SyntaxKind::AddKw, - TokenKind::Delete => SyntaxKind::DeleteKw, - TokenKind::Insert => SyntaxKind::InsertKw, - TokenKind::Replace => SyntaxKind::ReplaceKw, + // Keywords -> Kw variants + Table => SyntaxKind::TableKw, + Chain => SyntaxKind::ChainKw, + Rule => SyntaxKind::RuleKw, + Set => SyntaxKind::SetKw, + Map => SyntaxKind::MapKw, + Element => SyntaxKind::ElementKw, + Include => SyntaxKind::IncludeKw, + Define => SyntaxKind::DefineKw, + Flush => SyntaxKind::FlushKw, + Add => SyntaxKind::AddKw, + Delete => SyntaxKind::DeleteKw, + Insert => SyntaxKind::InsertKw, + Replace => SyntaxKind::ReplaceKw, - TokenKind::Filter => SyntaxKind::FilterKw, - TokenKind::Nat => SyntaxKind::NatKw, - TokenKind::Route => SyntaxKind::RouteKw, + // Chain types and hooks + Filter => SyntaxKind::FilterKw, + Nat => SyntaxKind::NatKw, + Route => SyntaxKind::RouteKw, + Input => SyntaxKind::InputKw, + Output => SyntaxKind::OutputKw, + Forward => SyntaxKind::ForwardKw, + Prerouting => SyntaxKind::PreroutingKw, + Postrouting => SyntaxKind::PostroutingKw, - TokenKind::Input => SyntaxKind::InputKw, - TokenKind::Output => SyntaxKind::OutputKw, - TokenKind::Forward => SyntaxKind::ForwardKw, - TokenKind::Prerouting => SyntaxKind::PreroutingKw, - TokenKind::Postrouting => SyntaxKind::PostroutingKw, + // Protocols and families + Ip => SyntaxKind::IpKw, + Ip6 => SyntaxKind::Ip6Kw, + Inet => SyntaxKind::InetKw, + Arp => SyntaxKind::ArpKw, + Bridge => SyntaxKind::BridgeKw, + Netdev => SyntaxKind::NetdevKw, + Tcp => SyntaxKind::TcpKw, + Udp => SyntaxKind::UdpKw, + Icmp => SyntaxKind::IcmpKw, + Icmpv6 => SyntaxKind::Icmpv6Kw, - TokenKind::Ip => SyntaxKind::IpKw, - TokenKind::Ip6 => SyntaxKind::Ip6Kw, - TokenKind::Inet => SyntaxKind::InetKw, - TokenKind::Arp => SyntaxKind::ArpKw, - TokenKind::Bridge => SyntaxKind::BridgeKw, - TokenKind::Netdev => SyntaxKind::NetdevKw, - TokenKind::Tcp => SyntaxKind::TcpKw, - TokenKind::Udp => SyntaxKind::UdpKw, - TokenKind::Icmp => SyntaxKind::IcmpKw, - TokenKind::Icmpv6 => SyntaxKind::Icmpv6Kw, + // Match keywords + Sport => SyntaxKind::SportKw, + Dport => SyntaxKind::DportKw, + Saddr => SyntaxKind::SaddrKw, + Daddr => SyntaxKind::DaddrKw, + Protocol => SyntaxKind::ProtocolKw, + Nexthdr => SyntaxKind::NexthdrKw, + Type => SyntaxKind::TypeKw, + Hook => SyntaxKind::HookKw, + Priority => SyntaxKind::PriorityKw, + Policy => SyntaxKind::PolicyKw, + Iifname => SyntaxKind::IifnameKw, + Oifname => SyntaxKind::OifnameKw, + Ct => SyntaxKind::CtKw, + State => SyntaxKind::StateKw, - TokenKind::Sport => SyntaxKind::SportKw, - TokenKind::Dport => SyntaxKind::DportKw, - TokenKind::Saddr => SyntaxKind::SaddrKw, - TokenKind::Daddr => SyntaxKind::DaddrKw, - TokenKind::Protocol => SyntaxKind::ProtocolKw, - TokenKind::Nexthdr => SyntaxKind::NexthdrKw, - TokenKind::Type => SyntaxKind::TypeKw, - TokenKind::Hook => SyntaxKind::HookKw, - TokenKind::Priority => SyntaxKind::PriorityKw, - TokenKind::Policy => SyntaxKind::PolicyKw, - TokenKind::Iifname => SyntaxKind::IifnameKw, - TokenKind::Oifname => SyntaxKind::OifnameKw, - TokenKind::Ct => SyntaxKind::CtKw, - TokenKind::State => SyntaxKind::StateKw, + // Actions + Accept => SyntaxKind::AcceptKw, + Drop => SyntaxKind::DropKw, + Reject => SyntaxKind::RejectKw, + Return => SyntaxKind::ReturnKw, + Jump => SyntaxKind::JumpKw, + Goto => SyntaxKind::GotoKw, + Continue => SyntaxKind::ContinueKw, + Log => SyntaxKind::LogKw, + Comment => SyntaxKind::CommentKw, - TokenKind::Accept => SyntaxKind::AcceptKw, - TokenKind::Drop => SyntaxKind::DropKw, - TokenKind::Reject => SyntaxKind::RejectKw, - TokenKind::Return => SyntaxKind::ReturnKw, - TokenKind::Jump => SyntaxKind::JumpKw, - TokenKind::Goto => SyntaxKind::GotoKw, - TokenKind::Continue => SyntaxKind::ContinueKw, - TokenKind::Log => SyntaxKind::LogKw, - TokenKind::Comment => SyntaxKind::CommentKw, + // States + Established => SyntaxKind::EstablishedKw, + Related => SyntaxKind::RelatedKw, + New => SyntaxKind::NewKw, + Invalid => SyntaxKind::InvalidKw, - TokenKind::Established => SyntaxKind::EstablishedKw, - TokenKind::Related => SyntaxKind::RelatedKw, - TokenKind::New => SyntaxKind::NewKw, - TokenKind::Invalid => SyntaxKind::InvalidKw, + // Additional protocol keywords + Vmap => SyntaxKind::VmapKw, + NdRouterAdvert => SyntaxKind::NdRouterAdvertKw, + NdNeighborSolicit => SyntaxKind::NdNeighborSolicitKw, + NdNeighborAdvert => SyntaxKind::NdNeighborAdvertKw, + EchoRequest => SyntaxKind::EchoRequestKw, + DestUnreachable => SyntaxKind::DestUnreachableKw, + RouterAdvertisement => SyntaxKind::RouterAdvertisementKw, + TimeExceeded => SyntaxKind::TimeExceededKw, + ParameterProblem => SyntaxKind::ParameterProblemKw, + PacketTooBig => SyntaxKind::PacketTooBigKw, - TokenKind::Vmap => SyntaxKind::VmapKw, - TokenKind::NdRouterAdvert => SyntaxKind::NdRouterAdvertKw, - TokenKind::NdNeighborSolicit => SyntaxKind::NdNeighborSolicitKw, - TokenKind::NdNeighborAdvert => SyntaxKind::NdNeighborAdvertKw, - TokenKind::EchoRequest => SyntaxKind::EchoRequestKw, - TokenKind::DestUnreachable => SyntaxKind::DestUnreachableKw, - TokenKind::RouterAdvertisement => SyntaxKind::RouterAdvertisementKw, - TokenKind::TimeExceeded => SyntaxKind::TimeExceededKw, - TokenKind::ParameterProblem => SyntaxKind::ParameterProblemKw, - TokenKind::PacketTooBig => SyntaxKind::PacketTooBigKw, + // Operators - direct mapping + Eq => SyntaxKind::EqOp, + Ne => SyntaxKind::NeOp, + Le => SyntaxKind::LeOp, + Ge => SyntaxKind::GeOp, + Lt => SyntaxKind::LtOp, + Gt => SyntaxKind::GtOp, - TokenKind::Eq => SyntaxKind::EqOp, - TokenKind::Ne => SyntaxKind::NeOp, - TokenKind::Le => SyntaxKind::LeOp, - TokenKind::Ge => SyntaxKind::GeOp, - TokenKind::Lt => SyntaxKind::LtOp, - TokenKind::Gt => SyntaxKind::GtOp, + // Punctuation - direct mapping + LeftBrace => SyntaxKind::LeftBrace, + RightBrace => SyntaxKind::RightBrace, + LeftParen => SyntaxKind::LeftParen, + RightParen => SyntaxKind::RightParen, + LeftBracket => SyntaxKind::LeftBracket, + RightBracket => SyntaxKind::RightBracket, + Comma => SyntaxKind::Comma, + Semicolon => SyntaxKind::Semicolon, + Colon => SyntaxKind::Colon, + Assign => SyntaxKind::Assign, + Dash => SyntaxKind::Dash, + Slash => SyntaxKind::Slash, + Dot => SyntaxKind::Dot, - TokenKind::LeftBrace => SyntaxKind::LeftBrace, - TokenKind::RightBrace => SyntaxKind::RightBrace, - TokenKind::LeftParen => SyntaxKind::LeftParen, - TokenKind::RightParen => SyntaxKind::RightParen, - TokenKind::LeftBracket => SyntaxKind::LeftBracket, - TokenKind::RightBracket => SyntaxKind::RightBracket, - TokenKind::Comma => SyntaxKind::Comma, - TokenKind::Semicolon => SyntaxKind::Semicolon, - TokenKind::Colon => SyntaxKind::Colon, - TokenKind::Assign => SyntaxKind::Assign, - TokenKind::Dash => SyntaxKind::Dash, - TokenKind::Slash => SyntaxKind::Slash, - TokenKind::Dot => SyntaxKind::Dot, + // Literals - map data-carrying variants to their types + StringLiteral(_) => SyntaxKind::StringLiteral, + NumberLiteral(_) => SyntaxKind::NumberLiteral, + IpAddress(_) => SyntaxKind::IpAddress, + Ipv6Address(_) => SyntaxKind::Ipv6Address, + MacAddress(_) => SyntaxKind::MacAddress, + Identifier(_) => SyntaxKind::Identifier, - TokenKind::StringLiteral(_) => SyntaxKind::StringLiteral, - TokenKind::NumberLiteral(_) => SyntaxKind::NumberLiteral, - TokenKind::IpAddress(_) => SyntaxKind::IpAddress, - TokenKind::Ipv6Address(_) => SyntaxKind::Ipv6Address, - TokenKind::MacAddress(_) => SyntaxKind::MacAddress, - TokenKind::Identifier(_) => SyntaxKind::Identifier, + // Special tokens + Newline => SyntaxKind::Newline, + CommentLine(_) => SyntaxKind::Comment, + Shebang(_) => SyntaxKind::Shebang, - TokenKind::Newline => SyntaxKind::Newline, - TokenKind::CommentLine(_) => SyntaxKind::Comment, - TokenKind::Shebang(_) => SyntaxKind::Shebang, - - TokenKind::Error => SyntaxKind::Error, + // Error fallback + Error => SyntaxKind::Error, } } } @@ -324,126 +340,7 @@ impl SyntaxKind { } pub fn from_raw(raw: RawSyntaxKind) -> Self { - match raw.0 { - 0 => SyntaxKind::Root, - 1 => SyntaxKind::Table, - 2 => SyntaxKind::Chain, - 3 => SyntaxKind::Rule, - 4 => SyntaxKind::Set, - 5 => SyntaxKind::Map, - 6 => SyntaxKind::Element, - 7 => SyntaxKind::Expression, - 8 => SyntaxKind::BinaryExpr, - 9 => SyntaxKind::UnaryExpr, - 10 => SyntaxKind::CallExpr, - 11 => SyntaxKind::SetExpr, - 12 => SyntaxKind::RangeExpr, - 13 => SyntaxKind::Statement, - 14 => SyntaxKind::IncludeStmt, - 15 => SyntaxKind::DefineStmt, - 16 => SyntaxKind::FlushStmt, - 17 => SyntaxKind::AddStmt, - 18 => SyntaxKind::DeleteStmt, - 19 => SyntaxKind::Identifier, - 20 => SyntaxKind::StringLiteral, - 21 => SyntaxKind::NumberLiteral, - 22 => SyntaxKind::IpAddress, - 23 => SyntaxKind::Ipv6Address, - 24 => SyntaxKind::MacAddress, - 25 => SyntaxKind::TableKw, - 26 => SyntaxKind::ChainKw, - 27 => SyntaxKind::RuleKw, - 28 => SyntaxKind::SetKw, - 29 => SyntaxKind::MapKw, - 30 => SyntaxKind::ElementKw, - 31 => SyntaxKind::IncludeKw, - 32 => SyntaxKind::DefineKw, - 33 => SyntaxKind::FlushKw, - 34 => SyntaxKind::AddKw, - 35 => SyntaxKind::DeleteKw, - 36 => SyntaxKind::InsertKw, - 37 => SyntaxKind::ReplaceKw, - 38 => SyntaxKind::FilterKw, - 39 => SyntaxKind::NatKw, - 40 => SyntaxKind::RouteKw, - 41 => SyntaxKind::InputKw, - 42 => SyntaxKind::OutputKw, - 43 => SyntaxKind::ForwardKw, - 44 => SyntaxKind::PreroutingKw, - 45 => SyntaxKind::PostroutingKw, - 46 => SyntaxKind::IpKw, - 47 => SyntaxKind::Ip6Kw, - 48 => SyntaxKind::InetKw, - 49 => SyntaxKind::ArpKw, - 50 => SyntaxKind::BridgeKw, - 51 => SyntaxKind::NetdevKw, - 52 => SyntaxKind::TcpKw, - 53 => SyntaxKind::UdpKw, - 54 => SyntaxKind::IcmpKw, - 55 => SyntaxKind::Icmpv6Kw, - 56 => SyntaxKind::SportKw, - 57 => SyntaxKind::DportKw, - 58 => SyntaxKind::SaddrKw, - 59 => SyntaxKind::DaddrKw, - 60 => SyntaxKind::ProtocolKw, - 61 => SyntaxKind::NexthdrKw, - 62 => SyntaxKind::TypeKw, - 63 => SyntaxKind::HookKw, - 64 => SyntaxKind::PriorityKw, - 65 => SyntaxKind::PolicyKw, - 66 => SyntaxKind::IifnameKw, - 67 => SyntaxKind::OifnameKw, - 68 => SyntaxKind::CtKw, - 69 => SyntaxKind::StateKw, - 70 => SyntaxKind::AcceptKw, - 71 => SyntaxKind::DropKw, - 72 => SyntaxKind::RejectKw, - 73 => SyntaxKind::ReturnKw, - 74 => SyntaxKind::JumpKw, - 75 => SyntaxKind::GotoKw, - 76 => SyntaxKind::ContinueKw, - 77 => SyntaxKind::LogKw, - 78 => SyntaxKind::CommentKw, - 79 => SyntaxKind::EstablishedKw, - 80 => SyntaxKind::RelatedKw, - 81 => SyntaxKind::NewKw, - 82 => SyntaxKind::InvalidKw, - 83 => SyntaxKind::EqOp, - 84 => SyntaxKind::NeOp, - 85 => SyntaxKind::LeOp, - 86 => SyntaxKind::GeOp, - 87 => SyntaxKind::LtOp, - 88 => SyntaxKind::GtOp, - 89 => SyntaxKind::LeftBrace, - 90 => SyntaxKind::RightBrace, - 91 => SyntaxKind::LeftParen, - 92 => SyntaxKind::RightParen, - 93 => SyntaxKind::LeftBracket, - 94 => SyntaxKind::RightBracket, - 95 => SyntaxKind::Comma, - 96 => SyntaxKind::Semicolon, - 97 => SyntaxKind::Colon, - 98 => SyntaxKind::Assign, - 99 => SyntaxKind::Dash, - 100 => SyntaxKind::Slash, - 101 => SyntaxKind::Dot, - 102 => SyntaxKind::Whitespace, - 103 => SyntaxKind::Newline, - 104 => SyntaxKind::Comment, - 105 => SyntaxKind::Shebang, - 106 => SyntaxKind::Error, - 107 => SyntaxKind::VmapKw, - 108 => SyntaxKind::NdRouterAdvertKw, - 109 => SyntaxKind::NdNeighborSolicitKw, - 110 => SyntaxKind::NdNeighborAdvertKw, - 111 => SyntaxKind::EchoRequestKw, - 112 => SyntaxKind::DestUnreachableKw, - 113 => SyntaxKind::RouterAdvertisementKw, - 114 => SyntaxKind::TimeExceededKw, - 115 => SyntaxKind::ParameterProblemKw, - 116 => SyntaxKind::PacketTooBigKw, - _ => SyntaxKind::Error, // Fallback to Error for invalid values - } + Self::try_from(raw.0 as u16).unwrap_or(SyntaxKind::Error) } } @@ -1350,7 +1247,7 @@ mod tests { let mut lexer = NftablesLexer::new(source); let tokens = lexer.tokenize().expect("Tokenization should succeed"); - // CST is now implemented - test that it works + // Test CST construction with basic table syntax let green_tree = CstBuilder::build_tree(&tokens); // Verify the tree was created successfully @@ -1367,4 +1264,88 @@ mod tests { let cst_result = CstBuilder::parse_to_cst(&tokens); assert!(cst_result.is_ok()); } + + #[test] + fn test_num_enum_improvements() { + // Test that from_raw uses num_enum for conversion + // Invalid values fall back to Error variant + + // Test valid conversions + assert_eq!(SyntaxKind::from_raw(RawSyntaxKind(0)), SyntaxKind::Root); + assert_eq!(SyntaxKind::from_raw(RawSyntaxKind(1)), SyntaxKind::Table); + assert_eq!(SyntaxKind::from_raw(RawSyntaxKind(25)), SyntaxKind::TableKw); + assert_eq!(SyntaxKind::from_raw(RawSyntaxKind(106)), SyntaxKind::Error); + assert_eq!( + SyntaxKind::from_raw(RawSyntaxKind(116)), + SyntaxKind::PacketTooBigKw + ); + + // Test invalid values automatically fall back to Error + assert_eq!(SyntaxKind::from_raw(RawSyntaxKind(999)), SyntaxKind::Error); + assert_eq!(SyntaxKind::from_raw(RawSyntaxKind(1000)), SyntaxKind::Error); + + // Test bidirectional conversion + for variant in [ + SyntaxKind::Root, + SyntaxKind::Table, + SyntaxKind::TableKw, + SyntaxKind::Error, + SyntaxKind::PacketTooBigKw, + ] { + let raw = variant.to_raw(); + let converted_back = SyntaxKind::from_raw(raw); + assert_eq!(variant, converted_back); + } + } + + #[test] + fn test_token_kind_conversion_improvements() { + // Test that From conversion is complete and correct + use crate::lexer::TokenKind; + + // Test keyword mappings + assert_eq!(SyntaxKind::from(TokenKind::Table), SyntaxKind::TableKw); + assert_eq!(SyntaxKind::from(TokenKind::Chain), SyntaxKind::ChainKw); + assert_eq!(SyntaxKind::from(TokenKind::Accept), SyntaxKind::AcceptKw); + + // Test operators + assert_eq!(SyntaxKind::from(TokenKind::Eq), SyntaxKind::EqOp); + assert_eq!(SyntaxKind::from(TokenKind::Lt), SyntaxKind::LtOp); + + // Test punctuation + assert_eq!( + SyntaxKind::from(TokenKind::LeftBrace), + SyntaxKind::LeftBrace + ); + assert_eq!( + SyntaxKind::from(TokenKind::Semicolon), + SyntaxKind::Semicolon + ); + + // Test literals (with data) + assert_eq!( + SyntaxKind::from(TokenKind::StringLiteral("test".to_string())), + SyntaxKind::StringLiteral + ); + assert_eq!( + SyntaxKind::from(TokenKind::NumberLiteral(42)), + SyntaxKind::NumberLiteral + ); + assert_eq!( + SyntaxKind::from(TokenKind::IpAddress("192.168.1.1".to_string())), + SyntaxKind::IpAddress + ); + assert_eq!( + SyntaxKind::from(TokenKind::Identifier("test".to_string())), + SyntaxKind::Identifier + ); + + // Test special tokens + assert_eq!(SyntaxKind::from(TokenKind::Newline), SyntaxKind::Newline); + assert_eq!( + SyntaxKind::from(TokenKind::CommentLine("# comment".to_string())), + SyntaxKind::Comment + ); + assert_eq!(SyntaxKind::from(TokenKind::Error), SyntaxKind::Error); + } }