diff --git a/Cargo.toml b/Cargo.toml index 395f440..a52b16f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,12 @@ [workspace] members = ["crates/*", "packages/*", "xtask"] -exclude = ["crates/pinakes-core/tests/fixtures/test-plugin"] +exclude = [ + "crates/pinakes-core/tests/fixtures/test-plugin", + "examples/plugins/auto-tagger", + "examples/plugins/text-enrichment", + "examples/plugins/subtitle-detector", + "examples/plugins/cbz-comics", +] resolver = "3" [workspace.package] @@ -15,6 +21,12 @@ rust-version = "1.95.0" # follows nightly Rust # while building any package. pinakes-core = { path = "./crates/pinakes-core" } pinakes-plugin-api = { path = "./crates/pinakes-plugin-api" } +pinakes-migrations = { path = "./crates/pinakes-migrations" } +pinakes-types = { path = "./crates/pinakes-types" } +pinakes-metadata = { path = "./crates/pinakes-metadata" } +pinakes-plugin = { path = "./crates/pinakes-plugin" } +pinakes-enrichment = { path = "./crates/pinakes-enrichment" } +pinakes-sync = { path = "./crates/pinakes-sync" } # Pinakes itself is a REST API server. UI and TUI are official visual components # that connect to the server. Using the API documentation, the user can write @@ -27,53 +39,54 @@ pinakes-tui = { path = "./packages/pinakes-tui" } # Other dependencies. Declaring them in the virtual manifests lets use reuse the crates # without having to track individual crate version across different types of crates. This # also includes *dev* dependencies. -tokio = { version = "1.50.0", features = ["full"] } +tokio = { version = "1.52.3", features = ["full"] } tokio-util = { version = "0.7.18", features = ["rt"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" -toml = "1.0.7" -clap = { version = "4.6.0", features = ["derive", "env"] } +toml = "1.1.2" +clap = { version = "4.6.1", features = ["derive", "env"] } chrono = { version = "0.4.44", features = ["serde"] } -uuid = { version = "1.22.0", features = ["v7", "serde"] } +uuid = { version = "1.23.1", features = ["v7", "serde"] } thiserror = "2.0.18" anyhow = "1.0.102" tracing = "0.1.44" tracing-subscriber = { version = "0.3.23", features = ["env-filter", "json"] } -blake3 = "1.8.3" -rustc-hash = "2.1.1" +blake3 = "1.8.5" +rustc-hash = "2.1.2" ed25519-dalek = { version = "2.2.0", features = ["std"] } -lofty = "0.23.3" +lofty = "0.24.0" lopdf = "0.40.0" epub = "2.1.5" -matroska = "0.30.0" +matroska = "0.30.1" gray_matter = "0.3.2" kamadak-exif = "0.6.1" -rusqlite = { version = "0.37.0", features = ["bundled", "column_decltype"] } -tokio-postgres = { version = "0.7.16", features = [ +rusqlite = { version = "0.39.0", features = ["bundled", "column_decltype"] } +tokio-postgres = { version = "0.7.17", features = [ "with-uuid-1", "with-chrono-0_4", "with-serde_json-1", ] } deadpool-postgres = "0.14.1" -postgres-types = { version = "0.2.12", features = ["derive"] } -postgres-native-tls = "0.5.2" +postgres-types = { version = "0.2.13", features = ["derive"] } +postgres-native-tls = "0.5.3" native-tls = "0.2.18" -refinery = { version = "0.9.0", features = ["rusqlite", "tokio-postgres"] } +refinery = { version = "0.9.1", features = ["tokio-postgres"] } +rusqlite_migration = "2.5.0" walkdir = "2.5.0" notify = { version = "8.2.0", features = ["macos_fsevent"] } -winnow = "1.0.0" -axum = { version = "0.8.8", features = ["macros", "multipart"] } +winnow = "1.0.3" +axum = { version = "0.8.9", features = ["macros", "multipart"] } axum-server = { version = "0.8.0" } tower = "0.5.3" -tower-http = { version = "0.6.8", features = ["cors", "trace", "set-header"] } +tower-http = { version = "0.6.11", features = ["cors", "trace", "set-header"] } governor = "0.10.4" tower_governor = "0.8.0" -reqwest = { version = "0.13.2", features = ["json", "query", "blocking"] } +reqwest = { version = "0.13.3", features = ["json", "query", "blocking"] } url = "2.5" ratatui = "0.30.0" crossterm = "0.29.0" -dioxus = { version = "0.7.3", features = ["desktop", "router"] } -dioxus-core = { version = "0.7.3" } +dioxus = { version = "0.7.9", features = ["desktop", "router"] } +dioxus-core = { version = "0.7.9" } async-trait = "0.1.89" futures = "0.3.32" image = { version = "0.25.10", default-features = false, features = [ @@ -84,24 +97,24 @@ image = { version = "0.25.10", default-features = false, features = [ "tiff", "bmp", ] } -pulldown-cmark = "0.13.3" +pulldown-cmark = "0.13.4" ammonia = "4.1.2" argon2 = { version = "0.5.3", features = ["std"] } mime_guess = "2.0.5" regex = "1.12.3" dioxus-free-icons = { version = "0.10.0", features = ["font-awesome-solid"] } rfd = "0.17.2" -gloo-timers = { version = "0.3.0", features = ["futures"] } -rand = "0.10.0" +gloo-timers = { version = "0.4.0", features = ["futures"] } +rand = "0.10.1" moka = { version = "0.12.15", features = ["future"] } urlencoding = "2.1.3" image_hasher = "3.1.1" percent-encoding = "2.3.2" http = "1.4.0" -wasmtime = { version = "43.0.0", features = ["component-model"] } -wit-bindgen = "0.54.0" +wasmtime = { version = "44.0.1", features = ["component-model"] } +wit-bindgen = "0.57.1" tempfile = "3.27.0" -utoipa = { version = "5.4.0", features = ["axum_extras", "uuid", "chrono"] } +utoipa = { version = "5.5.0", features = ["axum_extras", "uuid", "chrono"] } utoipa-axum = { version = "0.2.0" } utoipa-swagger-ui = { version = "9.0.2", features = ["axum"] } http-body-util = "0.1.3" diff --git a/crates/pinakes-enrichment/Cargo.toml b/crates/pinakes-enrichment/Cargo.toml new file mode 100644 index 0000000..bffcdc1 --- /dev/null +++ b/crates/pinakes-enrichment/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pinakes-enrichment" +edition.workspace = true +version.workspace = true +license.workspace = true + +[dependencies] +pinakes-types = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +url = { workspace = true } +chrono = { workspace = true } +uuid = { workspace = true } +async-trait = { workspace = true } +regex = { workspace = true } +urlencoding = { workspace = true } + +[lints] +workspace = true diff --git a/crates/pinakes-enrichment/src/books.rs b/crates/pinakes-enrichment/src/books.rs new file mode 100644 index 0000000..63b09c6 --- /dev/null +++ b/crates/pinakes-enrichment/src/books.rs @@ -0,0 +1,298 @@ +use std::sync::LazyLock; + +use chrono::Utc; +use pinakes_types::{ + error::{PinakesError, Result}, + model::MediaItem, +}; +use uuid::Uuid; + +use super::{ + EnrichmentSourceType, + ExternalMetadata, + MetadataEnricher, + googlebooks::GoogleBooksClient, + openlibrary::OpenLibraryClient, +}; + +// --- ISBN helper (duplicated from pinakes-core::books to avoid circular dep) +// --- +static ISBN_PATTERNS: LazyLock> = LazyLock::new(|| { + [ + r"ISBN(?:-13)?(?:\s+is|:)?\s*(\d{3}-\d{1,5}-\d{1,7}-\d{1,7}-\d)", + r"ISBN(?:-10)?(?:\s+is|:)?\s*(\d{1,5}-\d{1,7}-\d{1,7}-[\dXx])", + r"ISBN(?:-13)?\s+(\d{13})", + r"ISBN(?:-10)?\s+(\d{9}[\dXx])", + r"\b(\d{3}-\d{1,5}-\d{1,7}-\d{1,7}-\d)\b", + r"\b(\d{1,5}-\d{1,7}-\d{1,7}-[\dXx])\b", + ] + .iter() + .filter_map(|p| regex::Regex::new(p).ok()) + .collect() +}); + +fn extract_isbn_from_text(text: &str) -> Option { + for pattern in ISBN_PATTERNS.iter() { + if let Some(captures) = pattern.captures(text) + && let Some(isbn) = captures.get(1) + { + return Some(isbn.as_str().to_string()); + } + } + None +} + +/// Book enricher that tries `OpenLibrary` first, then falls back to Google +/// Books +pub struct BookEnricher { + openlibrary: OpenLibraryClient, + googlebooks: GoogleBooksClient, +} + +impl BookEnricher { + #[must_use] + pub fn new(google_api_key: Option) -> Self { + Self { + openlibrary: OpenLibraryClient::new(), + googlebooks: GoogleBooksClient::new(google_api_key), + } + } + + /// Try to enrich from `OpenLibrary` first + /// + /// # Errors + /// + /// Returns an error if the metadata cannot be serialized. + pub async fn try_openlibrary( + &self, + isbn: &str, + ) -> Result> { + match self.openlibrary.fetch_by_isbn(isbn).await { + Ok(book) => { + let metadata_json = serde_json::to_string(&book).map_err(|e| { + PinakesError::External(format!("Failed to serialize metadata: {e}")) + })?; + + Ok(Some(ExternalMetadata { + id: Uuid::new_v4(), + media_id: pinakes_types::model::MediaId(Uuid::nil()), /* Will be set by caller */ + source: EnrichmentSourceType::OpenLibrary, + external_id: None, + metadata_json, + confidence: calculate_openlibrary_confidence(&book), + last_updated: Utc::now(), + })) + }, + Err(_) => Ok(None), + } + } + + /// Try to enrich from Google Books + /// + /// # Errors + /// + /// Returns an error if the metadata cannot be serialized. + pub async fn try_googlebooks( + &self, + isbn: &str, + ) -> Result> { + match self.googlebooks.fetch_by_isbn(isbn).await { + Ok(books) if !books.is_empty() => { + let book = &books[0]; + let metadata_json = serde_json::to_string(book).map_err(|e| { + PinakesError::External(format!("Failed to serialize metadata: {e}")) + })?; + + Ok(Some(ExternalMetadata { + id: Uuid::new_v4(), + media_id: pinakes_types::model::MediaId(Uuid::nil()), /* Will be set by caller */ + source: EnrichmentSourceType::GoogleBooks, + external_id: Some(book.id.clone()), + metadata_json, + confidence: calculate_googlebooks_confidence(&book.volume_info), + last_updated: Utc::now(), + })) + }, + _ => Ok(None), + } + } + + /// Try to enrich by searching with title and author + /// + /// # Errors + /// + /// Returns an error if the metadata cannot be serialized. + pub async fn enrich_by_search( + &self, + title: &str, + author: Option<&str>, + ) -> Result> { + // Try OpenLibrary search first + if let Ok(results) = self.openlibrary.search(title, author).await + && let Some(result) = results.first() + { + let metadata_json = serde_json::to_string(result).map_err(|e| { + PinakesError::External(format!("Failed to serialize metadata: {e}")) + })?; + + return Ok(Some(ExternalMetadata { + id: Uuid::new_v4(), + media_id: pinakes_types::model::MediaId(Uuid::nil()), + source: EnrichmentSourceType::OpenLibrary, + external_id: result.key.clone(), + metadata_json, + confidence: 0.6, // Lower confidence for search results + last_updated: Utc::now(), + })); + } + + // Fall back to Google Books + if let Ok(results) = self.googlebooks.search(title, author).await + && let Some(book) = results.first() + { + let metadata_json = serde_json::to_string(book).map_err(|e| { + PinakesError::External(format!("Failed to serialize metadata: {e}")) + })?; + + return Ok(Some(ExternalMetadata { + id: Uuid::new_v4(), + media_id: pinakes_types::model::MediaId(Uuid::nil()), + source: EnrichmentSourceType::GoogleBooks, + external_id: Some(book.id.clone()), + metadata_json, + confidence: 0.6, + last_updated: Utc::now(), + })); + } + + Ok(None) + } +} + +#[async_trait::async_trait] +impl MetadataEnricher for BookEnricher { + fn source(&self) -> EnrichmentSourceType { + // Returns the preferred source + EnrichmentSourceType::OpenLibrary + } + + async fn enrich(&self, item: &MediaItem) -> Result> { + // Try ISBN-based enrichment first by checking title/description for ISBN + // patterns + if let Some(ref title) = item.title { + if let Some(isbn) = extract_isbn_from_text(title) { + if let Some(mut metadata) = self.try_openlibrary(&isbn).await? { + metadata.media_id = item.id; + return Ok(Some(metadata)); + } + if let Some(mut metadata) = self.try_googlebooks(&isbn).await? { + metadata.media_id = item.id; + return Ok(Some(metadata)); + } + } + + // Fall back to title/author search + let author = item.artist.as_deref(); + return self.enrich_by_search(title, author).await; + } + + // No title available + Ok(None) + } +} + +/// Calculate confidence score for `OpenLibrary` metadata +#[must_use] +pub fn calculate_openlibrary_confidence( + book: &super::openlibrary::OpenLibraryBook, +) -> f64 { + let mut score: f64 = 0.5; // Base score + + if book.title.is_some() { + score += 0.1; + } + if !book.authors.is_empty() { + score += 0.1; + } + if !book.publishers.is_empty() { + score += 0.05; + } + if book.publish_date.is_some() { + score += 0.05; + } + if book.description.is_some() { + score += 0.1; + } + if !book.covers.is_empty() { + score += 0.1; + } + + score.min(1.0) +} + +/// Calculate confidence score for Google Books metadata +#[must_use] +pub fn calculate_googlebooks_confidence( + info: &super::googlebooks::VolumeInfo, +) -> f64 { + let mut score: f64 = 0.5; // Base score + + if info.title.is_some() { + score += 0.1; + } + if !info.authors.is_empty() { + score += 0.1; + } + if info.publisher.is_some() { + score += 0.05; + } + if info.published_date.is_some() { + score += 0.05; + } + if info.description.is_some() { + score += 0.1; + } + if info.image_links.is_some() { + score += 0.1; + } + + score.min(1.0) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_openlibrary_confidence_calculation() { + let book = super::super::openlibrary::OpenLibraryBook { + title: Some("Test Book".to_string()), + subtitle: None, + authors: vec![], + publishers: vec![], + publish_date: None, + number_of_pages: None, + subjects: vec![], + covers: vec![], + isbn_10: vec![], + isbn_13: vec![], + series: vec![], + description: None, + languages: vec![], + }; + + let confidence = calculate_openlibrary_confidence(&book); + assert_eq!(confidence, 0.6); // 0.5 base + 0.1 for title + } + + #[test] + fn test_googlebooks_confidence_calculation() { + let info = super::super::googlebooks::VolumeInfo { + title: Some("Test Book".to_string()), + ..Default::default() + }; + + let confidence = calculate_googlebooks_confidence(&info); + assert_eq!(confidence, 0.6); // 0.5 base + 0.1 for title + } +} diff --git a/crates/pinakes-enrichment/src/googlebooks.rs b/crates/pinakes-enrichment/src/googlebooks.rs new file mode 100644 index 0000000..1b4abe2 --- /dev/null +++ b/crates/pinakes-enrichment/src/googlebooks.rs @@ -0,0 +1,294 @@ +use std::fmt::Write as _; + +use pinakes_types::error::{PinakesError, Result}; +use serde::{Deserialize, Serialize}; + +/// Google Books API client for book metadata enrichment +pub struct GoogleBooksClient { + client: reqwest::Client, + api_key: Option, +} + +impl GoogleBooksClient { + /// Create a new `GoogleBooksClient`. + #[must_use] + pub fn new(api_key: Option) -> Self { + let client = reqwest::Client::builder() + .user_agent("Pinakes/1.0") + .timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { client, api_key } + } + + /// Fetch book metadata by ISBN + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be + /// parsed. + pub async fn fetch_by_isbn(&self, isbn: &str) -> Result> { + let mut url = + format!("https://www.googleapis.com/books/v1/volumes?q=isbn:{isbn}"); + + if let Some(ref key) = self.api_key { + let _ = write!(url, "&key={key}"); + } + + let response = self.client.get(&url).send().await.map_err(|e| { + PinakesError::External(format!("Google Books request failed: {e}")) + })?; + + if !response.status().is_success() { + return Err(PinakesError::External(format!( + "Google Books returned status: {}", + response.status() + ))); + } + + let volumes: GoogleBooksResponse = response.json().await.map_err(|e| { + PinakesError::External(format!( + "Failed to parse Google Books response: {e}" + )) + })?; + + Ok(volumes.items) + } + + /// Search for books by title and author + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be + /// parsed. + pub async fn search( + &self, + title: &str, + author: Option<&str>, + ) -> Result> { + let mut query = format!("intitle:{}", urlencoding::encode(title)); + + if let Some(author) = author { + let _ = write!(query, "+inauthor:{}", urlencoding::encode(author)); + } + + let mut url = format!( + "https://www.googleapis.com/books/v1/volumes?q={query}&maxResults=5" + ); + + if let Some(ref key) = self.api_key { + let _ = write!(url, "&key={key}"); + } + + let response = self.client.get(&url).send().await.map_err(|e| { + PinakesError::External(format!("Google Books search failed: {e}")) + })?; + + if !response.status().is_success() { + return Err(PinakesError::External(format!( + "Google Books search returned status: {}", + response.status() + ))); + } + + let volumes: GoogleBooksResponse = response.json().await.map_err(|e| { + PinakesError::External(format!("Failed to parse search results: {e}")) + })?; + + Ok(volumes.items) + } + + /// Download cover image from Google Books + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be + /// read. + pub async fn fetch_cover(&self, image_link: &str) -> Result> { + // Replace thumbnail link with higher resolution if possible + let high_res_link = image_link + .replace("&zoom=1", "&zoom=2") + .replace("&edge=curl", ""); + + let response = + self.client.get(&high_res_link).send().await.map_err(|e| { + PinakesError::External(format!("Cover download failed: {e}")) + })?; + + if !response.status().is_success() { + return Err(PinakesError::External(format!( + "Cover download returned status: {}", + response.status() + ))); + } + + response.bytes().await.map(|b| b.to_vec()).map_err(|e| { + PinakesError::External(format!("Failed to read cover data: {e}")) + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GoogleBooksResponse { + #[serde(default)] + pub items: Vec, + + #[serde(default)] + pub total_items: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GoogleBook { + pub id: String, + + #[serde(default)] + pub volume_info: VolumeInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct VolumeInfo { + #[serde(default)] + pub title: Option, + + #[serde(default)] + pub subtitle: Option, + + #[serde(default)] + pub authors: Vec, + + #[serde(default)] + pub publisher: Option, + + #[serde(default)] + pub published_date: Option, + + #[serde(default)] + pub description: Option, + + #[serde(default)] + pub page_count: Option, + + #[serde(default)] + pub categories: Vec, + + #[serde(default)] + pub average_rating: Option, + + #[serde(default)] + pub ratings_count: Option, + + #[serde(default)] + pub image_links: Option, + + #[serde(default)] + pub language: Option, + + #[serde(default)] + pub industry_identifiers: Vec, + + #[serde(default)] + pub main_category: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageLinks { + #[serde(default)] + pub small_thumbnail: Option, + + #[serde(default)] + pub thumbnail: Option, + + #[serde(default)] + pub small: Option, + + #[serde(default)] + pub medium: Option, + + #[serde(default)] + pub large: Option, + + #[serde(default)] + pub extra_large: Option, +} + +impl ImageLinks { + /// Get the best available image link (highest resolution) + #[must_use] + pub fn best_link(&self) -> Option<&String> { + self + .extra_large + .as_ref() + .or(self.large.as_ref()) + .or(self.medium.as_ref()) + .or(self.small.as_ref()) + .or(self.thumbnail.as_ref()) + .or(self.small_thumbnail.as_ref()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IndustryIdentifier { + #[serde(rename = "type")] + pub identifier_type: String, + + pub identifier: String, +} + +impl IndustryIdentifier { + /// Check if this is an ISBN-13 + #[must_use] + pub fn is_isbn13(&self) -> bool { + self.identifier_type == "ISBN_13" + } + + /// Check if this is an ISBN-10 + #[must_use] + pub fn is_isbn10(&self) -> bool { + self.identifier_type == "ISBN_10" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_googlebooks_client_creation() { + let client = GoogleBooksClient::new(None); + assert!(client.api_key.is_none()); + + let client_with_key = GoogleBooksClient::new(Some("test-key".to_string())); + assert_eq!(client_with_key.api_key, Some("test-key".to_string())); + } + + #[test] + fn test_image_links_best_link() { + let links = ImageLinks { + small_thumbnail: Some("small.jpg".to_string()), + thumbnail: Some("thumb.jpg".to_string()), + small: None, + medium: Some("medium.jpg".to_string()), + large: Some("large.jpg".to_string()), + extra_large: None, + }; + + assert_eq!(links.best_link(), Some(&"large.jpg".to_string())); + } + + #[test] + fn test_industry_identifier_type_checks() { + let isbn13 = IndustryIdentifier { + identifier_type: "ISBN_13".to_string(), + identifier: "9780123456789".to_string(), + }; + assert!(isbn13.is_isbn13()); + assert!(!isbn13.is_isbn10()); + + let isbn10 = IndustryIdentifier { + identifier_type: "ISBN_10".to_string(), + identifier: "0123456789".to_string(), + }; + assert!(!isbn10.is_isbn13()); + assert!(isbn10.is_isbn10()); + } +} diff --git a/crates/pinakes-enrichment/src/lastfm.rs b/crates/pinakes-enrichment/src/lastfm.rs new file mode 100644 index 0000000..cdcb4bd --- /dev/null +++ b/crates/pinakes-enrichment/src/lastfm.rs @@ -0,0 +1,116 @@ +//! Last.fm metadata enrichment for audio files. + +use std::time::Duration; + +use chrono::Utc; +use pinakes_types::{ + error::{PinakesError, Result}, + model::MediaItem, +}; +use uuid::Uuid; + +use super::{EnrichmentSourceType, ExternalMetadata, MetadataEnricher}; + +pub struct LastFmEnricher { + client: reqwest::Client, + api_key: String, + base_url: String, +} + +impl LastFmEnricher { + /// Create a new `LastFmEnricher`. + #[must_use] + pub fn new(api_key: String) -> Self { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(10)) + .connect_timeout(Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + api_key, + base_url: "https://ws.audioscrobbler.com/2.0".to_string(), + } + } +} + +#[async_trait::async_trait] +impl MetadataEnricher for LastFmEnricher { + fn source(&self) -> EnrichmentSourceType { + EnrichmentSourceType::LastFm + } + + async fn enrich(&self, item: &MediaItem) -> Result> { + let artist = match &item.artist { + Some(a) if !a.is_empty() => a, + _ => return Ok(None), + }; + + let title = match &item.title { + Some(t) if !t.is_empty() => t, + _ => return Ok(None), + }; + + let url = format!("{}/", self.base_url); + + let resp = self + .client + .get(&url) + .query(&[ + ("method", "track.getInfo"), + ("api_key", self.api_key.as_str()), + ("artist", artist.as_str()), + ("track", title.as_str()), + ("format", "json"), + ]) + .send() + .await + .map_err(|e| { + PinakesError::MetadataExtraction(format!("Last.fm request failed: {e}")) + })?; + + if !resp.status().is_success() { + return Ok(None); + } + + let body = resp.text().await.map_err(|e| { + PinakesError::MetadataExtraction(format!( + "Last.fm response read failed: {e}" + )) + })?; + + let json: serde_json::Value = serde_json::from_str(&body).map_err(|e| { + PinakesError::MetadataExtraction(format!( + "Last.fm JSON parse failed: {e}" + )) + })?; + + // Check for error response + if json.get("error").is_some() { + return Ok(None); + } + + let Some(track) = json.get("track") else { + return Ok(None); + }; + + let mbid = track.get("mbid").and_then(|m| m.as_str()).map(String::from); + let listeners = track + .get("listeners") + .and_then(|l| l.as_str()) + .and_then(|l| l.parse::().ok()) + .unwrap_or(0.0); + // Normalize listeners to confidence (arbitrary scale) + let confidence = (listeners / 1_000_000.0).min(1.0); + + Ok(Some(ExternalMetadata { + id: Uuid::now_v7(), + media_id: item.id, + source: EnrichmentSourceType::LastFm, + external_id: mbid, + metadata_json: body, + confidence, + last_updated: Utc::now(), + })) + } +} diff --git a/crates/pinakes-enrichment/src/lib.rs b/crates/pinakes-enrichment/src/lib.rs new file mode 100644 index 0000000..ed6c3dc --- /dev/null +++ b/crates/pinakes-enrichment/src/lib.rs @@ -0,0 +1,76 @@ +pub mod books; +pub mod googlebooks; +pub mod lastfm; +pub mod musicbrainz; +pub mod openlibrary; +pub mod tmdb; + +use chrono::{DateTime, Utc}; +use pinakes_types::{ + error::Result, + model::{MediaId, MediaItem}, +}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Externally-sourced metadata for a media item. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExternalMetadata { + pub id: Uuid, + pub media_id: MediaId, + pub source: EnrichmentSourceType, + pub external_id: Option, + pub metadata_json: String, + pub confidence: f64, + pub last_updated: DateTime, +} + +/// Supported enrichment data sources. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum EnrichmentSourceType { + #[serde(rename = "musicbrainz")] + MusicBrainz, + #[serde(rename = "tmdb")] + Tmdb, + #[serde(rename = "lastfm")] + LastFm, + #[serde(rename = "openlibrary")] + OpenLibrary, + #[serde(rename = "googlebooks")] + GoogleBooks, +} + +impl std::fmt::Display for EnrichmentSourceType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Self::MusicBrainz => "musicbrainz", + Self::Tmdb => "tmdb", + Self::LastFm => "lastfm", + Self::OpenLibrary => "openlibrary", + Self::GoogleBooks => "googlebooks", + }; + write!(f, "{s}") + } +} + +impl std::str::FromStr for EnrichmentSourceType { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s { + "musicbrainz" => Ok(Self::MusicBrainz), + "tmdb" => Ok(Self::Tmdb), + "lastfm" => Ok(Self::LastFm), + "openlibrary" => Ok(Self::OpenLibrary), + "googlebooks" => Ok(Self::GoogleBooks), + _ => Err(format!("unknown enrichment source: {s}")), + } + } +} + +/// Trait for metadata enrichment providers. +#[async_trait::async_trait] +pub trait MetadataEnricher: Send + Sync { + fn source(&self) -> EnrichmentSourceType; + async fn enrich(&self, item: &MediaItem) -> Result>; +} diff --git a/crates/pinakes-enrichment/src/mod.rs b/crates/pinakes-enrichment/src/mod.rs new file mode 100644 index 0000000..527601d --- /dev/null +++ b/crates/pinakes-enrichment/src/mod.rs @@ -0,0 +1,79 @@ +//! Metadata enrichment from external sources. + +pub mod books; +pub mod googlebooks; +pub mod lastfm; +pub mod musicbrainz; +pub mod openlibrary; +pub mod tmdb; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use pinakes_types::{ + error::Result, + model::{MediaId, MediaItem}, +}; + +/// Externally-sourced metadata for a media item. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExternalMetadata { + pub id: Uuid, + pub media_id: MediaId, + pub source: EnrichmentSourceType, + pub external_id: Option, + pub metadata_json: String, + pub confidence: f64, + pub last_updated: DateTime, +} + +/// Supported enrichment data sources. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum EnrichmentSourceType { + #[serde(rename = "musicbrainz")] + MusicBrainz, + #[serde(rename = "tmdb")] + Tmdb, + #[serde(rename = "lastfm")] + LastFm, + #[serde(rename = "openlibrary")] + OpenLibrary, + #[serde(rename = "googlebooks")] + GoogleBooks, +} + +impl std::fmt::Display for EnrichmentSourceType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Self::MusicBrainz => "musicbrainz", + Self::Tmdb => "tmdb", + Self::LastFm => "lastfm", + Self::OpenLibrary => "openlibrary", + Self::GoogleBooks => "googlebooks", + }; + write!(f, "{s}") + } +} + +impl std::str::FromStr for EnrichmentSourceType { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s { + "musicbrainz" => Ok(Self::MusicBrainz), + "tmdb" => Ok(Self::Tmdb), + "lastfm" => Ok(Self::LastFm), + "openlibrary" => Ok(Self::OpenLibrary), + "googlebooks" => Ok(Self::GoogleBooks), + _ => Err(format!("unknown enrichment source: {s}")), + } + } +} + +/// Trait for metadata enrichment providers. +#[async_trait::async_trait] +pub trait MetadataEnricher: Send + Sync { + fn source(&self) -> EnrichmentSourceType; + async fn enrich(&self, item: &MediaItem) -> Result>; +} diff --git a/crates/pinakes-enrichment/src/musicbrainz.rs b/crates/pinakes-enrichment/src/musicbrainz.rs new file mode 100644 index 0000000..344e6f0 --- /dev/null +++ b/crates/pinakes-enrichment/src/musicbrainz.rs @@ -0,0 +1,148 @@ +//! `MusicBrainz` metadata enrichment for audio files. + +use std::{fmt::Write as _, time::Duration}; + +use chrono::Utc; +use pinakes_types::{ + error::{PinakesError, Result}, + model::MediaItem, +}; +use uuid::Uuid; + +use super::{EnrichmentSourceType, ExternalMetadata, MetadataEnricher}; + +pub struct MusicBrainzEnricher { + client: reqwest::Client, + base_url: String, +} + +impl Default for MusicBrainzEnricher { + fn default() -> Self { + Self::new() + } +} + +impl MusicBrainzEnricher { + /// Create a new `MusicBrainzEnricher`. + #[must_use] + pub fn new() -> Self { + let client = reqwest::Client::builder() + .user_agent("Pinakes/0.1 (https://github.com/notashelf/pinakes)") + .timeout(Duration::from_secs(10)) + .connect_timeout(Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + base_url: "https://musicbrainz.org/ws/2".to_string(), + } + } +} + +fn escape_lucene_query(s: &str) -> String { + let special_chars = [ + '+', '-', '&', '|', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', + '?', ':', '\\', '/', + ]; + let mut escaped = String::with_capacity(s.len() * 2); + for c in s.chars() { + if special_chars.contains(&c) { + escaped.push('\\'); + } + escaped.push(c); + } + escaped +} + +#[async_trait::async_trait] +impl MetadataEnricher for MusicBrainzEnricher { + fn source(&self) -> EnrichmentSourceType { + EnrichmentSourceType::MusicBrainz + } + + async fn enrich(&self, item: &MediaItem) -> Result> { + let title = match &item.title { + Some(t) if !t.is_empty() => t, + _ => return Ok(None), + }; + + let mut query = format!("recording:{}", escape_lucene_query(title)); + if let Some(ref artist) = item.artist { + let _ = write!(query, " AND artist:{}", escape_lucene_query(artist)); + } + + let url = format!("{}/recording/", self.base_url); + + let resp = self + .client + .get(&url) + .query(&[ + ("query", &query), + ("fmt", &"json".to_string()), + ("limit", &"1".to_string()), + ]) + .send() + .await + .map_err(|e| { + PinakesError::MetadataExtraction(format!( + "MusicBrainz request failed: {e}" + )) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + if status == reqwest::StatusCode::TOO_MANY_REQUESTS + || status == reqwest::StatusCode::SERVICE_UNAVAILABLE + { + return Err(PinakesError::MetadataExtraction(format!( + "MusicBrainz rate limited (HTTP {})", + status.as_u16() + ))); + } + return Ok(None); + } + + let body = resp.text().await.map_err(|e| { + PinakesError::MetadataExtraction(format!( + "MusicBrainz response read failed: {e}" + )) + })?; + + // Parse to check if we got results + let json: serde_json::Value = serde_json::from_str(&body).map_err(|e| { + PinakesError::MetadataExtraction(format!( + "MusicBrainz JSON parse failed: {e}" + )) + })?; + + let recordings = json.get("recordings").and_then(|r| r.as_array()); + if recordings.is_none_or(std::vec::Vec::is_empty) { + return Ok(None); + } + + let Some(recordings) = recordings else { + return Ok(None); + }; + let recording = &recordings[0]; + let external_id = recording + .get("id") + .and_then(|id| id.as_str()) + .map(String::from); + let score = (recording + .get("score") + .and_then(serde_json::Value::as_f64) + .unwrap_or(0.0) + / 100.0) + .min(1.0); + + Ok(Some(ExternalMetadata { + id: Uuid::now_v7(), + media_id: item.id, + source: EnrichmentSourceType::MusicBrainz, + external_id, + metadata_json: body, + confidence: score, + last_updated: Utc::now(), + })) + } +} diff --git a/crates/pinakes-enrichment/src/openlibrary.rs b/crates/pinakes-enrichment/src/openlibrary.rs new file mode 100644 index 0000000..0dd4db7 --- /dev/null +++ b/crates/pinakes-enrichment/src/openlibrary.rs @@ -0,0 +1,307 @@ +use std::fmt::Write as _; + +use pinakes_types::error::{PinakesError, Result}; +use serde::{Deserialize, Serialize}; + +/// `OpenLibrary` API client for book metadata enrichment +pub struct OpenLibraryClient { + client: reqwest::Client, + base_url: String, +} + +impl Default for OpenLibraryClient { + fn default() -> Self { + Self::new() + } +} + +impl OpenLibraryClient { + /// Create a new `OpenLibraryClient`. + #[must_use] + pub fn new() -> Self { + let client = reqwest::Client::builder() + .user_agent("Pinakes/1.0") + .timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + base_url: "https://openlibrary.org".to_string(), + } + } + + /// Fetch book metadata by ISBN + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be + /// parsed. + pub async fn fetch_by_isbn(&self, isbn: &str) -> Result { + let url = format!("{}/isbn/{}.json", self.base_url, isbn); + + let response = self.client.get(&url).send().await.map_err(|e| { + PinakesError::External(format!("OpenLibrary request failed: {e}")) + })?; + + if !response.status().is_success() { + return Err(PinakesError::External(format!( + "OpenLibrary returned status: {}", + response.status() + ))); + } + + response.json::().await.map_err(|e| { + PinakesError::External(format!( + "Failed to parse OpenLibrary response: {e}" + )) + }) + } + + /// Search for books by title and author + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be + /// parsed. + pub async fn search( + &self, + title: &str, + author: Option<&str>, + ) -> Result> { + let mut url = format!( + "{}/search.json?title={}", + self.base_url, + urlencoding::encode(title) + ); + + if let Some(author) = author { + let _ = write!(url, "&author={}", urlencoding::encode(author)); + } + + url.push_str("&limit=5"); + + let response = self.client.get(&url).send().await.map_err(|e| { + PinakesError::External(format!("OpenLibrary search failed: {e}")) + })?; + + if !response.status().is_success() { + return Err(PinakesError::External(format!( + "OpenLibrary search returned status: {}", + response.status() + ))); + } + + let search_response: OpenLibrarySearchResponse = + response.json().await.map_err(|e| { + PinakesError::External(format!("Failed to parse search results: {e}")) + })?; + + Ok(search_response.docs) + } + + /// Fetch cover image by cover ID + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be + /// read. + pub async fn fetch_cover( + &self, + cover_id: i64, + size: CoverSize, + ) -> Result> { + let size_str = match size { + CoverSize::Small => "S", + CoverSize::Medium => "M", + CoverSize::Large => "L", + }; + + let url = + format!("https://covers.openlibrary.org/b/id/{cover_id}-{size_str}.jpg"); + + let response = self.client.get(&url).send().await.map_err(|e| { + PinakesError::External(format!("Cover download failed: {e}")) + })?; + + if !response.status().is_success() { + return Err(PinakesError::External(format!( + "Cover download returned status: {}", + response.status() + ))); + } + + response.bytes().await.map(|b| b.to_vec()).map_err(|e| { + PinakesError::External(format!("Failed to read cover data: {e}")) + }) + } + + /// Fetch cover by ISBN + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be + /// read. + pub async fn fetch_cover_by_isbn( + &self, + isbn: &str, + size: CoverSize, + ) -> Result> { + let size_str = match size { + CoverSize::Small => "S", + CoverSize::Medium => "M", + CoverSize::Large => "L", + }; + + let url = + format!("https://covers.openlibrary.org/b/isbn/{isbn}-{size_str}.jpg"); + + let response = self.client.get(&url).send().await.map_err(|e| { + PinakesError::External(format!("Cover download failed: {e}")) + })?; + + if !response.status().is_success() { + return Err(PinakesError::External(format!( + "Cover download returned status: {}", + response.status() + ))); + } + + response.bytes().await.map(|b| b.to_vec()).map_err(|e| { + PinakesError::External(format!("Failed to read cover data: {e}")) + }) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum CoverSize { + Small, // 256x256 + Medium, // 600x800 + Large, // Original +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenLibraryBook { + #[serde(default)] + pub title: Option, + + #[serde(default)] + pub subtitle: Option, + + #[serde(default)] + pub authors: Vec, + + #[serde(default)] + pub publishers: Vec, + + #[serde(default)] + pub publish_date: Option, + + #[serde(default)] + pub number_of_pages: Option, + + #[serde(default)] + pub subjects: Vec, + + #[serde(default)] + pub covers: Vec, + + #[serde(default)] + pub isbn_10: Vec, + + #[serde(default)] + pub isbn_13: Vec, + + #[serde(default)] + pub series: Vec, + + #[serde(default)] + pub description: Option, + + #[serde(default)] + pub languages: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthorRef { + pub key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LanguageRef { + pub key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum StringOrObject { + String(String), + Object { value: String }, +} + +impl StringOrObject { + #[must_use] + pub fn as_str(&self) -> &str { + match self { + Self::String(s) => s, + Self::Object { value } => value, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenLibrarySearchResponse { + #[serde(default)] + pub docs: Vec, + + #[serde(default)] + pub num_found: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenLibrarySearchResult { + #[serde(default)] + pub key: Option, + + #[serde(default)] + pub title: Option, + + #[serde(default)] + pub author_name: Vec, + + #[serde(default)] + pub first_publish_year: Option, + + #[serde(default)] + pub publisher: Vec, + + #[serde(default)] + pub isbn: Vec, + + #[serde(default)] + pub cover_i: Option, + + #[serde(default)] + pub subject: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_openlibrary_client_creation() { + let client = OpenLibraryClient::new(); + assert_eq!(client.base_url, "https://openlibrary.org"); + } + + #[test] + fn test_string_or_object_parsing() { + let string_desc: StringOrObject = + serde_json::from_str(r#""Simple description""#).unwrap(); + assert_eq!(string_desc.as_str(), "Simple description"); + + let object_desc: StringOrObject = + serde_json::from_str(r#"{"value": "Object description"}"#).unwrap(); + assert_eq!(object_desc.as_str(), "Object description"); + } +} diff --git a/crates/pinakes-enrichment/src/tmdb.rs b/crates/pinakes-enrichment/src/tmdb.rs new file mode 100644 index 0000000..810db2e --- /dev/null +++ b/crates/pinakes-enrichment/src/tmdb.rs @@ -0,0 +1,125 @@ +//! TMDB (The Movie Database) metadata enrichment for video files. + +use std::time::Duration; + +use chrono::Utc; +use pinakes_types::{ + error::{PinakesError, Result}, + model::MediaItem, +}; +use uuid::Uuid; + +use super::{EnrichmentSourceType, ExternalMetadata, MetadataEnricher}; + +pub struct TmdbEnricher { + client: reqwest::Client, + api_key: String, + base_url: String, +} + +impl TmdbEnricher { + /// Create a new `TMDb` enricher. + /// + /// # Panics + /// + /// Panics if the HTTP client cannot be built (programming error in client + /// configuration). + #[must_use] + pub fn new(api_key: String) -> Self { + Self { + client: reqwest::Client::builder() + .timeout(Duration::from_secs(10)) + .connect_timeout(Duration::from_secs(5)) + .build() + .expect("failed to build HTTP client with configured timeouts"), + api_key, + base_url: "https://api.themoviedb.org/3".to_string(), + } + } +} + +#[async_trait::async_trait] +impl MetadataEnricher for TmdbEnricher { + fn source(&self) -> EnrichmentSourceType { + EnrichmentSourceType::Tmdb + } + + async fn enrich(&self, item: &MediaItem) -> Result> { + let title = match &item.title { + Some(t) if !t.is_empty() => t, + _ => return Ok(None), + }; + + let url = format!("{}/search/movie", self.base_url); + + let resp = self + .client + .get(&url) + .query(&[ + ("api_key", &self.api_key), + ("query", &title.clone()), + ("page", &"1".to_string()), + ]) + .send() + .await + .map_err(|e| { + PinakesError::MetadataExtraction(format!("TMDB request failed: {e}")) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + return Err(PinakesError::MetadataExtraction( + "TMDB API key is invalid (401)".into(), + )); + } + if status == reqwest::StatusCode::TOO_MANY_REQUESTS { + tracing::warn!("TMDB rate limit exceeded (429)"); + return Ok(None); + } + tracing::debug!(status = %status, "TMDB search returned non-success status"); + return Ok(None); + } + + let body = resp.text().await.map_err(|e| { + PinakesError::MetadataExtraction(format!( + "TMDB response read failed: {e}" + )) + })?; + + let json: serde_json::Value = serde_json::from_str(&body).map_err(|e| { + PinakesError::MetadataExtraction(format!("TMDB JSON parse failed: {e}")) + })?; + + let results = json.get("results").and_then(|r| r.as_array()); + if results.is_none_or(std::vec::Vec::is_empty) { + return Ok(None); + } + + let Some(results) = results else { + return Ok(None); + }; + let movie = &results[0]; + let external_id = match movie.get("id").and_then(serde_json::Value::as_i64) + { + Some(id) => id.to_string(), + None => return Ok(None), + }; + let popularity = movie + .get("popularity") + .and_then(serde_json::Value::as_f64) + .unwrap_or(0.0); + // Normalize popularity to 0-1 range (TMDB popularity can be very high) + let confidence = (popularity / 100.0).min(1.0); + + Ok(Some(ExternalMetadata { + id: Uuid::now_v7(), + media_id: item.id, + source: EnrichmentSourceType::Tmdb, + external_id: Some(external_id), + metadata_json: body, + confidence, + last_updated: Utc::now(), + })) + } +} diff --git a/crates/pinakes-metadata/Cargo.toml b/crates/pinakes-metadata/Cargo.toml new file mode 100644 index 0000000..5c26a23 --- /dev/null +++ b/crates/pinakes-metadata/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "pinakes-metadata" +edition.workspace = true +version.workspace = true +license.workspace = true + +[dependencies] +pinakes-types = { workspace = true } +lofty = { workspace = true } +lopdf = { workspace = true } +epub = { workspace = true } +matroska = { workspace = true } +image = { workspace = true } +kamadak-exif = { workspace = true } +gray_matter = { workspace = true } +rustc-hash = { workspace = true } +chrono = { workspace = true } +image_hasher = { workspace = true } +tracing = { workspace = true } +regex = { workspace = true } + +[lints] +workspace = true diff --git a/crates/pinakes-metadata/src/audio.rs b/crates/pinakes-metadata/src/audio.rs new file mode 100644 index 0000000..b33ff3f --- /dev/null +++ b/crates/pinakes-metadata/src/audio.rs @@ -0,0 +1,91 @@ +use std::path::Path; + +use lofty::{ + file::{AudioFile, TaggedFileExt}, + tag::Accessor, +}; +use pinakes_types::{ + error::{PinakesError, Result}, + media_type::{BuiltinMediaType, MediaType}, +}; + +use super::{ExtractedMetadata, MetadataExtractor}; + +pub struct AudioExtractor; + +impl MetadataExtractor for AudioExtractor { + fn extract(&self, path: &Path) -> Result { + let tagged_file = lofty::read_from_path(path).map_err(|e| { + PinakesError::MetadataExtraction(format!("audio metadata: {e}")) + })?; + + let mut meta = ExtractedMetadata::default(); + + if let Some(tag) = tagged_file + .primary_tag() + .or_else(|| tagged_file.first_tag()) + { + meta.title = tag.title().map(|s| s.to_string()); + meta.artist = tag.artist().map(|s| s.to_string()); + meta.album = tag.album().map(|s| s.to_string()); + meta.genre = tag.genre().map(|s| s.to_string()); + meta.year = tag.date().map(|ts| i32::from(ts.year)); + } + + if let Some(tag) = tagged_file + .primary_tag() + .or_else(|| tagged_file.first_tag()) + { + if let Some(track) = tag.track() { + meta + .extra + .insert("track_number".to_string(), track.to_string()); + } + if let Some(disc) = tag.disk() { + meta + .extra + .insert("disc_number".to_string(), disc.to_string()); + } + if let Some(comment) = tag.comment() { + meta + .extra + .insert("comment".to_string(), comment.to_string()); + } + } + + let properties = tagged_file.properties(); + let duration = properties.duration(); + if !duration.is_zero() { + meta.duration_secs = Some(duration.as_secs_f64()); + } + + if let Some(bitrate) = properties.audio_bitrate() { + meta + .extra + .insert("bitrate".to_string(), format!("{bitrate} kbps")); + } + if let Some(sample_rate) = properties.sample_rate() { + meta + .extra + .insert("sample_rate".to_string(), format!("{sample_rate} Hz")); + } + if let Some(channels) = properties.channels() { + meta + .extra + .insert("channels".to_string(), channels.to_string()); + } + + Ok(meta) + } + + fn supported_types(&self) -> Vec { + vec![ + MediaType::Builtin(BuiltinMediaType::Mp3), + MediaType::Builtin(BuiltinMediaType::Flac), + MediaType::Builtin(BuiltinMediaType::Ogg), + MediaType::Builtin(BuiltinMediaType::Wav), + MediaType::Builtin(BuiltinMediaType::Aac), + MediaType::Builtin(BuiltinMediaType::Opus), + ] + } +} diff --git a/crates/pinakes-metadata/src/document.rs b/crates/pinakes-metadata/src/document.rs new file mode 100644 index 0000000..0752f2f --- /dev/null +++ b/crates/pinakes-metadata/src/document.rs @@ -0,0 +1,460 @@ +use std::{path::Path, sync::LazyLock}; + +use pinakes_types::{ + error::{PinakesError, Result}, + media_type::{BuiltinMediaType, MediaType}, +}; + +use super::{ExtractedMetadata, MetadataExtractor}; + +// --- ISBN helpers (duplicated from pinakes-core::books to avoid circular dep) +// --- + +static ISBN_PATTERNS: LazyLock> = LazyLock::new(|| { + [ + r"ISBN(?:-13)?(?:\s+is|:)?\s*(\d{3}-\d{1,5}-\d{1,7}-\d{1,7}-\d)", + r"ISBN(?:-10)?(?:\s+is|:)?\s*(\d{1,5}-\d{1,7}-\d{1,7}-[\dXx])", + r"ISBN(?:-13)?\s+(\d{13})", + r"ISBN(?:-10)?\s+(\d{9}[\dXx])", + r"\b(\d{3}-\d{1,5}-\d{1,7}-\d{1,7}-\d)\b", + r"\b(\d{1,5}-\d{1,7}-\d{1,7}-[\dXx])\b", + ] + .iter() + .filter_map(|p| regex::Regex::new(p).ok()) + .collect() +}); + +fn extract_isbn_from_text(text: &str) -> Option { + for pattern in ISBN_PATTERNS.iter() { + if let Some(captures) = pattern.captures(text) + && let Some(isbn) = captures.get(1) + && let Ok(normalized) = normalize_isbn(isbn.as_str()) + { + return Some(normalized); + } + } + None +} + +fn normalize_isbn(isbn: &str) -> std::result::Result { + let clean: String = isbn + .chars() + .filter(|c| c.is_ascii_digit() || *c == 'X' || *c == 'x') + .collect(); + + match clean.len() { + 10 => isbn10_to_isbn13(&clean), + 13 => { + if is_valid_isbn13(&clean) { + Ok(clean) + } else { + Err(()) + } + }, + _ => Err(()), + } +} + +fn isbn10_to_isbn13(isbn10: &str) -> std::result::Result { + if isbn10.len() != 10 { + return Err(()); + } + let mut isbn13 = format!("978{}", &isbn10[..9]); + let check_digit = calculate_isbn13_check_digit(&isbn13).ok_or(())?; + isbn13.push_str(&check_digit.to_string()); + Ok(isbn13) +} + +fn calculate_isbn13_check_digit(isbn_without_check: &str) -> Option { + if isbn_without_check.len() != 12 { + return None; + } + let sum: u32 = isbn_without_check + .chars() + .enumerate() + .filter_map(|(i, c)| { + c.to_digit(10).map(|d| if i % 2 == 0 { d } else { d * 3 }) + }) + .sum(); + Some((10 - (sum % 10)) % 10) +} + +fn is_valid_isbn13(isbn13: &str) -> bool { + if isbn13.len() != 13 { + return false; + } + let sum: u32 = isbn13 + .chars() + .enumerate() + .filter_map(|(i, c)| { + c.to_digit(10).map(|d| if i % 2 == 0 { d } else { d * 3 }) + }) + .sum(); + sum.is_multiple_of(10) +} + +pub struct DocumentExtractor; + +impl MetadataExtractor for DocumentExtractor { + fn extract(&self, path: &Path) -> Result { + match MediaType::from_path(path) { + Some(MediaType::Builtin(BuiltinMediaType::Pdf)) => extract_pdf(path), + Some(MediaType::Builtin(BuiltinMediaType::Epub)) => extract_epub(path), + Some(MediaType::Builtin(BuiltinMediaType::Djvu)) => extract_djvu(path), + _ => Ok(ExtractedMetadata::default()), + } + } + + fn supported_types(&self) -> Vec { + vec![ + MediaType::Builtin(BuiltinMediaType::Pdf), + MediaType::Builtin(BuiltinMediaType::Epub), + MediaType::Builtin(BuiltinMediaType::Djvu), + ] + } +} + +fn extract_pdf(path: &Path) -> Result { + let doc = lopdf::Document::load(path) + .map_err(|e| PinakesError::MetadataExtraction(format!("PDF load: {e}")))?; + + let mut meta = ExtractedMetadata::default(); + let mut book_meta = pinakes_types::model::BookMetadata::default(); + + // Find the Info dictionary via the trailer + if let Ok(info_ref) = doc.trailer.get(b"Info") { + let info_obj = info_ref + .as_reference() + .map_or(Some(info_ref), |reference| doc.get_object(reference).ok()); + + if let Some(obj) = info_obj + && let Ok(dict) = obj.as_dict() + { + if let Ok(title) = dict.get(b"Title") { + meta.title = pdf_object_to_string(title); + } + if let Ok(author) = dict.get(b"Author") { + let author_str = pdf_object_to_string(author); + meta.artist.clone_from(&author_str); + + // Parse multiple authors if separated by semicolon, comma, or "and" + if let Some(authors_str) = author_str { + book_meta.authors = authors_str + .split(&[';', ','][..]) + .flat_map(|part| part.split(" and ")) + .map(|name| name.trim().to_string()) + .filter(|name| !name.is_empty()) + .enumerate() + .map(|(pos, name)| { + let mut author = pinakes_types::model::AuthorInfo::new(name); + author.position = i32::try_from(pos).unwrap_or(i32::MAX); + author + }) + .collect(); + } + } + if let Ok(subject) = dict.get(b"Subject") { + meta.description = pdf_object_to_string(subject); + } + if let Ok(creator) = dict.get(b"Creator") { + meta.extra.insert( + "creator".to_string(), + pdf_object_to_string(creator).unwrap_or_default(), + ); + } + if let Ok(producer) = dict.get(b"Producer") { + meta.extra.insert( + "producer".to_string(), + pdf_object_to_string(producer).unwrap_or_default(), + ); + } + } + } + + // Page count + let pages = doc.get_pages(); + let page_count = pages.len(); + if page_count > 0 { + book_meta.page_count = Some(i32::try_from(page_count).unwrap_or(i32::MAX)); + } + + // Try to extract ISBN from first few pages + // Extract text from up to the first 5 pages and search for ISBN patterns + let mut extracted_text = String::new(); + let max_pages = page_count.min(5); + + for (_page_num, page_id) in pages.iter().take(max_pages) { + if let Ok(content) = doc.get_page_content(*page_id) { + // PDF content streams contain raw operators, but may have text strings + if let Ok(text) = std::str::from_utf8(&content) { + extracted_text.push_str(text); + extracted_text.push(' '); + } + } + } + + // Extract ISBN from the text + if let Some(isbn) = extract_isbn_from_text(&extracted_text) + && let Ok(normalized) = normalize_isbn(&isbn) + { + book_meta.isbn13 = Some(normalized); + book_meta.isbn = Some(isbn); + } + + // Set format + book_meta.format = Some("pdf".to_string()); + + meta.book_metadata = Some(book_meta); + Ok(meta) +} + +fn pdf_object_to_string(obj: &lopdf::Object) -> Option { + match obj { + lopdf::Object::String(bytes, _) => { + Some(String::from_utf8_lossy(bytes).into_owned()) + }, + lopdf::Object::Name(name) => { + Some(String::from_utf8_lossy(name).into_owned()) + }, + _ => None, + } +} + +fn extract_epub(path: &Path) -> Result { + let mut doc = epub::doc::EpubDoc::new(path).map_err(|e| { + PinakesError::MetadataExtraction(format!("EPUB parse: {e}")) + })?; + + let mut meta = ExtractedMetadata { + title: doc.mdata("title").map(|item| item.value.clone()), + artist: doc.mdata("creator").map(|item| item.value.clone()), + description: doc.mdata("description").map(|item| item.value.clone()), + ..Default::default() + }; + + let mut book_meta = pinakes_types::model::BookMetadata::default(); + + // Extract basic metadata + if let Some(lang) = doc.mdata("language") { + book_meta.language = Some(lang.value.clone()); + } + if let Some(publisher) = doc.mdata("publisher") { + book_meta.publisher = Some(publisher.value.clone()); + } + if let Some(date) = doc.mdata("date") { + // Try to parse as YYYY-MM-DD or just YYYY + if let Ok(parsed_date) = + chrono::NaiveDate::parse_from_str(&date.value, "%Y-%m-%d") + { + book_meta.publication_date = Some(parsed_date); + } else if let Ok(year) = date.value.parse::() { + book_meta.publication_date = chrono::NaiveDate::from_ymd_opt(year, 1, 1); + } + } + + // Extract authors - iterate through all metadata items + let mut authors = Vec::new(); + let mut position = 0; + for item in &doc.metadata { + if item.property == "creator" || item.property == "dc:creator" { + let mut author = + pinakes_types::model::AuthorInfo::new(item.value.clone()); + author.position = position; + position += 1; + + // Check for file-as in refinements + if let Some(file_as_ref) = item.refinement("file-as") { + author.file_as = Some(file_as_ref.value.clone()); + } + + // Check for role in refinements + if let Some(role_ref) = item.refinement("role") { + author.role.clone_from(&role_ref.value); + } + + authors.push(author); + } + } + book_meta.authors = authors; + + // Extract ISBNs from identifiers + let mut identifiers = rustc_hash::FxHashMap::default(); + for item in &doc.metadata { + if item.property == "identifier" || item.property == "dc:identifier" { + // Try to get scheme from refinements + let scheme = item + .refinement("identifier-type") + .map(|r| r.value.to_lowercase()); + + let id_type = match scheme.as_deref() { + Some("isbn" | "isbn-10" | "isbn10") => "isbn", + Some("isbn-13" | "isbn13") => "isbn13", + Some("asin") => "asin", + Some("doi") => "doi", + _ => { + // Fallback: detect from value pattern. + // ISBN-10 = 10 chars bare, ISBN-13 = 13 chars bare, + // hyphenated ISBN-13 = 17 chars (e.g. 978-0-123-45678-9). + // Parentheses required: && binds tighter than ||. + if (item.value.len() == 10 || item.value.len() == 13) + || (item.value.contains('-') + && (item.value.len() == 13 || item.value.len() == 17)) + { + "isbn" + } else { + "other" + } + }, + }; + + // Try to normalize ISBN + if (id_type == "isbn" || id_type == "isbn13") + && let Ok(normalized) = normalize_isbn(&item.value) + { + book_meta.isbn13 = Some(normalized.clone()); + book_meta.isbn = Some(item.value.clone()); + } + + identifiers + .entry(id_type.to_string()) + .or_insert_with(Vec::new) + .push(item.value.clone()); + } + } + book_meta.identifiers = identifiers; + + // Extract Calibre series metadata by parsing the content.opf file + // Try common OPF locations + let opf_paths = vec!["OEBPS/content.opf", "content.opf", "OPS/content.opf"]; + let mut opf_data = None; + for path in opf_paths { + if let Some(data) = doc.get_resource_str_by_path(path) { + opf_data = Some(data); + break; + } + } + + if let Some(opf_content) = opf_data { + // Look for + if let Some(series_start) = opf_content.find("name=\"calibre:series\"") + && let Some(content_start) = + opf_content[series_start..].find("content=\"") + { + let after_content = &opf_content[series_start + content_start + 9..]; + if let Some(quote_end) = after_content.find('"') { + book_meta.series_name = Some(after_content[..quote_end].to_string()); + } + } + + // Look for + if let Some(index_start) = opf_content.find("name=\"calibre:series_index\"") + && let Some(content_start) = opf_content[index_start..].find("content=\"") + { + let after_content = &opf_content[index_start + content_start + 9..]; + if let Some(quote_end) = after_content.find('"') + && let Ok(index) = after_content[..quote_end].parse::() + { + book_meta.series_index = Some(index); + } + } + } + + // Set format + book_meta.format = Some("epub".to_string()); + + meta.book_metadata = Some(book_meta); + Ok(meta) +} + +fn extract_djvu(path: &Path) -> Result { + // DjVu files contain metadata in SEXPR (S-expression) format within + // ANTa/ANTz chunks, or in the DIRM chunk. We parse the raw bytes to + // extract any metadata fields we can find. + + // Guard against loading very large DjVu files into memory. + const MAX_DJVU_SIZE: u64 = 50 * 1024 * 1024; // 50 MB + let file_meta = std::fs::metadata(path) + .map_err(|e| PinakesError::MetadataExtraction(format!("DjVu stat: {e}")))?; + if file_meta.len() > MAX_DJVU_SIZE { + return Ok(ExtractedMetadata::default()); + } + + let data = std::fs::read(path) + .map_err(|e| PinakesError::MetadataExtraction(format!("DjVu read: {e}")))?; + + let mut meta = ExtractedMetadata::default(); + + // DjVu files start with "AT&T" magic followed by FORM:DJVU or FORM:DJVM + if data.len() < 16 { + return Ok(meta); + } + + // Search for metadata annotations in the file. DjVu metadata is stored + // as S-expressions like (metadata (key "value") ...) within ANTa chunks. + let content = String::from_utf8_lossy(&data); + + // Look for (metadata ...) blocks + if let Some(meta_start) = content.find("(metadata") { + let remainder = &content[meta_start..]; + // Extract key-value pairs like (title "Some Title") + extract_djvu_field(remainder, "title", &mut meta.title); + extract_djvu_field(remainder, "author", &mut meta.artist); + + let mut desc = None; + extract_djvu_field(remainder, "subject", &mut desc); + if desc.is_none() { + extract_djvu_field(remainder, "description", &mut desc); + } + meta.description = desc; + + let mut year_str = None; + extract_djvu_field(remainder, "year", &mut year_str); + if let Some(ref y) = year_str { + meta.year = y.parse().ok(); + } + + let mut creator = None; + extract_djvu_field(remainder, "creator", &mut creator); + if let Some(c) = creator { + meta.extra.insert("creator".to_string(), c); + } + } + + // Also check for booklet-style metadata that some DjVu encoders write + // outside the metadata SEXPR + if meta.title.is_none() + && let Some(title_start) = content.find("(bookmarks") + { + let remainder = &content[title_start..]; + // First bookmark title is often the document title + if let Some(q1) = remainder.find('"') { + let after_q1 = &remainder[q1 + 1..]; + if let Some(q2) = after_q1.find('"') { + let val = &after_q1[..q2]; + if !val.is_empty() { + meta.title = Some(val.to_string()); + } + } + } + } + + Ok(meta) +} + +fn extract_djvu_field(sexpr: &str, key: &str, out: &mut Option) { + // Look for patterns like (key "value") in the S-expression + let pattern = format!("({key}"); + if let Some(start) = sexpr.find(&pattern) { + let remainder = &sexpr[start + pattern.len()..]; + // Find the quoted value + if let Some(q1) = remainder.find('"') { + let after_q1 = &remainder[q1 + 1..]; + if let Some(q2) = after_q1.find('"') { + let val = &after_q1[..q2]; + if !val.is_empty() { + *out = Some(val.to_string()); + } + } + } + } +} diff --git a/crates/pinakes-metadata/src/image.rs b/crates/pinakes-metadata/src/image.rs new file mode 100644 index 0000000..196ad09 --- /dev/null +++ b/crates/pinakes-metadata/src/image.rs @@ -0,0 +1,300 @@ +use std::path::Path; + +use pinakes_types::{ + error::Result, + media_type::{BuiltinMediaType, MediaType}, +}; + +use super::{ExtractedMetadata, MetadataExtractor}; + +pub struct ImageExtractor; + +impl MetadataExtractor for ImageExtractor { + fn extract(&self, path: &Path) -> Result { + let mut meta = ExtractedMetadata::default(); + + let file = std::fs::File::open(path)?; + let mut buf_reader = std::io::BufReader::new(&file); + + let Ok(exif_data) = + exif::Reader::new().read_from_container(&mut buf_reader) + else { + return Ok(meta); + }; + + // Image dimensions + if let Some(width) = exif_data + .get_field(exif::Tag::PixelXDimension, exif::In::PRIMARY) + .or_else(|| exif_data.get_field(exif::Tag::ImageWidth, exif::In::PRIMARY)) + && let Some(w) = field_to_u32(width) + { + meta.extra.insert("width".to_string(), w.to_string()); + } + if let Some(height) = exif_data + .get_field(exif::Tag::PixelYDimension, exif::In::PRIMARY) + .or_else(|| { + exif_data.get_field(exif::Tag::ImageLength, exif::In::PRIMARY) + }) + && let Some(h) = field_to_u32(height) + { + meta.extra.insert("height".to_string(), h.to_string()); + } + + // Camera make and model - set both in top-level fields and extra + if let Some(make) = exif_data.get_field(exif::Tag::Make, exif::In::PRIMARY) + { + let val = make.display_value().to_string().trim().to_string(); + if !val.is_empty() { + meta.camera_make = Some(val.clone()); + meta.extra.insert("camera_make".to_string(), val); + } + } + if let Some(model) = + exif_data.get_field(exif::Tag::Model, exif::In::PRIMARY) + { + let val = model.display_value().to_string().trim().to_string(); + if !val.is_empty() { + meta.camera_model = Some(val.clone()); + meta.extra.insert("camera_model".to_string(), val); + } + } + + // Date taken - parse EXIF date format (YYYY:MM:DD HH:MM:SS) + if let Some(date) = exif_data + .get_field(exif::Tag::DateTimeOriginal, exif::In::PRIMARY) + .or_else(|| exif_data.get_field(exif::Tag::DateTime, exif::In::PRIMARY)) + { + let val = date.display_value().to_string(); + if !val.is_empty() { + // Try parsing EXIF format: "YYYY:MM:DD HH:MM:SS" + if let Some(dt) = parse_exif_datetime(&val) { + meta.date_taken = Some(dt); + } + meta.extra.insert("date_taken".to_string(), val); + } + } + + // GPS coordinates - set both in top-level fields and extra + if let (Some(lat), Some(lat_ref), Some(lon), Some(lon_ref)) = ( + exif_data.get_field(exif::Tag::GPSLatitude, exif::In::PRIMARY), + exif_data.get_field(exif::Tag::GPSLatitudeRef, exif::In::PRIMARY), + exif_data.get_field(exif::Tag::GPSLongitude, exif::In::PRIMARY), + exif_data.get_field(exif::Tag::GPSLongitudeRef, exif::In::PRIMARY), + ) && let (Some(lat_val), Some(lon_val)) = + (dms_to_decimal(lat, lat_ref), dms_to_decimal(lon, lon_ref)) + { + meta.latitude = Some(lat_val); + meta.longitude = Some(lon_val); + meta + .extra + .insert("gps_latitude".to_string(), format!("{lat_val:.6}")); + meta + .extra + .insert("gps_longitude".to_string(), format!("{lon_val:.6}")); + } + + // Exposure info + if let Some(iso) = + exif_data.get_field(exif::Tag::PhotographicSensitivity, exif::In::PRIMARY) + { + let val = iso.display_value().to_string(); + if !val.is_empty() { + meta.extra.insert("iso".to_string(), val); + } + } + if let Some(exposure) = + exif_data.get_field(exif::Tag::ExposureTime, exif::In::PRIMARY) + { + let val = exposure.display_value().to_string(); + if !val.is_empty() { + meta.extra.insert("exposure_time".to_string(), val); + } + } + if let Some(aperture) = + exif_data.get_field(exif::Tag::FNumber, exif::In::PRIMARY) + { + let val = aperture.display_value().to_string(); + if !val.is_empty() { + meta.extra.insert("f_number".to_string(), val); + } + } + if let Some(focal) = + exif_data.get_field(exif::Tag::FocalLength, exif::In::PRIMARY) + { + let val = focal.display_value().to_string(); + if !val.is_empty() { + meta.extra.insert("focal_length".to_string(), val); + } + } + + // Lens model + if let Some(lens) = + exif_data.get_field(exif::Tag::LensModel, exif::In::PRIMARY) + { + let val = lens.display_value().to_string(); + if !val.is_empty() && val != "\"\"" { + meta + .extra + .insert("lens_model".to_string(), val.trim_matches('"').to_string()); + } + } + + // Flash + if let Some(flash) = + exif_data.get_field(exif::Tag::Flash, exif::In::PRIMARY) + { + let val = flash.display_value().to_string(); + if !val.is_empty() { + meta.extra.insert("flash".to_string(), val); + } + } + + // Orientation + if let Some(orientation) = + exif_data.get_field(exif::Tag::Orientation, exif::In::PRIMARY) + { + let val = orientation.display_value().to_string(); + if !val.is_empty() { + meta.extra.insert("orientation".to_string(), val); + } + } + + // Software + if let Some(software) = + exif_data.get_field(exif::Tag::Software, exif::In::PRIMARY) + { + let val = software.display_value().to_string(); + if !val.is_empty() { + meta.extra.insert("software".to_string(), val); + } + } + + // Image description as title + if let Some(desc) = + exif_data.get_field(exif::Tag::ImageDescription, exif::In::PRIMARY) + { + let val = desc.display_value().to_string(); + if !val.is_empty() && val != "\"\"" { + meta.title = Some(val.trim_matches('"').to_string()); + } + } + + // Artist + if let Some(artist) = + exif_data.get_field(exif::Tag::Artist, exif::In::PRIMARY) + { + let val = artist.display_value().to_string(); + if !val.is_empty() && val != "\"\"" { + meta.artist = Some(val.trim_matches('"').to_string()); + } + } + + // Copyright as description + if let Some(copyright) = + exif_data.get_field(exif::Tag::Copyright, exif::In::PRIMARY) + { + let val = copyright.display_value().to_string(); + if !val.is_empty() && val != "\"\"" { + meta.description = Some(val.trim_matches('"').to_string()); + } + } + + Ok(meta) + } + + fn supported_types(&self) -> Vec { + vec![ + MediaType::Builtin(BuiltinMediaType::Jpeg), + MediaType::Builtin(BuiltinMediaType::Png), + MediaType::Builtin(BuiltinMediaType::Gif), + MediaType::Builtin(BuiltinMediaType::Webp), + MediaType::Builtin(BuiltinMediaType::Avif), + MediaType::Builtin(BuiltinMediaType::Tiff), + MediaType::Builtin(BuiltinMediaType::Bmp), + // RAW formats (TIFF-based, kamadak-exif handles these) + MediaType::Builtin(BuiltinMediaType::Cr2), + MediaType::Builtin(BuiltinMediaType::Nef), + MediaType::Builtin(BuiltinMediaType::Arw), + MediaType::Builtin(BuiltinMediaType::Dng), + MediaType::Builtin(BuiltinMediaType::Orf), + MediaType::Builtin(BuiltinMediaType::Rw2), + // HEIC + MediaType::Builtin(BuiltinMediaType::Heic), + ] + } +} + +fn field_to_u32(field: &exif::Field) -> Option { + match &field.value { + exif::Value::Long(v) => v.first().copied(), + exif::Value::Short(v) => v.first().map(|&x| u32::from(x)), + _ => None, + } +} + +fn dms_to_decimal( + dms_field: &exif::Field, + ref_field: &exif::Field, +) -> Option { + if let exif::Value::Rational(ref rationals) = dms_field.value + && rationals.len() >= 3 + { + let degrees = rationals[0].to_f64(); + let minutes = rationals[1].to_f64(); + let seconds = rationals[2].to_f64(); + let mut decimal = degrees + minutes / 60.0 + seconds / 3600.0; + + let ref_str = ref_field.display_value().to_string(); + if ref_str.contains('S') || ref_str.contains('W') { + decimal = -decimal; + } + + return Some(decimal); + } + None +} + +/// Parse EXIF datetime format: "YYYY:MM:DD HH:MM:SS" +fn parse_exif_datetime(s: &str) -> Option> { + use chrono::NaiveDateTime; + + // EXIF format is "YYYY:MM:DD HH:MM:SS" + let s = s.trim().trim_matches('"'); + + // Try standard EXIF format + if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y:%m:%d %H:%M:%S") { + return Some(dt.and_utc()); + } + + // Try ISO format as fallback + if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") { + return Some(dt.and_utc()); + } + + None +} + +/// Generate a perceptual hash for an image file. +/// +/// Uses DCT (Discrete Cosine Transform) hash algorithm for robust similarity +/// detection. Returns a hex-encoded hash string, or None if the image cannot be +/// processed. +#[must_use] +pub fn generate_perceptual_hash(path: &Path) -> Option { + use image_hasher::{HashAlg, HasherConfig}; + + // Open and decode the image + let img = image::open(path).ok()?; + + // Create hasher with DCT algorithm (good for finding similar images) + let hasher = HasherConfig::new() + .hash_alg(HashAlg::DoubleGradient) + .hash_size(8, 8) // 64-bit hash + .to_hasher(); + + // Generate hash + let hash = hasher.hash_image(&img); + + // Convert to hex string for storage + Some(hash.to_base64()) +} diff --git a/crates/pinakes-metadata/src/lib.rs b/crates/pinakes-metadata/src/lib.rs new file mode 100644 index 0000000..7a89362 --- /dev/null +++ b/crates/pinakes-metadata/src/lib.rs @@ -0,0 +1,73 @@ +pub mod audio; +pub mod document; +pub mod image; +pub mod markdown; +pub mod video; + +use std::path::Path; + +use pinakes_types::{ + error::Result, + media_type::MediaType, + model::BookMetadata, +}; +use rustc_hash::FxHashMap; + +#[derive(Debug, Clone, Default)] +pub struct ExtractedMetadata { + pub title: Option, + pub artist: Option, + pub album: Option, + pub genre: Option, + pub year: Option, + pub duration_secs: Option, + pub description: Option, + pub extra: FxHashMap, + pub book_metadata: Option, + + // Photo-specific metadata + pub date_taken: Option>, + pub latitude: Option, + pub longitude: Option, + pub camera_make: Option, + pub camera_model: Option, + pub rating: Option, +} + +pub trait MetadataExtractor: Send + Sync { + /// Extract metadata from a file at the given path. + /// + /// # Errors + /// + /// Returns an error if the file cannot be read or parsed. + fn extract(&self, path: &Path) -> Result; + fn supported_types(&self) -> Vec; +} + +/// Extract metadata from a file using the appropriate extractor for the given +/// media type. +/// +/// # Errors +/// +/// Returns an error if extraction fails. Returns a default `ExtractedMetadata` +/// when no extractor supports the media type. +pub fn extract_metadata( + path: &Path, + media_type: &MediaType, +) -> Result { + let extractors: Vec> = vec![ + Box::new(audio::AudioExtractor), + Box::new(document::DocumentExtractor), + Box::new(video::VideoExtractor), + Box::new(markdown::MarkdownExtractor), + Box::new(image::ImageExtractor), + ]; + + for extractor in &extractors { + if extractor.supported_types().contains(media_type) { + return extractor.extract(path); + } + } + + Ok(ExtractedMetadata::default()) +} diff --git a/crates/pinakes-metadata/src/markdown.rs b/crates/pinakes-metadata/src/markdown.rs new file mode 100644 index 0000000..e9b4b1a --- /dev/null +++ b/crates/pinakes-metadata/src/markdown.rs @@ -0,0 +1,46 @@ +use std::path::Path; + +use pinakes_types::{ + error::Result, + media_type::{BuiltinMediaType, MediaType}, +}; + +use super::{ExtractedMetadata, MetadataExtractor}; + +pub struct MarkdownExtractor; + +impl MetadataExtractor for MarkdownExtractor { + fn extract(&self, path: &Path) -> Result { + let content = std::fs::read_to_string(path)?; + let parsed = + gray_matter::Matter::::new().parse(&content); + + let mut meta = ExtractedMetadata::default(); + + if let Some(data) = parsed.ok().and_then(|p| p.data) + && let gray_matter::Pod::Hash(map) = data + { + if let Some(gray_matter::Pod::String(title)) = map.get("title") { + meta.title = Some(title.clone()); + } + if let Some(gray_matter::Pod::String(author)) = map.get("author") { + meta.artist = Some(author.clone()); + } + if let Some(gray_matter::Pod::String(desc)) = map.get("description") { + meta.description = Some(desc.clone()); + } + if let Some(gray_matter::Pod::String(date)) = map.get("date") { + meta.extra.insert("date".to_string(), date.clone()); + } + } + + Ok(meta) + } + + fn supported_types(&self) -> Vec { + vec![ + MediaType::Builtin(BuiltinMediaType::Markdown), + MediaType::Builtin(BuiltinMediaType::PlainText), + ] + } +} diff --git a/crates/pinakes-metadata/src/mod.rs b/crates/pinakes-metadata/src/mod.rs new file mode 100644 index 0000000..403a06b --- /dev/null +++ b/crates/pinakes-metadata/src/mod.rs @@ -0,0 +1,70 @@ +pub mod audio; +pub mod document; +pub mod image; +pub mod markdown; +pub mod video; + +use std::path::Path; + +use rustc_hash::FxHashMap; + +use pinakes_types::{error::Result, media_type::MediaType, model::BookMetadata}; + +#[derive(Debug, Clone, Default)] +pub struct ExtractedMetadata { + pub title: Option, + pub artist: Option, + pub album: Option, + pub genre: Option, + pub year: Option, + pub duration_secs: Option, + pub description: Option, + pub extra: FxHashMap, + pub book_metadata: Option, + + // Photo-specific metadata + pub date_taken: Option>, + pub latitude: Option, + pub longitude: Option, + pub camera_make: Option, + pub camera_model: Option, + pub rating: Option, +} + +pub trait MetadataExtractor: Send + Sync { + /// Extract metadata from a file at the given path. + /// + /// # Errors + /// + /// Returns an error if the file cannot be read or parsed. + fn extract(&self, path: &Path) -> Result; + fn supported_types(&self) -> Vec; +} + +/// Extract metadata from a file using the appropriate extractor for the given +/// media type. +/// +/// # Errors +/// +/// Returns an error if no extractor supports the media type, or if extraction +/// fails. +pub fn extract_metadata( + path: &Path, + media_type: &MediaType, +) -> Result { + let extractors: Vec> = vec![ + Box::new(audio::AudioExtractor), + Box::new(document::DocumentExtractor), + Box::new(video::VideoExtractor), + Box::new(markdown::MarkdownExtractor), + Box::new(image::ImageExtractor), + ]; + + for extractor in &extractors { + if extractor.supported_types().contains(media_type) { + return extractor.extract(path); + } + } + + Ok(ExtractedMetadata::default()) +} diff --git a/crates/pinakes-metadata/src/video.rs b/crates/pinakes-metadata/src/video.rs new file mode 100644 index 0000000..5720d42 --- /dev/null +++ b/crates/pinakes-metadata/src/video.rs @@ -0,0 +1,129 @@ +use std::path::Path; + +use pinakes_types::{ + error::{PinakesError, Result}, + media_type::{BuiltinMediaType, MediaType}, +}; + +use super::{ExtractedMetadata, MetadataExtractor}; + +pub struct VideoExtractor; + +impl MetadataExtractor for VideoExtractor { + fn extract(&self, path: &Path) -> Result { + match MediaType::from_path(path) { + Some(MediaType::Builtin(BuiltinMediaType::Mkv)) => extract_mkv(path), + Some(MediaType::Builtin(BuiltinMediaType::Mp4)) => extract_mp4(path), + _ => Ok(ExtractedMetadata::default()), + } + } + + fn supported_types(&self) -> Vec { + vec![ + MediaType::Builtin(BuiltinMediaType::Mp4), + MediaType::Builtin(BuiltinMediaType::Mkv), + ] + } +} + +fn extract_mkv(path: &Path) -> Result { + let file = std::fs::File::open(path)?; + let mkv = matroska::Matroska::open(file) + .map_err(|e| PinakesError::MetadataExtraction(format!("MKV parse: {e}")))?; + + let mut meta = ExtractedMetadata { + title: mkv.info.title.clone(), + duration_secs: mkv.info.duration.map(|dur| dur.as_secs_f64()), + ..Default::default() + }; + + // Extract resolution and codec info from tracks + for track in &mkv.tracks { + match &track.settings { + matroska::Settings::Video(v) => { + meta.extra.insert( + "resolution".to_string(), + format!("{}x{}", v.pixel_width, v.pixel_height), + ); + if !track.codec_id.is_empty() { + meta + .extra + .insert("video_codec".to_string(), track.codec_id.clone()); + } + }, + matroska::Settings::Audio(a) => { + meta.extra.insert( + "sample_rate".to_string(), + format!("{:.0} Hz", a.sample_rate), + ); + meta + .extra + .insert("channels".to_string(), a.channels.to_string()); + if !track.codec_id.is_empty() { + meta + .extra + .insert("audio_codec".to_string(), track.codec_id.clone()); + } + }, + matroska::Settings::None => {}, + } + } + + Ok(meta) +} + +fn extract_mp4(path: &Path) -> Result { + use lofty::{ + file::{AudioFile, TaggedFileExt}, + tag::Accessor, + }; + + let tagged_file = lofty::read_from_path(path).map_err(|e| { + PinakesError::MetadataExtraction(format!("MP4 metadata: {e}")) + })?; + + let mut meta = ExtractedMetadata::default(); + + if let Some(tag) = tagged_file + .primary_tag() + .or_else(|| tagged_file.first_tag()) + { + meta.title = tag + .title() + .map(|s: std::borrow::Cow<'_, str>| s.to_string()); + meta.artist = tag + .artist() + .map(|s: std::borrow::Cow<'_, str>| s.to_string()); + meta.album = tag + .album() + .map(|s: std::borrow::Cow<'_, str>| s.to_string()); + meta.genre = tag + .genre() + .map(|s: std::borrow::Cow<'_, str>| s.to_string()); + meta.year = tag.date().map(|ts| i32::from(ts.year)); + } + + let properties = tagged_file.properties(); + let duration = properties.duration(); + if !duration.is_zero() { + meta.duration_secs = Some(duration.as_secs_f64()); + } + + if let Some(bitrate) = properties.audio_bitrate() { + meta + .extra + .insert("audio_bitrate".to_string(), format!("{bitrate} kbps")); + } + if let Some(sample_rate) = properties.sample_rate() { + meta + .extra + .insert("sample_rate".to_string(), format!("{sample_rate} Hz")); + } + if let Some(channels) = properties.channels() { + meta + .extra + .insert("channels".to_string(), channels.to_string()); + } + + Ok(meta) +} diff --git a/crates/pinakes-plugin/Cargo.toml b/crates/pinakes-plugin/Cargo.toml new file mode 100644 index 0000000..d3d6192 --- /dev/null +++ b/crates/pinakes-plugin/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "pinakes-plugin" +edition.workspace = true +version.workspace = true +license.workspace = true + +[dependencies] +pinakes-types = { workspace = true } +pinakes-plugin-api = { workspace = true } +wasmtime = { workspace = true } +ed25519-dalek = { workspace = true } +reqwest = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +anyhow = { workspace = true } +rustc-hash = { workspace = true } +walkdir = { workspace = true } +uuid = { workspace = true } +url = { workspace = true } +blake3 = { workspace = true } +rand = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } + +[lints] +workspace = true diff --git a/crates/pinakes-plugin/src/lib.rs b/crates/pinakes-plugin/src/lib.rs new file mode 100644 index 0000000..08d063a --- /dev/null +++ b/crates/pinakes-plugin/src/lib.rs @@ -0,0 +1,15 @@ +pub mod loader; +pub mod registry; +pub mod rpc; +pub mod runtime; +pub mod security; +pub mod signature; + +pub use loader::PluginLoader; +pub use registry::{PluginRegistry, RegisteredPlugin}; +pub use runtime::{WasmPlugin, WasmRuntime}; +pub use security::CapabilityEnforcer; +pub use signature::{SignatureStatus, verify_plugin_signature}; + +mod manager; +pub use manager::{PluginManager, PluginManagerConfig}; diff --git a/crates/pinakes-plugin/src/loader.rs b/crates/pinakes-plugin/src/loader.rs new file mode 100644 index 0000000..f8242e8 --- /dev/null +++ b/crates/pinakes-plugin/src/loader.rs @@ -0,0 +1,432 @@ +//! Plugin loader for discovering and loading plugins from the filesystem + +use std::path::{Path, PathBuf}; + +use anyhow::{Result, anyhow}; +use pinakes_plugin_api::PluginManifest; +use tracing::{debug, info, warn}; +use walkdir::WalkDir; + +/// Plugin loader handles discovery and loading of plugins from directories +pub struct PluginLoader { + /// Directories to search for plugins + plugin_dirs: Vec, +} + +impl PluginLoader { + /// Create a new plugin loader + #[must_use] + pub const fn new(plugin_dirs: Vec) -> Self { + Self { plugin_dirs } + } + + /// Discover all plugins in configured directories + /// + /// # Errors + /// + /// Returns an error if a plugin directory cannot be searched. + pub fn discover_plugins(&self) -> Result> { + let mut manifests = Vec::new(); + + for dir in &self.plugin_dirs { + if !dir.exists() { + warn!("Plugin directory does not exist: {:?}", dir); + continue; + } + + info!("Discovering plugins in: {:?}", dir); + + let found = Self::discover_in_directory(dir); + info!("Found {} plugins in {:?}", found.len(), dir); + manifests.extend(found); + } + + Ok(manifests) + } + + /// Discover plugins in a specific directory + fn discover_in_directory(dir: &Path) -> Vec { + let mut manifests = Vec::new(); + + // Walk the directory looking for plugin.toml files + for entry in WalkDir::new(dir) + .max_depth(3) // Don't go too deep + .follow_links(false) + { + let entry = match entry { + Ok(e) => e, + Err(e) => { + warn!("Error reading directory entry: {}", e); + continue; + }, + }; + + let path = entry.path(); + + // Look for plugin.toml files + if path.file_name() == Some(std::ffi::OsStr::new("plugin.toml")) { + debug!("Found plugin manifest: {:?}", path); + + match PluginManifest::from_file(path) { + Ok(manifest) => { + info!("Loaded manifest for plugin: {}", manifest.plugin.name); + manifests.push(manifest); + }, + Err(e) => { + warn!("Failed to load manifest from {:?}: {}", path, e); + }, + } + } + } + + manifests + } + + /// Resolve the WASM binary path from a manifest + /// + /// # Errors + /// + /// Returns an error if the WASM binary is not found or its path escapes the + /// plugin directory. + pub fn resolve_wasm_path( + &self, + manifest: &PluginManifest, + ) -> Result { + // The WASM path in the manifest is relative to the manifest file + // We need to search for it in the plugin directories + + for dir in &self.plugin_dirs { + // Look for a directory matching the plugin name + let plugin_dir = dir.join(&manifest.plugin.name); + if !plugin_dir.exists() { + continue; + } + + // Check for plugin.toml in this directory + let manifest_path = plugin_dir.join("plugin.toml"); + if !manifest_path.exists() { + continue; + } + + // Resolve WASM path relative to this directory + let wasm_path = plugin_dir.join(&manifest.plugin.binary.wasm); + if wasm_path.exists() { + // Verify the resolved path is within the plugin directory (prevent path + // traversal) + let canonical_wasm = wasm_path + .canonicalize() + .map_err(|e| anyhow!("Failed to canonicalize WASM path: {e}"))?; + let canonical_plugin_dir = plugin_dir + .canonicalize() + .map_err(|e| anyhow!("Failed to canonicalize plugin dir: {e}"))?; + if !canonical_wasm.starts_with(&canonical_plugin_dir) { + return Err(anyhow!( + "WASM binary path escapes plugin directory: {}", + wasm_path.display() + )); + } + return Ok(canonical_wasm); + } + } + + Err(anyhow!( + "WASM binary not found for plugin: {}", + manifest.plugin.name + )) + } + + /// Download a plugin from a URL + /// + /// # Errors + /// + /// Returns an error if the URL is not HTTPS, no plugin directories are + /// configured, the download fails, the archive is too large, or extraction + /// fails. + pub async fn download_plugin(&self, url: &str) -> Result { + const MAX_PLUGIN_SIZE: u64 = 100 * 1024 * 1024; // 100 MB + + // Only allow HTTPS downloads + if !url.starts_with("https://") { + return Err(anyhow!( + "Only HTTPS URLs are allowed for plugin downloads: {url}" + )); + } + + let dest_dir = self + .plugin_dirs + .first() + .ok_or_else(|| anyhow!("No plugin directories configured"))?; + + std::fs::create_dir_all(dest_dir)?; + + // Download the archive with timeout and size limits + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_mins(5)) + .build() + .map_err(|e| anyhow!("Failed to build HTTP client: {e}"))?; + + let response = client + .get(url) + .send() + .await + .map_err(|e| anyhow!("Failed to download plugin: {e}"))?; + + if !response.status().is_success() { + return Err(anyhow!( + "Plugin download failed with status: {}", + response.status() + )); + } + + // Check content-length header before downloading + if let Some(content_length) = response.content_length() + && content_length > MAX_PLUGIN_SIZE + { + return Err(anyhow!( + "Plugin archive too large: {content_length} bytes (max \ + {MAX_PLUGIN_SIZE} bytes)" + )); + } + + let bytes = response + .bytes() + .await + .map_err(|e| anyhow!("Failed to read plugin response: {e}"))?; + + // Check actual size after download + if bytes.len() as u64 > MAX_PLUGIN_SIZE { + return Err(anyhow!( + "Plugin archive too large: {} bytes (max {} bytes)", + bytes.len(), + MAX_PLUGIN_SIZE + )); + } + + // Write archive to a unique temp file + let temp_archive = + dest_dir.join(format!(".download-{}.tar.gz", uuid::Uuid::now_v7())); + std::fs::write(&temp_archive, &bytes)?; + + // Extract using tar with -C to target directory + let canonical_dest = dest_dir + .canonicalize() + .map_err(|e| anyhow!("Failed to canonicalize dest dir: {e}"))?; + let output = std::process::Command::new("tar") + .args([ + "xzf", + &temp_archive.to_string_lossy(), + "-C", + &canonical_dest.to_string_lossy(), + ]) + .output() + .map_err(|e| anyhow!("Failed to extract plugin archive: {e}"))?; + + // Clean up the archive + let _ = std::fs::remove_file(&temp_archive); + + if !output.status.success() { + return Err(anyhow!( + "Failed to extract plugin archive: {}", + String::from_utf8_lossy(&output.stderr) + )); + } + + // Validate that all extracted files are within dest_dir + for entry in WalkDir::new(&canonical_dest).follow_links(false) { + let entry = entry?; + let entry_canonical = entry.path().canonicalize()?; + if !entry_canonical.starts_with(&canonical_dest) { + return Err(anyhow!( + "Extracted file escapes destination directory: {}", + entry.path().display() + )); + } + } + + // Find the extracted plugin directory by looking for plugin.toml + for entry in WalkDir::new(dest_dir).max_depth(2).follow_links(false) { + let entry = entry?; + if entry.file_name() == "plugin.toml" { + let plugin_dir = entry + .path() + .parent() + .ok_or_else(|| anyhow!("Invalid plugin.toml location"))?; + + // Validate the manifest + let manifest = PluginManifest::from_file(entry.path())?; + info!("Downloaded and extracted plugin: {}", manifest.plugin.name); + + return Ok(plugin_dir.to_path_buf()); + } + } + + Err(anyhow!( + "No plugin.toml found after extracting archive from: {url}" + )) + } + + /// Validate a plugin package + /// + /// # Errors + /// + /// Returns an error if the path does not exist, is missing `plugin.toml`, + /// the WASM binary is not found, or the WASM file is invalid. + pub fn validate_plugin_package(&self, path: &Path) -> Result<()> { + // Check that the path exists + if !path.exists() { + return Err(anyhow!("Plugin path does not exist: {}", path.display())); + } + + // Check for plugin.toml + let manifest_path = path.join("plugin.toml"); + if !manifest_path.exists() { + return Err(anyhow!("Missing plugin.toml in {}", path.display())); + } + + // Parse and validate manifest + let manifest = PluginManifest::from_file(&manifest_path)?; + + // Check that WASM binary exists + let wasm_path = path.join(&manifest.plugin.binary.wasm); + if !wasm_path.exists() { + return Err(anyhow!( + "WASM binary not found: {}", + manifest.plugin.binary.wasm + )); + } + + // Verify the WASM path is within the plugin directory (prevent path + // traversal) + let canonical_wasm = wasm_path.canonicalize()?; + let canonical_path = path.canonicalize()?; + if !canonical_wasm.starts_with(&canonical_path) { + return Err(anyhow!( + "WASM binary path escapes plugin directory: {}", + wasm_path.display() + )); + } + + // Validate WASM file + let wasm_bytes = std::fs::read(&wasm_path)?; + if wasm_bytes.len() < 4 || &wasm_bytes[0..4] != b"\0asm" { + return Err(anyhow!("Invalid WASM file: {}", wasm_path.display())); + } + + Ok(()) + } + + /// Get plugin directory path for a given plugin name + #[must_use] + pub fn get_plugin_dir(&self, plugin_name: &str) -> Option { + for dir in &self.plugin_dirs { + let plugin_dir = dir.join(plugin_name); + if plugin_dir.exists() { + return Some(plugin_dir); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use tempfile::TempDir; + + use super::*; + + #[test] + fn test_discover_plugins_empty() { + let temp_dir = TempDir::new().unwrap(); + let loader = PluginLoader::new(vec![temp_dir.path().to_path_buf()]); + + let manifests = loader.discover_plugins().unwrap(); + assert_eq!(manifests.len(), 0); + } + + #[test] + fn test_discover_plugins_with_manifest() { + let temp_dir = TempDir::new().unwrap(); + let plugin_dir = temp_dir.path().join("test-plugin"); + std::fs::create_dir(&plugin_dir).unwrap(); + + // Create a valid manifest + let manifest_content = r#" +[plugin] +name = "test-plugin" +version = "1.0.0" +api_version = "1.0" +kind = ["media_type"] + +[plugin.binary] +wasm = "plugin.wasm" +"#; + std::fs::write(plugin_dir.join("plugin.toml"), manifest_content).unwrap(); + + // Create dummy WASM file + std::fs::write(plugin_dir.join("plugin.wasm"), b"\0asm\x01\x00\x00\x00") + .unwrap(); + + let loader = PluginLoader::new(vec![temp_dir.path().to_path_buf()]); + let manifests = loader.discover_plugins().unwrap(); + + assert_eq!(manifests.len(), 1); + assert_eq!(manifests[0].plugin.name, "test-plugin"); + } + + #[test] + fn test_validate_plugin_package() { + let temp_dir = TempDir::new().unwrap(); + let plugin_dir = temp_dir.path().join("test-plugin"); + std::fs::create_dir(&plugin_dir).unwrap(); + + // Create a valid manifest + let manifest_content = r#" +[plugin] +name = "test-plugin" +version = "1.0.0" +api_version = "1.0" +kind = ["media_type"] + +[plugin.binary] +wasm = "plugin.wasm" +"#; + std::fs::write(plugin_dir.join("plugin.toml"), manifest_content).unwrap(); + + let loader = PluginLoader::new(vec![]); + + // Should fail without WASM file + assert!(loader.validate_plugin_package(&plugin_dir).is_err()); + + // Create valid WASM file (magic number only) + std::fs::write(plugin_dir.join("plugin.wasm"), b"\0asm\x01\x00\x00\x00") + .unwrap(); + + // Should succeed now + assert!(loader.validate_plugin_package(&plugin_dir).is_ok()); + } + + #[test] + fn test_validate_invalid_wasm() { + let temp_dir = TempDir::new().unwrap(); + let plugin_dir = temp_dir.path().join("test-plugin"); + std::fs::create_dir(&plugin_dir).unwrap(); + + let manifest_content = r#" +[plugin] +name = "test-plugin" +version = "1.0.0" +api_version = "1.0" +kind = ["media_type"] + +[plugin.binary] +wasm = "plugin.wasm" +"#; + std::fs::write(plugin_dir.join("plugin.toml"), manifest_content).unwrap(); + + // Create invalid WASM file + std::fs::write(plugin_dir.join("plugin.wasm"), b"not wasm").unwrap(); + + let loader = PluginLoader::new(vec![]); + assert!(loader.validate_plugin_package(&plugin_dir).is_err()); + } +} diff --git a/crates/pinakes-plugin/src/manager.rs b/crates/pinakes-plugin/src/manager.rs new file mode 100644 index 0000000..22609e2 --- /dev/null +++ b/crates/pinakes-plugin/src/manager.rs @@ -0,0 +1,916 @@ +use std::{path::PathBuf, sync::Arc}; + +use anyhow::Result; +use pinakes_plugin_api::{PluginContext, PluginMetadata}; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; + +use crate::{ + CapabilityEnforcer, + PluginLoader, + PluginRegistry, + RegisteredPlugin, + SignatureStatus, + WasmPlugin, + WasmRuntime, + signature, +}; + +/// Plugin manager coordinates plugin lifecycle and operations +pub struct PluginManager { + /// Plugin registry + registry: Arc>, + + /// WASM runtime for executing plugins + runtime: Arc, + + /// Plugin loader for discovery and loading + loader: PluginLoader, + + /// Capability enforcer for security + enforcer: CapabilityEnforcer, + + /// Plugin data directory + data_dir: PathBuf, + + /// Plugin cache directory + cache_dir: PathBuf, + + /// Configuration + config: PluginManagerConfig, +} + +/// Configuration for the plugin manager +#[derive(Debug, Clone)] +pub struct PluginManagerConfig { + /// Directories to search for plugins + pub plugin_dirs: Vec, + + /// Whether to enable hot-reload (for development) + pub enable_hot_reload: bool, + + /// Whether to allow unsigned plugins + pub allow_unsigned: bool, + + /// Maximum number of concurrent plugin operations + pub max_concurrent_ops: usize, + + /// Plugin timeout in seconds + pub plugin_timeout_secs: u64, + + /// Timeout configuration for different call types + pub timeouts: pinakes_types::config::PluginTimeoutConfig, + + /// Max consecutive failures before circuit breaker disables plugin + pub max_consecutive_failures: u32, + + /// Trusted Ed25519 public keys for signature verification (hex-encoded) + pub trusted_keys: Vec, +} + +impl Default for PluginManagerConfig { + fn default() -> Self { + Self { + plugin_dirs: vec![], + enable_hot_reload: false, + allow_unsigned: false, + max_concurrent_ops: 4, + plugin_timeout_secs: 30, + timeouts: + pinakes_types::config::PluginTimeoutConfig::default(), + max_consecutive_failures: 5, + trusted_keys: vec![], + } + } +} + +impl From for PluginManagerConfig { + fn from(cfg: pinakes_types::config::PluginsConfig) -> Self { + Self { + plugin_dirs: cfg.plugin_dirs, + enable_hot_reload: cfg.enable_hot_reload, + allow_unsigned: cfg.allow_unsigned, + max_concurrent_ops: cfg.max_concurrent_ops, + plugin_timeout_secs: cfg.plugin_timeout_secs, + timeouts: cfg.timeouts, + max_consecutive_failures: cfg.max_consecutive_failures, + trusted_keys: cfg.trusted_keys, + } + } +} + +impl PluginManager { + /// Create a new plugin manager + /// + /// # Errors + /// + /// Returns an error if the data or cache directories cannot be created, or + /// if the WASM runtime cannot be initialized. + pub fn new( + data_dir: PathBuf, + cache_dir: PathBuf, + config: PluginManagerConfig, + ) -> Result { + // Ensure directories exist + std::fs::create_dir_all(&data_dir)?; + std::fs::create_dir_all(&cache_dir)?; + + let runtime = Arc::new(WasmRuntime::new()?); + let registry = Arc::new(RwLock::new(PluginRegistry::new())); + let loader = PluginLoader::new(config.plugin_dirs.clone()); + let enforcer = CapabilityEnforcer::new(); + + Ok(Self { + registry, + runtime, + loader, + enforcer, + data_dir, + cache_dir, + config, + }) + } + + /// Discover and load all plugins from configured directories. + /// + /// Plugins are loaded in dependency order: if plugin A declares a + /// dependency on plugin B, B is loaded first. Cycles and missing + /// dependencies are detected and reported as warnings; affected plugins + /// are skipped rather than causing a hard failure. + /// + /// # Errors + /// + /// Returns an error if plugin discovery fails. + pub async fn discover_and_load_all(&self) -> Result> { + info!("Discovering plugins from {:?}", self.config.plugin_dirs); + + let manifests = self.loader.discover_plugins()?; + let ordered = Self::resolve_load_order(&manifests); + let mut loaded_plugins = Vec::new(); + + for manifest in ordered { + match self.load_plugin_from_manifest(&manifest).await { + Ok(plugin_id) => { + info!("Loaded plugin: {}", plugin_id); + loaded_plugins.push(plugin_id); + }, + Err(e) => { + warn!("Failed to load plugin {}: {}", manifest.plugin.name, e); + }, + } + } + + Ok(loaded_plugins) + } + + /// Topological sort of manifests by their declared `dependencies`. + /// + /// Uses Kahn's algorithm. Plugins whose dependencies are missing or form + /// a cycle are logged as warnings and excluded from the result. + fn resolve_load_order( + manifests: &[pinakes_plugin_api::PluginManifest], + ) -> Vec { + use std::collections::VecDeque; + + use rustc_hash::{FxHashMap, FxHashSet}; + + // Index manifests by name for O(1) lookup + let by_name: FxHashMap<&str, usize> = manifests + .iter() + .enumerate() + .map(|(i, m)| (m.plugin.name.as_str(), i)) + .collect(); + + // Check for missing dependencies and warn early + let known: FxHashSet<&str> = by_name.keys().copied().collect(); + for manifest in manifests { + for dep in &manifest.plugin.dependencies { + if !known.contains(dep.as_str()) { + warn!( + "Plugin '{}' depends on '{}' which was not discovered; it will be \ + skipped", + manifest.plugin.name, dep + ); + } + } + } + + // Build adjacency: in_degree[i] = number of deps that must load before i + let mut in_degree = vec![0usize; manifests.len()]; + // dependents[i] = indices that depend on i (i must load before them) + let mut dependents: Vec> = vec![vec![]; manifests.len()]; + + for (i, manifest) in manifests.iter().enumerate() { + for dep in &manifest.plugin.dependencies { + if let Some(&dep_idx) = by_name.get(dep.as_str()) { + in_degree[i] += 1; + dependents[dep_idx].push(i); + } else { + // Missing dep: set in_degree impossibly high so it never resolves + in_degree[i] = usize::MAX; + } + } + } + + // Kahn's algorithm + let mut queue: VecDeque = VecDeque::new(); + for (i, °) in in_degree.iter().enumerate() { + if deg == 0 { + queue.push_back(i); + } + } + + let mut result = Vec::with_capacity(manifests.len()); + while let Some(idx) = queue.pop_front() { + result.push(manifests[idx].clone()); + for &dependent in &dependents[idx] { + if in_degree[dependent] == usize::MAX { + continue; // already poisoned by missing dep + } + in_degree[dependent] -= 1; + if in_degree[dependent] == 0 { + queue.push_back(dependent); + } + } + } + + // Anything not in `result` is part of a cycle or has a missing dep + if result.len() < manifests.len() { + let loaded: FxHashSet<&str> = + result.iter().map(|m| m.plugin.name.as_str()).collect(); + for manifest in manifests { + if !loaded.contains(manifest.plugin.name.as_str()) { + warn!( + "Plugin '{}' was skipped due to unresolved dependencies or a \ + dependency cycle", + manifest.plugin.name + ); + } + } + } + + result + } + + /// Load a plugin from a manifest file + /// + /// # Errors + /// + /// Returns an error if the plugin ID is invalid, capability validation + /// fails, the WASM binary cannot be loaded, or the plugin cannot be + /// registered. + async fn load_plugin_from_manifest( + &self, + manifest: &pinakes_plugin_api::PluginManifest, + ) -> Result { + let plugin_id = manifest.plugin_id(); + + // Validate plugin_id to prevent path traversal + if plugin_id.contains('/') + || plugin_id.contains('\\') + || plugin_id.contains("..") + { + return Err(anyhow::anyhow!("Invalid plugin ID: {plugin_id}")); + } + + // Check if already loaded + { + let registry = self.registry.read().await; + if registry.is_loaded(&plugin_id) { + return Ok(plugin_id); + } + } + + // Validate capabilities + let capabilities = manifest.to_capabilities(); + self.enforcer.validate_capabilities(&capabilities)?; + + // Create plugin context + let plugin_data_dir = self.data_dir.join(&plugin_id); + let plugin_cache_dir = self.cache_dir.join(&plugin_id); + tokio::fs::create_dir_all(&plugin_data_dir).await?; + tokio::fs::create_dir_all(&plugin_cache_dir).await?; + + let context = PluginContext { + data_dir: plugin_data_dir, + cache_dir: plugin_cache_dir, + config: manifest + .config + .iter() + .map(|(k, v)| { + ( + k.clone(), + serde_json::to_value(v).unwrap_or_else(|e| { + tracing::warn!( + "failed to serialize config value for key {}: {}", + k, + e + ); + serde_json::Value::Null + }), + ) + }) + .collect(), + capabilities: capabilities.clone(), + }; + + // Load WASM binary + let wasm_path = self.loader.resolve_wasm_path(manifest)?; + + // Verify plugin signature unless unsigned plugins are allowed + if !self.config.allow_unsigned { + let plugin_dir = wasm_path + .parent() + .ok_or_else(|| anyhow::anyhow!("WASM path has no parent directory"))?; + + let trusted_keys: Vec = self + .config + .trusted_keys + .iter() + .filter_map(|hex| { + signature::parse_public_key(hex) + .map_err(|e| warn!("Ignoring malformed trusted key: {e}")) + .ok() + }) + .collect(); + + match signature::verify_plugin_signature( + plugin_dir, + &wasm_path, + &trusted_keys, + )? { + SignatureStatus::Valid => { + debug!("Plugin '{plugin_id}' signature verified"); + }, + SignatureStatus::Unsigned => { + return Err(anyhow::anyhow!( + "Plugin '{plugin_id}' is unsigned and allow_unsigned is false" + )); + }, + SignatureStatus::Invalid(reason) => { + return Err(anyhow::anyhow!( + "Plugin '{plugin_id}' has an invalid signature: {reason}" + )); + }, + } + } + + let wasm_plugin = self.runtime.load_plugin(&wasm_path, context)?; + + // Initialize plugin + let init_succeeded = match wasm_plugin + .call_function("initialize", &[]) + .await + { + Ok(_) => true, + Err(e) => { + tracing::warn!(plugin_id = %plugin_id, "plugin initialization failed: {}", e); + false + }, + }; + + // Register plugin + let metadata = PluginMetadata { + id: plugin_id.clone(), + name: manifest.plugin.name.clone(), + version: manifest.plugin.version.clone(), + author: manifest.plugin.author.clone().unwrap_or_default(), + description: manifest + .plugin + .description + .clone() + .unwrap_or_default(), + api_version: manifest.plugin.api_version.clone(), + capabilities_required: capabilities, + }; + + // Derive manifest_path from the loader's plugin directories + let manifest_path = self + .loader + .get_plugin_dir(&manifest.plugin.name) + .map(|dir| dir.join("plugin.toml")); + + let registered = RegisteredPlugin { + id: plugin_id.clone(), + metadata, + wasm_plugin, + manifest: manifest.clone(), + manifest_path, + enabled: init_succeeded, + }; + + { + let mut registry = self.registry.write().await; + registry.register(registered)?; + } + + Ok(plugin_id) + } + + /// Install a plugin from a file or URL + /// + /// # Errors + /// + /// Returns an error if the plugin cannot be downloaded, the manifest cannot + /// be read, or the plugin cannot be loaded. + pub async fn install_plugin(&self, source: &str) -> Result { + info!("Installing plugin from: {}", source); + + // Download/copy plugin to plugins directory + let plugin_path = + if source.starts_with("http://") || source.starts_with("https://") { + // Download from URL + self.loader.download_plugin(source).await? + } else { + // Copy from local file + PathBuf::from(source) + }; + + // Load the manifest + let manifest_path = plugin_path.join("plugin.toml"); + let manifest = + pinakes_plugin_api::PluginManifest::from_file(&manifest_path)?; + + // Load the plugin + self.load_plugin_from_manifest(&manifest).await + } + + /// Uninstall a plugin + /// + /// # Errors + /// + /// Returns an error if the plugin ID is invalid, the plugin cannot be shut + /// down, cannot be unregistered, or its data directories cannot be removed. + pub async fn uninstall_plugin(&self, plugin_id: &str) -> Result<()> { + // Validate plugin_id to prevent path traversal + if plugin_id.contains('/') + || plugin_id.contains('\\') + || plugin_id.contains("..") + { + return Err(anyhow::anyhow!("Invalid plugin ID: {plugin_id}")); + } + + info!("Uninstalling plugin: {}", plugin_id); + + // Shutdown plugin first + self.shutdown_plugin(plugin_id).await?; + + // Remove from registry + { + let mut registry = self.registry.write().await; + registry.unregister(plugin_id)?; + } + + // Remove plugin data and cache + let plugin_data_dir = self.data_dir.join(plugin_id); + let plugin_cache_dir = self.cache_dir.join(plugin_id); + + if plugin_data_dir.exists() { + std::fs::remove_dir_all(&plugin_data_dir)?; + } + if plugin_cache_dir.exists() { + std::fs::remove_dir_all(&plugin_cache_dir)?; + } + + Ok(()) + } + + /// Enable a plugin + /// + /// # Errors + /// + /// Returns an error if the plugin ID is not found in the registry. + pub async fn enable_plugin(&self, plugin_id: &str) -> Result<()> { + let mut registry = self.registry.write().await; + registry.enable(plugin_id) + } + + /// Disable a plugin + /// + /// # Errors + /// + /// Returns an error if the plugin ID is not found in the registry. + pub async fn disable_plugin(&self, plugin_id: &str) -> Result<()> { + let mut registry = self.registry.write().await; + registry.disable(plugin_id) + } + + /// Shutdown a specific plugin + /// + /// # Errors + /// + /// Returns an error if the plugin ID is not found in the registry. + pub async fn shutdown_plugin(&self, plugin_id: &str) -> Result<()> { + debug!("Shutting down plugin: {}", plugin_id); + + let registry = self.registry.read().await; + if let Some(plugin) = registry.get(plugin_id) { + let _ = plugin.wasm_plugin.call_function("shutdown", &[]).await; + Ok(()) + } else { + Err(anyhow::anyhow!("Plugin not found: {plugin_id}")) + } + } + + /// Shutdown all plugins + /// + /// # Errors + /// + /// This function always returns `Ok(())`. Individual plugin shutdown errors + /// are logged but do not cause the overall operation to fail. + pub async fn shutdown_all(&self) -> Result<()> { + info!("Shutting down all plugins"); + + let plugin_ids: Vec = { + let registry = self.registry.read().await; + registry.list_all().iter().map(|p| p.id.clone()).collect() + }; + + for plugin_id in plugin_ids { + if let Err(e) = self.shutdown_plugin(&plugin_id).await { + error!("Failed to shutdown plugin {}: {}", plugin_id, e); + } + } + + Ok(()) + } + + /// Get list of all registered plugins + pub async fn list_plugins(&self) -> Vec { + let registry = self.registry.read().await; + registry + .list_all() + .iter() + .map(|p| p.metadata.clone()) + .collect() + } + + /// Get plugin metadata by ID + pub async fn get_plugin(&self, plugin_id: &str) -> Option { + let registry = self.registry.read().await; + registry.get(plugin_id).map(|p| p.metadata.clone()) + } + + /// Get enabled plugins of a specific kind, sorted by priority (ascending). + /// + /// # Returns + /// + /// `(plugin_id, priority, kinds, wasm_plugin)` tuples. + pub async fn get_enabled_by_kind_sorted( + &self, + kind: &str, + ) -> Vec<(String, u16, Vec, WasmPlugin)> { + let registry = self.registry.read().await; + let mut plugins: Vec<_> = registry + .get_by_kind(kind) + .into_iter() + .filter(|p| p.enabled) + .map(|p| { + ( + p.id.clone(), + p.manifest.plugin.priority, + p.manifest.plugin.kind.clone(), + p.wasm_plugin.clone(), + ) + }) + .collect(); + drop(registry); + plugins.sort_by_key(|(_, priority, ..)| *priority); + plugins + } + + /// Get a reference to the capability enforcer. + #[must_use] + pub const fn enforcer(&self) -> &CapabilityEnforcer { + &self.enforcer + } + + /// List all UI pages provided by loaded plugins. + /// + /// Returns a vector of `(plugin_id, page)` tuples for all enabled plugins + /// that provide pages in their manifests. Both inline and file-referenced + /// page entries are resolved. + pub async fn list_ui_pages( + &self, + ) -> Vec<(String, pinakes_plugin_api::UiPage)> { + self + .list_ui_pages_with_endpoints() + .await + .into_iter() + .map(|(id, page, _)| (id, page)) + .collect() + } + + /// List all UI pages provided by loaded plugins, including each plugin's + /// declared endpoint allowlist. + /// + /// Returns a vector of `(plugin_id, page, allowed_endpoints)` tuples. The + /// `allowed_endpoints` list mirrors the `required_endpoints` field from the + /// plugin manifest's `[ui]` section. + pub async fn list_ui_pages_with_endpoints( + &self, + ) -> Vec<(String, pinakes_plugin_api::UiPage, Vec)> { + let registry = self.registry.read().await; + let mut pages = Vec::new(); + for plugin in registry.list_all() { + if !plugin.enabled { + continue; + } + let allowed = plugin.manifest.ui.required_endpoints.clone(); + let plugin_dir = plugin + .manifest_path + .as_ref() + .and_then(|p| p.parent()) + .map(std::path::Path::to_path_buf); + let Some(plugin_dir) = plugin_dir else { + for entry in &plugin.manifest.ui.pages { + if let pinakes_plugin_api::manifest::UiPageEntry::Inline(page) = entry + { + pages.push((plugin.id.clone(), (**page).clone(), allowed.clone())); + } + } + continue; + }; + match plugin.manifest.load_ui_pages(&plugin_dir) { + Ok(loaded) => { + for page in loaded { + pages.push((plugin.id.clone(), page, allowed.clone())); + } + }, + Err(e) => { + tracing::warn!( + "Failed to load UI pages for plugin '{}': {e}", + plugin.id + ); + }, + } + } + pages + } + + /// Collect CSS custom property overrides declared by all enabled plugins. + /// + /// When multiple plugins declare the same property name, later-loaded plugins + /// overwrite earlier ones. Returns an empty map if no plugins are loaded or + /// none declare theme extensions. + pub async fn list_ui_theme_extensions( + &self, + ) -> rustc_hash::FxHashMap { + let registry = self.registry.read().await; + let mut merged = rustc_hash::FxHashMap::default(); + for plugin in registry.list_all() { + if !plugin.enabled { + continue; + } + for (k, v) in &plugin.manifest.ui.theme_extensions { + merged.insert(k.clone(), v.clone()); + } + } + merged + } + + /// List all UI widgets provided by loaded plugins. + /// + /// Returns a vector of `(plugin_id, widget)` tuples for all enabled plugins + /// that provide widgets in their manifests. + pub async fn list_ui_widgets( + &self, + ) -> Vec<(String, pinakes_plugin_api::UiWidget)> { + let registry = self.registry.read().await; + let mut widgets = Vec::new(); + for plugin in registry.list_all() { + if !plugin.enabled { + continue; + } + for widget in &plugin.manifest.ui.widgets { + widgets.push((plugin.id.clone(), widget.clone())); + } + } + widgets + } + + /// Check if a plugin is loaded and enabled + pub async fn is_plugin_enabled(&self, plugin_id: &str) -> bool { + let registry = self.registry.read().await; + registry.is_enabled(plugin_id).unwrap_or(false) + } + + /// Reload a plugin (for hot-reload during development) + /// + /// # Errors + /// + /// Returns an error if hot-reload is disabled, the plugin is not found, it + /// cannot be shut down, or the reloaded plugin cannot be registered. + pub async fn reload_plugin(&self, plugin_id: &str) -> Result<()> { + if !self.config.enable_hot_reload { + return Err(anyhow::anyhow!("Hot-reload is disabled")); + } + + info!("Reloading plugin: {}", plugin_id); + + // Re-read the manifest from disk if possible, falling back to cached + // version + let manifest = { + let registry = self.registry.read().await; + let plugin = registry + .get(plugin_id) + .ok_or_else(|| anyhow::anyhow!("Plugin not found"))?; + let manifest = plugin.manifest_path.as_ref().map_or_else( + || plugin.manifest.clone(), + |manifest_path| { + pinakes_plugin_api::PluginManifest::from_file(manifest_path) + .unwrap_or_else(|e| { + warn!( + "Failed to re-read manifest from disk, using cached: {}", + e + ); + plugin.manifest.clone() + }) + }, + ); + drop(registry); + manifest + }; + + // Shutdown and unload current version + self.shutdown_plugin(plugin_id).await?; + { + let mut registry = self.registry.write().await; + registry.unregister(plugin_id)?; + } + + // Reload from manifest + self.load_plugin_from_manifest(&manifest).await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use tempfile::TempDir; + + use super::*; + + #[tokio::test] + async fn test_plugin_manager_creation() { + let temp_dir = TempDir::new().unwrap(); + let data_dir = temp_dir.path().join("data"); + let cache_dir = temp_dir.path().join("cache"); + + let config = PluginManagerConfig::default(); + let manager = + PluginManager::new(data_dir.clone(), cache_dir.clone(), config); + + assert!(manager.is_ok()); + assert!(data_dir.exists()); + assert!(cache_dir.exists()); + } + + #[tokio::test] + async fn test_list_plugins_empty() { + let temp_dir = TempDir::new().unwrap(); + let data_dir = temp_dir.path().join("data"); + let cache_dir = temp_dir.path().join("cache"); + + let config = PluginManagerConfig::default(); + let manager = PluginManager::new(data_dir, cache_dir, config).unwrap(); + + let plugins = manager.list_plugins().await; + assert_eq!(plugins.len(), 0); + } + + /// Build a minimal manifest for dependency resolution tests + fn test_manifest( + name: &str, + deps: Vec, + ) -> pinakes_plugin_api::PluginManifest { + use pinakes_plugin_api::manifest::{PluginBinary, PluginInfo}; + + pinakes_plugin_api::PluginManifest { + plugin: PluginInfo { + name: name.to_string(), + version: "1.0.0".to_string(), + api_version: "1.0".to_string(), + author: None, + description: None, + homepage: None, + license: None, + priority: 500, + kind: vec!["media_type".to_string()], + binary: PluginBinary { + wasm: "plugin.wasm".to_string(), + entrypoint: None, + }, + dependencies: deps, + }, + capabilities: Default::default(), + config: Default::default(), + ui: Default::default(), + } + } + + #[test] + fn test_resolve_load_order_no_deps() { + let manifests = vec![ + test_manifest("alpha", vec![]), + test_manifest("beta", vec![]), + test_manifest("gamma", vec![]), + ]; + + let ordered = PluginManager::resolve_load_order(&manifests); + assert_eq!(ordered.len(), 3); + } + + #[test] + fn test_resolve_load_order_linear_chain() { + // gamma depends on beta, beta depends on alpha + let manifests = vec![ + test_manifest("gamma", vec!["beta".to_string()]), + test_manifest("alpha", vec![]), + test_manifest("beta", vec!["alpha".to_string()]), + ]; + + let ordered = PluginManager::resolve_load_order(&manifests); + assert_eq!(ordered.len(), 3); + + let names: Vec<&str> = + ordered.iter().map(|m| m.plugin.name.as_str()).collect(); + let alpha_pos = names.iter().position(|&n| n == "alpha").unwrap(); + let beta_pos = names.iter().position(|&n| n == "beta").unwrap(); + let gamma_pos = names.iter().position(|&n| n == "gamma").unwrap(); + assert!(alpha_pos < beta_pos, "alpha must load before beta"); + assert!(beta_pos < gamma_pos, "beta must load before gamma"); + } + + #[test] + fn test_resolve_load_order_cycle_detected() { + // A -> B -> C -> A (cycle) + let manifests = vec![ + test_manifest("a", vec!["c".to_string()]), + test_manifest("b", vec!["a".to_string()]), + test_manifest("c", vec!["b".to_string()]), + ]; + + let ordered = PluginManager::resolve_load_order(&manifests); + // All three should be excluded due to cycle + assert_eq!(ordered.len(), 0); + } + + #[test] + fn test_resolve_load_order_missing_dependency() { + let manifests = vec![ + test_manifest("good", vec![]), + test_manifest("bad", vec!["nonexistent".to_string()]), + ]; + + let ordered = PluginManager::resolve_load_order(&manifests); + // Only "good" should be loaded; "bad" depends on something missing + assert_eq!(ordered.len(), 1); + assert_eq!(ordered[0].plugin.name, "good"); + } + + #[test] + fn test_resolve_load_order_partial_cycle() { + // "ok" has no deps, "cycle_a" and "cycle_b" form a cycle + let manifests = vec![ + test_manifest("ok", vec![]), + test_manifest("cycle_a", vec!["cycle_b".to_string()]), + test_manifest("cycle_b", vec!["cycle_a".to_string()]), + ]; + + let ordered = PluginManager::resolve_load_order(&manifests); + assert_eq!(ordered.len(), 1); + assert_eq!(ordered[0].plugin.name, "ok"); + } + + #[test] + fn test_resolve_load_order_diamond() { + // Man look at how beautiful my diamond is... + // A + // / \ + // B C + // \ / + // D + let manifests = vec![ + test_manifest("d", vec!["b".to_string(), "c".to_string()]), + test_manifest("b", vec!["a".to_string()]), + test_manifest("c", vec!["a".to_string()]), + test_manifest("a", vec![]), + ]; + + let ordered = PluginManager::resolve_load_order(&manifests); + assert_eq!(ordered.len(), 4); + + let names: Vec<&str> = + ordered.iter().map(|m| m.plugin.name.as_str()).collect(); + let a_pos = names.iter().position(|&n| n == "a").unwrap(); + let b_pos = names.iter().position(|&n| n == "b").unwrap(); + let c_pos = names.iter().position(|&n| n == "c").unwrap(); + let d_pos = names.iter().position(|&n| n == "d").unwrap(); + assert!(a_pos < b_pos); + assert!(a_pos < c_pos); + assert!(b_pos < d_pos); + assert!(c_pos < d_pos); + } +} diff --git a/crates/pinakes-plugin/src/registry.rs b/crates/pinakes-plugin/src/registry.rs new file mode 100644 index 0000000..ce13d86 --- /dev/null +++ b/crates/pinakes-plugin/src/registry.rs @@ -0,0 +1,309 @@ +//! Plugin registry for managing loaded plugins + +use std::path::PathBuf; + +use anyhow::{Result, anyhow}; +use pinakes_plugin_api::{PluginManifest, PluginMetadata}; +use rustc_hash::FxHashMap; + +use super::runtime::WasmPlugin; + +/// A registered plugin with its metadata and runtime state +#[derive(Clone)] +pub struct RegisteredPlugin { + pub id: String, + pub metadata: PluginMetadata, + pub wasm_plugin: WasmPlugin, + pub manifest: PluginManifest, + pub manifest_path: Option, + pub enabled: bool, +} + +/// Plugin registry maintains the state of all loaded plugins +pub struct PluginRegistry { + /// Map of plugin ID to registered plugin + plugins: FxHashMap, +} + +impl PluginRegistry { + /// Create a new empty registry + #[must_use] + pub fn new() -> Self { + Self { + plugins: FxHashMap::default(), + } + } + + /// Register a new plugin + /// + /// # Errors + /// + /// Returns an error if a plugin with the same ID is already registered. + pub fn register(&mut self, plugin: RegisteredPlugin) -> Result<()> { + if self.plugins.contains_key(&plugin.id) { + return Err(anyhow!("Plugin already registered: {}", plugin.id)); + } + + self.plugins.insert(plugin.id.clone(), plugin); + Ok(()) + } + + /// Unregister a plugin by ID + /// + /// # Errors + /// + /// Returns an error if the plugin ID is not found. + pub fn unregister(&mut self, plugin_id: &str) -> Result<()> { + self + .plugins + .remove(plugin_id) + .ok_or_else(|| anyhow!("Plugin not found: {plugin_id}"))?; + Ok(()) + } + + /// Get a plugin by ID + #[must_use] + pub fn get(&self, plugin_id: &str) -> Option<&RegisteredPlugin> { + self.plugins.get(plugin_id) + } + + /// Get a mutable reference to a plugin by ID + pub fn get_mut(&mut self, plugin_id: &str) -> Option<&mut RegisteredPlugin> { + self.plugins.get_mut(plugin_id) + } + + /// Check if a plugin is loaded + #[must_use] + pub fn is_loaded(&self, plugin_id: &str) -> bool { + self.plugins.contains_key(plugin_id) + } + + /// Check if a plugin is enabled. Returns `None` if the plugin is not found. + #[must_use] + pub fn is_enabled(&self, plugin_id: &str) -> Option { + self.plugins.get(plugin_id).map(|p| p.enabled) + } + + /// Enable a plugin + /// + /// # Errors + /// + /// Returns an error if the plugin ID is not found. + pub fn enable(&mut self, plugin_id: &str) -> Result<()> { + let plugin = self + .plugins + .get_mut(plugin_id) + .ok_or_else(|| anyhow!("Plugin not found: {plugin_id}"))?; + + plugin.enabled = true; + Ok(()) + } + + /// Disable a plugin + /// + /// # Errors + /// + /// Returns an error if the plugin ID is not found. + pub fn disable(&mut self, plugin_id: &str) -> Result<()> { + let plugin = self + .plugins + .get_mut(plugin_id) + .ok_or_else(|| anyhow!("Plugin not found: {plugin_id}"))?; + + plugin.enabled = false; + Ok(()) + } + + /// List all registered plugins + #[must_use] + pub fn list_all(&self) -> Vec<&RegisteredPlugin> { + self.plugins.values().collect() + } + + /// List all enabled plugins + #[must_use] + pub fn list_enabled(&self) -> Vec<&RegisteredPlugin> { + self.plugins.values().filter(|p| p.enabled).collect() + } + + /// Get plugins by kind (e.g., "`media_type`", "`metadata_extractor`") + #[must_use] + pub fn get_by_kind(&self, kind: &str) -> Vec<&RegisteredPlugin> { + self + .plugins + .values() + .filter(|p| p.manifest.plugin.kind.iter().any(|k| k == kind)) + .collect() + } + + /// Get count of registered plugins + #[must_use] + pub fn count(&self) -> usize { + self.plugins.len() + } + + /// Get count of enabled plugins + #[must_use] + pub fn count_enabled(&self) -> usize { + self.plugins.values().filter(|p| p.enabled).count() + } +} + +impl Default for PluginRegistry { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use pinakes_plugin_api::{Capabilities, manifest::ManifestCapabilities}; + use rustc_hash::FxHashMap; + + use super::*; + + fn create_test_plugin(id: &str, kind: Vec) -> RegisteredPlugin { + let manifest = PluginManifest { + plugin: pinakes_plugin_api::manifest::PluginInfo { + name: id.to_string(), + version: "1.0.0".to_string(), + api_version: "1.0".to_string(), + author: Some("Test".to_string()), + description: Some("Test plugin".to_string()), + homepage: None, + license: None, + kind, + binary: pinakes_plugin_api::manifest::PluginBinary { + wasm: "test.wasm".to_string(), + entrypoint: None, + }, + dependencies: vec![], + priority: 0, + }, + capabilities: ManifestCapabilities::default(), + config: FxHashMap::default(), + ui: Default::default(), + }; + + RegisteredPlugin { + id: id.to_string(), + metadata: PluginMetadata { + id: id.to_string(), + name: id.to_string(), + version: "1.0.0".to_string(), + author: "Test".to_string(), + description: "Test plugin".to_string(), + api_version: "1.0".to_string(), + capabilities_required: Capabilities::default(), + }, + wasm_plugin: WasmPlugin::default(), + manifest, + manifest_path: None, + enabled: true, + } + } + + #[test] + fn test_registry_register_and_get() { + let mut registry = PluginRegistry::new(); + let plugin = + create_test_plugin("test-plugin", vec!["media_type".to_string()]); + + registry.register(plugin).unwrap(); + + assert!(registry.is_loaded("test-plugin")); + assert!(registry.get("test-plugin").is_some()); + } + + #[test] + fn test_registry_duplicate_register() { + let mut registry = PluginRegistry::new(); + let plugin = + create_test_plugin("test-plugin", vec!["media_type".to_string()]); + + registry.register(plugin.clone()).unwrap(); + let result = registry.register(plugin); + + assert!(result.is_err()); + } + + #[test] + fn test_registry_unregister() { + let mut registry = PluginRegistry::new(); + let plugin = + create_test_plugin("test-plugin", vec!["media_type".to_string()]); + + registry.register(plugin).unwrap(); + registry.unregister("test-plugin").unwrap(); + + assert!(!registry.is_loaded("test-plugin")); + } + + #[test] + fn test_registry_enable_disable() { + let mut registry = PluginRegistry::new(); + let plugin = + create_test_plugin("test-plugin", vec!["media_type".to_string()]); + + registry.register(plugin).unwrap(); + assert_eq!(registry.is_enabled("test-plugin"), Some(true)); + + registry.disable("test-plugin").unwrap(); + assert_eq!(registry.is_enabled("test-plugin"), Some(false)); + + registry.enable("test-plugin").unwrap(); + assert_eq!(registry.is_enabled("test-plugin"), Some(true)); + + assert_eq!(registry.is_enabled("nonexistent"), None); + } + + #[test] + fn test_registry_get_by_kind() { + let mut registry = PluginRegistry::new(); + + registry + .register(create_test_plugin("plugin1", vec![ + "media_type".to_string(), + ])) + .unwrap(); + registry + .register(create_test_plugin("plugin2", vec![ + "metadata_extractor".to_string(), + ])) + .unwrap(); + registry + .register(create_test_plugin("plugin3", vec![ + "media_type".to_string(), + ])) + .unwrap(); + + let media_type_plugins = registry.get_by_kind("media_type"); + assert_eq!(media_type_plugins.len(), 2); + + let extractor_plugins = registry.get_by_kind("metadata_extractor"); + assert_eq!(extractor_plugins.len(), 1); + } + + #[test] + fn test_registry_counts() { + let mut registry = PluginRegistry::new(); + + registry + .register(create_test_plugin("plugin1", vec![ + "media_type".to_string(), + ])) + .unwrap(); + registry + .register(create_test_plugin("plugin2", vec![ + "media_type".to_string(), + ])) + .unwrap(); + + assert_eq!(registry.count(), 2); + assert_eq!(registry.count_enabled(), 2); + + registry.disable("plugin1").unwrap(); + assert_eq!(registry.count(), 2); + assert_eq!(registry.count_enabled(), 1); + } +} diff --git a/crates/pinakes-plugin/src/rpc.rs b/crates/pinakes-plugin/src/rpc.rs new file mode 100644 index 0000000..e875d11 --- /dev/null +++ b/crates/pinakes-plugin/src/rpc.rs @@ -0,0 +1,240 @@ +//! JSON RPC types for structured plugin function calls. +//! +//! Each extension point maps to well-known exported function names. +//! Requests are serialized to JSON, passed to the plugin, and responses +//! are deserialized from JSON written by the plugin via `host_set_result`. + +use std::path::PathBuf; + +use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; + +/// Request to check if a plugin can handle a file +#[derive(Debug, Serialize)] +pub struct CanHandleRequest { + pub path: PathBuf, + pub mime_type: Option, +} + +/// Response from `can_handle` +#[derive(Debug, Deserialize)] +pub struct CanHandleResponse { + pub can_handle: bool, +} + +/// Media type definition returned by `supported_media_types` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginMediaTypeDefinition { + pub id: String, + pub name: String, + pub category: Option, + pub extensions: Vec, + pub mime_types: Vec, +} + +/// Request to extract metadata from a file +#[derive(Debug, Serialize)] +pub struct ExtractMetadataRequest { + pub path: PathBuf, +} + +/// Metadata response from a plugin (all fields optional for partial results) +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct ExtractMetadataResponse { + #[serde(default)] + pub title: Option, + #[serde(default)] + pub artist: Option, + #[serde(default)] + pub album: Option, + #[serde(default)] + pub genre: Option, + #[serde(default)] + pub year: Option, + #[serde(default)] + pub duration_secs: Option, + #[serde(default)] + pub description: Option, + #[serde(default)] + pub extra: FxHashMap, +} + +/// Request to generate a thumbnail +#[derive(Debug, Serialize)] +pub struct GenerateThumbnailRequest { + pub source_path: PathBuf, + pub output_path: PathBuf, + pub max_width: u32, + pub max_height: u32, + pub format: String, +} + +/// Response from thumbnail generation +#[derive(Debug, Deserialize)] +pub struct GenerateThumbnailResponse { + pub path: PathBuf, + pub width: u32, + pub height: u32, + pub format: String, +} + +/// Event sent to event handler plugins +#[derive(Debug, Serialize)] +pub struct HandleEventRequest { + pub event_type: String, + pub payload: serde_json::Value, +} + +/// Search request for search backend plugins +#[derive(Debug, Serialize)] +pub struct SearchRequest { + pub query: String, + pub limit: usize, + pub offset: usize, +} + +/// Search response +#[derive(Debug, Clone, Deserialize)] +pub struct SearchResponse { + pub results: Vec, + #[serde(default)] + pub total_count: Option, +} + +/// Individual search result +#[derive(Debug, Clone, Deserialize)] +pub struct SearchResultItem { + pub id: String, + pub score: f64, + pub snippet: Option, +} + +/// Request to index a media item in a search backend +#[derive(Debug, Serialize)] +pub struct IndexItemRequest { + pub id: String, + pub title: Option, + pub artist: Option, + pub album: Option, + pub description: Option, + pub tags: Vec, + pub media_type: String, + pub path: PathBuf, +} + +/// Request to remove a media item from a search backend +#[derive(Debug, Serialize)] +pub struct RemoveItemRequest { + pub id: String, +} + +/// A theme definition returned by a theme provider plugin +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginThemeDefinition { + pub id: String, + pub name: String, + pub description: Option, + pub dark: bool, +} + +/// Response from `load_theme` +#[derive(Debug, Clone, Deserialize)] +pub struct LoadThemeResponse { + pub css: Option, + pub colors: FxHashMap, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_metadata_request_serialization() { + let req = ExtractMetadataRequest { + path: "/tmp/test.mp3".into(), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("/tmp/test.mp3")); + } + + #[test] + fn test_extract_metadata_response_partial() { + let json = r#"{"title":"My Song","extra":{"bpm":"120"}}"#; + let resp: ExtractMetadataResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.title.as_deref(), Some("My Song")); + assert_eq!(resp.artist, None); + assert_eq!(resp.extra.get("bpm").map(String::as_str), Some("120")); + } + + #[test] + fn test_extract_metadata_response_empty() { + let json = "{}"; + let resp: ExtractMetadataResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.title, None); + assert!(resp.extra.is_empty()); + } + + #[test] + fn test_can_handle_response() { + let json = r#"{"can_handle":true}"#; + let resp: CanHandleResponse = serde_json::from_str(json).unwrap(); + assert!(resp.can_handle); + } + + #[test] + fn test_can_handle_response_false() { + let json = r#"{"can_handle":false}"#; + let resp: CanHandleResponse = serde_json::from_str(json).unwrap(); + assert!(!resp.can_handle); + } + + #[test] + fn test_plugin_media_type_definition_round_trip() { + let def = PluginMediaTypeDefinition { + id: "heif".to_string(), + name: "HEIF Image".to_string(), + category: Some("image".to_string()), + extensions: vec!["heif".to_string(), "heic".to_string()], + mime_types: vec!["image/heif".to_string()], + }; + let json = serde_json::to_string(&def).unwrap(); + let parsed: PluginMediaTypeDefinition = + serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.id, "heif"); + assert_eq!(parsed.extensions.len(), 2); + } + + #[test] + fn test_search_response() { + let json = + r#"{"results":[{"id":"abc","score":0.95,"snippet":"match here"}]}"#; + let resp: SearchResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.results.len(), 1); + assert_eq!(resp.results[0].id, "abc"); + } + + #[test] + fn test_generate_thumbnail_request_serialization() { + let req = GenerateThumbnailRequest { + source_path: "/media/photo.heif".into(), + output_path: "/tmp/thumb.jpg".into(), + max_width: 256, + max_height: 256, + format: "jpeg".to_string(), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("photo.heif")); + assert!(json.contains("256")); + } + + #[test] + fn test_handle_event_request_serialization() { + let req = HandleEventRequest { + event_type: "MediaImported".to_string(), + payload: serde_json::json!({"id": "abc-123"}), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("MediaImported")); + assert!(json.contains("abc-123")); + } +} diff --git a/crates/pinakes-plugin/src/runtime.rs b/crates/pinakes-plugin/src/runtime.rs new file mode 100644 index 0000000..e07a1c4 --- /dev/null +++ b/crates/pinakes-plugin/src/runtime.rs @@ -0,0 +1,925 @@ +//! WASM runtime for executing plugins + +use std::{path::Path, sync::Arc}; + +use anyhow::{Result, anyhow}; +use pinakes_plugin_api::PluginContext; +use wasmtime::{ + Caller, + Config, + Engine, + Linker, + Module, + Store, + StoreLimitsBuilder, + Val, + anyhow, +}; + +/// WASM runtime wrapper for executing plugins +pub struct WasmRuntime { + engine: Engine, +} + +impl WasmRuntime { + /// Create a new WASM runtime + /// + /// # Errors + /// + /// Returns an error if the WASM engine cannot be created with the given + /// configuration. + pub fn new() -> Result { + let mut config = Config::new(); + config.wasm_component_model(true); + config.max_wasm_stack(1024 * 1024); // 1MB stack + config.consume_fuel(true); // enable fuel metering for CPU limits + + let engine = Engine::new(&config)?; + + Ok(Self { engine }) + } + + /// Load a plugin from a WASM file + /// + /// # Errors + /// + /// Returns an error if the WASM file does not exist, cannot be read, or + /// cannot be compiled. + pub fn load_plugin( + &self, + wasm_path: &Path, + context: PluginContext, + ) -> Result { + if !wasm_path.exists() { + return Err(anyhow!("WASM file not found: {}", wasm_path.display())); + } + + let wasm_bytes = std::fs::read(wasm_path)?; + let module = Module::new(&self.engine, &wasm_bytes)?; + + Ok(WasmPlugin { + module: Arc::new(module), + context, + }) + } +} + +/// Store data passed to each WASM invocation +pub struct PluginStoreData { + pub context: PluginContext, + pub exchange_buffer: Vec, + pub pending_events: Vec<(String, String)>, + pub limiter: wasmtime::StoreLimits, +} + +/// A loaded WASM plugin instance +#[derive(Clone)] +pub struct WasmPlugin { + module: Arc, + context: PluginContext, +} + +impl WasmPlugin { + /// Get the plugin context + #[must_use] + pub const fn context(&self) -> &PluginContext { + &self.context + } + + /// Execute a plugin function, returning both the result bytes and any + /// events the plugin queued via `host_emit_event`. + /// + /// Creates a fresh store and instance per invocation with host functions + /// linked, calls the requested exported function, drains both the exchange + /// buffer and the pending events list before the store is dropped, and + /// returns both. + /// + /// # Errors + /// + /// Returns an error if the function cannot be found, instantiation fails, + /// or the function call returns an error. + pub async fn call_function_with_events( + &self, + function_name: &str, + params: &[u8], + ) -> Result<(Vec, Vec<(String, String)>)> { + let engine = self.module.engine(); + + // Build memory limiter from capabilities + let memory_limit = self + .context + .capabilities + .max_memory_bytes + .unwrap_or(512 * 1024 * 1024); // default 512 MB + + let limiter = StoreLimitsBuilder::new().memory_size(memory_limit).build(); + + let store_data = PluginStoreData { + context: self.context.clone(), + exchange_buffer: Vec::new(), + pending_events: Vec::new(), + limiter, + }; + let mut store = Store::new(engine, store_data); + store.limiter(|data| &mut data.limiter); + + // Set fuel limit based on capabilities + if let Some(max_cpu_time_ms) = self.context.capabilities.max_cpu_time_ms { + let fuel = max_cpu_time_ms * 100_000; + store.set_fuel(fuel)?; + } else { + store.set_fuel(1_000_000_000)?; + } + + let mut linker = Linker::new(engine); + HostFunctions::setup_linker(&mut linker)?; + + let instance = linker.instantiate_async(&mut store, &self.module).await?; + + let memory = instance.get_memory(&mut store, "memory"); + + // If there are params and memory is available, write them to the module + let mut alloc_offset: i32 = 0; + if !params.is_empty() + && let Some(mem) = &memory + { + // Call the plugin's alloc function if available, otherwise write at + // offset 0 + let offset = if let Ok(alloc) = + instance.get_typed_func::(&mut store, "alloc") + { + let result = alloc + .call_async( + &mut store, + i32::try_from(params.len()).unwrap_or(i32::MAX), + ) + .await?; + if result < 0 { + return Err(anyhow!( + "plugin alloc returned negative offset: {result}" + )); + } + u32::try_from(result).unwrap_or(0) as usize + } else { + 0 + }; + + alloc_offset = i32::try_from(offset).unwrap_or(i32::MAX); + let mem_data = mem.data_mut(&mut store); + if offset + params.len() <= mem_data.len() { + mem_data[offset..offset + params.len()].copy_from_slice(params); + } + } + + let func = + instance + .get_func(&mut store, function_name) + .ok_or_else(|| { + anyhow!("exported function '{function_name}' not found") + })?; + + let func_ty = func.ty(&store); + let param_count = func_ty.params().len(); + let result_count = func_ty.results().len(); + + let mut results = vec![Val::I32(0); result_count]; + + // Call with appropriate params based on function signature; convention: + // (ptr, len) + if param_count == 2 && !params.is_empty() { + func + .call_async( + &mut store, + &[ + Val::I32(alloc_offset), + Val::I32(i32::try_from(params.len()).unwrap_or(i32::MAX)), + ], + &mut results, + ) + .await?; + } else if param_count == 0 { + func.call_async(&mut store, &[], &mut results).await?; + } else { + // Generic: fill with zeroes + let params_vals: Vec = + std::iter::repeat_n(Val::I32(0), param_count).collect(); + func + .call_async(&mut store, ¶ms_vals, &mut results) + .await?; + } + + // Drain both buffers before the store is dropped. + let pending_events = std::mem::take(&mut store.data_mut().pending_events); + let exchange = std::mem::take(&mut store.data_mut().exchange_buffer); + + let result = if !exchange.is_empty() { + exchange + } else if let Some(Val::I32(ret)) = results.first() { + ret.to_le_bytes().to_vec() + } else { + Vec::new() + }; + + Ok((result, pending_events)) + } + + /// Execute a plugin function, discarding any events the plugin queued. + /// + /// This is a thin wrapper around [`Self::call_function_with_events`]. + /// + /// # Errors + /// + /// Returns an error if the function cannot be found, instantiation fails, + /// or the function call returns an error. + pub async fn call_function( + &self, + function_name: &str, + params: &[u8], + ) -> Result> { + let (data, _events) = self + .call_function_with_events(function_name, params) + .await?; + Ok(data) + } + + /// Call a plugin function with JSON request/response serialization. + /// + /// Serializes `request` to JSON, calls the named function, deserializes + /// the response. Wraps the call with `tokio::time::timeout`. + /// + /// # Errors + /// + /// Returns an error if serialization fails, the call times out, the plugin + /// traps, or the response is malformed JSON. + #[allow(clippy::future_not_send)] // Req doesn't need Sync; called within local tasks + pub async fn call_function_json( + &self, + function_name: &str, + request: &Req, + timeout: std::time::Duration, + ) -> anyhow::Result + where + Req: serde::Serialize, + Resp: serde::de::DeserializeOwned, + { + let request_bytes = serde_json::to_vec(request) + .map_err(|e| anyhow::anyhow!("failed to serialize request: {e}"))?; + + let result = tokio::time::timeout( + timeout, + self.call_function(function_name, &request_bytes), + ) + .await + .map_err(|_| { + anyhow::anyhow!( + "plugin call '{function_name}' timed out after {timeout:?}" + ) + })??; + + serde_json::from_slice(&result).map_err(|e| { + anyhow::anyhow!( + "failed to deserialize response from '{function_name}': {e}" + ) + }) + } + + /// Call a plugin function with JSON serialization, also returning any + /// events the plugin queued via `host_emit_event`. + /// + /// Mirrors [`Self::call_function_json`] but delegates to + /// [`Self::call_function_with_events`] so the pending events list is not + /// discarded before returning. + /// + /// # Errors + /// + /// Returns an error if serialization fails, the call times out, the plugin + /// traps, or the response is malformed JSON. + #[allow(clippy::future_not_send)] // Req doesn't need Sync; called within local tasks + pub async fn call_function_json_with_events( + &self, + function_name: &str, + request: &Req, + timeout: std::time::Duration, + ) -> anyhow::Result<(Resp, Vec<(String, String)>)> + where + Req: serde::Serialize, + Resp: serde::de::DeserializeOwned, + { + let request_bytes = serde_json::to_vec(request) + .map_err(|e| anyhow::anyhow!("failed to serialize request: {e}"))?; + + let (result, pending_events) = tokio::time::timeout( + timeout, + self.call_function_with_events(function_name, &request_bytes), + ) + .await + .map_err(|_| { + anyhow::anyhow!( + "plugin call '{function_name}' timed out after {timeout:?}" + ) + })??; + + let resp = serde_json::from_slice(&result).map_err(|e| { + anyhow::anyhow!( + "failed to deserialize response from '{function_name}': {e}" + ) + })?; + + Ok((resp, pending_events)) + } +} + +#[cfg(test)] +impl Default for WasmPlugin { + fn default() -> Self { + let engine = Engine::default(); + let module = Module::new(&engine, br"(module)").unwrap(); + + Self { + module: Arc::new(module), + context: PluginContext { + data_dir: std::env::temp_dir(), + cache_dir: std::env::temp_dir(), + config: Default::default(), + capabilities: Default::default(), + }, + } + } +} + +/// Host functions that plugins can call +pub struct HostFunctions; + +impl HostFunctions { + /// Registers all host ABI functions (`host_log`, `host_read_file`, + /// `host_write_file`, `host_http_request`, `host_get_config`, + /// `host_get_env`, `host_get_buffer`, `host_set_result`, + /// `host_emit_event`) into the given linker. + /// + /// # Errors + /// + /// Returns an error if any host function cannot be registered in the linker. + pub fn setup_linker(linker: &mut Linker) -> Result<()> { + linker.func_wrap( + "env", + "host_log", + |mut caller: Caller<'_, PluginStoreData>, + level: i32, + ptr: i32, + len: i32| { + if ptr < 0 || len < 0 { + return; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + if let Some(mem) = memory { + let data = mem.data(&caller); + let start = u32::try_from(ptr).unwrap_or(0) as usize; + let end = start + u32::try_from(len).unwrap_or(0) as usize; + if end <= data.len() + && let Ok(msg) = std::str::from_utf8(&data[start..end]) + { + match level { + 0 => tracing::error!(plugin = true, "{}", msg), + 1 => tracing::warn!(plugin = true, "{}", msg), + 2 => tracing::info!(plugin = true, "{}", msg), + _ => tracing::debug!(plugin = true, "{}", msg), + } + } + } + }, + )?; + + linker.func_wrap( + "env", + "host_read_file", + |mut caller: Caller<'_, PluginStoreData>, + path_ptr: i32, + path_len: i32| + -> i32 { + if path_ptr < 0 || path_len < 0 { + return -1; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return -1 }; + + let data = mem.data(&caller); + let start = u32::try_from(path_ptr).unwrap_or(0) as usize; + let end = start + u32::try_from(path_len).unwrap_or(0) as usize; + if end > data.len() { + return -1; + } + + let path_str = match std::str::from_utf8(&data[start..end]) { + Ok(s) => s.to_string(), + Err(_) => return -1, + }; + + // Canonicalize path before checking permissions to prevent traversal + let Ok(path) = std::path::Path::new(&path_str).canonicalize() else { + return -1; + }; + + // Check read permission against canonicalized path + let can_read = caller + .data() + .context + .capabilities + .filesystem + .read + .iter() + .any(|allowed| { + allowed.canonicalize().is_ok_and(|a| path.starts_with(a)) + }); + + if !can_read { + tracing::warn!(path = %path_str, "plugin read access denied"); + return -2; + } + + std::fs::read(&path).map_or(-1, |contents| { + let len = i32::try_from(contents.len()).unwrap_or(i32::MAX); + caller.data_mut().exchange_buffer = contents; + len + }) + }, + )?; + + linker.func_wrap( + "env", + "host_write_file", + |mut caller: Caller<'_, PluginStoreData>, + path_ptr: i32, + path_len: i32, + data_ptr: i32, + data_len: i32| + -> i32 { + if path_ptr < 0 || path_len < 0 || data_ptr < 0 || data_len < 0 { + return -1; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return -1 }; + + let mem_data = mem.data(&caller); + let path_start = u32::try_from(path_ptr).unwrap_or(0) as usize; + let path_end = + path_start + u32::try_from(path_len).unwrap_or(0) as usize; + let data_start = u32::try_from(data_ptr).unwrap_or(0) as usize; + let data_end = + data_start + u32::try_from(data_len).unwrap_or(0) as usize; + + if path_end > mem_data.len() || data_end > mem_data.len() { + return -1; + } + + let path_str = + match std::str::from_utf8(&mem_data[path_start..path_end]) { + Ok(s) => s.to_string(), + Err(_) => return -1, + }; + let file_data = mem_data[data_start..data_end].to_vec(); + + // Canonicalize path for write (file may not exist yet) + let path = std::path::Path::new(&path_str); + let canonical = if path.exists() { + path.canonicalize().ok() + } else { + path + .parent() + .and_then(|p| p.canonicalize().ok()) + .map(|p| p.join(path.file_name().unwrap_or_default())) + }; + let Some(canonical) = canonical else { + return -1; + }; + + // Check write permission against canonicalized path + let can_write = caller + .data() + .context + .capabilities + .filesystem + .write + .iter() + .any(|allowed| { + allowed + .canonicalize() + .is_ok_and(|a| canonical.starts_with(a)) + }); + + if !can_write { + tracing::warn!(path = %path_str, "plugin write access denied"); + return -2; + } + + match std::fs::write(&canonical, &file_data) { + Ok(()) => 0, + Err(_) => -1, + } + }, + )?; + + linker.func_wrap( + "env", + "host_http_request", + |mut caller: Caller<'_, PluginStoreData>, + url_ptr: i32, + url_len: i32| + -> i32 { + if url_ptr < 0 || url_len < 0 { + return -1; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return -1 }; + + let data = mem.data(&caller); + let start = u32::try_from(url_ptr).unwrap_or(0) as usize; + let end = start + u32::try_from(url_len).unwrap_or(0) as usize; + if end > data.len() { + return -1; + } + + let url_str = match std::str::from_utf8(&data[start..end]) { + Ok(s) => s.to_string(), + Err(_) => return -1, + }; + + // Check network permission + if !caller.data().context.capabilities.network.enabled { + tracing::warn!(url = %url_str, "plugin network access denied"); + return -2; + } + + // Check domain whitelist if configured + if let Some(ref allowed) = + caller.data().context.capabilities.network.allowed_domains + { + let parsed = if let Ok(u) = url::Url::parse(&url_str) { + u + } else { + tracing::warn!(url = %url_str, "plugin provided invalid URL"); + return -1; + }; + let domain = parsed.host_str().unwrap_or(""); + + if !allowed.iter().any(|d| d.eq_ignore_ascii_case(domain)) { + tracing::warn!( + url = %url_str, + domain = domain, + "plugin domain not in allowlist" + ); + return -3; + } + } + + // Use block_in_place to avoid blocking the async runtime's thread pool. + // Falls back to a blocking client with timeout if block_in_place is + // unavailable. + let result = std::panic::catch_unwind(|| { + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .map_err(|e| e.to_string())?; + let resp = client + .get(&url_str) + .send() + .await + .map_err(|e| e.to_string())?; + let bytes = resp.bytes().await.map_err(|e| e.to_string())?; + Ok::<_, String>(bytes) + }) + }) + }); + + match result { + Ok(Ok(bytes)) => { + let len = i32::try_from(bytes.len()).unwrap_or(i32::MAX); + caller.data_mut().exchange_buffer = bytes.to_vec(); + len + }, + Ok(Err(_)) => -1, + Err(_) => { + // block_in_place panicked (e.g. current-thread runtime); + // fall back to blocking client with timeout + let Ok(client) = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + else { + return -1; + }; + client.get(&url_str).send().map_or(-1, |resp| { + resp.bytes().map_or(-1, |bytes| { + let len = i32::try_from(bytes.len()).unwrap_or(i32::MAX); + caller.data_mut().exchange_buffer = bytes.to_vec(); + len + }) + }) + }, + } + }, + )?; + + linker.func_wrap( + "env", + "host_get_config", + |mut caller: Caller<'_, PluginStoreData>, + key_ptr: i32, + key_len: i32| + -> i32 { + if key_ptr < 0 || key_len < 0 { + return -1; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return -1 }; + + let data = mem.data(&caller); + let start = u32::try_from(key_ptr).unwrap_or(0) as usize; + let end = start + u32::try_from(key_len).unwrap_or(0) as usize; + if end > data.len() { + return -1; + } + + let key_str = match std::str::from_utf8(&data[start..end]) { + Ok(s) => s.to_string(), + Err(_) => return -1, + }; + + let bytes = caller + .data() + .context + .config + .get(&key_str) + .map(|value| value.to_string().into_bytes()); + bytes.map_or(-1, |b| { + let len = i32::try_from(b.len()).unwrap_or(i32::MAX); + caller.data_mut().exchange_buffer = b; + len + }) + }, + )?; + + linker.func_wrap( + "env", + "host_get_env", + |mut caller: Caller<'_, PluginStoreData>, + key_ptr: i32, + key_len: i32| + -> i32 { + if key_ptr < 0 || key_len < 0 { + return -1; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return -1 }; + + let data = mem.data(&caller); + let start = u32::try_from(key_ptr).unwrap_or(0) as usize; + let end = start + u32::try_from(key_len).unwrap_or(0) as usize; + if end > data.len() { + return -1; + } + + let key_str = match std::str::from_utf8(&data[start..end]) { + Ok(s) => s.to_string(), + Err(_) => return -1, + }; + + // Check environment capability + let env_cap = &caller.data().context.capabilities.environment; + if !env_cap.enabled { + tracing::warn!( + var = %key_str, + "plugin environment access denied" + ); + return -2; + } + + // Check against allowed variables list if configured + if let Some(ref allowed) = env_cap.allowed_vars + && !allowed.iter().any(|v| v == &key_str) + { + tracing::warn!( + var = %key_str, + "plugin env var not in allowlist" + ); + return -2; + } + + match std::env::var(&key_str) { + Ok(value) => { + let bytes = value.into_bytes(); + let len = i32::try_from(bytes.len()).unwrap_or(i32::MAX); + caller.data_mut().exchange_buffer = bytes; + len + }, + Err(_) => -1, + } + }, + )?; + + linker.func_wrap( + "env", + "host_get_buffer", + |mut caller: Caller<'_, PluginStoreData>, + dest_ptr: i32, + dest_len: i32| + -> i32 { + if dest_ptr < 0 || dest_len < 0 { + return -1; + } + let buf = caller.data().exchange_buffer.clone(); + let copy_len = + buf.len().min(u32::try_from(dest_len).unwrap_or(0) as usize); + + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return -1 }; + + let mem_data = mem.data_mut(&mut caller); + let start = u32::try_from(dest_ptr).unwrap_or(0) as usize; + if start + copy_len > mem_data.len() { + return -1; + } + + mem_data[start..start + copy_len].copy_from_slice(&buf[..copy_len]); + i32::try_from(copy_len).unwrap_or(i32::MAX) + }, + )?; + + linker.func_wrap( + "env", + "host_set_result", + |mut caller: Caller<'_, PluginStoreData>, ptr: i32, len: i32| { + if ptr < 0 || len < 0 { + return; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return }; + + let data = mem.data(&caller); + let start = u32::try_from(ptr).unwrap_or(0) as usize; + let end = start + u32::try_from(len).unwrap_or(0) as usize; + if end <= data.len() { + caller.data_mut().exchange_buffer = data[start..end].to_vec(); + } + }, + )?; + + linker.func_wrap( + "env", + "host_emit_event", + |mut caller: Caller<'_, PluginStoreData>, + type_ptr: i32, + type_len: i32, + payload_ptr: i32, + payload_len: i32| + -> i32 { + const MAX_PENDING_EVENTS: usize = 1000; + + if type_ptr < 0 || type_len < 0 || payload_ptr < 0 || payload_len < 0 { + return -1; + } + let memory = caller + .get_export("memory") + .and_then(wasmtime::Extern::into_memory); + let Some(mem) = memory else { return -1 }; + + let type_start = u32::try_from(type_ptr).unwrap_or(0) as usize; + let type_end = + type_start + u32::try_from(type_len).unwrap_or(0) as usize; + let payload_start = u32::try_from(payload_ptr).unwrap_or(0) as usize; + let payload_end = + payload_start + u32::try_from(payload_len).unwrap_or(0) as usize; + + // Extract owned strings in a block so the immutable borrow of + // `caller` (via `mem.data`) is dropped before `caller.data_mut()`. + let (event_type, payload) = { + let data = mem.data(&caller); + if type_end > data.len() || payload_end > data.len() { + return -1; + } + let event_type = + match std::str::from_utf8(&data[type_start..type_end]) { + Ok(s) => s.to_string(), + Err(_) => return -1, + }; + let payload = + match std::str::from_utf8(&data[payload_start..payload_end]) { + Ok(s) => s.to_string(), + Err(_) => return -1, + }; + (event_type, payload) + }; + + if caller.data().pending_events.len() >= MAX_PENDING_EVENTS { + tracing::warn!("plugin exceeded max pending events limit"); + return -4; + } + + caller.data_mut().pending_events.push((event_type, payload)); + 0 + }, + )?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use pinakes_plugin_api::PluginContext; + use rustc_hash::FxHashMap; + + use super::*; + + #[test] + fn test_wasm_runtime_creation() { + let runtime = WasmRuntime::new(); + assert!(runtime.is_ok()); + } + + #[test] + fn test_host_functions_file_access() { + let mut capabilities = pinakes_plugin_api::Capabilities::default(); + capabilities.filesystem.read.push("/tmp".into()); + capabilities.filesystem.write.push("/tmp/output".into()); + + let context = PluginContext { + data_dir: "/tmp/data".into(), + cache_dir: "/tmp/cache".into(), + config: Default::default(), + capabilities, + }; + + // Verify capability checks work via context fields + let can_read = context + .capabilities + .filesystem + .read + .iter() + .any(|p| Path::new("/tmp/test.txt").starts_with(p)); + assert!(can_read); + + let cant_read = context + .capabilities + .filesystem + .read + .iter() + .any(|p| Path::new("/etc/passwd").starts_with(p)); + assert!(!cant_read); + + let can_write = context + .capabilities + .filesystem + .write + .iter() + .any(|p| Path::new("/tmp/output/file.txt").starts_with(p)); + assert!(can_write); + + let cant_write = context + .capabilities + .filesystem + .write + .iter() + .any(|p| Path::new("/tmp/file.txt").starts_with(p)); + assert!(!cant_write); + } + + #[test] + fn test_host_functions_network_access() { + let mut context = PluginContext { + data_dir: "/tmp/data".into(), + cache_dir: "/tmp/cache".into(), + config: FxHashMap::default(), + capabilities: Default::default(), + }; + + assert!(!context.capabilities.network.enabled); + + context.capabilities.network.enabled = true; + assert!(context.capabilities.network.enabled); + } + + #[test] + fn test_linker_setup() { + let engine = Engine::default(); + let mut linker = Linker::::new(&engine); + let result = HostFunctions::setup_linker(&mut linker); + assert!(result.is_ok()); + } +} diff --git a/crates/pinakes-plugin/src/security.rs b/crates/pinakes-plugin/src/security.rs new file mode 100644 index 0000000..6bebb94 --- /dev/null +++ b/crates/pinakes-plugin/src/security.rs @@ -0,0 +1,473 @@ +//! Capability-based security for plugins + +use std::path::{Path, PathBuf}; + +use anyhow::{Result, anyhow}; +use pinakes_plugin_api::Capabilities; + +/// Capability enforcer validates and enforces plugin capabilities +pub struct CapabilityEnforcer { + /// Maximum allowed memory per plugin (bytes) + max_memory_limit: usize, + + /// Maximum allowed CPU time per plugin (milliseconds) + max_cpu_time_limit: u64, + + /// Allowed filesystem read paths (system-wide) + allowed_read_paths: Vec, + + /// Allowed filesystem write paths (system-wide) + allowed_write_paths: Vec, + + /// Whether to allow network access by default + allow_network_default: bool, +} + +impl CapabilityEnforcer { + /// Create a new capability enforcer with default limits + #[must_use] + pub const fn new() -> Self { + Self { + max_memory_limit: 512 * 1024 * 1024, // 512 MB + max_cpu_time_limit: 60 * 1000, // 60 seconds + allowed_read_paths: vec![], + allowed_write_paths: vec![], + allow_network_default: false, + } + } + + /// Set maximum memory limit + #[must_use] + pub const fn with_max_memory(mut self, bytes: usize) -> Self { + self.max_memory_limit = bytes; + self + } + + /// Set maximum CPU time limit + #[must_use] + pub const fn with_max_cpu_time(mut self, milliseconds: u64) -> Self { + self.max_cpu_time_limit = milliseconds; + self + } + + /// Add allowed read path + #[must_use] + pub fn allow_read_path(mut self, path: PathBuf) -> Self { + self.allowed_read_paths.push(path); + self + } + + /// Add allowed write path + #[must_use] + pub fn allow_write_path(mut self, path: PathBuf) -> Self { + self.allowed_write_paths.push(path); + self + } + + /// Set default network access policy + #[must_use] + pub const fn with_network_default(mut self, allow: bool) -> Self { + self.allow_network_default = allow; + self + } + + /// Validate capabilities requested by a plugin + /// + /// # Errors + /// + /// Returns an error if the plugin requests capabilities that exceed the + /// configured system limits, such as memory, CPU time, filesystem paths, or + /// network access. + pub fn validate_capabilities( + &self, + capabilities: &Capabilities, + ) -> Result<()> { + // Validate memory limit + if let Some(memory) = capabilities.max_memory_bytes + && memory > self.max_memory_limit + { + return Err(anyhow!( + "Requested memory ({} bytes) exceeds limit ({} bytes)", + memory, + self.max_memory_limit + )); + } + + // Validate CPU time limit + if let Some(cpu_time) = capabilities.max_cpu_time_ms + && cpu_time > self.max_cpu_time_limit + { + return Err(anyhow!( + "Requested CPU time ({} ms) exceeds limit ({} ms)", + cpu_time, + self.max_cpu_time_limit + )); + } + + // Validate filesystem access + self.validate_filesystem_access(capabilities)?; + + // Validate network access + if capabilities.network.enabled && !self.allow_network_default { + return Err(anyhow!( + "Plugin requests network access, but network access is disabled by \ + policy" + )); + } + + Ok(()) + } + + /// Validate filesystem access capabilities + fn validate_filesystem_access( + &self, + capabilities: &Capabilities, + ) -> Result<()> { + // Check read paths + for path in &capabilities.filesystem.read { + if !self.is_read_allowed(path) { + return Err(anyhow!( + "Plugin requests read access to {} which is not in allowed paths", + path.display() + )); + } + } + + // Check write paths + for path in &capabilities.filesystem.write { + if !self.is_write_allowed(path) { + return Err(anyhow!( + "Plugin requests write access to {} which is not in allowed paths", + path.display() + )); + } + } + + Ok(()) + } + + /// Check if a path is allowed for reading + #[must_use] + pub fn is_read_allowed(&self, path: &Path) -> bool { + if self.allowed_read_paths.is_empty() { + return false; // deny-all when unconfigured + } + let Ok(canonical) = path.canonicalize() else { + return false; + }; + self.allowed_read_paths.iter().any(|allowed| { + allowed + .canonicalize() + .is_ok_and(|a| canonical.starts_with(a)) + }) + } + + /// Check if a path is allowed for writing + #[must_use] + pub fn is_write_allowed(&self, path: &Path) -> bool { + if self.allowed_write_paths.is_empty() { + return false; // deny-all when unconfigured + } + let canonical = if path.exists() { + path.canonicalize().ok() + } else { + path + .parent() + .and_then(|p| p.canonicalize().ok()) + .map(|p| p.join(path.file_name().unwrap_or_default())) + }; + let Some(canonical) = canonical else { + return false; + }; + self.allowed_write_paths.iter().any(|allowed| { + allowed + .canonicalize() + .is_ok_and(|a| canonical.starts_with(a)) + }) + } + + /// Check if network access is allowed for a plugin + #[must_use] + pub const fn is_network_allowed(&self, capabilities: &Capabilities) -> bool { + capabilities.network.enabled && self.allow_network_default + } + + /// Check if a specific domain is allowed + #[must_use] + pub fn is_domain_allowed( + &self, + capabilities: &Capabilities, + domain: &str, + ) -> bool { + if !capabilities.network.enabled { + return false; + } + + // If no domain restrictions, allow all domains + if capabilities.network.allowed_domains.is_none() { + return self.allow_network_default; + } + + // Check against allowed domains list + capabilities + .network + .allowed_domains + .as_ref() + .is_some_and(|domains| { + domains.iter().any(|d| d.eq_ignore_ascii_case(domain)) + }) + } + + /// Get effective memory limit for a plugin + #[must_use] + pub fn get_memory_limit(&self, capabilities: &Capabilities) -> usize { + capabilities + .max_memory_bytes + .unwrap_or(self.max_memory_limit) + .min(self.max_memory_limit) + } + + /// Get effective CPU time limit for a plugin + #[must_use] + pub fn get_cpu_time_limit(&self, capabilities: &Capabilities) -> u64 { + capabilities + .max_cpu_time_ms + .unwrap_or(self.max_cpu_time_limit) + .min(self.max_cpu_time_limit) + } + + /// Validate that a function call is allowed for a plugin's declared kinds. + /// + /// Defense-in-depth: even though the pipeline filters by kind, this prevents + /// bugs from calling wrong functions on plugins. Returns `true` if allowed. + #[must_use] + pub fn validate_function_call( + &self, + plugin_kinds: &[String], + function_name: &str, + ) -> bool { + match function_name { + // Lifecycle functions are always allowed + "initialize" | "shutdown" | "health_check" => true, + // MediaTypeProvider + "supported_media_types" | "can_handle" => { + plugin_kinds.iter().any(|k| k == "media_type") + }, + // supported_types is shared by metadata_extractor and thumbnail_generator + "supported_types" => { + plugin_kinds + .iter() + .any(|k| k == "metadata_extractor" || k == "thumbnail_generator") + }, + // MetadataExtractor + "extract_metadata" => { + plugin_kinds.iter().any(|k| k == "metadata_extractor") + }, + // ThumbnailGenerator + "generate_thumbnail" => { + plugin_kinds.iter().any(|k| k == "thumbnail_generator") + }, + // SearchBackend + "search" | "index_item" | "remove_item" | "get_stats" => { + plugin_kinds.iter().any(|k| k == "search_backend") + }, + // EventHandler + "interested_events" | "handle_event" => { + plugin_kinds.iter().any(|k| k == "event_handler") + }, + // ThemeProvider + "get_themes" | "load_theme" => { + plugin_kinds.iter().any(|k| k == "theme_provider") + }, + // Unknown function names are not allowed + _ => false, + } + } +} + +impl Default for CapabilityEnforcer { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use pinakes_plugin_api::{FilesystemCapability, NetworkCapability}; + + use super::*; + + #[test] + fn test_validate_memory_limit() { + let enforcer = CapabilityEnforcer::new().with_max_memory(100 * 1024 * 1024); // 100 MB + + let mut caps = Capabilities::default(); + caps.max_memory_bytes = Some(50 * 1024 * 1024); // 50 MB - OK + assert!(enforcer.validate_capabilities(&caps).is_ok()); + + caps.max_memory_bytes = Some(200 * 1024 * 1024); // 200 MB - exceeds limit + assert!(enforcer.validate_capabilities(&caps).is_err()); + } + + #[test] + fn test_validate_cpu_time_limit() { + let enforcer = CapabilityEnforcer::new().with_max_cpu_time(30_000); // 30 seconds + + let mut caps = Capabilities::default(); + caps.max_cpu_time_ms = Some(10_000); // 10 seconds - OK + assert!(enforcer.validate_capabilities(&caps).is_ok()); + + caps.max_cpu_time_ms = Some(60_000); // 60 seconds - exceeds limit + assert!(enforcer.validate_capabilities(&caps).is_err()); + } + + #[test] + fn test_filesystem_read_allowed() { + // Use real temp directories so canonicalize works + let tmp = tempfile::tempdir().unwrap(); + let allowed_dir = tmp.path().join("allowed"); + std::fs::create_dir_all(&allowed_dir).unwrap(); + let test_file = allowed_dir.join("test.txt"); + std::fs::write(&test_file, "test").unwrap(); + + let enforcer = CapabilityEnforcer::new().allow_read_path(allowed_dir); + + assert!(enforcer.is_read_allowed(&test_file)); + assert!(!enforcer.is_read_allowed(Path::new("/etc/passwd"))); + } + + #[test] + fn test_filesystem_read_denied_when_empty() { + let enforcer = CapabilityEnforcer::new(); + assert!(!enforcer.is_read_allowed(Path::new("/tmp/test.txt"))); + } + + #[test] + fn test_filesystem_write_allowed() { + let tmp = tempfile::tempdir().unwrap(); + let output_dir = tmp.path().join("output"); + std::fs::create_dir_all(&output_dir).unwrap(); + // Existing file in allowed dir + let existing = output_dir.join("file.txt"); + std::fs::write(&existing, "test").unwrap(); + + let enforcer = + CapabilityEnforcer::new().allow_write_path(output_dir.clone()); + + assert!(enforcer.is_write_allowed(&existing)); + // New file in allowed dir (parent exists) + assert!(enforcer.is_write_allowed(&output_dir.join("new_file.txt"))); + assert!(!enforcer.is_write_allowed(Path::new("/etc/config"))); + } + + #[test] + fn test_filesystem_write_denied_when_empty() { + let enforcer = CapabilityEnforcer::new(); + assert!(!enforcer.is_write_allowed(Path::new("/tmp/file.txt"))); + } + + #[test] + fn test_network_allowed() { + let enforcer = CapabilityEnforcer::new().with_network_default(true); + + let mut caps = Capabilities::default(); + caps.network.enabled = true; + + assert!(enforcer.is_network_allowed(&caps)); + + caps.network.enabled = false; + assert!(!enforcer.is_network_allowed(&caps)); + } + + #[test] + fn test_domain_restrictions() { + let enforcer = CapabilityEnforcer::new().with_network_default(true); + + let mut caps = Capabilities::default(); + caps.network.enabled = true; + caps.network.allowed_domains = Some(vec![ + "api.example.com".to_string(), + "cdn.example.com".to_string(), + ]); + + assert!(enforcer.is_domain_allowed(&caps, "api.example.com")); + assert!(enforcer.is_domain_allowed(&caps, "cdn.example.com")); + assert!(!enforcer.is_domain_allowed(&caps, "evil.com")); + } + + #[test] + fn test_get_effective_limits() { + let enforcer = CapabilityEnforcer::new() + .with_max_memory(100 * 1024 * 1024) + .with_max_cpu_time(30_000); + + let mut caps = Capabilities::default(); + + // No limits specified, use the defaults + assert_eq!(enforcer.get_memory_limit(&caps), 100 * 1024 * 1024); + assert_eq!(enforcer.get_cpu_time_limit(&caps), 30_000); + + // Plugin requests lower limits, use plugin's + caps.max_memory_bytes = Some(50 * 1024 * 1024); + caps.max_cpu_time_ms = Some(10_000); + assert_eq!(enforcer.get_memory_limit(&caps), 50 * 1024 * 1024); + assert_eq!(enforcer.get_cpu_time_limit(&caps), 10_000); + + // Plugin requests higher limits, cap at system max + caps.max_memory_bytes = Some(200 * 1024 * 1024); + caps.max_cpu_time_ms = Some(60_000); + assert_eq!(enforcer.get_memory_limit(&caps), 100 * 1024 * 1024); + assert_eq!(enforcer.get_cpu_time_limit(&caps), 30_000); + } + + #[test] + fn test_validate_function_call_lifecycle_always_allowed() { + let enforcer = CapabilityEnforcer::new(); + let kinds = vec!["metadata_extractor".to_string()]; + assert!(enforcer.validate_function_call(&kinds, "initialize")); + assert!(enforcer.validate_function_call(&kinds, "shutdown")); + assert!(enforcer.validate_function_call(&kinds, "health_check")); + } + + #[test] + fn test_validate_function_call_metadata_extractor() { + let enforcer = CapabilityEnforcer::new(); + let kinds = vec!["metadata_extractor".to_string()]; + assert!(enforcer.validate_function_call(&kinds, "extract_metadata")); + assert!(enforcer.validate_function_call(&kinds, "supported_types")); + assert!(!enforcer.validate_function_call(&kinds, "search")); + assert!(!enforcer.validate_function_call(&kinds, "generate_thumbnail")); + assert!(!enforcer.validate_function_call(&kinds, "can_handle")); + } + + #[test] + fn test_validate_function_call_multi_kind() { + let enforcer = CapabilityEnforcer::new(); + let kinds = + vec!["media_type".to_string(), "metadata_extractor".to_string()]; + assert!(enforcer.validate_function_call(&kinds, "can_handle")); + assert!(enforcer.validate_function_call(&kinds, "supported_media_types")); + assert!(enforcer.validate_function_call(&kinds, "extract_metadata")); + assert!(!enforcer.validate_function_call(&kinds, "search")); + } + + #[test] + fn test_validate_function_call_unknown_function() { + let enforcer = CapabilityEnforcer::new(); + let kinds = vec!["metadata_extractor".to_string()]; + assert!(!enforcer.validate_function_call(&kinds, "unknown_func")); + assert!(!enforcer.validate_function_call(&kinds, "")); + } + + #[test] + fn test_validate_function_call_shared_supported_types() { + let enforcer = CapabilityEnforcer::new(); + let extractor = vec!["metadata_extractor".to_string()]; + let generator = vec!["thumbnail_generator".to_string()]; + let search = vec!["search_backend".to_string()]; + assert!(enforcer.validate_function_call(&extractor, "supported_types")); + assert!(enforcer.validate_function_call(&generator, "supported_types")); + assert!(!enforcer.validate_function_call(&search, "supported_types")); + } +} diff --git a/crates/pinakes-plugin/src/signature.rs b/crates/pinakes-plugin/src/signature.rs new file mode 100644 index 0000000..64f9dc5 --- /dev/null +++ b/crates/pinakes-plugin/src/signature.rs @@ -0,0 +1,252 @@ +//! Plugin signature verification using Ed25519 + BLAKE3 +//! +//! Each plugin directory may contain a `plugin.sig` file alongside its +//! `plugin.toml`. The signature covers the BLAKE3 hash of the WASM binary +//! referenced by the manifest. Verification uses Ed25519 public keys +//! configured as trusted in the server's plugin settings. +//! +//! When `allow_unsigned` is false, plugins _must_ carry a valid signature +//! from one of the trusted keys or they will be rejected at load time. + +use std::path::Path; + +use anyhow::{Result, anyhow}; +use ed25519_dalek::{Signature, Verifier, VerifyingKey}; + +/// Outcome of a signature check on a plugin package. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SignatureStatus { + /// Signature is present and valid against a trusted key. + Valid, + /// No signature file found. + Unsigned, + /// Signature file exists but does not match any trusted key. + Invalid(String), +} + +/// Verify the signature of a plugin's WASM binary. +/// +/// Reads `plugin.sig` from `plugin_dir`, computes the BLAKE3 hash of the +/// WASM binary at `wasm_path`, and verifies the signature against each of +/// the `trusted_keys`. The signature file is raw 64-byte Ed25519 signature +/// over the 32-byte BLAKE3 digest. +/// +/// # Errors +/// +/// Returns an error only on I/O failures, never for cryptographic rejection, +/// which is reported via [`SignatureStatus`] instead. +pub fn verify_plugin_signature( + plugin_dir: &Path, + wasm_path: &Path, + trusted_keys: &[VerifyingKey], +) -> Result { + let sig_path = plugin_dir.join("plugin.sig"); + if !sig_path.exists() { + return Ok(SignatureStatus::Unsigned); + } + + let sig_bytes = std::fs::read(&sig_path) + .map_err(|e| anyhow!("failed to read plugin.sig: {e}"))?; + + let signature = Signature::from_slice(&sig_bytes).map_err(|e| { + // Malformed signature file is an invalid signature, not an I/O error + tracing::warn!(path = %sig_path.display(), "malformed plugin.sig: {e}"); + anyhow!("malformed plugin.sig: {e}") + }); + let Ok(signature) = signature else { + return Ok(SignatureStatus::Invalid( + "malformed signature file".to_string(), + )); + }; + + // BLAKE3 hash of the WASM binary is the signed message + let wasm_bytes = std::fs::read(wasm_path) + .map_err(|e| anyhow!("failed to read WASM binary for verification: {e}"))?; + let digest = blake3::hash(&wasm_bytes); + let message = digest.as_bytes(); + + for key in trusted_keys { + if key.verify(message, &signature).is_ok() { + return Ok(SignatureStatus::Valid); + } + } + + Ok(SignatureStatus::Invalid( + "signature did not match any trusted key".to_string(), + )) +} + +/// Parse a hex-encoded Ed25519 public key (64 hex characters = 32 bytes). +/// +/// # Errors +/// +/// Returns an error if the string is not valid hex or is the wrong length. +pub fn parse_public_key(hex_str: &str) -> Result { + let hex_str = hex_str.trim(); + if hex_str.len() != 64 { + return Err(anyhow!( + "expected 64 hex characters for Ed25519 public key, got {}", + hex_str.len() + )); + } + + let mut bytes = [0u8; 32]; + for (i, byte) in bytes.iter_mut().enumerate() { + *byte = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16) + .map_err(|e| anyhow!("invalid hex in public key: {e}"))?; + } + + VerifyingKey::from_bytes(&bytes) + .map_err(|e| anyhow!("invalid Ed25519 public key: {e}")) +} + +#[cfg(test)] +mod tests { + use ed25519_dalek::{Signer, SigningKey}; + use rand::RngExt; + + use super::*; + + fn make_keypair() -> (SigningKey, VerifyingKey) { + let secret_bytes: [u8; 32] = rand::rng().random(); + let signing = SigningKey::from_bytes(&secret_bytes); + let verifying = signing.verifying_key(); + (signing, verifying) + } + + #[test] + fn test_verify_unsigned_plugin() { + let dir = tempfile::tempdir().unwrap(); + let wasm_path = dir.path().join("plugin.wasm"); + std::fs::write(&wasm_path, b"\0asm\x01\x00\x00\x00").unwrap(); + + let (_, vk) = make_keypair(); + let status = + verify_plugin_signature(dir.path(), &wasm_path, &[vk]).unwrap(); + assert_eq!(status, SignatureStatus::Unsigned); + } + + #[test] + fn test_verify_valid_signature() { + let dir = tempfile::tempdir().unwrap(); + let wasm_path = dir.path().join("plugin.wasm"); + let wasm_bytes = b"\0asm\x01\x00\x00\x00some_code_here"; + std::fs::write(&wasm_path, wasm_bytes).unwrap(); + + let (sk, vk) = make_keypair(); + + // Sign the BLAKE3 hash of the WASM binary + let digest = blake3::hash(wasm_bytes); + let signature = sk.sign(digest.as_bytes()); + std::fs::write(dir.path().join("plugin.sig"), signature.to_bytes()) + .unwrap(); + + let status = + verify_plugin_signature(dir.path(), &wasm_path, &[vk]).unwrap(); + assert_eq!(status, SignatureStatus::Valid); + } + + #[test] + fn test_verify_wrong_key() { + let dir = tempfile::tempdir().unwrap(); + let wasm_path = dir.path().join("plugin.wasm"); + let wasm_bytes = b"\0asm\x01\x00\x00\x00some_code"; + std::fs::write(&wasm_path, wasm_bytes).unwrap(); + + let (sk, _) = make_keypair(); + let (_, wrong_vk) = make_keypair(); + + let digest = blake3::hash(wasm_bytes); + let signature = sk.sign(digest.as_bytes()); + std::fs::write(dir.path().join("plugin.sig"), signature.to_bytes()) + .unwrap(); + + let status = + verify_plugin_signature(dir.path(), &wasm_path, &[wrong_vk]).unwrap(); + assert!(matches!(status, SignatureStatus::Invalid(_))); + } + + #[test] + fn test_verify_tampered_wasm() { + let dir = tempfile::tempdir().unwrap(); + let wasm_path = dir.path().join("plugin.wasm"); + let original = b"\0asm\x01\x00\x00\x00original"; + std::fs::write(&wasm_path, original).unwrap(); + + let (sk, vk) = make_keypair(); + let digest = blake3::hash(original); + let signature = sk.sign(digest.as_bytes()); + std::fs::write(dir.path().join("plugin.sig"), signature.to_bytes()) + .unwrap(); + + // Tamper with the WASM file after signing + std::fs::write(&wasm_path, b"\0asm\x01\x00\x00\x00tampered").unwrap(); + + let status = + verify_plugin_signature(dir.path(), &wasm_path, &[vk]).unwrap(); + assert!(matches!(status, SignatureStatus::Invalid(_))); + } + + #[test] + fn test_verify_malformed_sig_file() { + let dir = tempfile::tempdir().unwrap(); + let wasm_path = dir.path().join("plugin.wasm"); + std::fs::write(&wasm_path, b"\0asm\x01\x00\x00\x00").unwrap(); + + // Write garbage to plugin.sig (wrong length) + std::fs::write(dir.path().join("plugin.sig"), b"not a signature").unwrap(); + + let (_, vk) = make_keypair(); + let status = + verify_plugin_signature(dir.path(), &wasm_path, &[vk]).unwrap(); + assert!(matches!(status, SignatureStatus::Invalid(_))); + } + + #[test] + fn test_verify_multiple_trusted_keys() { + let dir = tempfile::tempdir().unwrap(); + let wasm_path = dir.path().join("plugin.wasm"); + let wasm_bytes = b"\0asm\x01\x00\x00\x00multi_key_test"; + std::fs::write(&wasm_path, wasm_bytes).unwrap(); + + let (sk2, vk2) = make_keypair(); + let (_, vk1) = make_keypair(); + let (_, vk3) = make_keypair(); + + // Sign with key 2 + let digest = blake3::hash(wasm_bytes); + let signature = sk2.sign(digest.as_bytes()); + std::fs::write(dir.path().join("plugin.sig"), signature.to_bytes()) + .unwrap(); + + // Verify against [vk1, vk2, vk3]; should find vk2 + let status = + verify_plugin_signature(dir.path(), &wasm_path, &[vk1, vk2, vk3]) + .unwrap(); + assert_eq!(status, SignatureStatus::Valid); + } + + #[test] + fn test_parse_public_key_valid() { + let (_, vk) = make_keypair(); + let hex = hex_encode(vk.as_bytes()); + let parsed = parse_public_key(&hex).unwrap(); + assert_eq!(parsed, vk); + } + + #[test] + fn test_parse_public_key_wrong_length() { + assert!(parse_public_key("abcdef").is_err()); + } + + #[test] + fn test_parse_public_key_invalid_hex() { + let bad = + "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"; + assert!(parse_public_key(bad).is_err()); + } + + fn hex_encode(bytes: &[u8]) -> String { + bytes.iter().map(|b| format!("{b:02x}")).collect() + } +} diff --git a/crates/pinakes-sync/Cargo.toml b/crates/pinakes-sync/Cargo.toml new file mode 100644 index 0000000..b613325 --- /dev/null +++ b/crates/pinakes-sync/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "pinakes-sync" +edition.workspace = true +version.workspace = true +license.workspace = true + +[dependencies] +pinakes-types = { workspace = true } +tokio = { workspace = true } +chrono = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +blake3 = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } + +[lints] +workspace = true diff --git a/crates/pinakes-sync/src/chunked.rs b/crates/pinakes-sync/src/chunked.rs new file mode 100644 index 0000000..7fb7802 --- /dev/null +++ b/crates/pinakes-sync/src/chunked.rs @@ -0,0 +1,326 @@ +//! Chunked upload handling for large file sync. + +use std::path::{Path, PathBuf}; + +use chrono::Utc; +use pinakes_types::error::{PinakesError, Result}; +use tokio::{ + fs, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, +}; +use tracing::{debug, info}; +use uuid::Uuid; + +use super::{ChunkInfo, UploadSession}; + +/// Manager for chunked uploads. +#[derive(Debug, Clone)] +pub struct ChunkedUploadManager { + temp_dir: PathBuf, +} + +impl ChunkedUploadManager { + /// Create a new chunked upload manager. + #[must_use] + pub const fn new(temp_dir: PathBuf) -> Self { + Self { temp_dir } + } + + /// Initialize the temp directory. + /// + /// # Errors + /// + /// Returns an error if the directory cannot be created. + pub async fn init(&self) -> Result<()> { + fs::create_dir_all(&self.temp_dir).await?; + Ok(()) + } + + /// Get the temp file path for an upload session. + #[must_use] + pub fn temp_path(&self, session_id: Uuid) -> PathBuf { + self.temp_dir.join(format!("{session_id}.upload")) + } + + /// Create the temp file for a new upload session. + /// + /// # Errors + /// + /// Returns an error if the file cannot be created or sized. + pub async fn create_temp_file(&self, session: &UploadSession) -> Result<()> { + let path = self.temp_path(session.id); + + // Create a sparse file of the expected size + let file = fs::File::create(&path).await?; + file.set_len(session.expected_size).await?; + + debug!( + session_id = %session.id, + size = session.expected_size, + "created temp file for upload" + ); + + Ok(()) + } + + /// Write a chunk to the temp file. + /// + /// # Errors + /// + /// Returns an error if the session file is not found, the chunk index is out + /// of range, the chunk size is wrong, or the write fails. + pub async fn write_chunk( + &self, + session: &UploadSession, + chunk_index: u64, + data: &[u8], + ) -> Result { + let path = self.temp_path(session.id); + + if !path.exists() { + return Err(PinakesError::UploadSessionNotFound(session.id.to_string())); + } + + // Calculate offset + let offset = chunk_index * session.chunk_size; + + // Validate chunk + if offset >= session.expected_size { + return Err(PinakesError::ChunkOutOfOrder { + expected: session.chunk_count - 1, + actual: chunk_index, + }); + } + + // Calculate expected chunk size + let expected_size = if chunk_index == session.chunk_count - 1 { + // Last chunk may be smaller + session.expected_size - offset + } else { + session.chunk_size + }; + + if data.len() as u64 != expected_size { + return Err(PinakesError::InvalidData(format!( + "chunk {} has wrong size: expected {}, got {}", + chunk_index, + expected_size, + data.len() + ))); + } + + // Write chunk to file at offset + let mut file = fs::OpenOptions::new().write(true).open(&path).await?; + + file.seek(std::io::SeekFrom::Start(offset)).await?; + file.write_all(data).await?; + file.flush().await?; + + // Compute chunk hash + let hash = blake3::hash(data).to_hex().to_string(); + + debug!( + session_id = %session.id, + chunk_index, + offset, + size = data.len(), + "wrote chunk" + ); + + Ok(ChunkInfo { + upload_id: session.id, + chunk_index, + offset, + size: data.len() as u64, + hash, + received_at: Utc::now(), + }) + } + + /// Verify and finalize the upload. + /// + /// Checks that: + /// 1. All chunks are received + /// 2. File size matches expected + /// 3. Content hash matches expected + /// + /// # Errors + /// + /// Returns an error if chunks are missing, the file size does not match, the + /// hash does not match, or the file metadata cannot be read. + pub async fn finalize( + &self, + session: &UploadSession, + received_chunks: &[ChunkInfo], + ) -> Result { + let path = self.temp_path(session.id); + + // Check all chunks received + if received_chunks.len() as u64 != session.chunk_count { + return Err(PinakesError::InvalidData(format!( + "missing chunks: expected {}, got {}", + session.chunk_count, + received_chunks.len() + ))); + } + + // Verify chunk indices + let mut indices: Vec = + received_chunks.iter().map(|c| c.chunk_index).collect(); + indices.sort_unstable(); + for (i, idx) in indices.iter().enumerate() { + if *idx != i as u64 { + return Err(PinakesError::InvalidData(format!( + "chunk {i} missing or out of order" + ))); + } + } + + // Verify file size + let metadata = fs::metadata(&path).await?; + if metadata.len() != session.expected_size { + return Err(PinakesError::InvalidData(format!( + "file size mismatch: expected {}, got {}", + session.expected_size, + metadata.len() + ))); + } + + // Verify content hash + let computed_hash = compute_file_hash(&path).await?; + if computed_hash != session.expected_hash.0 { + return Err(PinakesError::StorageIntegrity(format!( + "hash mismatch: expected {}, computed {}", + session.expected_hash, computed_hash + ))); + } + + info!( + session_id = %session.id, + hash = %session.expected_hash, + size = session.expected_size, + "finalized chunked upload" + ); + + Ok(path) + } + + /// Cancel an upload and clean up temp file. + /// + /// # Errors + /// + /// Returns an error if the temp file cannot be removed. + pub async fn cancel(&self, session_id: Uuid) -> Result<()> { + let path = self.temp_path(session_id); + if path.exists() { + fs::remove_file(&path).await?; + debug!(session_id = %session_id, "cancelled upload, removed temp file"); + } + Ok(()) + } + + /// Clean up expired temp files. + /// + /// # Errors + /// + /// Returns an error if the temp directory cannot be read. + pub async fn cleanup_expired(&self, max_age_hours: u64) -> Result { + let mut count = 0u64; + let max_age = std::time::Duration::from_secs(max_age_hours * 3600); + + let mut entries = fs::read_dir(&self.temp_dir).await?; + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + if path.extension().is_some_and(|e| e == "upload") + && let Ok(metadata) = fs::metadata(&path).await + && let Ok(modified) = metadata.modified() + { + let age = std::time::SystemTime::now() + .duration_since(modified) + .unwrap_or_default(); + if age > max_age { + let _ = fs::remove_file(&path).await; + count += 1; + } + } + } + + if count > 0 { + info!(count, "cleaned up expired upload temp files"); + } + Ok(count) + } +} + +/// Compute the BLAKE3 hash of a file. +async fn compute_file_hash(path: &Path) -> Result { + let mut file = fs::File::open(path).await?; + let mut hasher = blake3::Hasher::new(); + let mut buf = vec![0u8; 64 * 1024]; + + loop { + let n = file.read(&mut buf).await?; + if n == 0 { + break; + } + hasher.update(&buf[..n]); + } + + Ok(hasher.finalize().to_hex().to_string()) +} + +#[cfg(test)] +mod tests { + use pinakes_types::model::ContentHash; + use tempfile::tempdir; + + use super::*; + use crate::UploadStatus; + + #[tokio::test] + async fn test_chunked_upload() { + let dir = tempdir().unwrap(); + let manager = ChunkedUploadManager::new(dir.path().to_path_buf()); + manager.init().await.unwrap(); + + // Create test data + let data = b"Hello, World! This is test data for chunked upload."; + let hash = blake3::hash(data).to_hex().to_string(); + let chunk_size = 20u64; + + let session = UploadSession { + id: Uuid::now_v7(), + device_id: super::super::DeviceId::new(), + target_path: "/test/file.txt".to_string(), + expected_hash: ContentHash::new(hash.clone()), + expected_size: data.len() as u64, + chunk_size, + chunk_count: (data.len() as u64).div_ceil(chunk_size), + status: UploadStatus::InProgress, + created_at: Utc::now(), + expires_at: Utc::now() + chrono::Duration::hours(24), + last_activity: Utc::now(), + }; + + manager.create_temp_file(&session).await.unwrap(); + + // Write chunks + let mut chunks = Vec::new(); + for i in 0..session.chunk_count { + let start = (i * chunk_size) as usize; + let end = ((i + 1) * chunk_size).min(data.len() as u64) as usize; + let chunk_data = &data[start..end]; + + let chunk = manager.write_chunk(&session, i, chunk_data).await.unwrap(); + chunks.push(chunk); + } + + // Finalize + let final_path = manager.finalize(&session, &chunks).await.unwrap(); + assert!(final_path.exists()); + + // Verify content + let content = fs::read(&final_path).await.unwrap(); + assert_eq!(&content[..], data); + } +} diff --git a/crates/pinakes-sync/src/conflict.rs b/crates/pinakes-sync/src/conflict.rs new file mode 100644 index 0000000..986ccdd --- /dev/null +++ b/crates/pinakes-sync/src/conflict.rs @@ -0,0 +1,148 @@ +//! Conflict detection and resolution for sync. + +use pinakes_types::config::ConflictResolution; + +use super::DeviceSyncState; + +/// Detect if there's a conflict between local and server state. +#[must_use] +pub fn detect_conflict(state: &DeviceSyncState) -> Option { + // If either side has no hash, no conflict possible + let local_hash = state.local_hash.as_ref()?; + let server_hash = state.server_hash.as_ref()?; + + // Same hash = no conflict + if local_hash == server_hash { + return None; + } + + // Both have different hashes = conflict + Some(ConflictInfo { + path: state.path.clone(), + local_hash: local_hash.clone(), + server_hash: server_hash.clone(), + local_mtime: state.local_mtime, + server_mtime: state.server_mtime, + }) +} + +/// Information about a detected conflict. +#[derive(Debug, Clone)] +pub struct ConflictInfo { + pub path: String, + pub local_hash: String, + pub server_hash: String, + pub local_mtime: Option, + pub server_mtime: Option, +} + +/// Result of resolving a conflict. +#[derive(Debug, Clone)] +pub enum ConflictOutcome { + /// Use the server version + UseServer, + /// Use the local version (upload it) + UseLocal, + /// Keep both versions (rename one) + KeepBoth { new_local_path: String }, + /// Requires manual intervention + Manual, +} + +/// Resolve a conflict based on the configured strategy. +#[must_use] +pub fn resolve_conflict( + conflict: &ConflictInfo, + resolution: ConflictResolution, +) -> ConflictOutcome { + match resolution { + ConflictResolution::ServerWins => ConflictOutcome::UseServer, + ConflictResolution::ClientWins => ConflictOutcome::UseLocal, + ConflictResolution::KeepBoth => { + let new_path = + generate_conflict_path(&conflict.path, &conflict.local_hash); + ConflictOutcome::KeepBoth { + new_local_path: new_path, + } + }, + ConflictResolution::Manual => ConflictOutcome::Manual, + } +} + +/// Generate a new path for the conflicting local file. +/// Format: filename.conflict-<`short_hash>.ext` +fn generate_conflict_path(original_path: &str, local_hash: &str) -> String { + let short_hash = &local_hash[..8.min(local_hash.len())]; + + if let Some((base, ext)) = original_path.rsplit_once('.') { + format!("{base}.conflict-{short_hash}.{ext}") + } else { + format!("{original_path}.conflict-{short_hash}") + } +} + +/// Automatic conflict resolution based on modification times. +/// Useful when `ConflictResolution` is set to a time-based strategy. +#[must_use] +pub const fn resolve_by_mtime(conflict: &ConflictInfo) -> ConflictOutcome { + match (conflict.local_mtime, conflict.server_mtime) { + (Some(local), Some(server)) => { + if local > server { + ConflictOutcome::UseLocal + } else { + ConflictOutcome::UseServer + } + }, + (Some(_), None) => ConflictOutcome::UseLocal, + (None, Some(_)) => ConflictOutcome::UseServer, + (None, None) => ConflictOutcome::UseServer, // Default to server + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::FileSyncStatus; + + #[test] + fn test_generate_conflict_path() { + assert_eq!( + generate_conflict_path("/path/to/file.txt", "abc12345"), + "/path/to/file.conflict-abc12345.txt" + ); + + assert_eq!( + generate_conflict_path("/path/to/file", "abc12345"), + "/path/to/file.conflict-abc12345" + ); + } + + #[test] + fn test_detect_conflict() { + let state_no_conflict = DeviceSyncState { + device_id: super::super::DeviceId::new(), + path: "/test".to_string(), + local_hash: Some("abc".to_string()), + server_hash: Some("abc".to_string()), + local_mtime: None, + server_mtime: None, + sync_status: FileSyncStatus::Synced, + last_synced_at: None, + conflict_info_json: None, + }; + assert!(detect_conflict(&state_no_conflict).is_none()); + + let state_conflict = DeviceSyncState { + device_id: super::super::DeviceId::new(), + path: "/test".to_string(), + local_hash: Some("abc".to_string()), + server_hash: Some("def".to_string()), + local_mtime: None, + server_mtime: None, + sync_status: FileSyncStatus::Conflict, + last_synced_at: None, + conflict_info_json: None, + }; + assert!(detect_conflict(&state_conflict).is_some()); + } +} diff --git a/crates/pinakes-sync/src/lib.rs b/crates/pinakes-sync/src/lib.rs new file mode 100644 index 0000000..e54a4ed --- /dev/null +++ b/crates/pinakes-sync/src/lib.rs @@ -0,0 +1,7 @@ +mod chunked; +mod conflict; +mod models; + +pub use chunked::*; +pub use conflict::*; +pub use models::*; diff --git a/crates/pinakes-sync/src/models.rs b/crates/pinakes-sync/src/models.rs new file mode 100644 index 0000000..229c5e9 --- /dev/null +++ b/crates/pinakes-sync/src/models.rs @@ -0,0 +1,382 @@ +//! Sync domain models. + +use std::fmt; + +use chrono::{DateTime, Utc}; +use pinakes_types::{ + config::ConflictResolution, + model::{ContentHash, MediaId, UserId}, +}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Unique identifier for a sync device. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct DeviceId(pub Uuid); + +impl DeviceId { + #[must_use] + pub fn new() -> Self { + Self(Uuid::now_v7()) + } +} + +impl Default for DeviceId { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Display for DeviceId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Type of sync device. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, +)] +#[serde(rename_all = "lowercase")] +pub enum DeviceType { + Desktop, + Mobile, + Tablet, + Server, + #[default] + Other, +} + +impl fmt::Display for DeviceType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Desktop => write!(f, "desktop"), + Self::Mobile => write!(f, "mobile"), + Self::Tablet => write!(f, "tablet"), + Self::Server => write!(f, "server"), + Self::Other => write!(f, "other"), + } + } +} + +impl std::str::FromStr for DeviceType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "desktop" => Ok(Self::Desktop), + "mobile" => Ok(Self::Mobile), + "tablet" => Ok(Self::Tablet), + "server" => Ok(Self::Server), + "other" => Ok(Self::Other), + _ => Err(format!("unknown device type: {s}")), + } + } +} + +/// A registered sync device. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SyncDevice { + pub id: DeviceId, + pub user_id: UserId, + pub name: String, + pub device_type: DeviceType, + pub client_version: String, + pub os_info: Option, + pub last_sync_at: Option>, + pub last_seen_at: DateTime, + pub sync_cursor: Option, + pub enabled: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl SyncDevice { + #[must_use] + pub fn new( + user_id: UserId, + name: String, + device_type: DeviceType, + client_version: String, + ) -> Self { + let now = Utc::now(); + Self { + id: DeviceId::new(), + user_id, + name, + device_type, + client_version, + os_info: None, + last_sync_at: None, + last_seen_at: now, + sync_cursor: None, + enabled: true, + created_at: now, + updated_at: now, + } + } +} + +/// Type of change recorded in the sync log. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SyncChangeType { + Created, + Modified, + Deleted, + Moved, + MetadataUpdated, +} + +impl fmt::Display for SyncChangeType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Created => write!(f, "created"), + Self::Modified => write!(f, "modified"), + Self::Deleted => write!(f, "deleted"), + Self::Moved => write!(f, "moved"), + Self::MetadataUpdated => write!(f, "metadata_updated"), + } + } +} + +impl std::str::FromStr for SyncChangeType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "created" => Ok(Self::Created), + "modified" => Ok(Self::Modified), + "deleted" => Ok(Self::Deleted), + "moved" => Ok(Self::Moved), + "metadata_updated" => Ok(Self::MetadataUpdated), + _ => Err(format!("unknown sync change type: {s}")), + } + } +} + +/// An entry in the sync log tracking a change. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SyncLogEntry { + pub id: Uuid, + pub sequence: i64, + pub change_type: SyncChangeType, + pub media_id: Option, + pub path: String, + pub content_hash: Option, + pub file_size: Option, + pub metadata_json: Option, + pub changed_by_device: Option, + pub timestamp: DateTime, +} + +impl SyncLogEntry { + #[must_use] + pub fn new( + change_type: SyncChangeType, + path: String, + media_id: Option, + content_hash: Option, + ) -> Self { + Self { + id: Uuid::now_v7(), + sequence: 0, // Will be assigned by database + change_type, + media_id, + path, + content_hash, + file_size: None, + metadata_json: None, + changed_by_device: None, + timestamp: Utc::now(), + } + } +} + +/// Sync status for a file on a device. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FileSyncStatus { + Synced, + PendingUpload, + PendingDownload, + Conflict, + Deleted, +} + +impl fmt::Display for FileSyncStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Synced => write!(f, "synced"), + Self::PendingUpload => write!(f, "pending_upload"), + Self::PendingDownload => write!(f, "pending_download"), + Self::Conflict => write!(f, "conflict"), + Self::Deleted => write!(f, "deleted"), + } + } +} + +impl std::str::FromStr for FileSyncStatus { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "synced" => Ok(Self::Synced), + "pending_upload" => Ok(Self::PendingUpload), + "pending_download" => Ok(Self::PendingDownload), + "conflict" => Ok(Self::Conflict), + "deleted" => Ok(Self::Deleted), + _ => Err(format!("unknown file sync status: {s}")), + } + } +} + +/// Sync state for a specific file on a specific device. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeviceSyncState { + pub device_id: DeviceId, + pub path: String, + pub local_hash: Option, + pub server_hash: Option, + pub local_mtime: Option, + pub server_mtime: Option, + pub sync_status: FileSyncStatus, + pub last_synced_at: Option>, + pub conflict_info_json: Option, +} + +/// A sync conflict that needs resolution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SyncConflict { + pub id: Uuid, + pub device_id: DeviceId, + pub path: String, + pub local_hash: String, + pub local_mtime: i64, + pub server_hash: String, + pub server_mtime: i64, + pub detected_at: DateTime, + pub resolved_at: Option>, + pub resolution: Option, +} + +impl SyncConflict { + #[must_use] + pub fn new( + device_id: DeviceId, + path: String, + local_hash: String, + local_mtime: i64, + server_hash: String, + server_mtime: i64, + ) -> Self { + Self { + id: Uuid::now_v7(), + device_id, + path, + local_hash, + local_mtime, + server_hash, + server_mtime, + detected_at: Utc::now(), + resolved_at: None, + resolution: None, + } + } +} + +/// Status of an upload session. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum UploadStatus { + Pending, + InProgress, + Completed, + Failed, + Expired, + Cancelled, +} + +impl fmt::Display for UploadStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Pending => write!(f, "pending"), + Self::InProgress => write!(f, "in_progress"), + Self::Completed => write!(f, "completed"), + Self::Failed => write!(f, "failed"), + Self::Expired => write!(f, "expired"), + Self::Cancelled => write!(f, "cancelled"), + } + } +} + +impl std::str::FromStr for UploadStatus { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "pending" => Ok(Self::Pending), + "in_progress" => Ok(Self::InProgress), + "completed" => Ok(Self::Completed), + "failed" => Ok(Self::Failed), + "expired" => Ok(Self::Expired), + "cancelled" => Ok(Self::Cancelled), + _ => Err(format!("unknown upload status: {s}")), + } + } +} + +/// A chunked upload session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UploadSession { + pub id: Uuid, + pub device_id: DeviceId, + pub target_path: String, + pub expected_hash: ContentHash, + pub expected_size: u64, + pub chunk_size: u64, + pub chunk_count: u64, + pub status: UploadStatus, + pub created_at: DateTime, + pub expires_at: DateTime, + pub last_activity: DateTime, +} + +impl UploadSession { + #[must_use] + pub fn new( + device_id: DeviceId, + target_path: String, + expected_hash: ContentHash, + expected_size: u64, + chunk_size: u64, + timeout_hours: u64, + ) -> Self { + let now = Utc::now(); + let chunk_count = expected_size.div_ceil(chunk_size); + Self { + id: Uuid::now_v7(), + device_id, + target_path, + expected_hash, + expected_size, + chunk_size, + chunk_count, + status: UploadStatus::Pending, + created_at: now, + expires_at: now + chrono::Duration::hours(timeout_hours as i64), + last_activity: now, + } + } +} + +/// Information about an uploaded chunk. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChunkInfo { + pub upload_id: Uuid, + pub chunk_index: u64, + pub offset: u64, + pub size: u64, + pub hash: String, + pub received_at: DateTime, +} diff --git a/crates/pinakes-types/Cargo.toml b/crates/pinakes-types/Cargo.toml new file mode 100644 index 0000000..02fee48 --- /dev/null +++ b/crates/pinakes-types/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "pinakes-types" +edition.workspace = true +version.workspace = true +license.workspace = true + +[dependencies] +thiserror = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +chrono = { workspace = true } +uuid = { workspace = true } +rustc-hash = { workspace = true } +toml = { workspace = true } +anyhow = { workspace = true } + +[lints] +workspace = true diff --git a/crates/pinakes-types/src/config.rs b/crates/pinakes-types/src/config.rs new file mode 100644 index 0000000..9c131a8 --- /dev/null +++ b/crates/pinakes-types/src/config.rs @@ -0,0 +1,1761 @@ +use std::path::{Path, PathBuf}; + +use chrono::{DateTime, Datelike, Utc}; +use serde::{Deserialize, Serialize}; + +/// A schedule for a recurring task. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum Schedule { + Interval { + secs: u64, + }, + Daily { + hour: u32, + minute: u32, + }, + Weekly { + day: u32, + hour: u32, + minute: u32, + }, +} + +impl Schedule { + #[must_use] + pub fn next_run(&self, from: DateTime) -> Option> { + match self { + Self::Interval { secs } => { + Some( + from + + chrono::Duration::seconds( + i64::try_from(*secs).unwrap_or(i64::MAX), + ), + ) + }, + Self::Daily { hour, minute } => { + let today = from.date_naive().and_hms_opt(*hour, *minute, 0)?; + let today_utc = today.and_utc(); + if today_utc > from { + Some(today_utc) + } else { + Some(today_utc + chrono::Duration::days(1)) + } + }, + Self::Weekly { day, hour, minute } => { + let current_day = from.weekday().num_days_from_monday(); + let target_day = *day; + let days_ahead = match target_day.cmp(¤t_day) { + std::cmp::Ordering::Greater => target_day - current_day, + std::cmp::Ordering::Less => 7 - (current_day - target_day), + std::cmp::Ordering::Equal => { + let today = + from.date_naive().and_hms_opt(*hour, *minute, 0)?.and_utc(); + if today > from { + return Some(today); + } + 7 + }, + }; + let target_date = + from.date_naive() + chrono::Duration::days(i64::from(days_ahead)); + Some(target_date.and_hms_opt(*hour, *minute, 0)?.and_utc()) + }, + } + } + + #[must_use] + pub fn display_string(&self) -> String { + match self { + Self::Interval { secs } => { + if *secs >= 3600 { + format!("Every {}h", secs / 3600) + } else if *secs >= 60 { + format!("Every {}m", secs / 60) + } else { + format!("Every {secs}s") + } + }, + Self::Daily { hour, minute } => { + format!("Daily {hour:02}:{minute:02}") + }, + Self::Weekly { day, hour, minute } => { + let day_name = match day { + 0 => "Mon", + 1 => "Tue", + 2 => "Wed", + 3 => "Thu", + 4 => "Fri", + 5 => "Sat", + _ => "Sun", + }; + format!("{day_name} {hour:02}:{minute:02}") + }, + } + } +} + +/// Expand environment variables in a string using `std::env::var` for lookup. +/// Supports both `${VAR_NAME}` and `$VAR_NAME` syntax. +/// Returns an error if a referenced variable is not set. +fn expand_env_var_string(input: &str) -> crate::error::Result { + expand_env_vars(input, |name| { + std::env::var(name).map_err(|_| { + crate::error::PinakesError::Config(format!( + "environment variable not set: {name}" + )) + }) + }) +} + +/// Expand environment variables in a string using the provided lookup function. +/// Supports both `${VAR_NAME}` and `$VAR_NAME` syntax. +fn expand_env_vars( + input: &str, + lookup: impl Fn(&str) -> crate::error::Result, +) -> crate::error::Result { + let mut result = String::new(); + let mut chars = input.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '$' { + // Check if it's ${VAR} or $VAR syntax + let use_braces = chars.peek() == Some(&'{'); + if use_braces { + chars.next(); // consume '{' + } + + // Collect variable name + let mut var_name = String::new(); + while let Some(&next_ch) = chars.peek() { + if use_braces { + if next_ch == '}' { + chars.next(); // consume '}' + break; + } + var_name.push(next_ch); + chars.next(); + } else { + // For $VAR syntax, stop at non-alphanumeric/underscore + if next_ch.is_alphanumeric() || next_ch == '_' { + var_name.push(next_ch); + chars.next(); + } else { + break; + } + } + } + + if var_name.is_empty() { + return Err(crate::error::PinakesError::Config( + "empty environment variable name".to_string(), + )); + } + + result.push_str(&lookup(&var_name)?); + } else if ch == '\\' { + // Handle escaped characters + if let Some(&next_ch) = chars.peek() { + if next_ch == '$' { + chars.next(); // consume the escaped $ + result.push('$'); + } else { + result.push(ch); + } + } else { + result.push(ch); + } + } else { + result.push(ch); + } + } + + Ok(result) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + pub storage: StorageConfig, + pub directories: DirectoryConfig, + pub scanning: ScanningConfig, + pub server: ServerConfig, + #[serde(default)] + pub ui: UiConfig, + #[serde(default)] + pub accounts: AccountsConfig, + #[serde(default)] + pub rate_limits: RateLimitConfig, + #[serde(default)] + pub jobs: JobsConfig, + #[serde(default)] + pub thumbnails: ThumbnailConfig, + #[serde(default)] + pub webhooks: Vec, + #[serde(default)] + pub scheduled_tasks: Vec, + #[serde(default)] + pub plugins: PluginsConfig, + #[serde(default)] + pub transcoding: TranscodingConfig, + #[serde(default)] + pub enrichment: EnrichmentConfig, + #[serde(default)] + pub cloud: CloudConfig, + #[serde(default)] + pub analytics: AnalyticsConfig, + #[serde(default)] + pub photos: PhotoConfig, + #[serde(default)] + pub managed_storage: ManagedStorageConfig, + #[serde(default)] + pub sync: SyncConfig, + #[serde(default)] + pub sharing: SharingConfig, + #[serde(default)] + pub trash: TrashConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScheduledTaskConfig { + pub id: String, + pub enabled: bool, + pub schedule: Schedule, + pub last_run: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RateLimitConfig { + /// Global rate limit: requests per second (token replenish interval). + /// Default: 1 (combined with `burst_size=100` gives ~100 req/sec) + #[serde(default = "default_global_per_second")] + pub global_per_second: u64, + /// Global rate limit: burst size (max concurrent requests per IP) + #[serde(default = "default_global_burst")] + pub global_burst_size: u32, + /// Login rate limit: seconds between token replenishment. + /// Default: 12 (one token every 12s, combined with burst=5 gives ~5 req/min) + #[serde(default = "default_login_per_second")] + pub login_per_second: u64, + /// Login rate limit: burst size + #[serde(default = "default_login_burst")] + pub login_burst_size: u32, + /// Search rate limit: seconds between token replenishment. + /// Default: 6 (one token every 6s, combined with burst=10 gives ~10 req/min) + #[serde(default = "default_search_per_second")] + pub search_per_second: u64, + /// Search rate limit: burst size + #[serde(default = "default_search_burst")] + pub search_burst_size: u32, + /// Streaming rate limit: seconds between token replenishment. + /// Default: 60 (one per minute) + #[serde(default = "default_stream_per_second")] + pub stream_per_second: u64, + /// Streaming rate limit: burst size (max concurrent streams) + #[serde(default = "default_stream_burst")] + pub stream_burst_size: u32, + /// Share token rate limit: seconds between token replenishment. + /// Default: 2 + #[serde(default = "default_share_per_second")] + pub share_per_second: u64, + /// Share token rate limit: burst size + #[serde(default = "default_share_burst")] + pub share_burst_size: u32, +} + +const fn default_global_per_second() -> u64 { + 1 +} +const fn default_global_burst() -> u32 { + 100 +} +const fn default_login_per_second() -> u64 { + 12 +} +const fn default_login_burst() -> u32 { + 5 +} +const fn default_search_per_second() -> u64 { + 6 +} +const fn default_search_burst() -> u32 { + 10 +} +const fn default_stream_per_second() -> u64 { + 60 +} +const fn default_stream_burst() -> u32 { + 5 +} +const fn default_share_per_second() -> u64 { + 2 +} +const fn default_share_burst() -> u32 { + 20 +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + global_per_second: default_global_per_second(), + global_burst_size: default_global_burst(), + login_per_second: default_login_per_second(), + login_burst_size: default_login_burst(), + search_per_second: default_search_per_second(), + search_burst_size: default_search_burst(), + stream_per_second: default_stream_per_second(), + stream_burst_size: default_stream_burst(), + share_per_second: default_share_per_second(), + share_burst_size: default_share_burst(), + } + } +} + +impl RateLimitConfig { + /// Validate that all rate limit values are positive. + /// + /// # Errors + /// + /// Returns an error string if any rate limit value is zero. + pub fn validate(&self) -> Result<(), String> { + for (name, value) in [ + ("global_per_second", self.global_per_second), + ("global_burst_size", u64::from(self.global_burst_size)), + ("login_per_second", self.login_per_second), + ("login_burst_size", u64::from(self.login_burst_size)), + ("search_per_second", self.search_per_second), + ("search_burst_size", u64::from(self.search_burst_size)), + ("stream_per_second", self.stream_per_second), + ("stream_burst_size", u64::from(self.stream_burst_size)), + ("share_per_second", self.share_per_second), + ("share_burst_size", u64::from(self.share_burst_size)), + ] { + if value == 0 { + return Err(format!("{name} must be > 0")); + } + } + Ok(()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JobsConfig { + #[serde(default = "default_worker_count")] + pub worker_count: usize, + #[serde(default = "default_cache_ttl")] + pub cache_ttl_secs: u64, + /// Maximum time a job is allowed to run before being cancelled (in seconds). + /// Set to 0 to disable timeout. Default: 3600 (1 hour). + #[serde(default = "default_job_timeout")] + pub job_timeout_secs: u64, +} + +const fn default_worker_count() -> usize { + 2 +} +const fn default_cache_ttl() -> u64 { + 60 +} +const fn default_job_timeout() -> u64 { + 3600 +} + +impl Default for JobsConfig { + fn default() -> Self { + Self { + worker_count: default_worker_count(), + cache_ttl_secs: default_cache_ttl(), + job_timeout_secs: default_job_timeout(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThumbnailConfig { + #[serde(default = "default_thumb_size")] + pub size: u32, + #[serde(default = "default_thumb_quality")] + pub quality: u8, + #[serde(default)] + pub ffmpeg_path: Option, + #[serde(default = "default_video_seek")] + pub video_seek_secs: u32, +} + +const fn default_thumb_size() -> u32 { + 320 +} +const fn default_thumb_quality() -> u8 { + 80 +} +const fn default_video_seek() -> u32 { + 2 +} + +impl Default for ThumbnailConfig { + fn default() -> Self { + Self { + size: default_thumb_size(), + quality: default_thumb_quality(), + ffmpeg_path: None, + video_seek_secs: default_video_seek(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebhookConfig { + pub url: String, + pub events: Vec, + #[serde(default)] + pub secret: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UiConfig { + #[serde(default = "default_theme")] + pub theme: String, + #[serde(default = "default_view")] + pub default_view: String, + #[serde(default = "default_page_size")] + pub default_page_size: usize, + #[serde(default = "default_view_mode")] + pub default_view_mode: String, + #[serde(default)] + pub auto_play_media: bool, + #[serde(default = "default_true")] + pub show_thumbnails: bool, + #[serde(default)] + pub sidebar_collapsed: bool, +} + +fn default_theme() -> String { + "dark".to_string() +} +fn default_view() -> String { + "library".to_string() +} +const fn default_page_size() -> usize { + 50 +} +fn default_view_mode() -> String { + "grid".to_string() +} +const fn default_true() -> bool { + true +} + +impl Default for UiConfig { + fn default() -> Self { + Self { + theme: default_theme(), + default_view: default_view(), + default_page_size: default_page_size(), + default_view_mode: default_view_mode(), + auto_play_media: false, + show_thumbnails: true, + sidebar_collapsed: false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccountsConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub users: Vec, + /// Session expiry in hours. Defaults to 24. + #[serde(default = "default_session_expiry_hours")] + pub session_expiry_hours: u64, +} + +const fn default_session_expiry_hours() -> u64 { + 24 +} + +impl Default for AccountsConfig { + fn default() -> Self { + Self { + enabled: false, + users: Vec::new(), + session_expiry_hours: default_session_expiry_hours(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserAccount { + pub username: String, + pub password_hash: String, + #[serde(default)] + pub role: UserRole, +} + +#[derive( + Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize, +)] +#[serde(rename_all = "lowercase")] +pub enum UserRole { + Admin, + Editor, + #[default] + Viewer, +} + +impl UserRole { + #[must_use] + pub const fn can_read(self) -> bool { + true + } + + #[must_use] + pub const fn can_write(self) -> bool { + matches!(self, Self::Admin | Self::Editor) + } + + #[must_use] + pub const fn can_admin(self) -> bool { + matches!(self, Self::Admin) + } +} + +impl std::fmt::Display for UserRole { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Admin => write!(f, "admin"), + Self::Editor => write!(f, "editor"), + Self::Viewer => write!(f, "viewer"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginTimeoutConfig { + /// Timeout for capability discovery queries (`supported_types`, + /// `interested_events`) + #[serde(default = "default_capability_query_timeout")] + pub capability_query_secs: u64, + /// Timeout for processing calls (`extract_metadata`, `generate_thumbnail`) + #[serde(default = "default_processing_timeout")] + pub processing_secs: u64, + /// Timeout for event handler calls + #[serde(default = "default_event_handler_timeout")] + pub event_handler_secs: u64, +} + +const fn default_capability_query_timeout() -> u64 { + 2 +} + +const fn default_processing_timeout() -> u64 { + 30 +} + +const fn default_event_handler_timeout() -> u64 { + 10 +} + +impl Default for PluginTimeoutConfig { + fn default() -> Self { + Self { + capability_query_secs: default_capability_query_timeout(), + processing_secs: default_processing_timeout(), + event_handler_secs: default_event_handler_timeout(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginsConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_plugin_data_dir")] + pub data_dir: PathBuf, + #[serde(default = "default_plugin_cache_dir")] + pub cache_dir: PathBuf, + #[serde(default)] + pub plugin_dirs: Vec, + #[serde(default)] + pub enable_hot_reload: bool, + #[serde(default)] + pub allow_unsigned: bool, + #[serde(default = "default_max_concurrent_ops")] + pub max_concurrent_ops: usize, + #[serde(default = "default_plugin_timeout")] + pub plugin_timeout_secs: u64, + #[serde(default)] + pub timeouts: PluginTimeoutConfig, + #[serde(default = "default_max_consecutive_failures")] + pub max_consecutive_failures: u32, + + /// Hex-encoded Ed25519 public keys trusted for plugin signature + /// verification. Each entry is 64 hex characters (32 bytes). + #[serde(default)] + pub trusted_keys: Vec, +} + +fn default_plugin_data_dir() -> PathBuf { + Config::default_data_dir().join("plugins").join("data") +} + +fn default_plugin_cache_dir() -> PathBuf { + Config::default_data_dir().join("plugins").join("cache") +} + +const fn default_max_concurrent_ops() -> usize { + 4 +} + +const fn default_plugin_timeout() -> u64 { + 30 +} + +const fn default_max_consecutive_failures() -> u32 { + 5 +} + +impl Default for PluginsConfig { + fn default() -> Self { + Self { + enabled: false, + data_dir: default_plugin_data_dir(), + cache_dir: default_plugin_cache_dir(), + plugin_dirs: vec![], + enable_hot_reload: false, + allow_unsigned: false, + max_concurrent_ops: default_max_concurrent_ops(), + plugin_timeout_secs: default_plugin_timeout(), + timeouts: PluginTimeoutConfig::default(), + max_consecutive_failures: default_max_consecutive_failures(), + trusted_keys: vec![], + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TranscodingConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub cache_dir: Option, + #[serde(default = "default_cache_ttl_hours")] + pub cache_ttl_hours: u64, + #[serde(default = "default_max_concurrent_transcodes")] + pub max_concurrent: usize, + #[serde(default)] + pub hardware_acceleration: Option, + #[serde(default)] + pub profiles: Vec, +} + +const fn default_cache_ttl_hours() -> u64 { + 48 +} + +const fn default_max_concurrent_transcodes() -> usize { + 2 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TranscodeProfile { + pub name: String, + pub video_codec: String, + pub audio_codec: String, + pub max_bitrate_kbps: u32, + pub max_resolution: String, +} + +impl Default for TranscodingConfig { + fn default() -> Self { + Self { + enabled: false, + cache_dir: None, + cache_ttl_hours: default_cache_ttl_hours(), + max_concurrent: default_max_concurrent_transcodes(), + hardware_acceleration: None, + profiles: vec![ + TranscodeProfile { + name: "high".to_string(), + video_codec: "h264".to_string(), + audio_codec: "aac".to_string(), + max_bitrate_kbps: 8000, + max_resolution: "1080p".to_string(), + }, + TranscodeProfile { + name: "medium".to_string(), + video_codec: "h264".to_string(), + audio_codec: "aac".to_string(), + max_bitrate_kbps: 4000, + max_resolution: "720p".to_string(), + }, + ], + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct EnrichmentConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub auto_enrich_on_import: bool, + #[serde(default)] + pub sources: EnrichmentSources, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct EnrichmentSources { + #[serde(default)] + pub musicbrainz: EnrichmentSource, + #[serde(default)] + pub tmdb: EnrichmentSource, + #[serde(default)] + pub lastfm: EnrichmentSource, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct EnrichmentSource { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub api_key: Option, + #[serde(default)] + pub api_endpoint: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CloudConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_auto_sync_interval")] + pub auto_sync_interval_mins: u64, + #[serde(default)] + pub accounts: Vec, +} + +const fn default_auto_sync_interval() -> u64 { + 60 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CloudAccount { + pub id: String, + pub provider: String, + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub sync_rules: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CloudSyncRule { + pub local_path: PathBuf, + pub remote_path: String, + pub direction: CloudSyncDirection, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CloudSyncDirection { + Upload, + Download, + Bidirectional, +} + +impl Default for CloudConfig { + fn default() -> Self { + Self { + enabled: false, + auto_sync_interval_mins: default_auto_sync_interval(), + accounts: vec![], + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnalyticsConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_true")] + pub track_usage: bool, + #[serde(default = "default_retention_days")] + pub retention_days: u64, +} + +const fn default_retention_days() -> u64 { + 90 +} + +impl Default for AnalyticsConfig { + fn default() -> Self { + Self { + enabled: false, + track_usage: true, + retention_days: default_retention_days(), + } + } +} + +/// Feature toggles for photo processing (image analysis features). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PhotoFeatures { + /// Generate perceptual hashes for image duplicate detection (CPU-intensive) + #[serde(default = "default_true")] + pub generate_perceptual_hash: bool, + + /// Automatically create tags from EXIF keywords + #[serde(default)] + pub auto_tag_from_exif: bool, + + /// Generate multi-resolution thumbnails (tiny, grid, preview) + #[serde(default)] + pub multi_resolution_thumbnails: bool, +} + +impl Default for PhotoFeatures { + fn default() -> Self { + Self { + generate_perceptual_hash: true, + auto_tag_from_exif: false, + multi_resolution_thumbnails: false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PhotoConfig { + /// Feature toggles for photo processing + #[serde(flatten)] + pub features: PhotoFeatures, + + /// Auto-detect photo events/albums based on time and location + #[serde(default)] + pub enable_event_detection: bool, + + /// Minimum number of photos to form an event + #[serde(default = "default_min_event_photos")] + pub min_event_photos: usize, + + /// Maximum time gap between photos in the same event (in seconds) + #[serde(default = "default_event_time_gap")] + pub event_time_gap_secs: i64, + + /// Maximum distance between photos in the same event (in kilometers) + #[serde(default = "default_event_distance")] + pub event_max_distance_km: f64, +} + +impl PhotoConfig { + /// Returns true if perceptual hashing is enabled. + #[must_use] + pub const fn generate_perceptual_hash(&self) -> bool { + self.features.generate_perceptual_hash + } + + /// Returns true if auto-tagging from EXIF is enabled. + #[must_use] + pub const fn auto_tag_from_exif(&self) -> bool { + self.features.auto_tag_from_exif + } + + /// Returns true if multi-resolution thumbnails are enabled. + #[must_use] + pub const fn multi_resolution_thumbnails(&self) -> bool { + self.features.multi_resolution_thumbnails + } +} + +const fn default_min_event_photos() -> usize { + 5 +} + +const fn default_event_time_gap() -> i64 { + 2 * 60 * 60 // 2 hours +} + +const fn default_event_distance() -> f64 { + 1.0 // 1 km +} + +impl Default for PhotoConfig { + fn default() -> Self { + Self { + features: PhotoFeatures::default(), + enable_event_detection: false, + min_event_photos: default_min_event_photos(), + event_time_gap_secs: default_event_time_gap(), + event_max_distance_km: default_event_distance(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManagedStorageConfig { + /// Enable managed storage for file uploads + #[serde(default)] + pub enabled: bool, + /// Directory where managed files are stored + #[serde(default = "default_managed_storage_dir")] + pub storage_dir: PathBuf, + /// Maximum upload size in bytes (default: 10GB) + #[serde(default = "default_max_upload_size")] + pub max_upload_size: u64, + /// Allowed MIME types for uploads (empty = allow all) + #[serde(default)] + pub allowed_mime_types: Vec, + /// Automatically clean up orphaned blobs + #[serde(default = "default_true")] + pub auto_cleanup: bool, + /// Verify file integrity on read + #[serde(default)] + pub verify_on_read: bool, +} + +fn default_managed_storage_dir() -> PathBuf { + Config::default_data_dir().join("managed") +} + +const fn default_max_upload_size() -> u64 { + 10 * 1024 * 1024 * 1024 // 10GB +} + +impl Default for ManagedStorageConfig { + fn default() -> Self { + Self { + enabled: false, + storage_dir: default_managed_storage_dir(), + max_upload_size: default_max_upload_size(), + allowed_mime_types: vec![], + auto_cleanup: true, + verify_on_read: false, + } + } +} + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, +)] +#[serde(rename_all = "snake_case")] +pub enum ConflictResolution { + ServerWins, + ClientWins, + #[default] + KeepBoth, + Manual, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SyncConfig { + /// Enable cross-device sync functionality + #[serde(default)] + pub enabled: bool, + /// Default conflict resolution strategy + #[serde(default)] + pub default_conflict_resolution: ConflictResolution, + /// Maximum file size for sync in MB + #[serde(default = "default_max_sync_file_size")] + pub max_file_size_mb: u64, + /// Chunk size for chunked uploads in KB + #[serde(default = "default_chunk_size")] + pub chunk_size_kb: u64, + /// Upload session timeout in hours + #[serde(default = "default_upload_timeout")] + pub upload_timeout_hours: u64, + /// Maximum concurrent uploads per device + #[serde(default = "default_max_concurrent_uploads")] + pub max_concurrent_uploads: usize, + /// Sync log retention in days + #[serde(default = "default_sync_log_retention")] + pub sync_log_retention_days: u64, + /// Temporary directory for chunked upload storage + #[serde(default = "default_temp_upload_dir")] + pub temp_upload_dir: PathBuf, +} + +const fn default_max_sync_file_size() -> u64 { + 4096 // 4GB +} + +const fn default_chunk_size() -> u64 { + 4096 // 4MB +} + +const fn default_upload_timeout() -> u64 { + 24 // 24 hours +} + +const fn default_max_concurrent_uploads() -> usize { + 3 +} + +const fn default_sync_log_retention() -> u64 { + 90 // 90 days +} + +fn default_temp_upload_dir() -> PathBuf { + Config::default_data_dir().join("temp_uploads") +} + +impl Default for SyncConfig { + fn default() -> Self { + Self { + enabled: false, + default_conflict_resolution: ConflictResolution::default(), + max_file_size_mb: default_max_sync_file_size(), + chunk_size_kb: default_chunk_size(), + upload_timeout_hours: default_upload_timeout(), + max_concurrent_uploads: default_max_concurrent_uploads(), + sync_log_retention_days: default_sync_log_retention(), + temp_upload_dir: default_temp_upload_dir(), + } + } +} + +/// Core permission flags for the sharing subsystem. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SharingPermissions { + /// Enable sharing functionality + #[serde(default = "default_true")] + pub enabled: bool, + /// Allow creating public share links + #[serde(default = "default_true")] + pub allow_public_links: bool, + /// Allow users to reshare content shared with them + #[serde(default = "default_true")] + pub allow_reshare: bool, +} + +impl Default for SharingPermissions { + fn default() -> Self { + Self { + enabled: true, + allow_public_links: true, + allow_reshare: true, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SharingConfig { + /// Core permission flags for sharing + #[serde(flatten)] + pub permissions: SharingPermissions, + /// Require password for public share links + #[serde(default)] + pub require_public_link_password: bool, + /// Enable share notifications + #[serde(default = "default_true")] + pub notifications_enabled: bool, + /// Maximum expiry time for public links in hours (0 = unlimited) + #[serde(default)] + pub max_public_link_expiry_hours: u64, + /// Notification retention in days + #[serde(default = "default_notification_retention")] + pub notification_retention_days: u64, + /// Share activity log retention in days + #[serde(default = "default_activity_retention")] + pub activity_retention_days: u64, +} + +impl SharingConfig { + /// Returns true if sharing is enabled. + #[must_use] + pub const fn enabled(&self) -> bool { + self.permissions.enabled + } + + /// Returns true if public links are allowed. + #[must_use] + pub const fn allow_public_links(&self) -> bool { + self.permissions.allow_public_links + } + + /// Returns true if resharing is allowed. + #[must_use] + pub const fn allow_reshare(&self) -> bool { + self.permissions.allow_reshare + } +} + +const fn default_notification_retention() -> u64 { + 30 +} + +const fn default_activity_retention() -> u64 { + 90 +} + +impl Default for SharingConfig { + fn default() -> Self { + Self { + permissions: SharingPermissions::default(), + require_public_link_password: false, + notifications_enabled: true, + max_public_link_expiry_hours: 0, + notification_retention_days: default_notification_retention(), + activity_retention_days: default_activity_retention(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrashConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_trash_retention_days")] + pub retention_days: u64, + #[serde(default)] + pub auto_empty: bool, +} + +const fn default_trash_retention_days() -> u64 { + 30 +} + +impl Default for TrashConfig { + fn default() -> Self { + Self { + enabled: false, + retention_days: default_trash_retention_days(), + auto_empty: false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StorageConfig { + pub backend: StorageBackendType, + pub sqlite: Option, + pub postgres: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum StorageBackendType { + Sqlite, + Postgres, +} + +impl StorageBackendType { + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + Self::Sqlite => "sqlite", + Self::Postgres => "postgres", + } + } +} + +impl std::fmt::Display for StorageBackendType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SqliteConfig { + pub path: PathBuf, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PostgresConfig { + pub host: String, + pub port: u16, + pub database: String, + pub username: String, + pub password: String, + pub max_connections: usize, + /// Enable TLS for `PostgreSQL` connections + #[serde(default)] + pub tls_enabled: bool, + /// Verify TLS certificates (default: true) + #[serde(default = "default_true")] + pub tls_verify_ca: bool, + /// Path to custom CA certificate file (PEM format) + #[serde(default)] + pub tls_ca_cert_path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DirectoryConfig { + pub roots: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScanningConfig { + pub watch: bool, + pub poll_interval_secs: u64, + pub ignore_patterns: Vec, + #[serde(default = "default_import_concurrency")] + pub import_concurrency: usize, +} + +const fn default_import_concurrency() -> usize { + 8 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + pub host: String, + pub port: u16, + /// Optional API key for bearer token authentication. + /// If set, all requests (except /health) must include `Authorization: Bearer + /// `. Can also be set via `PINAKES_API_KEY` environment variable. + pub api_key: Option, + /// Explicitly disable authentication (INSECURE - use only for development). + /// When true, all requests are allowed without authentication. + /// This must be explicitly set to true; empty `api_key` alone is not + /// sufficient. + #[serde(default)] + pub authentication_disabled: bool, + /// Enable CORS (Cross-Origin Resource Sharing). + /// When false, default localhost origins are used. + #[serde(default)] + pub cors_enabled: bool, + /// Allowed CORS origins when `cors_enabled` is true. + /// If empty and `cors_enabled` is true, defaults to localhost origins. + #[serde(default)] + pub cors_origins: Vec, + /// TLS/HTTPS configuration + #[serde(default)] + pub tls: TlsConfig, + /// Enable the Swagger UI at /api/docs. + /// Defaults to true. Set to false to disable in production if desired. + #[serde(default = "default_true")] + pub swagger_ui: bool, +} + +/// TLS/HTTPS configuration for secure connections +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsConfig { + /// Enable TLS (HTTPS) + #[serde(default)] + pub enabled: bool, + /// Path to the TLS certificate file (PEM format) + #[serde(default)] + pub cert_path: Option, + /// Path to the TLS private key file (PEM format) + #[serde(default)] + pub key_path: Option, + /// Enable HTTP to HTTPS redirect (starts a second listener on `http_port`) + #[serde(default)] + pub redirect_http: bool, + /// Port for HTTP redirect listener (default: 80) + #[serde(default = "default_http_port")] + pub http_port: u16, + /// Enable HSTS (HTTP Strict Transport Security) header + #[serde(default = "default_true")] + pub hsts_enabled: bool, + /// HSTS max-age in seconds (default: 1 year) + #[serde(default = "default_hsts_max_age")] + pub hsts_max_age: u64, +} + +const fn default_http_port() -> u16 { + 80 +} + +const fn default_hsts_max_age() -> u64 { + 31_536_000 // 1 year in seconds +} + +impl Default for TlsConfig { + fn default() -> Self { + Self { + enabled: false, + cert_path: None, + key_path: None, + redirect_http: false, + http_port: default_http_port(), + hsts_enabled: true, + hsts_max_age: default_hsts_max_age(), + } + } +} + +impl TlsConfig { + /// Validate TLS configuration + /// + /// # Errors + /// + /// Returns an error string if TLS is enabled but required paths are missing + /// or invalid. + pub fn validate(&self) -> Result<(), String> { + if self.enabled { + if self.cert_path.is_none() { + return Err("TLS enabled but cert_path not specified".into()); + } + if self.key_path.is_none() { + return Err("TLS enabled but key_path not specified".into()); + } + if let Some(ref cert_path) = self.cert_path + && !cert_path.exists() + { + return Err(format!( + "TLS certificate file not found: {}", + cert_path.display() + )); + } + if let Some(ref key_path) = self.key_path + && !key_path.exists() + { + return Err(format!("TLS key file not found: {}", key_path.display())); + } + } + Ok(()) + } +} + +impl Config { + /// Load configuration from a TOML file, expanding environment variables in + /// secret fields. + /// + /// # Errors + /// + /// Returns [`crate::error::PinakesError`] if the file cannot be read, parsed, + /// or contains invalid environment variable references. + pub fn from_file(path: &Path) -> crate::error::Result { + let content = std::fs::read_to_string(path).map_err(|e| { + crate::error::PinakesError::Config(format!( + "failed to read config file: {e}" + )) + })?; + let mut config: Self = toml::from_str(&content).map_err(|e| { + crate::error::PinakesError::Config(format!("failed to parse config: {e}")) + })?; + config.expand_env_vars()?; + Ok(config) + } + + /// Expand environment variables in secret fields. + /// Supports ${`VAR_NAME`} and $`VAR_NAME` syntax. + fn expand_env_vars(&mut self) -> crate::error::Result<()> { + // Postgres password + if let Some(ref mut postgres) = self.storage.postgres { + postgres.password = expand_env_var_string(&postgres.password)?; + } + + // Server API key + if let Some(ref api_key) = self.server.api_key { + self.server.api_key = Some(expand_env_var_string(api_key)?); + } + + // Webhook secrets + for webhook in &mut self.webhooks { + if let Some(ref secret) = webhook.secret { + webhook.secret = Some(expand_env_var_string(secret)?); + } + } + + // Enrichment API keys + if let Some(ref api_key) = self.enrichment.sources.musicbrainz.api_key { + self.enrichment.sources.musicbrainz.api_key = + Some(expand_env_var_string(api_key)?); + } + if let Some(ref api_key) = self.enrichment.sources.tmdb.api_key { + self.enrichment.sources.tmdb.api_key = + Some(expand_env_var_string(api_key)?); + } + if let Some(ref api_key) = self.enrichment.sources.lastfm.api_key { + self.enrichment.sources.lastfm.api_key = + Some(expand_env_var_string(api_key)?); + } + + Ok(()) + } + + /// Try loading from file, falling back to defaults if the file doesn't exist. + /// + /// # Errors + /// + /// Returns [`crate::error::PinakesError`] if the file exists but cannot be + /// read or parsed. + pub fn load_or_default(path: &Path) -> crate::error::Result { + if path.exists() { + Self::from_file(path) + } else { + let config = Self::default(); + // Ensure the data directory exists for the default SQLite database + config.ensure_dirs()?; + Ok(config) + } + } + + /// Save the current config to a TOML file. + /// + /// # Errors + /// + /// Returns [`crate::error::PinakesError`] if the file cannot be written or + /// the config cannot be serialized. + pub fn save_to_file(&self, path: &Path) -> crate::error::Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let content = toml::to_string_pretty(self).map_err(|e| { + crate::error::PinakesError::Config(format!( + "failed to serialize config: {e}" + )) + })?; + std::fs::write(path, content)?; + Ok(()) + } + + /// Ensure all directories needed by this config exist and are writable. + /// + /// # Errors + /// + /// Returns [`crate::error::PinakesError`] if a required directory cannot be + /// created or is read-only. + pub fn ensure_dirs(&self) -> crate::error::Result<()> { + if let Some(ref sqlite) = self.storage.sqlite + && let Some(parent) = sqlite.path.parent() + { + // Skip if parent is empty string (happens with bare filenames like + // "pinakes.db") + if !parent.as_os_str().is_empty() { + std::fs::create_dir_all(parent)?; + let metadata = std::fs::metadata(parent)?; + if metadata.permissions().readonly() { + return Err(crate::error::PinakesError::Config(format!( + "directory is not writable: {}", + parent.display() + ))); + } + } + } + Ok(()) + } + + /// Returns the default config file path following XDG conventions. + #[must_use] + pub fn default_config_path() -> PathBuf { + std::env::var("XDG_CONFIG_HOME").map_or_else( + |_| { + std::env::var("HOME").map_or_else( + |_| PathBuf::from("pinakes.toml"), + |home| { + PathBuf::from(home) + .join(".config") + .join("pinakes") + .join("pinakes.toml") + }, + ) + }, + |xdg| PathBuf::from(xdg).join("pinakes").join("pinakes.toml"), + ) + } + + /// Validate configuration values for correctness. + /// + /// # Errors + /// + /// Returns an error string if any configuration value is invalid. + pub fn validate(&self) -> Result<(), String> { + if self.server.port == 0 { + return Err("server port cannot be 0".into()); + } + if self.server.host.is_empty() { + return Err("server host cannot be empty".into()); + } + if self.scanning.poll_interval_secs == 0 { + return Err("poll interval cannot be 0".into()); + } + if self.scanning.import_concurrency == 0 + || self.scanning.import_concurrency > 256 + { + return Err("import_concurrency must be between 1 and 256".into()); + } + + // Validate authentication configuration + let has_api_key = + self.server.api_key.as_ref().is_some_and(|k| !k.is_empty()); + let has_accounts = !self.accounts.users.is_empty(); + let auth_disabled = self.server.authentication_disabled; + + if !auth_disabled && !has_api_key && !has_accounts { + return Err( + "authentication is not configured: set an api_key, configure user \ + accounts, or explicitly set authentication_disabled = true" + .into(), + ); + } + + // Empty API key is not allowed (must use authentication_disabled flag) + if let Some(ref api_key) = self.server.api_key + && api_key.is_empty() + { + return Err( + "empty api_key is not allowed. To disable authentication, set \ + authentication_disabled = true instead" + .into(), + ); + } + + // Require TLS when authentication is enabled on non-localhost + let is_localhost = self.server.host == "127.0.0.1" + || self.server.host == "localhost" + || self.server.host == "::1"; + + if (has_api_key || has_accounts) + && !auth_disabled + && !is_localhost + && !self.server.tls.enabled + { + return Err( + "TLS must be enabled when authentication is used on non-localhost \ + hosts. Set server.tls.enabled = true or bind to localhost only" + .into(), + ); + } + + // Validate rate limits + self.rate_limits.validate()?; + + // Validate TLS configuration + self.server.tls.validate()?; + Ok(()) + } + + /// Returns the default data directory following XDG conventions. + #[must_use] + pub fn default_data_dir() -> PathBuf { + std::env::var("XDG_DATA_HOME").map_or_else( + |_| { + std::env::var("HOME").map_or_else( + |_| PathBuf::from("pinakes-data"), + |home| { + PathBuf::from(home) + .join(".local") + .join("share") + .join("pinakes") + }, + ) + }, + |xdg| PathBuf::from(xdg).join("pinakes"), + ) + } +} + +impl Default for Config { + fn default() -> Self { + let data_dir = Self::default_data_dir(); + Self { + storage: StorageConfig { + backend: StorageBackendType::Sqlite, + sqlite: Some(SqliteConfig { + path: data_dir.join("pinakes.db"), + }), + postgres: None, + }, + directories: DirectoryConfig { roots: vec![] }, + scanning: ScanningConfig { + watch: false, + poll_interval_secs: 300, + ignore_patterns: vec![ + ".*".to_string(), + "node_modules".to_string(), + "__pycache__".to_string(), + "target".to_string(), + ], + import_concurrency: default_import_concurrency(), + }, + server: ServerConfig { + host: "127.0.0.1".to_string(), + port: 3000, + api_key: None, + authentication_disabled: false, + cors_enabled: false, + cors_origins: vec![], + tls: TlsConfig::default(), + swagger_ui: true, + }, + ui: UiConfig::default(), + accounts: AccountsConfig::default(), + rate_limits: RateLimitConfig::default(), + jobs: JobsConfig::default(), + thumbnails: ThumbnailConfig::default(), + webhooks: vec![], + scheduled_tasks: vec![], + plugins: PluginsConfig::default(), + transcoding: TranscodingConfig::default(), + enrichment: EnrichmentConfig::default(), + cloud: CloudConfig::default(), + analytics: AnalyticsConfig::default(), + photos: PhotoConfig::default(), + managed_storage: ManagedStorageConfig::default(), + sync: SyncConfig::default(), + sharing: SharingConfig::default(), + trash: TrashConfig::default(), + } + } +} + +#[cfg(test)] +mod tests { + use rustc_hash::FxHashMap; + + use super::*; + + fn test_config_with_concurrency(concurrency: usize) -> Config { + let mut config = Config::default(); + config.scanning.import_concurrency = concurrency; + config.server.authentication_disabled = true; // Disable auth for concurrency tests + config + } + + #[test] + fn test_validate_import_concurrency_zero() { + let config = test_config_with_concurrency(0); + assert!(config.validate().is_err()); + assert!( + config + .validate() + .unwrap_err() + .contains("import_concurrency") + ); + } + + #[test] + fn test_validate_import_concurrency_too_high() { + let config = test_config_with_concurrency(257); + assert!(config.validate().is_err()); + assert!( + config + .validate() + .unwrap_err() + .contains("import_concurrency") + ); + } + + #[test] + fn test_validate_import_concurrency_valid() { + let config = test_config_with_concurrency(8); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_validate_import_concurrency_boundary_low() { + let config = test_config_with_concurrency(1); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_validate_import_concurrency_boundary_high() { + let config = test_config_with_concurrency(256); + assert!(config.validate().is_ok()); + } + + // Environment variable expansion tests using expand_env_vars with a + // HashMap lookup. This avoids unsafe std::env::set_var and is + // thread-safe for parallel test execution. + fn test_lookup<'a>( + vars: &'a FxHashMap<&str, &str>, + ) -> impl Fn(&str) -> crate::error::Result + 'a { + move |name| { + vars + .get(name) + .map(std::string::ToString::to_string) + .ok_or_else(|| { + crate::error::PinakesError::Config(format!( + "environment variable not set: {name}" + )) + }) + } + } + + #[test] + fn test_expand_env_var_simple() { + let vars = [("TEST_VAR_SIMPLE", "test_value")] + .into_iter() + .collect::>(); + let result = expand_env_vars("$TEST_VAR_SIMPLE", test_lookup(&vars)); + assert_eq!(result.unwrap(), "test_value"); + } + + #[test] + fn test_expand_env_var_braces() { + let vars = [("TEST_VAR_BRACES", "test_value")] + .into_iter() + .collect::>(); + let result = expand_env_vars("${TEST_VAR_BRACES}", test_lookup(&vars)); + assert_eq!(result.unwrap(), "test_value"); + } + + #[test] + fn test_expand_env_var_embedded() { + let vars = [("TEST_VAR_EMBEDDED", "value")] + .into_iter() + .collect::>(); + let result = + expand_env_vars("prefix_${TEST_VAR_EMBEDDED}_suffix", test_lookup(&vars)); + assert_eq!(result.unwrap(), "prefix_value_suffix"); + } + + #[test] + fn test_expand_env_var_multiple() { + let vars = [("VAR1", "value1"), ("VAR2", "value2")] + .into_iter() + .collect::>(); + let result = expand_env_vars("${VAR1}_${VAR2}", test_lookup(&vars)); + assert_eq!(result.unwrap(), "value1_value2"); + } + + #[test] + fn test_expand_env_var_missing() { + let vars = FxHashMap::default(); + let result = expand_env_vars("${NONEXISTENT_VAR}", test_lookup(&vars)); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("environment variable not set") + ); + } + + #[test] + fn test_expand_env_var_empty_name() { + let vars = FxHashMap::default(); + let result = expand_env_vars("${}", test_lookup(&vars)); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("empty environment variable name") + ); + } + + #[test] + fn test_expand_env_var_escaped() { + let vars = FxHashMap::default(); + let result = expand_env_vars("\\$NOT_A_VAR", test_lookup(&vars)); + assert_eq!(result.unwrap(), "$NOT_A_VAR"); + } + + #[test] + fn test_expand_env_var_no_vars() { + let vars = FxHashMap::default(); + let result = expand_env_vars("plain_text", test_lookup(&vars)); + assert_eq!(result.unwrap(), "plain_text"); + } + + #[test] + fn test_expand_env_var_underscore() { + let vars = [("TEST_VAR_NAME", "value")] + .into_iter() + .collect::>(); + let result = expand_env_vars("$TEST_VAR_NAME", test_lookup(&vars)); + assert_eq!(result.unwrap(), "value"); + } + + #[test] + fn test_expand_env_var_mixed_syntax() { + let vars = [("VAR1_MIXED", "v1"), ("VAR2_MIXED", "v2")] + .into_iter() + .collect::>(); + let result = + expand_env_vars("$VAR1_MIXED and ${VAR2_MIXED}", test_lookup(&vars)); + assert_eq!(result.unwrap(), "v1 and v2"); + } +} diff --git a/crates/pinakes-types/src/error.rs b/crates/pinakes-types/src/error.rs new file mode 100644 index 0000000..33f0621 --- /dev/null +++ b/crates/pinakes-types/src/error.rs @@ -0,0 +1,142 @@ +use std::path::PathBuf; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum PinakesError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("database error: {0}")] + Database(String), + + #[error("migration error: {0}")] + Migration(String), + + #[error("configuration error: {0}")] + Config(String), + + #[error("media item not found: {0}")] + NotFound(String), + + #[error("duplicate content hash: {0}")] + DuplicateHash(String), + + #[error("unsupported media type for path: {0}")] + UnsupportedMediaType(PathBuf), + + #[error("metadata extraction failed: {0}")] + MetadataExtraction(String), + + #[error("thumbnail generation failed: {0}")] + ThumbnailGeneration(String), + + #[error("search query parse error: {0}")] + SearchParse(String), + + #[error("file not found at path: {0}")] + FileNotFound(PathBuf), + + #[error("tag not found: {0}")] + TagNotFound(String), + + #[error("collection not found: {0}")] + CollectionNotFound(String), + + #[error("invalid operation: {0}")] + InvalidOperation(String), + + #[error("invalid data: {0}")] + InvalidData(String), + + #[error("authentication error: {0}")] + Authentication(String), + + #[error("authorization error: {0}")] + Authorization(String), + + #[error("path not allowed: {0}")] + PathNotAllowed(String), + + #[error("external API error: {0}")] + External(String), + + // Managed Storage errors + #[error("managed storage not enabled")] + ManagedStorageDisabled, + + #[error("upload too large: {0} bytes exceeds limit")] + UploadTooLarge(u64), + + #[error("blob not found: {0}")] + BlobNotFound(String), + + #[error("storage integrity error: {0}")] + StorageIntegrity(String), + + // Sync errors + #[error("sync not enabled")] + SyncDisabled, + + #[error("device not found: {0}")] + DeviceNotFound(String), + + #[error("sync conflict: {0}")] + SyncConflict(String), + + #[error("upload session expired: {0}")] + UploadSessionExpired(String), + + #[error("upload session not found: {0}")] + UploadSessionNotFound(String), + + #[error("chunk out of order: expected {expected}, got {actual}")] + ChunkOutOfOrder { expected: u64, actual: u64 }, + + // Sharing errors + #[error("share not found: {0}")] + ShareNotFound(String), + + #[error("share expired: {0}")] + ShareExpired(String), + + #[error("share password required")] + SharePasswordRequired, + + #[error("share password invalid")] + SharePasswordInvalid, + + #[error("insufficient share permissions")] + InsufficientSharePermissions, + + #[error("serialization error: {0}")] + Serialization(String), + + #[error("external tool `{tool}` failed: {stderr}")] + ExternalTool { tool: String, stderr: String }, + + #[error("subtitle track {index} not found in media")] + SubtitleTrackNotFound { index: u32 }, + + #[error("invalid language code: {0}")] + InvalidLanguageCode(String), +} + +impl From for PinakesError { + fn from(e: serde_json::Error) -> Self { + Self::Serialization(e.to_string()) + } +} + +/// Build a closure that wraps a database error with operation context. +/// +/// Usage: `stmt.execute(params).map_err(db_ctx("insert_media", media_id))?;` +pub fn db_ctx( + operation: &str, + entity: impl std::fmt::Display, +) -> impl FnOnce(E) -> PinakesError { + let context = format!("{operation} [{entity}]"); + move |e| PinakesError::Database(format!("{context}: {e}")) +} + +pub type Result = std::result::Result; diff --git a/crates/pinakes-types/src/lib.rs b/crates/pinakes-types/src/lib.rs new file mode 100644 index 0000000..8b482c4 --- /dev/null +++ b/crates/pinakes-types/src/lib.rs @@ -0,0 +1,4 @@ +pub mod config; +pub mod error; +pub mod media_type; +pub mod model; diff --git a/crates/pinakes-types/src/media_type/builtin.rs b/crates/pinakes-types/src/media_type/builtin.rs new file mode 100644 index 0000000..93701b7 --- /dev/null +++ b/crates/pinakes-types/src/media_type/builtin.rs @@ -0,0 +1,292 @@ +use std::path::Path; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum BuiltinMediaType { + // Audio + Mp3, + Flac, + Ogg, + Wav, + Aac, + Opus, + + // Video + Mp4, + Mkv, + Avi, + Webm, + + // Documents + Pdf, + Epub, + Djvu, + + // Text + Markdown, + PlainText, + + // Images + Jpeg, + Png, + Gif, + Webp, + Svg, + Avif, + Tiff, + Bmp, + + // RAW Images + Cr2, + Nef, + Arw, + Dng, + Orf, + Rw2, + + // HEIC/HEIF + Heic, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum MediaCategory { + Audio, + Video, + Document, + Text, + Image, +} + +impl BuiltinMediaType { + /// Get the unique, stable ID for this media type. + #[must_use] + pub const fn id(&self) -> &'static str { + match self { + Self::Mp3 => "mp3", + Self::Flac => "flac", + Self::Ogg => "ogg", + Self::Wav => "wav", + Self::Aac => "aac", + Self::Opus => "opus", + Self::Mp4 => "mp4", + Self::Mkv => "mkv", + Self::Avi => "avi", + Self::Webm => "webm", + Self::Pdf => "pdf", + Self::Epub => "epub", + Self::Djvu => "djvu", + Self::Markdown => "markdown", + Self::PlainText => "plaintext", + Self::Jpeg => "jpeg", + Self::Png => "png", + Self::Gif => "gif", + Self::Webp => "webp", + Self::Svg => "svg", + Self::Avif => "avif", + Self::Tiff => "tiff", + Self::Bmp => "bmp", + Self::Cr2 => "cr2", + Self::Nef => "nef", + Self::Arw => "arw", + Self::Dng => "dng", + Self::Orf => "orf", + Self::Rw2 => "rw2", + Self::Heic => "heic", + } + } + + /// Get the display name for this media type + #[must_use] + pub fn name(&self) -> String { + match self { + Self::Mp3 => "MP3 Audio".to_string(), + Self::Flac => "FLAC Audio".to_string(), + Self::Ogg => "OGG Audio".to_string(), + Self::Wav => "WAV Audio".to_string(), + Self::Aac => "AAC Audio".to_string(), + Self::Opus => "Opus Audio".to_string(), + Self::Mp4 => "MP4 Video".to_string(), + Self::Mkv => "MKV Video".to_string(), + Self::Avi => "AVI Video".to_string(), + Self::Webm => "WebM Video".to_string(), + Self::Pdf => "PDF Document".to_string(), + Self::Epub => "EPUB eBook".to_string(), + Self::Djvu => "DjVu Document".to_string(), + Self::Markdown => "Markdown".to_string(), + Self::PlainText => "Plain Text".to_string(), + Self::Jpeg => "JPEG Image".to_string(), + Self::Png => "PNG Image".to_string(), + Self::Gif => "GIF Image".to_string(), + Self::Webp => "WebP Image".to_string(), + Self::Svg => "SVG Image".to_string(), + Self::Avif => "AVIF Image".to_string(), + Self::Tiff => "TIFF Image".to_string(), + Self::Bmp => "BMP Image".to_string(), + Self::Cr2 => "Canon RAW (CR2)".to_string(), + Self::Nef => "Nikon RAW (NEF)".to_string(), + Self::Arw => "Sony RAW (ARW)".to_string(), + Self::Dng => "Adobe DNG RAW".to_string(), + Self::Orf => "Olympus RAW (ORF)".to_string(), + Self::Rw2 => "Panasonic RAW (RW2)".to_string(), + Self::Heic => "HEIC Image".to_string(), + } + } + + #[must_use] + pub fn from_extension(ext: &str) -> Option { + match ext.to_ascii_lowercase().as_str() { + "mp3" => Some(Self::Mp3), + "flac" => Some(Self::Flac), + "ogg" | "oga" => Some(Self::Ogg), + "wav" => Some(Self::Wav), + "aac" | "m4a" => Some(Self::Aac), + "opus" => Some(Self::Opus), + "mp4" | "m4v" => Some(Self::Mp4), + "mkv" => Some(Self::Mkv), + "avi" => Some(Self::Avi), + "webm" => Some(Self::Webm), + "pdf" => Some(Self::Pdf), + "epub" => Some(Self::Epub), + "djvu" => Some(Self::Djvu), + "md" | "markdown" => Some(Self::Markdown), + "txt" | "text" => Some(Self::PlainText), + "jpg" | "jpeg" => Some(Self::Jpeg), + "png" => Some(Self::Png), + "gif" => Some(Self::Gif), + "webp" => Some(Self::Webp), + "svg" => Some(Self::Svg), + "avif" => Some(Self::Avif), + "tiff" | "tif" => Some(Self::Tiff), + "bmp" => Some(Self::Bmp), + "cr2" => Some(Self::Cr2), + "nef" => Some(Self::Nef), + "arw" => Some(Self::Arw), + "dng" => Some(Self::Dng), + "orf" => Some(Self::Orf), + "rw2" => Some(Self::Rw2), + "heic" | "heif" => Some(Self::Heic), + _ => None, + } + } + + pub fn from_path(path: &Path) -> Option { + path + .extension() + .and_then(|e| e.to_str()) + .and_then(Self::from_extension) + } + + #[must_use] + pub const fn mime_type(&self) -> &'static str { + match self { + Self::Mp3 => "audio/mpeg", + Self::Flac => "audio/flac", + Self::Ogg => "audio/ogg", + Self::Wav => "audio/wav", + Self::Aac => "audio/aac", + Self::Opus => "audio/opus", + Self::Mp4 => "video/mp4", + Self::Mkv => "video/x-matroska", + Self::Avi => "video/x-msvideo", + Self::Webm => "video/webm", + Self::Pdf => "application/pdf", + Self::Epub => "application/epub+zip", + Self::Djvu => "image/vnd.djvu", + Self::Markdown => "text/markdown", + Self::PlainText => "text/plain", + Self::Jpeg => "image/jpeg", + Self::Png => "image/png", + Self::Gif => "image/gif", + Self::Webp => "image/webp", + Self::Svg => "image/svg+xml", + Self::Avif => "image/avif", + Self::Tiff => "image/tiff", + Self::Bmp => "image/bmp", + Self::Cr2 => "image/x-canon-cr2", + Self::Nef => "image/x-nikon-nef", + Self::Arw => "image/x-sony-arw", + Self::Dng => "image/x-adobe-dng", + Self::Orf => "image/x-olympus-orf", + Self::Rw2 => "image/x-panasonic-rw2", + Self::Heic => "image/heic", + } + } + + #[must_use] + pub const fn category(&self) -> MediaCategory { + match self { + Self::Mp3 + | Self::Flac + | Self::Ogg + | Self::Wav + | Self::Aac + | Self::Opus => MediaCategory::Audio, + Self::Mp4 | Self::Mkv | Self::Avi | Self::Webm => MediaCategory::Video, + Self::Pdf | Self::Epub | Self::Djvu => MediaCategory::Document, + Self::Markdown | Self::PlainText => MediaCategory::Text, + Self::Jpeg + | Self::Png + | Self::Gif + | Self::Webp + | Self::Svg + | Self::Avif + | Self::Tiff + | Self::Bmp + | Self::Cr2 + | Self::Nef + | Self::Arw + | Self::Dng + | Self::Orf + | Self::Rw2 + | Self::Heic => MediaCategory::Image, + } + } + + #[must_use] + pub const fn extensions(&self) -> &'static [&'static str] { + match self { + Self::Mp3 => &["mp3"], + Self::Flac => &["flac"], + Self::Ogg => &["ogg", "oga"], + Self::Wav => &["wav"], + Self::Aac => &["aac", "m4a"], + Self::Opus => &["opus"], + Self::Mp4 => &["mp4", "m4v"], + Self::Mkv => &["mkv"], + Self::Avi => &["avi"], + Self::Webm => &["webm"], + Self::Pdf => &["pdf"], + Self::Epub => &["epub"], + Self::Djvu => &["djvu"], + Self::Markdown => &["md", "markdown"], + Self::PlainText => &["txt", "text"], + Self::Jpeg => &["jpg", "jpeg"], + Self::Png => &["png"], + Self::Gif => &["gif"], + Self::Webp => &["webp"], + Self::Svg => &["svg"], + Self::Avif => &["avif"], + Self::Tiff => &["tiff", "tif"], + Self::Bmp => &["bmp"], + Self::Cr2 => &["cr2"], + Self::Nef => &["nef"], + Self::Arw => &["arw"], + Self::Dng => &["dng"], + Self::Orf => &["orf"], + Self::Rw2 => &["rw2"], + Self::Heic => &["heic", "heif"], + } + } + + /// Returns true if this is a RAW image format. + #[must_use] + pub const fn is_raw(&self) -> bool { + matches!( + self, + Self::Cr2 | Self::Nef | Self::Arw | Self::Dng | Self::Orf | Self::Rw2 + ) + } +} diff --git a/crates/pinakes-types/src/media_type/mod.rs b/crates/pinakes-types/src/media_type/mod.rs new file mode 100644 index 0000000..2c73ef0 --- /dev/null +++ b/crates/pinakes-types/src/media_type/mod.rs @@ -0,0 +1,281 @@ +//! Media types +//! +//! Supports both +//! built-in media types and plugin-registered custom types. + +use std::path::Path; + +use serde::{Deserialize, Serialize}; + +pub mod builtin; +pub mod registry; + +pub use builtin::{BuiltinMediaType, MediaCategory}; +pub use registry::{MediaTypeDescriptor, MediaTypeRegistry}; + +/// Media type identifier, can be either built-in or custom +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MediaType { + /// Built-in media type (backward compatible) + Builtin(BuiltinMediaType), + + /// Custom media type from a plugin + Custom(String), +} + +impl MediaType { + /// Create a new custom media type + pub fn custom(id: impl Into) -> Self { + Self::Custom(id.into()) + } + + /// Get the type ID as a string + #[must_use] + pub fn id(&self) -> String { + match self { + Self::Builtin(b) => b.id().to_string(), + Self::Custom(id) => id.clone(), + } + } + + /// Get the display name for this media type + /// For custom types without a registry, returns the ID as the name + #[must_use] + pub fn name(&self) -> String { + match self { + Self::Builtin(b) => b.name(), + Self::Custom(id) => id.clone(), + } + } + + /// Get the display name for this media type with registry support + #[must_use] + pub fn name_with_registry(&self, registry: &MediaTypeRegistry) -> String { + match self { + Self::Builtin(b) => b.name(), + Self::Custom(id) => { + registry + .get(id) + .map_or_else(|| id.clone(), |d| d.name.clone()) + }, + } + } + + /// Get the category for this media type + /// For custom types without a registry, returns [`MediaCategory::Document`] + /// as default + #[must_use] + pub const fn category(&self) -> MediaCategory { + match self { + Self::Builtin(b) => b.category(), + Self::Custom(_) => MediaCategory::Document, + } + } + + /// Get the category for this media type with registry support + #[must_use] + pub fn category_with_registry( + &self, + registry: &MediaTypeRegistry, + ) -> MediaCategory { + match self { + Self::Builtin(b) => b.category(), + Self::Custom(id) => { + registry + .get(id) + .and_then(|d| d.category) + .unwrap_or(MediaCategory::Document) + }, + } + } + + /// Get the MIME type + /// For custom types without a registry, returns "application/octet-stream" + #[must_use] + pub fn mime_type(&self) -> String { + match self { + Self::Builtin(b) => b.mime_type().to_string(), + Self::Custom(_) => "application/octet-stream".to_string(), + } + } + + /// Get the MIME type with registry support + #[must_use] + pub fn mime_type_with_registry( + &self, + registry: &MediaTypeRegistry, + ) -> String { + match self { + Self::Builtin(b) => b.mime_type().to_string(), + Self::Custom(id) => { + registry + .get(id) + .and_then(|d| d.mime_types.first().cloned()) + .unwrap_or_else(|| "application/octet-stream".to_string()) + }, + } + } + + /// Get file extensions + /// For custom types without a registry, returns an empty vec + #[must_use] + pub fn extensions(&self) -> Vec { + match self { + Self::Builtin(b) => { + b.extensions() + .iter() + .map(std::string::ToString::to_string) + .collect() + }, + Self::Custom(_) => vec![], + } + } + + /// Get file extensions with registry support + #[must_use] + pub fn extensions_with_registry( + &self, + registry: &MediaTypeRegistry, + ) -> Vec { + match self { + Self::Builtin(b) => { + b.extensions() + .iter() + .map(std::string::ToString::to_string) + .collect() + }, + Self::Custom(id) => { + registry + .get(id) + .map(|d| d.extensions.clone()) + .unwrap_or_default() + }, + } + } + + /// Check if this is a RAW image format + #[must_use] + pub const fn is_raw(&self) -> bool { + match self { + Self::Builtin(b) => b.is_raw(), + Self::Custom(_) => false, + } + } + + /// Resolve a media type from file extension (built-in types only) + /// Use `from_extension_with_registry` for custom types + pub fn from_extension(ext: &str) -> Option { + BuiltinMediaType::from_extension(ext).map(Self::Builtin) + } + + /// Resolve a media type from file extension with registry (includes custom + /// types) + #[must_use] + pub fn from_extension_with_registry( + ext: &str, + registry: &MediaTypeRegistry, + ) -> Option { + // Try built-in types first + if let Some(builtin) = BuiltinMediaType::from_extension(ext) { + return Some(Self::Builtin(builtin)); + } + + // Try registered custom types + registry + .get_by_extension(ext) + .map(|desc| Self::Custom(desc.id.clone())) + } + + /// Resolve a media type from file path (built-in types only) + /// Use `from_path_with_registry` for custom types + pub fn from_path(path: &Path) -> Option { + path + .extension() + .and_then(|e| e.to_str()) + .and_then(Self::from_extension) + } + + /// Resolve a media type from file path with registry (includes custom types) + #[must_use] + pub fn from_path_with_registry( + path: &Path, + registry: &MediaTypeRegistry, + ) -> Option { + path + .extension() + .and_then(|e| e.to_str()) + .and_then(|ext| Self::from_extension_with_registry(ext, registry)) + } +} + +// Implement `From` for easier conversion +impl From for MediaType { + fn from(builtin: BuiltinMediaType) -> Self { + Self::Builtin(builtin) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_builtin_media_type() { + let mt = MediaType::Builtin(BuiltinMediaType::Mp3); + + assert_eq!(mt.id(), "mp3"); + assert_eq!(mt.mime_type(), "audio/mpeg"); + assert_eq!(mt.category(), MediaCategory::Audio); + } + + #[test] + fn test_custom_media_type() { + let mut registry = MediaTypeRegistry::new(); + + let descriptor = MediaTypeDescriptor { + id: "heif".to_string(), + name: "HEIF Image".to_string(), + category: Some(MediaCategory::Image), + extensions: vec!["heif".to_string()], + mime_types: vec!["image/heif".to_string()], + plugin_id: Some("heif-plugin".to_string()), + }; + + registry.register(descriptor).unwrap(); + + let mt = MediaType::custom("heif"); + assert_eq!(mt.id(), "heif"); + assert_eq!(mt.mime_type_with_registry(®istry), "image/heif"); + assert_eq!(mt.category_with_registry(®istry), MediaCategory::Image); + } + + #[test] + fn test_from_extension_builtin() { + let registry = MediaTypeRegistry::new(); + let mt = MediaType::from_extension_with_registry("mp3", ®istry); + + assert!(mt.is_some()); + assert_eq!(mt.unwrap(), MediaType::Builtin(BuiltinMediaType::Mp3)); + } + + #[test] + fn test_from_extension_custom() { + let mut registry = MediaTypeRegistry::new(); + + let descriptor = MediaTypeDescriptor { + id: "customformat".to_string(), + name: "Custom Format".to_string(), + category: Some(MediaCategory::Image), + extensions: vec!["xyz".to_string()], + mime_types: vec!["application/x-custom".to_string()], + plugin_id: Some("custom-plugin".to_string()), + }; + + registry.register(descriptor).unwrap(); + + let mt = MediaType::from_extension_with_registry("xyz", ®istry); + assert!(mt.is_some()); + assert_eq!(mt.unwrap(), MediaType::custom("customformat")); + } +} diff --git a/crates/pinakes-types/src/media_type/registry.rs b/crates/pinakes-types/src/media_type/registry.rs new file mode 100644 index 0000000..871f12c --- /dev/null +++ b/crates/pinakes-types/src/media_type/registry.rs @@ -0,0 +1,297 @@ +//! Media type registry for managing both built-in and custom media types + +use anyhow::{Result, anyhow}; +use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; + +use super::MediaCategory; + +/// Descriptor for a media type (built-in or custom) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MediaTypeDescriptor { + /// Unique identifier + pub id: String, + + /// Display name + pub name: String, + + /// Category + pub category: Option, + + /// File extensions + pub extensions: Vec, + + /// MIME types + pub mime_types: Vec, + + /// Plugin that registered this type (None for built-in types) + pub plugin_id: Option, +} + +/// Registry for media types +#[derive(Debug, Clone)] +pub struct MediaTypeRegistry { + /// Map of media type ID to descriptor + types: FxHashMap, + + /// Map of extension to media type ID + extension_map: FxHashMap, +} + +impl MediaTypeRegistry { + /// Create a new empty registry + #[must_use] + pub fn new() -> Self { + Self { + types: FxHashMap::default(), + extension_map: FxHashMap::default(), + } + } + + /// Register a new media type + pub fn register(&mut self, descriptor: MediaTypeDescriptor) -> Result<()> { + // Check if ID is already registered + if self.types.contains_key(&descriptor.id) { + return Err(anyhow!("Media type already registered: {}", descriptor.id)); + } + + // Register extensions + for ext in &descriptor.extensions { + let ext_lower = ext.to_lowercase(); + if self.extension_map.contains_key(&ext_lower) { + // Extension already registered - this is OK, we'll use the first one + // In a more sophisticated system, we might track multiple types per + // extension + continue; + } + self.extension_map.insert(ext_lower, descriptor.id.clone()); + } + + // Register the type + self.types.insert(descriptor.id.clone(), descriptor); + + Ok(()) + } + + /// Unregister a media type + pub fn unregister(&mut self, id: &str) -> Result<()> { + let descriptor = self + .types + .remove(id) + .ok_or_else(|| anyhow!("Media type not found: {id}"))?; + + // Remove extensions + for ext in &descriptor.extensions { + let ext_lower = ext.to_lowercase(); + if self.extension_map.get(&ext_lower) == Some(&descriptor.id) { + self.extension_map.remove(&ext_lower); + } + } + + Ok(()) + } + + /// Get a media type descriptor by ID + #[must_use] + pub fn get(&self, id: &str) -> Option<&MediaTypeDescriptor> { + self.types.get(id) + } + + /// Get a media type by file extension + #[must_use] + pub fn get_by_extension(&self, ext: &str) -> Option<&MediaTypeDescriptor> { + let ext_lower = ext.to_lowercase(); + self + .extension_map + .get(&ext_lower) + .and_then(|id| self.types.get(id)) + } + + /// List all registered media types + #[must_use] + pub fn list_all(&self) -> Vec<&MediaTypeDescriptor> { + self.types.values().collect() + } + + /// List media types from a specific plugin + #[must_use] + pub fn list_by_plugin(&self, plugin_id: &str) -> Vec<&MediaTypeDescriptor> { + self + .types + .values() + .filter(|d| d.plugin_id.as_deref() == Some(plugin_id)) + .collect() + } + + /// List built-in media types (`plugin_id` is None) + #[must_use] + pub fn list_builtin(&self) -> Vec<&MediaTypeDescriptor> { + self + .types + .values() + .filter(|d| d.plugin_id.is_none()) + .collect() + } + + /// Get count of registered types + #[must_use] + pub fn count(&self) -> usize { + self.types.len() + } + + /// Check if a media type is registered + #[must_use] + pub fn contains(&self, id: &str) -> bool { + self.types.contains_key(id) + } + + /// Unregister all types from a specific plugin + pub fn unregister_plugin(&mut self, plugin_id: &str) -> Result { + let type_ids: Vec = self + .types + .values() + .filter(|d| d.plugin_id.as_deref() == Some(plugin_id)) + .map(|d| d.id.clone()) + .collect(); + + let count = type_ids.len(); + + for id in type_ids { + self.unregister(&id)?; + } + + Ok(count) + } +} + +impl Default for MediaTypeRegistry { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_descriptor(id: &str, ext: &str) -> MediaTypeDescriptor { + MediaTypeDescriptor { + id: id.to_string(), + name: format!("{id} Type"), + category: Some(MediaCategory::Document), + extensions: vec![ext.to_string()], + mime_types: vec![format!("application/{}", id)], + plugin_id: Some("test-plugin".to_string()), + } + } + + #[test] + fn test_register_and_get() { + let mut registry = MediaTypeRegistry::new(); + let descriptor = create_test_descriptor("test", "tst"); + + registry.register(descriptor).unwrap(); + + let retrieved = registry.get("test").unwrap(); + assert_eq!(retrieved.id, "test"); + assert_eq!(retrieved.name, "test Type"); + } + + #[test] + fn test_register_duplicate() { + let mut registry = MediaTypeRegistry::new(); + let descriptor = create_test_descriptor("test", "tst"); + + registry.register(descriptor.clone()).unwrap(); + let result = registry.register(descriptor); + + assert!(result.is_err()); + } + + #[test] + fn test_get_by_extension() { + let mut registry = MediaTypeRegistry::new(); + let descriptor = create_test_descriptor("test", "tst"); + + registry.register(descriptor).unwrap(); + + let retrieved = registry.get_by_extension("tst").unwrap(); + assert_eq!(retrieved.id, "test"); + + // Test case insensitivity + let retrieved = registry.get_by_extension("TST").unwrap(); + assert_eq!(retrieved.id, "test"); + } + + #[test] + fn test_unregister() { + let mut registry = MediaTypeRegistry::new(); + let descriptor = create_test_descriptor("test", "tst"); + + registry.register(descriptor).unwrap(); + assert!(registry.contains("test")); + + registry.unregister("test").unwrap(); + assert!(!registry.contains("test")); + + // Extension should also be removed + assert!(registry.get_by_extension("tst").is_none()); + } + + #[test] + fn test_list_by_plugin() { + let mut registry = MediaTypeRegistry::new(); + + let desc1 = MediaTypeDescriptor { + id: "type1".to_string(), + name: "Type 1".to_string(), + category: Some(MediaCategory::Document), + extensions: vec!["t1".to_string()], + mime_types: vec!["application/type1".to_string()], + plugin_id: Some("plugin1".to_string()), + }; + + let desc2 = MediaTypeDescriptor { + id: "type2".to_string(), + name: "Type 2".to_string(), + category: Some(MediaCategory::Document), + extensions: vec!["t2".to_string()], + mime_types: vec!["application/type2".to_string()], + plugin_id: Some("plugin2".to_string()), + }; + + registry.register(desc1).unwrap(); + registry.register(desc2).unwrap(); + + let plugin1_types = registry.list_by_plugin("plugin1"); + assert_eq!(plugin1_types.len(), 1); + assert_eq!(plugin1_types[0].id, "type1"); + + let plugin2_types = registry.list_by_plugin("plugin2"); + assert_eq!(plugin2_types.len(), 1); + assert_eq!(plugin2_types[0].id, "type2"); + } + + #[test] + fn test_unregister_plugin() { + let mut registry = MediaTypeRegistry::new(); + + for i in 1..=3 { + let desc = MediaTypeDescriptor { + id: format!("type{i}"), + name: format!("Type {i}"), + category: Some(MediaCategory::Document), + extensions: vec![format!("t{}", i)], + mime_types: vec![format!("application/type{}", i)], + plugin_id: Some("test-plugin".to_string()), + }; + registry.register(desc).unwrap(); + } + + assert_eq!(registry.count(), 3); + + let removed = registry.unregister_plugin("test-plugin").unwrap(); + assert_eq!(removed, 3); + assert_eq!(registry.count(), 0); + } +} diff --git a/crates/pinakes-types/src/model.rs b/crates/pinakes-types/src/model.rs new file mode 100644 index 0000000..f2f2863 --- /dev/null +++ b/crates/pinakes-types/src/model.rs @@ -0,0 +1,688 @@ +use std::{fmt, path::PathBuf}; + +use chrono::{DateTime, Utc}; +use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::media_type::MediaType; + +/// Unique identifier for a user account. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct UserId(pub Uuid); + +impl UserId { + #[must_use] + pub fn new() -> Self { + Self(Uuid::now_v7()) + } +} + +impl Default for UserId { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Display for UserId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for UserId { + fn from(id: Uuid) -> Self { + Self(id) + } +} + +/// Unique identifier for a media item. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct MediaId(pub Uuid); + +impl MediaId { + /// Creates a new media ID using `UUIDv7`. + #[must_use] + pub fn new() -> Self { + Self(Uuid::now_v7()) + } +} + +impl fmt::Display for MediaId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Default for MediaId { + fn default() -> Self { + Self(uuid::Uuid::nil()) + } +} + +/// BLAKE3 content hash for deduplication. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ContentHash(pub String); + +impl ContentHash { + /// Creates a new content hash from a hex string. + #[must_use] + pub const fn new(hex: String) -> Self { + Self(hex) + } +} + +impl fmt::Display for ContentHash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Storage mode for media items +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, +)] +#[serde(rename_all = "lowercase")] +pub enum StorageMode { + /// File exists on disk, referenced by path + #[default] + External, + /// File is stored in managed content-addressable storage + Managed, +} + +impl fmt::Display for StorageMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::External => write!(f, "external"), + Self::Managed => write!(f, "managed"), + } + } +} + +impl std::str::FromStr for StorageMode { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "external" => Ok(Self::External), + "managed" => Ok(Self::Managed), + _ => Err(format!("unknown storage mode: {s}")), + } + } +} + +/// A blob stored in managed storage (content-addressable) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManagedBlob { + pub content_hash: ContentHash, + pub file_size: u64, + pub mime_type: String, + pub reference_count: u32, + pub stored_at: DateTime, + pub last_verified: Option>, +} + +/// Result of uploading a file to managed storage +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UploadResult { + pub media_id: MediaId, + pub content_hash: ContentHash, + pub was_duplicate: bool, + pub file_size: u64, +} + +/// Statistics about managed storage +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ManagedStorageStats { + pub total_blobs: u64, + pub total_size_bytes: u64, + pub unique_size_bytes: u64, + pub deduplication_ratio: f64, + pub managed_media_count: u64, + pub orphaned_blobs: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MediaItem { + pub id: MediaId, + pub path: PathBuf, + pub file_name: String, + pub media_type: MediaType, + pub content_hash: ContentHash, + pub file_size: u64, + pub title: Option, + pub artist: Option, + pub album: Option, + pub genre: Option, + pub year: Option, + pub duration_secs: Option, + pub description: Option, + pub thumbnail_path: Option, + pub custom_fields: FxHashMap, + /// File modification time (Unix timestamp in seconds), used for incremental + /// scanning + pub file_mtime: Option, + + // Photo-specific metadata + pub date_taken: Option>, + pub latitude: Option, + pub longitude: Option, + pub camera_make: Option, + pub camera_model: Option, + pub rating: Option, + pub perceptual_hash: Option, + + // Managed storage fields + /// How the file is stored (external on disk or managed in + /// content-addressable storage) + #[serde(default)] + pub storage_mode: StorageMode, + /// Original filename for uploaded files (preserved separately from + /// `file_name`) + pub original_filename: Option, + /// When the file was uploaded to managed storage + pub uploaded_at: Option>, + /// Storage key for looking up the blob (usually same as `content_hash`) + pub storage_key: Option, + + pub created_at: DateTime, + pub updated_at: DateTime, + + /// Soft delete timestamp. If set, the item is in the trash. + pub deleted_at: Option>, + + /// When markdown links were last extracted from this file. + pub links_extracted_at: Option>, +} + +/// A custom field attached to a media item. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CustomField { + pub field_type: CustomFieldType, + pub value: String, +} + +/// Type of custom field value. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CustomFieldType { + Text, + Number, + Date, + Boolean, +} + +impl CustomFieldType { + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + Self::Text => "text", + Self::Number => "number", + Self::Date => "date", + Self::Boolean => "boolean", + } + } +} + +impl std::fmt::Display for CustomFieldType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +/// A tag that can be applied to media items. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tag { + pub id: Uuid, + pub name: String, + pub parent_id: Option, + pub created_at: DateTime, +} + +/// A collection of media items. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Collection { + pub id: Uuid, + pub name: String, + pub description: Option, + pub kind: CollectionKind, + pub filter_query: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// Kind of collection. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CollectionKind { + Manual, + Virtual, +} + +impl CollectionKind { + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + Self::Manual => "manual", + Self::Virtual => "virtual", + } + } +} + +impl std::fmt::Display for CollectionKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +/// A member of a collection with position tracking. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CollectionMember { + pub collection_id: Uuid, + pub media_id: MediaId, + pub position: i32, + pub added_at: DateTime, +} + +/// An audit trail entry. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditEntry { + pub id: Uuid, + pub media_id: Option, + pub action: AuditAction, + pub details: Option, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AuditAction { + // Media actions + Imported, + Updated, + Deleted, + Tagged, + Untagged, + AddedToCollection, + RemovedFromCollection, + Opened, + Scanned, + + // Authentication actions + LoginSuccess, + LoginFailed, + Logout, + SessionExpired, + + // Authorization actions + PermissionDenied, + RoleChanged, + LibraryAccessGranted, + LibraryAccessRevoked, + + // User management + UserCreated, + UserUpdated, + UserDeleted, + + // Plugin actions + PluginInstalled, + PluginUninstalled, + PluginEnabled, + PluginDisabled, + + // Configuration actions + ConfigChanged, + RootDirectoryAdded, + RootDirectoryRemoved, + + // Social/Sharing actions + ShareLinkCreated, + ShareLinkAccessed, + + // System actions + DatabaseVacuumed, + DatabaseCleared, + ExportCompleted, + IntegrityCheckCompleted, +} + +impl fmt::Display for AuditAction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + // Media actions + Self::Imported => "imported", + Self::Updated => "updated", + Self::Deleted => "deleted", + Self::Tagged => "tagged", + Self::Untagged => "untagged", + Self::AddedToCollection => "added_to_collection", + Self::RemovedFromCollection => "removed_from_collection", + Self::Opened => "opened", + Self::Scanned => "scanned", + + // Authentication actions + Self::LoginSuccess => "login_success", + Self::LoginFailed => "login_failed", + Self::Logout => "logout", + Self::SessionExpired => "session_expired", + + // Authorization actions + Self::PermissionDenied => "permission_denied", + Self::RoleChanged => "role_changed", + Self::LibraryAccessGranted => "library_access_granted", + Self::LibraryAccessRevoked => "library_access_revoked", + + // User management + Self::UserCreated => "user_created", + Self::UserUpdated => "user_updated", + Self::UserDeleted => "user_deleted", + + // Plugin actions + Self::PluginInstalled => "plugin_installed", + Self::PluginUninstalled => "plugin_uninstalled", + Self::PluginEnabled => "plugin_enabled", + Self::PluginDisabled => "plugin_disabled", + + // Configuration actions + Self::ConfigChanged => "config_changed", + Self::RootDirectoryAdded => "root_directory_added", + Self::RootDirectoryRemoved => "root_directory_removed", + + // Social/Sharing actions + Self::ShareLinkCreated => "share_link_created", + Self::ShareLinkAccessed => "share_link_accessed", + + // System actions + Self::DatabaseVacuumed => "database_vacuumed", + Self::DatabaseCleared => "database_cleared", + Self::ExportCompleted => "export_completed", + Self::IntegrityCheckCompleted => "integrity_check_completed", + }; + write!(f, "{s}") + } +} + +/// Pagination parameters for list queries. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Pagination { + pub offset: u64, + pub limit: u64, + pub sort: Option, +} + +impl Pagination { + /// Creates a new pagination instance. + #[must_use] + pub const fn new(offset: u64, limit: u64, sort: Option) -> Self { + Self { + offset, + limit, + sort, + } + } +} + +impl Default for Pagination { + fn default() -> Self { + Self { + offset: 0, + limit: 50, + sort: None, + } + } +} + +/// A saved search query. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SavedSearch { + pub id: Uuid, + pub name: String, + pub query: String, + pub sort_order: Option, + pub created_at: DateTime, +} + +// Book Management Types + +/// Metadata for book-type media. +/// +/// Used both as a DB record (with populated `media_id`, `created_at`, +/// `updated_at`) and as an extraction result (with placeholder values for +/// those fields when the record has not yet been persisted). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BookMetadata { + pub media_id: MediaId, + pub isbn: Option, + pub isbn13: Option, + pub publisher: Option, + pub language: Option, + pub page_count: Option, + pub publication_date: Option, + pub series_name: Option, + pub series_index: Option, + pub format: Option, + pub authors: Vec, + pub identifiers: FxHashMap>, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl Default for BookMetadata { + fn default() -> Self { + let now = Utc::now(); + Self { + media_id: MediaId(uuid::Uuid::nil()), + isbn: None, + isbn13: None, + publisher: None, + language: None, + page_count: None, + publication_date: None, + series_name: None, + series_index: None, + format: None, + authors: Vec::new(), + identifiers: FxHashMap::default(), + created_at: now, + updated_at: now, + } + } +} + +/// Information about a book author. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct AuthorInfo { + pub name: String, + pub role: String, + pub file_as: Option, + pub position: i32, +} + +impl AuthorInfo { + /// Creates a new author with the given name. + #[must_use] + pub fn new(name: String) -> Self { + Self { + name, + role: "author".to_string(), + file_as: None, + position: 0, + } + } + + /// Sets the author's role. + #[must_use] + pub fn with_role(mut self, role: String) -> Self { + self.role = role; + self + } + + #[must_use] + pub fn with_file_as(mut self, file_as: String) -> Self { + self.file_as = Some(file_as); + self + } + + #[must_use] + pub const fn with_position(mut self, position: i32) -> Self { + self.position = position; + self + } +} + +/// Reading progress for a book. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadingProgress { + pub media_id: MediaId, + pub user_id: Uuid, + pub current_page: i32, + pub total_pages: Option, + pub progress_percent: f64, + pub last_read_at: DateTime, +} + +impl ReadingProgress { + /// Creates a new reading progress entry. + #[must_use] + pub fn new( + media_id: MediaId, + user_id: Uuid, + current_page: i32, + total_pages: Option, + ) -> Self { + let progress_percent = total_pages.map_or(0.0, |total| { + if total > 0 { + (f64::from(current_page) / f64::from(total) * 100.0).min(100.0) + } else { + 0.0 + } + }); + + Self { + media_id, + user_id, + current_page, + total_pages, + progress_percent, + last_read_at: Utc::now(), + } + } +} + +/// Reading status for a book. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ReadingStatus { + ToRead, + Reading, + Completed, + Abandoned, +} + +impl fmt::Display for ReadingStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ToRead => write!(f, "to_read"), + Self::Reading => write!(f, "reading"), + Self::Completed => write!(f, "completed"), + Self::Abandoned => write!(f, "abandoned"), + } + } +} + +/// Type of markdown link +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LinkType { + /// Wikilink: [[target]] or [[target|display]] + Wikilink, + /// Markdown link: [text](path) + MarkdownLink, + /// Embed: ![[target]] + Embed, +} + +impl fmt::Display for LinkType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Wikilink => write!(f, "wikilink"), + Self::MarkdownLink => write!(f, "markdown_link"), + Self::Embed => write!(f, "embed"), + } + } +} + +impl std::str::FromStr for LinkType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "wikilink" => Ok(Self::Wikilink), + "markdown_link" => Ok(Self::MarkdownLink), + "embed" => Ok(Self::Embed), + _ => Err(format!("unknown link type: {s}")), + } + } +} + +/// A markdown link extracted from a file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MarkdownLink { + pub id: Uuid, + pub source_media_id: MediaId, + /// Raw link target as written in the source (wikilink name or path) + pub target_path: String, + /// Resolved target `media_id` (None if unresolved) + pub target_media_id: Option, + pub link_type: LinkType, + /// Display text for the link + pub link_text: Option, + /// Line number in source file (1-indexed) + pub line_number: Option, + /// Surrounding text for backlink preview + pub context: Option, + pub created_at: DateTime, +} + +/// Information about a backlink (incoming link). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BacklinkInfo { + pub link_id: Uuid, + pub source_id: MediaId, + pub source_title: Option, + pub source_path: String, + pub link_text: Option, + pub line_number: Option, + pub context: Option, + pub link_type: LinkType, +} + +/// Graph data for visualization. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct GraphData { + pub nodes: Vec, + pub edges: Vec, +} + +/// A node in the graph visualization. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GraphNode { + pub id: String, + pub label: String, + pub title: Option, + pub media_type: String, + /// Number of outgoing links from this node + pub link_count: u32, + /// Number of incoming links to this node + pub backlink_count: u32, +} + +/// An edge (link) in the graph visualization. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GraphEdge { + pub source: String, + pub target: String, + pub link_type: LinkType, +}