From 5fbc0f45bdad2bd938038d8627bcec8c42c44126 Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Fri, 14 Nov 2025 21:53:51 +0300 Subject: [PATCH] eh: don't load entire files into memory for hash replace; argchecking Signed-off-by: NotAShelf Change-Id: Ie3385f68e70ee7848272010fbd41845e6a6a6964 --- Cargo.toml | 1 + eh/Cargo.toml | 4 +- eh/src/command.rs | 2 + eh/src/error.rs | 3 + eh/src/util.rs | 265 ++++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 254 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e2149c4..dda7e17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ version = "0.1.2" clap = { default-features = false, features = [ "std", "help", "derive" ], version = "4.5.51" } clap_complete = "4.5.60" regex = "1.12.2" +tempfile = "3.23.0" thiserror = "2.0.17" tracing = "0.1.41" tracing-subscriber = "0.3.20" diff --git a/eh/Cargo.toml b/eh/Cargo.toml index 0b0300b..1e62b9b 100644 --- a/eh/Cargo.toml +++ b/eh/Cargo.toml @@ -13,11 +13,9 @@ name = "eh" [dependencies] clap.workspace = true regex.workspace = true +tempfile.workspace = true thiserror.workspace = true tracing.workspace = true tracing-subscriber.workspace = true walkdir.workspace = true yansi.workspace = true - -[dev-dependencies] -tempfile = "3.0" diff --git a/eh/src/command.rs b/eh/src/command.rs index cf19869..74ca3f1 100644 --- a/eh/src/command.rs +++ b/eh/src/command.rs @@ -54,6 +54,7 @@ impl NixCommand { self } + #[allow(dead_code, reason = "FIXME")] pub fn args(mut self, args: I) -> Self where I: IntoIterator, @@ -63,6 +64,7 @@ impl NixCommand { self } + #[must_use] pub fn args_ref(mut self, args: &[String]) -> Self { self.args.extend(args.iter().cloned()); self diff --git a/eh/src/error.rs b/eh/src/error.rs index 7e255ef..b9c2bfb 100644 --- a/eh/src/error.rs +++ b/eh/src/error.rs @@ -28,6 +28,9 @@ pub enum EhError { #[error("Command execution failed: {command}")] CommandFailed { command: String }, + + #[error("Invalid input: {input} - {reason}")] + InvalidInput { input: String, reason: String }, } pub type Result = std::result::Result; diff --git a/eh/src/util.rs b/eh/src/util.rs index 6f1490d..9896e18 100644 --- a/eh/src/util.rs +++ b/eh/src/util.rs @@ -1,10 +1,10 @@ use std::{ - fs, - io::Write, + io::{BufWriter, Write}, path::{Path, PathBuf}, }; use regex::Regex; +use tempfile::NamedTempFile; use tracing::{info, warn}; use walkdir::WalkDir; use yansi::Paint; @@ -63,14 +63,13 @@ impl NixFileFixer for DefaultNixFileFixer { fn find_nix_files(&self) -> Result> { let files: Vec = WalkDir::new(".") .into_iter() - .filter_map(|entry| entry.ok()) + .filter_map(std::result::Result::ok) .filter(|entry| { entry.file_type().is_file() && entry .path() .extension() - .map(|ext| ext.eq_ignore_ascii_case("nix")) - .unwrap_or(false) + .is_some_and(|ext| ext.eq_ignore_ascii_case("nix")) }) .map(|entry| entry.path().to_path_buf()) .collect(); @@ -82,8 +81,8 @@ impl NixFileFixer for DefaultNixFileFixer { } fn fix_hash_in_file(&self, file_path: &Path, new_hash: &str) -> Result { - let content = fs::read_to_string(file_path)?; - let patterns = [ + // Pre-compile regex patterns once to avoid repeated compilation + let patterns: Vec<(Regex, String)> = [ (r#"hash\s*=\s*"[^"]*""#, format!(r#"hash = "{new_hash}""#)), ( r#"sha256\s*=\s*"[^"]*""#, @@ -93,28 +92,47 @@ impl NixFileFixer for DefaultNixFileFixer { r#"outputHash\s*=\s*"[^"]*""#, format!(r#"outputHash = "{new_hash}""#), ), - ]; - let mut new_content = content; + ] + .into_iter() + .map(|(pattern, replacement)| { + Regex::new(pattern) + .map(|re| (re, replacement)) + .map_err(EhError::Regex) + }) + .collect::>>()?; + + // Read the entire file content + let content = std::fs::read_to_string(file_path)?; let mut replaced = false; - for (pattern, replacement) in &patterns { - let re = Regex::new(pattern)?; - if re.is_match(&new_content) { - new_content = re - .replace_all(&new_content, replacement.as_str()) + let mut result_content = content; + + // Apply replacements + for (re, replacement) in &patterns { + if re.is_match(&result_content) { + result_content = re + .replace_all(&result_content, replacement.as_str()) .into_owned(); replaced = true; } } + + // Write back to file atomically if replaced { - fs::write(file_path, new_content).map_err(|_| { + let temp_file = + NamedTempFile::new_in(file_path.parent().unwrap_or(Path::new(".")))?; + { + let mut writer = BufWriter::new(temp_file.as_file()); + writer.write_all(result_content.as_bytes())?; + writer.flush()?; + } + temp_file.persist(file_path).map_err(|_e| { EhError::HashFixFailed { path: file_path.to_string_lossy().to_string(), } })?; - Ok(true) - } else { - Ok(false) } + + Ok(replaced) } } @@ -156,6 +174,24 @@ fn pre_evaluate(_subcommand: &str, args: &[String]) -> Result { Ok(false) } +fn validate_nix_args(args: &[String]) -> Result<()> { + const DANGEROUS_PATTERNS: &[&str] = &[ + ";", "&&", "||", "|", "`", "$(", "${", ">", "<", ">>", "<<", "2>", "2>>", + ]; + + for arg in args { + for pattern in DANGEROUS_PATTERNS { + if arg.contains(pattern) { + return Err(EhError::InvalidInput { + input: arg.clone(), + reason: format!("contains potentially dangerous pattern: {pattern}"), + }); + } + } + } + Ok(()) +} + /// Shared retry logic for nix commands (build/run/shell). pub fn handle_nix_with_retry( subcommand: &str, @@ -165,6 +201,7 @@ pub fn handle_nix_with_retry( classifier: &dyn NixErrorClassifier, interactive: bool, ) -> Result { + validate_nix_args(args)?; // Pre-evaluate for build commands to catch errors early if !pre_evaluate(subcommand, args)? { return Err(EhError::NixCommandFailed( @@ -297,6 +334,7 @@ pub fn handle_nix_with_retry( code: output.status.code().unwrap_or(1), }) } + pub struct DefaultNixErrorClassifier; impl NixErrorClassifier for DefaultNixErrorClassifier { @@ -310,3 +348,194 @@ impl NixErrorClassifier for DefaultNixErrorClassifier { && stderr.contains("refusing")) } } + +#[cfg(test)] +mod tests { + use std::io::Write; + + use tempfile::NamedTempFile; + + use super::*; + + #[test] + fn test_streaming_hash_replacement() { + let temp_file = NamedTempFile::new().unwrap(); + let file_path = temp_file.path(); + + // Write test content with multiple hash patterns + let test_content = r#"stdenv.mkDerivation { + name = "test-package"; + src = fetchurl { + url = "https://example.com.tar.gz"; + hash = "sha256-oldhash123"; + sha256 = "sha256-oldhash456"; + outputHash = "sha256-oldhash789"; + }; +}"#; + + let mut file = std::fs::File::create(file_path).unwrap(); + file.write_all(test_content.as_bytes()).unwrap(); + file.flush().unwrap(); + + let fixer = DefaultNixFileFixer; + let result = fixer + .fix_hash_in_file(file_path, "sha256-newhash999") + .unwrap(); + + assert!(result, "Hash replacement should return true"); + + // Verify the content was updated + let updated_content = std::fs::read_to_string(file_path).unwrap(); + assert!(updated_content.contains("sha256-newhash999")); + assert!(!updated_content.contains("sha256-oldhash123")); + assert!(!updated_content.contains("sha256-oldhash456")); + assert!(!updated_content.contains("sha256-oldhash789")); + } + + #[test] + fn test_streaming_no_replacement_needed() { + let temp_file = NamedTempFile::new().unwrap(); + let file_path = temp_file.path().to_path_buf(); + + let test_content = r#"stdenv.mkDerivation { + name = "test-package"; + src = fetchurl { + url = "https://example.com.tar.gz"; + }; +}"#; + + { + let mut file = std::fs::File::create(&file_path).unwrap(); + file.write_all(test_content.as_bytes()).unwrap(); + file.flush().unwrap(); + } // File is closed here + + // Test hash replacement + let fixer = DefaultNixFileFixer; + let result = fixer + .fix_hash_in_file(&file_path, "sha256-newhash999") + .unwrap(); + + assert!( + !result, + "Hash replacement should return false when no patterns found" + ); + + // Verify the content was unchanged, ignoring trailing newline differences + let updated_content = std::fs::read_to_string(&file_path).unwrap(); + let normalized_original = test_content.trim_end(); + let normalized_updated = updated_content.trim_end(); + assert_eq!(normalized_updated, normalized_original); + } + + // FIXME: this is a little stupid, but it works + #[test] + fn test_streaming_large_file_handling() { + let temp_file = NamedTempFile::new().unwrap(); + let file_path = temp_file.path(); + let mut file = std::fs::File::create(file_path).unwrap(); + + // Write header with hash + file.write_all(b"stdenv.mkDerivation {\n name = \"large-package\";\n src = fetchurl {\n url = \"https://example.com/large.tar.gz\";\n hash = \"sha256-oldhash\";\n };\n").unwrap(); + + for i in 0..10000 { + writeln!(file, " # Large comment line {} to simulate file size", i) + .unwrap(); + } + + file.flush().unwrap(); + + // Test that streaming can handle large files without memory issues + let fixer = DefaultNixFileFixer; + let result = fixer + .fix_hash_in_file(file_path, "sha256-newhash999") + .unwrap(); + + assert!(result, "Hash replacement should work for large files"); + + // Verify the hash was replaced + let updated_content = std::fs::read_to_string(file_path).unwrap(); + assert!(updated_content.contains("sha256-newhash999")); + assert!(!updated_content.contains("sha256-oldhash")); + } + + #[test] + fn test_streaming_file_permissions_preserved() { + let temp_file = NamedTempFile::new().unwrap(); + let file_path = temp_file.path(); + + // Write test content + let test_content = r#"stdenv.mkDerivation { + name = "test"; + src = fetchurl { + url = "https://example.com"; + hash = "sha256-oldhash"; + }; +}"#; + + let mut file = std::fs::File::create(file_path).unwrap(); + file.write_all(test_content.as_bytes()).unwrap(); + file.flush().unwrap(); + + // Get original permissions + let original_metadata = std::fs::metadata(file_path).unwrap(); + let _original_permissions = original_metadata.permissions(); + + // Test hash replacement + let fixer = DefaultNixFileFixer; + let result = fixer.fix_hash_in_file(file_path, "sha256-newhash").unwrap(); + + assert!(result, "Hash replacement should succeed"); + + // Verify file still exists and has reasonable permissions + let new_metadata = std::fs::metadata(file_path).unwrap(); + assert!( + new_metadata.is_file(), + "File should still exist after replacement" + ); + } + + #[test] + fn test_input_validation_blocks_dangerous_patterns() { + let dangerous_args = vec![ + "package; rm -rf /".to_string(), + "package && echo hacked".to_string(), + "package || echo hacked".to_string(), + "package | cat /etc/passwd".to_string(), + "package `whoami`".to_string(), + "package $(echo hacked lol!)".to_string(), + "package ${HOME}/file".to_string(), + ]; + + for arg in dangerous_args { + let result = validate_nix_args(std::slice::from_ref(&arg)); + assert!(result.is_err(), "Should reject dangerous argument: {}", arg); + + match result.unwrap_err() { + EhError::InvalidInput { input, reason } => { + assert_eq!(input, arg); + assert!(reason.contains("dangerous pattern")); + }, + _ => panic!("Expected InvalidInput error"), + } + } + } + + #[test] + fn test_input_validation_allows_safe_args() { + let safe_args = vec![ + "nixpkgs#hello".to_string(), + "--impure".to_string(), + "--print-build-logs".to_string(), + "/path/to/flake".to_string(), + ".#default".to_string(), + ]; + + let result = validate_nix_args(&safe_args); + assert!( + result.is_ok(), + "Should allow safe arguments: {:?}", + safe_args + ); + } +}