pinakes: import in parallel; various UI improvements

Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: I1eb47cd79cd4145c56af966f6756fe1d6a6a6964
This commit is contained in:
raf 2026-02-03 10:31:20 +03:00
commit 116fe7b059
Signed by: NotAShelf
GPG key ID: 29D95B64378DB4BF
42 changed files with 4316 additions and 316 deletions

View file

@ -5,16 +5,27 @@ use axum::extract::DefaultBodyLimit;
use axum::http::{HeaderValue, Method, header};
use axum::middleware;
use axum::routing::{delete, get, patch, post, put};
use tower::ServiceBuilder;
use tower_governor::GovernorLayer;
use tower_governor::governor::GovernorConfigBuilder;
use tower_http::cors::CorsLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer;
use crate::auth;
use crate::routes;
use crate::state::AppState;
/// Create the router with optional TLS configuration for HSTS headers
pub fn create_router(state: AppState) -> Router {
create_router_with_tls(state, None)
}
/// Create the router with TLS configuration for security headers
pub fn create_router_with_tls(
state: AppState,
tls_config: Option<&pinakes_core::config::TlsConfig>,
) -> Router {
// Global rate limit: 100 requests/sec per IP
let global_governor = Arc::new(
GovernorConfigBuilder::default()
@ -41,11 +52,16 @@ pub fn create_router(state: AppState) -> Router {
});
// Public routes (no auth required)
let public_routes = Router::new().route("/s/{token}", get(routes::social::access_shared_media));
let public_routes = Router::new()
.route("/s/{token}", get(routes::social::access_shared_media))
// Kubernetes-style health probes (no auth required for orchestration)
.route("/health/live", get(routes::health::liveness))
.route("/health/ready", get(routes::health::readiness));
// Read-only routes: any authenticated user (Viewer+)
let viewer_routes = Router::new()
.route("/health", get(routes::health::health))
.route("/health/detailed", get(routes::health::health_detailed))
.route("/media/count", get(routes::media::get_media_count))
.route("/media", get(routes::media::list_media))
.route("/media/{id}", get(routes::media::get_media))
@ -393,7 +409,40 @@ pub fn create_router(state: AppState) -> Router {
.merge(public_routes)
.merge(protected_api);
Router::new()
// Build security headers layer
let security_headers = ServiceBuilder::new()
// Prevent MIME type sniffing
.layer(SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
))
// Prevent clickjacking
.layer(SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
HeaderValue::from_static("DENY"),
))
// XSS protection (legacy but still useful for older browsers)
.layer(SetResponseHeaderLayer::overriding(
header::HeaderName::from_static("x-xss-protection"),
HeaderValue::from_static("1; mode=block"),
))
// Referrer policy
.layer(SetResponseHeaderLayer::overriding(
header::REFERRER_POLICY,
HeaderValue::from_static("strict-origin-when-cross-origin"),
))
// Permissions policy (disable unnecessary features)
.layer(SetResponseHeaderLayer::overriding(
header::HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("geolocation=(), microphone=(), camera=()"),
))
// Content Security Policy for API responses
.layer(SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
));
let router = Router::new()
.nest("/api/v1", full_api)
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
.layer(GovernorLayer {
@ -401,5 +450,26 @@ pub fn create_router(state: AppState) -> Router {
})
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)
.layer(security_headers);
// Add HSTS header when TLS is enabled
if let Some(tls) = tls_config {
if tls.enabled && tls.hsts_enabled {
let hsts_value = format!("max-age={}; includeSubDomains", tls.hsts_max_age);
let hsts_header = HeaderValue::from_str(&hsts_value).unwrap_or_else(|_| {
HeaderValue::from_static("max-age=31536000; includeSubDomains")
});
router
.layer(SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
hsts_header,
))
.with_state(state)
} else {
router.with_state(state)
}
} else {
router.with_state(state)
}
}

View file

@ -2,6 +2,9 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Result;
use axum::Router;
use axum::response::Redirect;
use axum::routing::any;
use clap::Parser;
use tokio::sync::RwLock;
use tracing::info;
@ -202,6 +205,7 @@ async fn main() -> Result<()> {
scanning: false,
files_found: total_found,
files_processed: total_processed,
files_skipped: 0,
errors: all_errors,
}
})
@ -459,7 +463,7 @@ async fn main() -> Result<()> {
let state = AppState {
storage: storage.clone(),
config: config_arc,
config: config_arc.clone(),
config_path: Some(config_path),
scan_progress: pinakes_core::scan::ScanProgress::new(),
sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
@ -489,23 +493,124 @@ async fn main() -> Result<()> {
});
}
let router = app::create_router(state);
let config_read = config_arc.read().await;
let tls_config = config_read.server.tls.clone();
drop(config_read);
info!(addr = %addr, "server listening");
let listener = tokio::net::TcpListener::bind(&addr).await?;
// Create router with TLS config for HSTS headers
let router = if tls_config.enabled {
app::create_router_with_tls(state, Some(&tls_config))
} else {
app::create_router(state)
};
axum::serve(
listener,
router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await?;
if tls_config.enabled {
// TLS/HTTPS mode
let cert_path = tls_config
.cert_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS enabled but cert_path not specified"))?;
let key_path = tls_config
.key_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS enabled but key_path not specified"))?;
info!(addr = %addr, cert = %cert_path.display(), "server listening with TLS");
// Configure TLS
let tls_config_builder =
axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_path, key_path).await?;
// Start HTTP redirect server if configured
if tls_config.redirect_http {
let http_addr = format!(
"{}:{}",
config_arc.read().await.server.host,
tls_config.http_port
);
let https_port = config_arc.read().await.server.port;
let https_host = config_arc.read().await.server.host.clone();
let redirect_router = create_https_redirect_router(https_host, https_port);
let shutdown = shutdown_token.clone();
tokio::spawn(async move {
let listener = match tokio::net::TcpListener::bind(&http_addr).await {
Ok(l) => l,
Err(e) => {
tracing::warn!(error = %e, addr = %http_addr, "failed to bind HTTP redirect listener");
return;
}
};
info!(addr = %http_addr, "HTTP redirect server listening");
let server = axum::serve(
listener,
redirect_router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
tokio::select! {
result = server => {
if let Err(e) = result {
tracing::warn!(error = %e, "HTTP redirect server error");
}
}
_ = shutdown.cancelled() => {
info!("HTTP redirect server shutting down");
}
}
});
}
// Start HTTPS server with graceful shutdown via Handle
let addr_parsed: std::net::SocketAddr = addr.parse()?;
let handle = axum_server::Handle::new();
let shutdown_handle = handle.clone();
// Spawn a task to trigger graceful shutdown
tokio::spawn(async move {
shutdown_signal().await;
shutdown_handle.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
});
axum_server::bind_rustls(addr_parsed, tls_config_builder)
.handle(handle)
.serve(router.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await?;
} else {
// Plain HTTP mode
info!(addr = %addr, "server listening");
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(
listener,
router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await?;
}
shutdown_token.cancel();
info!("server shut down");
Ok(())
}
/// Create a router that redirects all HTTP requests to HTTPS
fn create_https_redirect_router(https_host: String, https_port: u16) -> Router {
Router::new().fallback(any(move |uri: axum::http::Uri| {
let https_host = https_host.clone();
async move {
let path_and_query = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
let https_url = if https_port == 443 {
format!("https://{}{}", https_host, path_and_query)
} else {
format!("https://{}:{}{}", https_host, https_port, path_and_query)
};
Redirect::permanent(&https_url)
}
}))
}
async fn shutdown_signal() {
let ctrl_c = async {
match tokio::signal::ctrl_c().await {

View file

@ -5,6 +5,12 @@ use axum::http::{HeaderMap, StatusCode};
use crate::dto::{LoginRequest, LoginResponse, UserInfoResponse};
use crate::state::AppState;
/// Dummy password hash to use for timing-safe comparison when user doesn't exist.
/// This is a valid argon2 hash that will always fail verification but takes
/// similar time to verify as a real hash, preventing timing attacks that could
/// reveal whether a username exists.
const DUMMY_HASH: &str = "$argon2id$v=19$m=19456,t=2,p=1$VGltaW5nU2FmZUR1bW15$c2ltdWxhdGVkX2hhc2hfZm9yX3RpbWluZ19zYWZldHk";
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
@ -25,27 +31,47 @@ pub async fn login(
.iter()
.find(|u| u.username == req.username);
let user = match user {
Some(u) => u,
None => {
tracing::warn!(username = %req.username, "login failed: unknown user");
return Err(StatusCode::UNAUTHORIZED);
}
// Always perform password verification to prevent timing attacks.
// If the user doesn't exist, we verify against a dummy hash to ensure
// consistent response times regardless of whether the username exists.
use argon2::password_hash::PasswordVerifier;
let (hash_to_verify, user_found) = match user {
Some(u) => (&u.password_hash as &str, true),
None => (DUMMY_HASH, false),
};
// Verify password using argon2
use argon2::password_hash::PasswordVerifier;
let hash = &user.password_hash;
let parsed_hash = argon2::password_hash::PasswordHash::new(hash)
let parsed_hash = argon2::password_hash::PasswordHash::new(hash_to_verify)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let valid = argon2::Argon2::default()
let password_valid = argon2::Argon2::default()
.verify_password(req.password.as_bytes(), &parsed_hash)
.is_ok();
if !valid {
tracing::warn!(username = %req.username, "login failed: invalid password");
// Authentication fails if user wasn't found OR password was invalid
if !user_found || !password_valid {
// Log different messages for debugging but return same error
if !user_found {
tracing::warn!(username = %req.username, "login failed: unknown user");
} else {
tracing::warn!(username = %req.username, "login failed: invalid password");
}
// Record failed login attempt in audit log
let _ = pinakes_core::audit::record_action(
&state.storage,
None,
pinakes_core::model::AuditAction::LoginFailed,
Some(format!("username: {}", req.username)),
)
.await;
return Err(StatusCode::UNAUTHORIZED);
}
// At this point we know the user exists and password is valid
let user = user.expect("user should exist at this point");
// Generate session token
use rand::Rng;
let token: String = rand::rng()
@ -72,6 +98,15 @@ pub async fn login(
tracing::info!(username = %username, role = %role, "login successful");
// Record successful login in audit log
let _ = pinakes_core::audit::record_action(
&state.storage,
None,
pinakes_core::model::AuditAction::LoginSuccess,
Some(format!("username: {}, role: {}", username, role)),
)
.await;
Ok(Json(LoginResponse {
token,
username,
@ -81,8 +116,24 @@ pub async fn login(
pub async fn logout(State(state): State<AppState>, headers: HeaderMap) -> StatusCode {
if let Some(token) = extract_bearer_token(&headers) {
let sessions = state.sessions.read().await;
let username = sessions.get(token).map(|s| s.username.clone());
drop(sessions);
let mut sessions = state.sessions.write().await;
sessions.remove(token);
drop(sessions);
// Record logout in audit log
if let Some(user) = username {
let _ = pinakes_core::audit::record_action(
&state.storage,
None,
pinakes_core::model::AuditAction::Logout,
Some(format!("username: {}", user)),
)
.await;
}
}
StatusCode::OK
}

View file

@ -1,8 +1,221 @@
use axum::Json;
use std::time::Instant;
pub async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"version": env!("CARGO_PKG_VERSION"),
}))
use axum::Json;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde::{Deserialize, Serialize};
use crate::state::AppState;
/// Basic health check response
#[derive(Debug, Serialize, Deserialize)]
pub struct HealthResponse {
pub status: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<DatabaseHealth>,
#[serde(skip_serializing_if = "Option::is_none")]
pub filesystem: Option<FilesystemHealth>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache: Option<CacheHealth>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DatabaseHealth {
pub status: String,
pub latency_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub media_count: Option<u64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FilesystemHealth {
pub status: String,
pub roots_configured: usize,
pub roots_accessible: usize,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CacheHealth {
pub hit_rate: f64,
pub total_entries: u64,
pub responses_size: u64,
pub queries_size: u64,
pub media_size: u64,
}
/// Comprehensive health check - includes database, filesystem, and cache status
pub async fn health(State(state): State<AppState>) -> Json<HealthResponse> {
let mut response = HealthResponse {
status: "ok".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
database: None,
filesystem: None,
cache: None,
};
// Check database health
let db_start = Instant::now();
let db_health = match state.storage.count_media().await {
Ok(count) => DatabaseHealth {
status: "ok".to_string(),
latency_ms: db_start.elapsed().as_millis() as u64,
media_count: Some(count),
},
Err(e) => {
response.status = "degraded".to_string();
DatabaseHealth {
status: format!("error: {}", e),
latency_ms: db_start.elapsed().as_millis() as u64,
media_count: None,
}
}
};
response.database = Some(db_health);
// Check filesystem health (root directories)
let roots = match state.storage.list_root_dirs().await {
Ok(r) => r,
Err(_) => Vec::new(),
};
let roots_accessible = roots.iter().filter(|r| r.exists()).count();
if roots_accessible < roots.len() {
response.status = "degraded".to_string();
}
response.filesystem = Some(FilesystemHealth {
status: if roots_accessible == roots.len() {
"ok"
} else {
"degraded"
}
.to_string(),
roots_configured: roots.len(),
roots_accessible,
});
// Get cache statistics
let cache_stats = state.cache.stats();
response.cache = Some(CacheHealth {
hit_rate: cache_stats.overall_hit_rate(),
total_entries: cache_stats.total_entries(),
responses_size: cache_stats.responses.size,
queries_size: cache_stats.queries.size,
media_size: cache_stats.media.size,
});
Json(response)
}
/// Liveness probe - just checks if the server is running
/// Returns 200 OK if the server process is alive
pub async fn liveness() -> impl IntoResponse {
(
StatusCode::OK,
Json(serde_json::json!({
"status": "alive"
})),
)
}
/// Readiness probe - checks if the server can serve requests
/// Returns 200 OK if database is accessible
pub async fn readiness(State(state): State<AppState>) -> impl IntoResponse {
// Check database connectivity
let db_start = Instant::now();
match state.storage.count_media().await {
Ok(_) => {
let latency = db_start.elapsed().as_millis() as u64;
(
StatusCode::OK,
Json(serde_json::json!({
"status": "ready",
"database_latency_ms": latency
})),
)
}
Err(e) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"status": "not_ready",
"reason": e.to_string()
})),
),
}
}
/// Detailed health check for monitoring dashboards
#[derive(Debug, Serialize, Deserialize)]
pub struct DetailedHealthResponse {
pub status: String,
pub version: String,
pub uptime_seconds: u64,
pub database: DatabaseHealth,
pub filesystem: FilesystemHealth,
pub cache: CacheHealth,
pub jobs: JobsHealth,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JobsHealth {
pub pending: usize,
pub running: usize,
}
pub async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthResponse> {
// Check database
let db_start = Instant::now();
let (db_status, media_count) = match state.storage.count_media().await {
Ok(count) => ("ok".to_string(), Some(count)),
Err(e) => (format!("error: {}", e), None),
};
let db_latency = db_start.elapsed().as_millis() as u64;
// Check filesystem
let roots = state.storage.list_root_dirs().await.unwrap_or_default();
let roots_accessible = roots.iter().filter(|r| r.exists()).count();
// Get cache stats
let cache_stats = state.cache.stats();
// Get job queue stats
let job_stats = state.job_queue.stats().await;
let overall_status = if db_status == "ok" && roots_accessible == roots.len() {
"ok"
} else {
"degraded"
};
Json(DetailedHealthResponse {
status: overall_status.to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
uptime_seconds: 0, // Could track server start time
database: DatabaseHealth {
status: db_status,
latency_ms: db_latency,
media_count,
},
filesystem: FilesystemHealth {
status: if roots_accessible == roots.len() {
"ok"
} else {
"degraded"
}
.to_string(),
roots_configured: roots.len(),
roots_accessible,
},
cache: CacheHealth {
hit_rate: cache_stats.overall_hit_rate(),
total_entries: cache_stats.total_entries(),
responses_size: cache_stats.responses.size,
queries_size: cache_stats.queries.size,
media_size: cache_stats.media.size,
},
jobs: JobsHealth {
pending: job_stats.pending,
running: job_stats.running,
},
})
}