pinakes-server: TLS support; session persistence and security polish
Signed-off-by: NotAShelf <raf@notashelf.dev> Change-Id: If2c9c3e3af62bbf9f33a97be89ac40bc6a6a6964
This commit is contained in:
parent
758aba0f7a
commit
87a4482576
19 changed files with 1835 additions and 111 deletions
|
|
@ -2,6 +2,78 @@ use std::path::{Path, PathBuf};
|
|||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Expand environment variables in a string.
|
||||
/// Supports both ${VAR_NAME} and $VAR_NAME syntax.
|
||||
/// Returns an error if a referenced variable is not set.
|
||||
fn expand_env_var_string(input: &str) -> crate::error::Result<String> {
|
||||
let mut result = String::new();
|
||||
let mut chars = input.chars().peekable();
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
if ch == '$' {
|
||||
// Check if it's ${VAR} or $VAR syntax
|
||||
let use_braces = chars.peek() == Some(&'{');
|
||||
if use_braces {
|
||||
chars.next(); // consume '{'
|
||||
}
|
||||
|
||||
// Collect variable name
|
||||
let mut var_name = String::new();
|
||||
while let Some(&next_ch) = chars.peek() {
|
||||
if use_braces {
|
||||
if next_ch == '}' {
|
||||
chars.next(); // consume '}'
|
||||
break;
|
||||
}
|
||||
var_name.push(next_ch);
|
||||
chars.next();
|
||||
} else {
|
||||
// For $VAR syntax, stop at non-alphanumeric/underscore
|
||||
if next_ch.is_alphanumeric() || next_ch == '_' {
|
||||
var_name.push(next_ch);
|
||||
chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if var_name.is_empty() {
|
||||
return Err(crate::error::PinakesError::Config(
|
||||
"empty environment variable name".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Look up the environment variable
|
||||
match std::env::var(&var_name) {
|
||||
Ok(value) => result.push_str(&value),
|
||||
Err(_) => {
|
||||
return Err(crate::error::PinakesError::Config(format!(
|
||||
"environment variable not set: {}",
|
||||
var_name
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else if ch == '\\' {
|
||||
// Handle escaped characters
|
||||
if let Some(&next_ch) = chars.peek() {
|
||||
if next_ch == '$' {
|
||||
chars.next(); // consume the escaped $
|
||||
result.push('$');
|
||||
} else {
|
||||
result.push(ch);
|
||||
}
|
||||
} else {
|
||||
result.push(ch);
|
||||
}
|
||||
} else {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub storage: StorageConfig,
|
||||
|
|
@ -456,6 +528,15 @@ pub struct PostgresConfig {
|
|||
pub username: String,
|
||||
pub password: String,
|
||||
pub max_connections: usize,
|
||||
/// Enable TLS for PostgreSQL connections
|
||||
#[serde(default)]
|
||||
pub tls_enabled: bool,
|
||||
/// Verify TLS certificates (default: true)
|
||||
#[serde(default = "default_true")]
|
||||
pub tls_verify_ca: bool,
|
||||
/// Path to custom CA certificate file (PEM format)
|
||||
#[serde(default)]
|
||||
pub tls_ca_cert_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -484,6 +565,11 @@ pub struct ServerConfig {
|
|||
/// If set, all requests (except /health) must include `Authorization: Bearer <key>`.
|
||||
/// Can also be set via `PINAKES_API_KEY` environment variable.
|
||||
pub api_key: Option<String>,
|
||||
/// Explicitly disable authentication (INSECURE - use only for development).
|
||||
/// When true, all requests are allowed without authentication.
|
||||
/// This must be explicitly set to true; empty api_key alone is not sufficient.
|
||||
#[serde(default)]
|
||||
pub authentication_disabled: bool,
|
||||
/// TLS/HTTPS configuration
|
||||
#[serde(default)]
|
||||
pub tls: TlsConfig,
|
||||
|
|
@ -570,8 +656,45 @@ impl Config {
|
|||
let content = std::fs::read_to_string(path).map_err(|e| {
|
||||
crate::error::PinakesError::Config(format!("failed to read config file: {e}"))
|
||||
})?;
|
||||
toml::from_str(&content)
|
||||
.map_err(|e| crate::error::PinakesError::Config(format!("failed to parse config: {e}")))
|
||||
let mut config: Self = toml::from_str(&content).map_err(|e| {
|
||||
crate::error::PinakesError::Config(format!("failed to parse config: {e}"))
|
||||
})?;
|
||||
config.expand_env_vars()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Expand environment variables in secret fields.
|
||||
/// Supports ${VAR_NAME} and $VAR_NAME syntax.
|
||||
fn expand_env_vars(&mut self) -> crate::error::Result<()> {
|
||||
// Postgres password
|
||||
if let Some(ref mut postgres) = self.storage.postgres {
|
||||
postgres.password = expand_env_var_string(&postgres.password)?;
|
||||
}
|
||||
|
||||
// Server API key
|
||||
if let Some(ref api_key) = self.server.api_key {
|
||||
self.server.api_key = Some(expand_env_var_string(api_key)?);
|
||||
}
|
||||
|
||||
// Webhook secrets
|
||||
for webhook in &mut self.webhooks {
|
||||
if let Some(ref secret) = webhook.secret {
|
||||
webhook.secret = Some(expand_env_var_string(secret)?);
|
||||
}
|
||||
}
|
||||
|
||||
// Enrichment API keys
|
||||
if let Some(ref api_key) = self.enrichment.sources.musicbrainz.api_key {
|
||||
self.enrichment.sources.musicbrainz.api_key = Some(expand_env_var_string(api_key)?);
|
||||
}
|
||||
if let Some(ref api_key) = self.enrichment.sources.tmdb.api_key {
|
||||
self.enrichment.sources.tmdb.api_key = Some(expand_env_var_string(api_key)?);
|
||||
}
|
||||
if let Some(ref api_key) = self.enrichment.sources.lastfm.api_key {
|
||||
self.enrichment.sources.lastfm.api_key = Some(expand_env_var_string(api_key)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Try loading from file, falling back to defaults if the file doesn't exist.
|
||||
|
|
@ -643,6 +766,50 @@ impl Config {
|
|||
if self.scanning.import_concurrency == 0 || self.scanning.import_concurrency > 256 {
|
||||
return Err("import_concurrency must be between 1 and 256".into());
|
||||
}
|
||||
|
||||
// Validate authentication configuration
|
||||
let has_api_key = self
|
||||
.server
|
||||
.api_key
|
||||
.as_ref()
|
||||
.map_or(false, |k| !k.is_empty());
|
||||
let has_accounts = !self.accounts.users.is_empty();
|
||||
let auth_disabled = self.server.authentication_disabled;
|
||||
|
||||
if !auth_disabled && !has_api_key && !has_accounts {
|
||||
return Err(
|
||||
"authentication is not configured: set an api_key, configure user accounts, \
|
||||
or explicitly set authentication_disabled = true"
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Empty API key is not allowed (must use authentication_disabled flag)
|
||||
if let Some(ref api_key) = self.server.api_key {
|
||||
if api_key.is_empty() {
|
||||
return Err("empty api_key is not allowed. To disable authentication, \
|
||||
set authentication_disabled = true instead"
|
||||
.into());
|
||||
}
|
||||
}
|
||||
|
||||
// Require TLS when authentication is enabled on non-localhost
|
||||
let is_localhost = self.server.host == "127.0.0.1"
|
||||
|| self.server.host == "localhost"
|
||||
|| self.server.host == "::1";
|
||||
|
||||
if (has_api_key || has_accounts)
|
||||
&& !auth_disabled
|
||||
&& !is_localhost
|
||||
&& !self.server.tls.enabled
|
||||
{
|
||||
return Err(
|
||||
"TLS must be enabled when authentication is used on non-localhost hosts. \
|
||||
Set server.tls.enabled = true or bind to localhost only"
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Validate TLS configuration
|
||||
self.server.tls.validate()?;
|
||||
Ok(())
|
||||
|
|
@ -690,6 +857,7 @@ impl Default for Config {
|
|||
host: "127.0.0.1".to_string(),
|
||||
port: 3000,
|
||||
api_key: None,
|
||||
authentication_disabled: false,
|
||||
tls: TlsConfig::default(),
|
||||
},
|
||||
ui: UiConfig::default(),
|
||||
|
|
@ -714,6 +882,7 @@ mod tests {
|
|||
fn test_config_with_concurrency(concurrency: usize) -> Config {
|
||||
let mut config = Config::default();
|
||||
config.scanning.import_concurrency = concurrency;
|
||||
config.server.authentication_disabled = true; // Disable auth for concurrency tests
|
||||
config
|
||||
}
|
||||
|
||||
|
|
@ -758,4 +927,125 @@ mod tests {
|
|||
let config = test_config_with_concurrency(256);
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
// Environment variable expansion tests
|
||||
#[test]
|
||||
fn test_expand_env_var_simple() {
|
||||
unsafe {
|
||||
std::env::set_var("TEST_VAR_SIMPLE", "test_value");
|
||||
}
|
||||
let result = expand_env_var_string("$TEST_VAR_SIMPLE");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "test_value");
|
||||
unsafe {
|
||||
std::env::remove_var("TEST_VAR_SIMPLE");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_braces() {
|
||||
unsafe {
|
||||
std::env::set_var("TEST_VAR_BRACES", "test_value");
|
||||
}
|
||||
let result = expand_env_var_string("${TEST_VAR_BRACES}");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "test_value");
|
||||
unsafe {
|
||||
std::env::remove_var("TEST_VAR_BRACES");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_embedded() {
|
||||
unsafe {
|
||||
std::env::set_var("TEST_VAR_EMBEDDED", "value");
|
||||
}
|
||||
let result = expand_env_var_string("prefix_${TEST_VAR_EMBEDDED}_suffix");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "prefix_value_suffix");
|
||||
unsafe {
|
||||
std::env::remove_var("TEST_VAR_EMBEDDED");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_multiple() {
|
||||
unsafe {
|
||||
std::env::set_var("VAR1", "value1");
|
||||
std::env::set_var("VAR2", "value2");
|
||||
}
|
||||
let result = expand_env_var_string("${VAR1}_${VAR2}");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "value1_value2");
|
||||
unsafe {
|
||||
std::env::remove_var("VAR1");
|
||||
std::env::remove_var("VAR2");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_missing() {
|
||||
let result = expand_env_var_string("${NONEXISTENT_VAR}");
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("environment variable not set")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_empty_name() {
|
||||
let result = expand_env_var_string("${}");
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("empty environment variable name")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_escaped() {
|
||||
let result = expand_env_var_string("\\$NOT_A_VAR");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "$NOT_A_VAR");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_no_vars() {
|
||||
let result = expand_env_var_string("plain_text");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "plain_text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_underscore() {
|
||||
unsafe {
|
||||
std::env::set_var("TEST_VAR_NAME", "value");
|
||||
}
|
||||
let result = expand_env_var_string("$TEST_VAR_NAME");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "value");
|
||||
unsafe {
|
||||
std::env::remove_var("TEST_VAR_NAME");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_env_var_mixed_syntax() {
|
||||
unsafe {
|
||||
std::env::set_var("VAR1_MIXED", "v1");
|
||||
std::env::set_var("VAR2_MIXED", "v2");
|
||||
}
|
||||
let result = expand_env_var_string("$VAR1_MIXED and ${VAR2_MIXED}");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "v1 and v2");
|
||||
unsafe {
|
||||
std::env::remove_var("VAR1_MIXED");
|
||||
std::env::remove_var("VAR2_MIXED");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -5,6 +6,7 @@ use tracing::{info, warn};
|
|||
|
||||
use crate::error::Result;
|
||||
use crate::hash::compute_file_hash;
|
||||
use crate::media_type::MediaType;
|
||||
use crate::model::{ContentHash, MediaId};
|
||||
use crate::storage::DynStorageBackend;
|
||||
|
||||
|
|
@ -66,31 +68,202 @@ impl std::str::FromStr for IntegrityStatus {
|
|||
}
|
||||
}
|
||||
|
||||
/// Detect orphaned media items (files that no longer exist on disk).
|
||||
/// Detect orphaned media items (files that no longer exist on disk),
|
||||
/// untracked files (files on disk not in database), and moved files (same hash, different path).
|
||||
pub async fn detect_orphans(storage: &DynStorageBackend) -> Result<OrphanReport> {
|
||||
let media_paths = storage.list_media_paths().await?;
|
||||
let mut orphaned_ids = Vec::new();
|
||||
let moved_files = Vec::new();
|
||||
|
||||
// Build hash index: ContentHash -> Vec<(MediaId, PathBuf)>
|
||||
let mut hash_index: HashMap<ContentHash, Vec<(MediaId, PathBuf)>> = HashMap::new();
|
||||
for (id, path, hash) in &media_paths {
|
||||
hash_index
|
||||
.entry(hash.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((*id, path.clone()));
|
||||
}
|
||||
|
||||
// Detect orphaned files (in DB but not on disk)
|
||||
for (id, path, _hash) in &media_paths {
|
||||
if !path.exists() {
|
||||
orphaned_ids.push(*id);
|
||||
}
|
||||
}
|
||||
|
||||
// Detect moved files (orphaned items with same hash existing elsewhere)
|
||||
let moved_files = detect_moved_files(&orphaned_ids, &media_paths, &hash_index);
|
||||
|
||||
// Detect untracked files (on disk but not in DB)
|
||||
let untracked_paths = detect_untracked_files(storage, &media_paths).await?;
|
||||
|
||||
info!(
|
||||
orphaned = orphaned_ids.len(),
|
||||
untracked = untracked_paths.len(),
|
||||
moved = moved_files.len(),
|
||||
total = media_paths.len(),
|
||||
"orphan detection complete"
|
||||
);
|
||||
|
||||
Ok(OrphanReport {
|
||||
orphaned_ids,
|
||||
untracked_paths: Vec::new(),
|
||||
untracked_paths,
|
||||
moved_files,
|
||||
})
|
||||
}
|
||||
|
||||
/// Detect files that appear to have moved (same content hash, different path).
|
||||
fn detect_moved_files(
|
||||
orphaned_ids: &[MediaId],
|
||||
media_paths: &[(MediaId, PathBuf, ContentHash)],
|
||||
hash_index: &HashMap<ContentHash, Vec<(MediaId, PathBuf)>>,
|
||||
) -> Vec<(MediaId, PathBuf, PathBuf)> {
|
||||
let mut moved = Vec::new();
|
||||
|
||||
// Build lookup map for orphaned items: MediaId -> (PathBuf, ContentHash)
|
||||
let orphaned_map: HashMap<MediaId, (PathBuf, ContentHash)> = media_paths
|
||||
.iter()
|
||||
.filter(|(id, _, _)| orphaned_ids.contains(id))
|
||||
.map(|(id, path, hash)| (*id, (path.clone(), hash.clone())))
|
||||
.collect();
|
||||
|
||||
// For each orphaned item, check if there's another file with the same hash
|
||||
for (orphaned_id, (old_path, hash)) in &orphaned_map {
|
||||
if let Some(items_with_hash) = hash_index.get(hash) {
|
||||
// Find other items with same hash that exist on disk
|
||||
for (other_id, new_path) in items_with_hash {
|
||||
// Skip if it's the same item
|
||||
if other_id == orphaned_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if the new path exists
|
||||
if new_path.exists() {
|
||||
moved.push((*orphaned_id, old_path.clone(), new_path.clone()));
|
||||
// Only report first match (most likely candidate)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
moved
|
||||
}
|
||||
|
||||
/// Detect files on disk that are not tracked in the database.
|
||||
async fn detect_untracked_files(
|
||||
storage: &DynStorageBackend,
|
||||
media_paths: &[(MediaId, PathBuf, ContentHash)],
|
||||
) -> Result<Vec<PathBuf>> {
|
||||
// Get root directories
|
||||
let roots = storage.list_root_dirs().await?;
|
||||
if roots.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Build set of tracked paths for fast lookup
|
||||
let tracked_paths: HashSet<PathBuf> = media_paths
|
||||
.iter()
|
||||
.map(|(_, path, _)| path.clone())
|
||||
.collect();
|
||||
|
||||
// Get ignore patterns (we'll need to load config somehow, for now use empty)
|
||||
let ignore_patterns: Vec<String> = vec![
|
||||
".*".to_string(),
|
||||
"node_modules".to_string(),
|
||||
"__pycache__".to_string(),
|
||||
"target".to_string(),
|
||||
];
|
||||
|
||||
// Walk filesystem for each root in parallel (limit concurrency to 4)
|
||||
let mut filesystem_paths = HashSet::new();
|
||||
let mut tasks = tokio::task::JoinSet::new();
|
||||
|
||||
for root in roots {
|
||||
let ignore_patterns = ignore_patterns.clone();
|
||||
tasks.spawn_blocking(move || -> Result<Vec<PathBuf>> {
|
||||
let mut paths = Vec::new();
|
||||
|
||||
let walker = walkdir::WalkDir::new(&root)
|
||||
.follow_links(false)
|
||||
.into_iter()
|
||||
.filter_entry(|e| {
|
||||
// Skip directories that match ignore patterns
|
||||
if e.file_type().is_dir() {
|
||||
let name = e.file_name().to_string_lossy();
|
||||
for pattern in &ignore_patterns {
|
||||
if pattern.starts_with("*.") {
|
||||
// Extension pattern
|
||||
if let Some(ext) = pattern.strip_prefix("*.") {
|
||||
if name.ends_with(ext) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if pattern.contains('*') {
|
||||
// Glob pattern - simplified matching
|
||||
let pattern_without_stars = pattern.replace('*', "");
|
||||
if name.contains(&pattern_without_stars) {
|
||||
return false;
|
||||
}
|
||||
} else if name.as_ref() == pattern
|
||||
|| name.starts_with(&format!("{pattern}."))
|
||||
{
|
||||
// Exact match or starts with pattern
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
});
|
||||
|
||||
for entry in walker {
|
||||
match entry {
|
||||
Ok(entry) => {
|
||||
let path = entry.path();
|
||||
|
||||
// Only process files
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if it's a supported media type
|
||||
if MediaType::from_path(path).is_some() {
|
||||
paths.push(path.to_path_buf());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "failed to read directory entry");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(paths)
|
||||
});
|
||||
}
|
||||
|
||||
// Collect results from all tasks
|
||||
while let Some(result) = tasks.join_next().await {
|
||||
match result {
|
||||
Ok(Ok(paths)) => {
|
||||
filesystem_paths.extend(paths);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
warn!(error = %e, "failed to walk directory");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "task join error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute set difference: filesystem - tracked
|
||||
let untracked: Vec<PathBuf> = filesystem_paths
|
||||
.difference(&tracked_paths)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
Ok(untracked)
|
||||
}
|
||||
|
||||
/// Resolve orphaned media items by deleting them from the database.
|
||||
pub async fn resolve_orphans(
|
||||
storage: &DynStorageBackend,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,18 @@ pub struct DatabaseStats {
|
|||
pub backend_name: String,
|
||||
}
|
||||
|
||||
/// Session data for database-backed session storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionData {
|
||||
pub session_token: String,
|
||||
pub user_id: Option<String>,
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub last_accessed: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait StorageBackend: Send + Sync + 'static {
|
||||
// Migrations
|
||||
|
|
@ -412,6 +424,28 @@ pub trait StorageBackend: Send + Sync + 'static {
|
|||
progress: f32,
|
||||
) -> Result<()>;
|
||||
async fn cleanup_expired_transcodes(&self, before: DateTime<Utc>) -> Result<u64>;
|
||||
|
||||
// ===== Session Management =====
|
||||
/// Create a new session in the database
|
||||
async fn create_session(&self, session: &SessionData) -> Result<()>;
|
||||
|
||||
/// Get a session by its token, returns None if not found or expired
|
||||
async fn get_session(&self, session_token: &str) -> Result<Option<SessionData>>;
|
||||
|
||||
/// Update the last_accessed timestamp for a session
|
||||
async fn touch_session(&self, session_token: &str) -> Result<()>;
|
||||
|
||||
/// Delete a specific session
|
||||
async fn delete_session(&self, session_token: &str) -> Result<()>;
|
||||
|
||||
/// Delete all sessions for a specific user
|
||||
async fn delete_user_sessions(&self, username: &str) -> Result<u64>;
|
||||
|
||||
/// Delete all expired sessions (where expires_at < now)
|
||||
async fn delete_expired_sessions(&self) -> Result<u64>;
|
||||
|
||||
/// List all active sessions (optionally filtered by username)
|
||||
async fn list_active_sessions(&self, username: Option<&str>) -> Result<Vec<SessionData>>;
|
||||
}
|
||||
|
||||
/// Comprehensive library statistics.
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ use std::path::PathBuf;
|
|||
|
||||
use chrono::Utc;
|
||||
use deadpool_postgres::{Config as PoolConfig, Pool, Runtime};
|
||||
use native_tls::TlsConnector;
|
||||
use postgres_native_tls::MakeTlsConnector;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::{NoTls, Row};
|
||||
use uuid::Uuid;
|
||||
|
|
@ -27,19 +29,72 @@ impl PostgresBackend {
|
|||
pool_config.user = Some(config.username.clone());
|
||||
pool_config.password = Some(config.password.clone());
|
||||
|
||||
let pool = pool_config
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.map_err(|e| {
|
||||
PinakesError::Database(format!("failed to create connection pool: {e}"))
|
||||
if config.tls_enabled {
|
||||
// Build TLS connector
|
||||
let mut tls_builder = TlsConnector::builder();
|
||||
|
||||
// Load custom CA certificate if provided
|
||||
if let Some(ref ca_cert_path) = config.tls_ca_cert_path {
|
||||
let cert_bytes = std::fs::read(ca_cert_path).map_err(|e| {
|
||||
PinakesError::Config(format!(
|
||||
"failed to read CA certificate file {}: {e}",
|
||||
ca_cert_path.display()
|
||||
))
|
||||
})?;
|
||||
let cert = native_tls::Certificate::from_pem(&cert_bytes).map_err(|e| {
|
||||
PinakesError::Config(format!(
|
||||
"failed to parse CA certificate {}: {e}",
|
||||
ca_cert_path.display()
|
||||
))
|
||||
})?;
|
||||
tls_builder.add_root_certificate(cert);
|
||||
}
|
||||
|
||||
// Configure certificate validation
|
||||
if !config.tls_verify_ca {
|
||||
tracing::warn!(
|
||||
"PostgreSQL TLS certificate verification disabled - this is insecure!"
|
||||
);
|
||||
tls_builder.danger_accept_invalid_certs(true);
|
||||
}
|
||||
|
||||
let connector = tls_builder.build().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to build TLS connector: {e}"))
|
||||
})?;
|
||||
let tls = MakeTlsConnector::new(connector);
|
||||
|
||||
let pool = pool_config
|
||||
.create_pool(Some(Runtime::Tokio1), tls)
|
||||
.map_err(|e| {
|
||||
PinakesError::Database(format!("failed to create connection pool: {e}"))
|
||||
})?;
|
||||
|
||||
// Verify connectivity
|
||||
let _ = pool.get().await.map_err(|e| {
|
||||
PinakesError::Database(format!("failed to connect to postgres: {e}"))
|
||||
})?;
|
||||
|
||||
// Verify connectivity
|
||||
let _ = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("failed to connect to postgres: {e}")))?;
|
||||
tracing::info!("PostgreSQL connection established with TLS");
|
||||
Ok(Self { pool })
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"PostgreSQL TLS is disabled - connection is unencrypted. \
|
||||
Set postgres.tls_enabled = true to enable encryption."
|
||||
);
|
||||
|
||||
Ok(Self { pool })
|
||||
let pool = pool_config
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.map_err(|e| {
|
||||
PinakesError::Database(format!("failed to create connection pool: {e}"))
|
||||
})?;
|
||||
|
||||
// Verify connectivity
|
||||
let _ = pool.get().await.map_err(|e| {
|
||||
PinakesError::Database(format!("failed to connect to postgres: {e}"))
|
||||
})?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3229,6 +3284,167 @@ impl StorageBackend for PostgresBackend {
|
|||
.await?;
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
// ===== Session Management =====
|
||||
|
||||
async fn create_session(&self, session: &crate::storage::SessionData) -> Result<()> {
|
||||
let client = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("pool error: {e}")))?;
|
||||
|
||||
client
|
||||
.execute(
|
||||
"INSERT INTO sessions (session_token, user_id, username, role, created_at, expires_at, last_accessed)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
&[
|
||||
&session.session_token,
|
||||
&session.user_id,
|
||||
&session.username,
|
||||
&session.role,
|
||||
&session.created_at,
|
||||
&session.expires_at,
|
||||
&session.last_accessed,
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_session(
|
||||
&self,
|
||||
session_token: &str,
|
||||
) -> Result<Option<crate::storage::SessionData>> {
|
||||
let client = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("pool error: {e}")))?;
|
||||
|
||||
let row = client
|
||||
.query_opt(
|
||||
"SELECT session_token, user_id, username, role, created_at, expires_at, last_accessed
|
||||
FROM sessions WHERE session_token = $1",
|
||||
&[&session_token],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|r| crate::storage::SessionData {
|
||||
session_token: r.get(0),
|
||||
user_id: r.get(1),
|
||||
username: r.get(2),
|
||||
role: r.get(3),
|
||||
created_at: r.get(4),
|
||||
expires_at: r.get(5),
|
||||
last_accessed: r.get(6),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn touch_session(&self, session_token: &str) -> Result<()> {
|
||||
let client = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("pool error: {e}")))?;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
client
|
||||
.execute(
|
||||
"UPDATE sessions SET last_accessed = $1 WHERE session_token = $2",
|
||||
&[&now, &session_token],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_session(&self, session_token: &str) -> Result<()> {
|
||||
let client = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("pool error: {e}")))?;
|
||||
|
||||
client
|
||||
.execute(
|
||||
"DELETE FROM sessions WHERE session_token = $1",
|
||||
&[&session_token],
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user_sessions(&self, username: &str) -> Result<u64> {
|
||||
let client = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("pool error: {e}")))?;
|
||||
|
||||
let affected = client
|
||||
.execute("DELETE FROM sessions WHERE username = $1", &[&username])
|
||||
.await?;
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
async fn delete_expired_sessions(&self) -> Result<u64> {
|
||||
let client = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("pool error: {e}")))?;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let affected = client
|
||||
.execute("DELETE FROM sessions WHERE expires_at < $1", &[&now])
|
||||
.await?;
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
async fn list_active_sessions(
|
||||
&self,
|
||||
username: Option<&str>,
|
||||
) -> Result<Vec<crate::storage::SessionData>> {
|
||||
let client = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| PinakesError::Database(format!("pool error: {e}")))?;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let rows = if let Some(user) = username {
|
||||
client
|
||||
.query(
|
||||
"SELECT session_token, user_id, username, role, created_at, expires_at, last_accessed
|
||||
FROM sessions WHERE expires_at > $1 AND username = $2
|
||||
ORDER BY last_accessed DESC",
|
||||
&[&now, &user],
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
client
|
||||
.query(
|
||||
"SELECT session_token, user_id, username, role, created_at, expires_at, last_accessed
|
||||
FROM sessions WHERE expires_at > $1
|
||||
ORDER BY last_accessed DESC",
|
||||
&[&now],
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| crate::storage::SessionData {
|
||||
session_token: r.get(0),
|
||||
user_id: r.get(1),
|
||||
username: r.get(2),
|
||||
role: r.get(3),
|
||||
created_at: r.get(4),
|
||||
expires_at: r.get(5),
|
||||
last_accessed: r.get(6),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
|
|
|
|||
|
|
@ -3580,6 +3580,227 @@ impl StorageBackend for SqliteBackend {
|
|||
.map_err(|_| PinakesError::Database("cleanup_expired_transcodes timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
|
||||
// ===== Session Management =====
|
||||
|
||||
async fn create_session(&self, session: &crate::storage::SessionData) -> Result<()> {
|
||||
let conn = self.conn.clone();
|
||||
let session_token = session.session_token.clone();
|
||||
let user_id = session.user_id.clone();
|
||||
let username = session.username.clone();
|
||||
let role = session.role.clone();
|
||||
let created_at = session.created_at.to_rfc3339();
|
||||
let expires_at = session.expires_at.to_rfc3339();
|
||||
let last_accessed = session.last_accessed.to_rfc3339();
|
||||
|
||||
let fut = tokio::task::spawn_blocking(move || {
|
||||
let db = conn.lock().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to acquire database lock: {}", e))
|
||||
})?;
|
||||
db.execute(
|
||||
"INSERT INTO sessions (session_token, user_id, username, role, created_at, expires_at, last_accessed)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
params![
|
||||
&session_token,
|
||||
&user_id,
|
||||
&username,
|
||||
&role,
|
||||
&created_at,
|
||||
&expires_at,
|
||||
&last_accessed
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
});
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), fut)
|
||||
.await
|
||||
.map_err(|_| PinakesError::Database("create_session timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
|
||||
async fn get_session(
|
||||
&self,
|
||||
session_token: &str,
|
||||
) -> Result<Option<crate::storage::SessionData>> {
|
||||
let conn = self.conn.clone();
|
||||
let token = session_token.to_string();
|
||||
|
||||
let fut = tokio::task::spawn_blocking(move || {
|
||||
let db = conn.lock().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to acquire database lock: {}", e))
|
||||
})?;
|
||||
|
||||
let result = db
|
||||
.query_row(
|
||||
"SELECT session_token, user_id, username, role, created_at, expires_at, last_accessed
|
||||
FROM sessions WHERE session_token = ?",
|
||||
[&token],
|
||||
|row| {
|
||||
let created_at_str: String = row.get(4)?;
|
||||
let expires_at_str: String = row.get(5)?;
|
||||
let last_accessed_str: String = row.get(6)?;
|
||||
|
||||
Ok(crate::storage::SessionData {
|
||||
session_token: row.get(0)?,
|
||||
user_id: row.get(1)?,
|
||||
username: row.get(2)?,
|
||||
role: row.get(3)?,
|
||||
created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
|
||||
.with_timezone(&chrono::Utc),
|
||||
expires_at: chrono::DateTime::parse_from_rfc3339(&expires_at_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
|
||||
.with_timezone(&chrono::Utc),
|
||||
last_accessed: chrono::DateTime::parse_from_rfc3339(&last_accessed_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
|
||||
.with_timezone(&chrono::Utc),
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
Ok(result)
|
||||
});
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), fut)
|
||||
.await
|
||||
.map_err(|_| PinakesError::Database("get_session timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
|
||||
async fn touch_session(&self, session_token: &str) -> Result<()> {
|
||||
let conn = self.conn.clone();
|
||||
let token = session_token.to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
let fut = tokio::task::spawn_blocking(move || {
|
||||
let db = conn.lock().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to acquire database lock: {}", e))
|
||||
})?;
|
||||
db.execute(
|
||||
"UPDATE sessions SET last_accessed = ? WHERE session_token = ?",
|
||||
params![&now, &token],
|
||||
)?;
|
||||
Ok(())
|
||||
});
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), fut)
|
||||
.await
|
||||
.map_err(|_| PinakesError::Database("touch_session timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
|
||||
async fn delete_session(&self, session_token: &str) -> Result<()> {
|
||||
let conn = self.conn.clone();
|
||||
let token = session_token.to_string();
|
||||
|
||||
let fut = tokio::task::spawn_blocking(move || {
|
||||
let db = conn.lock().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to acquire database lock: {}", e))
|
||||
})?;
|
||||
db.execute("DELETE FROM sessions WHERE session_token = ?", [&token])?;
|
||||
Ok(())
|
||||
});
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), fut)
|
||||
.await
|
||||
.map_err(|_| PinakesError::Database("delete_session timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
|
||||
async fn delete_user_sessions(&self, username: &str) -> Result<u64> {
|
||||
let conn = self.conn.clone();
|
||||
let user = username.to_string();
|
||||
|
||||
let fut = tokio::task::spawn_blocking(move || {
|
||||
let db = conn.lock().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to acquire database lock: {}", e))
|
||||
})?;
|
||||
let affected = db.execute("DELETE FROM sessions WHERE username = ?", [&user])?;
|
||||
Ok(affected as u64)
|
||||
});
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), fut)
|
||||
.await
|
||||
.map_err(|_| PinakesError::Database("delete_user_sessions timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
|
||||
async fn delete_expired_sessions(&self) -> Result<u64> {
|
||||
let conn = self.conn.clone();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
let fut = tokio::task::spawn_blocking(move || {
|
||||
let db = conn.lock().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to acquire database lock: {}", e))
|
||||
})?;
|
||||
let affected = db.execute("DELETE FROM sessions WHERE expires_at < ?", [&now])?;
|
||||
Ok(affected as u64)
|
||||
});
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), fut)
|
||||
.await
|
||||
.map_err(|_| PinakesError::Database("delete_expired_sessions timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
|
||||
async fn list_active_sessions(
|
||||
&self,
|
||||
username: Option<&str>,
|
||||
) -> Result<Vec<crate::storage::SessionData>> {
|
||||
let conn = self.conn.clone();
|
||||
let user_filter = username.map(|s| s.to_string());
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
let fut = tokio::task::spawn_blocking(move || {
|
||||
let db = conn.lock().map_err(|e| {
|
||||
PinakesError::Database(format!("failed to acquire database lock: {}", e))
|
||||
})?;
|
||||
|
||||
let (query, params): (&str, Vec<String>) = if let Some(user) = user_filter {
|
||||
(
|
||||
"SELECT session_token, user_id, username, role, created_at, expires_at, last_accessed
|
||||
FROM sessions WHERE expires_at > ? AND username = ?
|
||||
ORDER BY last_accessed DESC",
|
||||
vec![now, user],
|
||||
)
|
||||
} else {
|
||||
(
|
||||
"SELECT session_token, user_id, username, role, created_at, expires_at, last_accessed
|
||||
FROM sessions WHERE expires_at > ?
|
||||
ORDER BY last_accessed DESC",
|
||||
vec![now],
|
||||
)
|
||||
};
|
||||
|
||||
let mut stmt = db.prepare(query)?;
|
||||
let param_refs: Vec<&dyn rusqlite::ToSql> =
|
||||
params.iter().map(|p| p as &dyn rusqlite::ToSql).collect();
|
||||
let rows = stmt.query_map(¶m_refs[..], |row| {
|
||||
let created_at_str: String = row.get(4)?;
|
||||
let expires_at_str: String = row.get(5)?;
|
||||
let last_accessed_str: String = row.get(6)?;
|
||||
|
||||
Ok(crate::storage::SessionData {
|
||||
session_token: row.get(0)?,
|
||||
user_id: row.get(1)?,
|
||||
username: row.get(2)?,
|
||||
role: row.get(3)?,
|
||||
created_at: chrono::DateTime::parse_from_rfc3339(&created_at_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
|
||||
.with_timezone(&chrono::Utc),
|
||||
expires_at: chrono::DateTime::parse_from_rfc3339(&expires_at_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
|
||||
.with_timezone(&chrono::Utc),
|
||||
last_accessed: chrono::DateTime::parse_from_rfc3339(&last_accessed_str)
|
||||
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
|
||||
.with_timezone(&chrono::Utc),
|
||||
})
|
||||
})?;
|
||||
|
||||
rows.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
});
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), fut)
|
||||
.await
|
||||
.map_err(|_| PinakesError::Database("list_active_sessions timed out".into()))?
|
||||
.map_err(|e: tokio::task::JoinError| PinakesError::Database(e.to_string()))?
|
||||
}
|
||||
}
|
||||
|
||||
// Needed for `query_row(...).optional()`
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue