eh: don't load entire files into memory for hash replace; argchecking
Some checks are pending
Rust / build (push) Waiting to run

Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: Ie3385f68e70ee7848272010fbd41845e6a6a6964
This commit is contained in:
raf 2025-11-14 21:53:51 +03:00
commit 5fbc0f45bd
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
5 changed files with 254 additions and 21 deletions

View file

@ -15,6 +15,7 @@ version = "0.1.2"
clap = { default-features = false, features = [ "std", "help", "derive" ], version = "4.5.51" }
clap_complete = "4.5.60"
regex = "1.12.2"
tempfile = "3.23.0"
thiserror = "2.0.17"
tracing = "0.1.41"
tracing-subscriber = "0.3.20"

View file

@ -13,11 +13,9 @@ name = "eh"
[dependencies]
clap.workspace = true
regex.workspace = true
tempfile.workspace = true
thiserror.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
walkdir.workspace = true
yansi.workspace = true
[dev-dependencies]
tempfile = "3.0"

View file

@ -54,6 +54,7 @@ impl NixCommand {
self
}
#[allow(dead_code, reason = "FIXME")]
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
@ -63,6 +64,7 @@ impl NixCommand {
self
}
#[must_use]
pub fn args_ref(mut self, args: &[String]) -> Self {
self.args.extend(args.iter().cloned());
self

View file

@ -28,6 +28,9 @@ pub enum EhError {
#[error("Command execution failed: {command}")]
CommandFailed { command: String },
#[error("Invalid input: {input} - {reason}")]
InvalidInput { input: String, reason: String },
}
pub type Result<T> = std::result::Result<T, EhError>;

View file

@ -1,10 +1,10 @@
use std::{
fs,
io::Write,
io::{BufWriter, Write},
path::{Path, PathBuf},
};
use regex::Regex;
use tempfile::NamedTempFile;
use tracing::{info, warn};
use walkdir::WalkDir;
use yansi::Paint;
@ -63,14 +63,13 @@ impl NixFileFixer for DefaultNixFileFixer {
fn find_nix_files(&self) -> Result<Vec<PathBuf>> {
let files: Vec<PathBuf> = WalkDir::new(".")
.into_iter()
.filter_map(|entry| entry.ok())
.filter_map(std::result::Result::ok)
.filter(|entry| {
entry.file_type().is_file()
&& entry
.path()
.extension()
.map(|ext| ext.eq_ignore_ascii_case("nix"))
.unwrap_or(false)
.is_some_and(|ext| ext.eq_ignore_ascii_case("nix"))
})
.map(|entry| entry.path().to_path_buf())
.collect();
@ -82,8 +81,8 @@ impl NixFileFixer for DefaultNixFileFixer {
}
fn fix_hash_in_file(&self, file_path: &Path, new_hash: &str) -> Result<bool> {
let content = fs::read_to_string(file_path)?;
let patterns = [
// Pre-compile regex patterns once to avoid repeated compilation
let patterns: Vec<(Regex, String)> = [
(r#"hash\s*=\s*"[^"]*""#, format!(r#"hash = "{new_hash}""#)),
(
r#"sha256\s*=\s*"[^"]*""#,
@ -93,28 +92,47 @@ impl NixFileFixer for DefaultNixFileFixer {
r#"outputHash\s*=\s*"[^"]*""#,
format!(r#"outputHash = "{new_hash}""#),
),
];
let mut new_content = content;
]
.into_iter()
.map(|(pattern, replacement)| {
Regex::new(pattern)
.map(|re| (re, replacement))
.map_err(EhError::Regex)
})
.collect::<Result<Vec<_>>>()?;
// Read the entire file content
let content = std::fs::read_to_string(file_path)?;
let mut replaced = false;
for (pattern, replacement) in &patterns {
let re = Regex::new(pattern)?;
if re.is_match(&new_content) {
new_content = re
.replace_all(&new_content, replacement.as_str())
let mut result_content = content;
// Apply replacements
for (re, replacement) in &patterns {
if re.is_match(&result_content) {
result_content = re
.replace_all(&result_content, replacement.as_str())
.into_owned();
replaced = true;
}
}
// Write back to file atomically
if replaced {
fs::write(file_path, new_content).map_err(|_| {
let temp_file =
NamedTempFile::new_in(file_path.parent().unwrap_or(Path::new(".")))?;
{
let mut writer = BufWriter::new(temp_file.as_file());
writer.write_all(result_content.as_bytes())?;
writer.flush()?;
}
temp_file.persist(file_path).map_err(|_e| {
EhError::HashFixFailed {
path: file_path.to_string_lossy().to_string(),
}
})?;
Ok(true)
} else {
Ok(false)
}
Ok(replaced)
}
}
@ -156,6 +174,24 @@ fn pre_evaluate(_subcommand: &str, args: &[String]) -> Result<bool> {
Ok(false)
}
fn validate_nix_args(args: &[String]) -> Result<()> {
const DANGEROUS_PATTERNS: &[&str] = &[
";", "&&", "||", "|", "`", "$(", "${", ">", "<", ">>", "<<", "2>", "2>>",
];
for arg in args {
for pattern in DANGEROUS_PATTERNS {
if arg.contains(pattern) {
return Err(EhError::InvalidInput {
input: arg.clone(),
reason: format!("contains potentially dangerous pattern: {pattern}"),
});
}
}
}
Ok(())
}
/// Shared retry logic for nix commands (build/run/shell).
pub fn handle_nix_with_retry(
subcommand: &str,
@ -165,6 +201,7 @@ pub fn handle_nix_with_retry(
classifier: &dyn NixErrorClassifier,
interactive: bool,
) -> Result<i32> {
validate_nix_args(args)?;
// Pre-evaluate for build commands to catch errors early
if !pre_evaluate(subcommand, args)? {
return Err(EhError::NixCommandFailed(
@ -297,6 +334,7 @@ pub fn handle_nix_with_retry(
code: output.status.code().unwrap_or(1),
})
}
pub struct DefaultNixErrorClassifier;
impl NixErrorClassifier for DefaultNixErrorClassifier {
@ -310,3 +348,194 @@ impl NixErrorClassifier for DefaultNixErrorClassifier {
&& stderr.contains("refusing"))
}
}
#[cfg(test)]
mod tests {
use std::io::Write;
use tempfile::NamedTempFile;
use super::*;
#[test]
fn test_streaming_hash_replacement() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path();
// Write test content with multiple hash patterns
let test_content = r#"stdenv.mkDerivation {
name = "test-package";
src = fetchurl {
url = "https://example.com.tar.gz";
hash = "sha256-oldhash123";
sha256 = "sha256-oldhash456";
outputHash = "sha256-oldhash789";
};
}"#;
let mut file = std::fs::File::create(file_path).unwrap();
file.write_all(test_content.as_bytes()).unwrap();
file.flush().unwrap();
let fixer = DefaultNixFileFixer;
let result = fixer
.fix_hash_in_file(file_path, "sha256-newhash999")
.unwrap();
assert!(result, "Hash replacement should return true");
// Verify the content was updated
let updated_content = std::fs::read_to_string(file_path).unwrap();
assert!(updated_content.contains("sha256-newhash999"));
assert!(!updated_content.contains("sha256-oldhash123"));
assert!(!updated_content.contains("sha256-oldhash456"));
assert!(!updated_content.contains("sha256-oldhash789"));
}
#[test]
fn test_streaming_no_replacement_needed() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
let test_content = r#"stdenv.mkDerivation {
name = "test-package";
src = fetchurl {
url = "https://example.com.tar.gz";
};
}"#;
{
let mut file = std::fs::File::create(&file_path).unwrap();
file.write_all(test_content.as_bytes()).unwrap();
file.flush().unwrap();
} // File is closed here
// Test hash replacement
let fixer = DefaultNixFileFixer;
let result = fixer
.fix_hash_in_file(&file_path, "sha256-newhash999")
.unwrap();
assert!(
!result,
"Hash replacement should return false when no patterns found"
);
// Verify the content was unchanged, ignoring trailing newline differences
let updated_content = std::fs::read_to_string(&file_path).unwrap();
let normalized_original = test_content.trim_end();
let normalized_updated = updated_content.trim_end();
assert_eq!(normalized_updated, normalized_original);
}
// FIXME: this is a little stupid, but it works
#[test]
fn test_streaming_large_file_handling() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path();
let mut file = std::fs::File::create(file_path).unwrap();
// Write header with hash
file.write_all(b"stdenv.mkDerivation {\n name = \"large-package\";\n src = fetchurl {\n url = \"https://example.com/large.tar.gz\";\n hash = \"sha256-oldhash\";\n };\n").unwrap();
for i in 0..10000 {
writeln!(file, " # Large comment line {} to simulate file size", i)
.unwrap();
}
file.flush().unwrap();
// Test that streaming can handle large files without memory issues
let fixer = DefaultNixFileFixer;
let result = fixer
.fix_hash_in_file(file_path, "sha256-newhash999")
.unwrap();
assert!(result, "Hash replacement should work for large files");
// Verify the hash was replaced
let updated_content = std::fs::read_to_string(file_path).unwrap();
assert!(updated_content.contains("sha256-newhash999"));
assert!(!updated_content.contains("sha256-oldhash"));
}
#[test]
fn test_streaming_file_permissions_preserved() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path();
// Write test content
let test_content = r#"stdenv.mkDerivation {
name = "test";
src = fetchurl {
url = "https://example.com";
hash = "sha256-oldhash";
};
}"#;
let mut file = std::fs::File::create(file_path).unwrap();
file.write_all(test_content.as_bytes()).unwrap();
file.flush().unwrap();
// Get original permissions
let original_metadata = std::fs::metadata(file_path).unwrap();
let _original_permissions = original_metadata.permissions();
// Test hash replacement
let fixer = DefaultNixFileFixer;
let result = fixer.fix_hash_in_file(file_path, "sha256-newhash").unwrap();
assert!(result, "Hash replacement should succeed");
// Verify file still exists and has reasonable permissions
let new_metadata = std::fs::metadata(file_path).unwrap();
assert!(
new_metadata.is_file(),
"File should still exist after replacement"
);
}
#[test]
fn test_input_validation_blocks_dangerous_patterns() {
let dangerous_args = vec![
"package; rm -rf /".to_string(),
"package && echo hacked".to_string(),
"package || echo hacked".to_string(),
"package | cat /etc/passwd".to_string(),
"package `whoami`".to_string(),
"package $(echo hacked lol!)".to_string(),
"package ${HOME}/file".to_string(),
];
for arg in dangerous_args {
let result = validate_nix_args(std::slice::from_ref(&arg));
assert!(result.is_err(), "Should reject dangerous argument: {}", arg);
match result.unwrap_err() {
EhError::InvalidInput { input, reason } => {
assert_eq!(input, arg);
assert!(reason.contains("dangerous pattern"));
},
_ => panic!("Expected InvalidInput error"),
}
}
}
#[test]
fn test_input_validation_allows_safe_args() {
let safe_args = vec![
"nixpkgs#hello".to_string(),
"--impure".to_string(),
"--print-build-logs".to_string(),
"/path/to/flake".to_string(),
".#default".to_string(),
];
let result = validate_nix_args(&safe_args);
assert!(
result.is_ok(),
"Should allow safe arguments: {:?}",
safe_args
);
}
}