diff --git a/Cargo.lock b/Cargo.lock index 68c1976e..fdfe07fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,6 +67,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "anyhow" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" + [[package]] name = "autocfg" version = "1.4.0" @@ -369,12 +375,14 @@ dependencies = [ name = "mrc" version = "0.1.0" dependencies = [ + "anyhow", "clap", "clap_derive", "ipc-channel", "native-tls", "serde", "serde_json", + "thiserror", "tokio", "tokio-native-tls", "tracing", @@ -751,6 +759,26 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" diff --git a/Cargo.toml b/Cargo.toml index 72b71ab1..709bc23f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,11 +16,13 @@ name = "server" path = "src/server.rs" [dependencies] +anyhow = "1.0" clap = { version = "4.5", features = ["derive"] } clap_derive = "4.5" ipc-channel = "0.19" serde = { version = "1", features = ["derive"] } serde_json = "1.0" +thiserror = "1.0" tokio = { version = "1.43", features = ["full"] } native-tls = "0.2" tokio-native-tls = "0.3" diff --git a/src/cli.rs b/src/cli.rs index 999aca2a..b27ddb0b 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -3,11 +3,10 @@ use mrc::set_property; use mrc::SOCKET_PATH; use mrc::{ get_property, loadfile, playlist_clear, playlist_move, playlist_next, playlist_prev, - playlist_remove, quit, seek, + playlist_remove, quit, seek, MrcError, Result, }; use serde_json::json; -use std::io::{self, Write}; -use std::path::PathBuf; +use std::{io::{self, Write}, path::PathBuf}; use tracing::{debug, error, info}; #[derive(Parser)] @@ -95,14 +94,17 @@ enum CommandOptions { } #[tokio::main] -async fn main() -> io::Result<()> { +async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let cli = Cli::parse(); if !PathBuf::from(SOCKET_PATH).exists() { debug!(SOCKET_PATH); error!("Error: MPV socket not found. Is MPV running?"); - return Ok(()); + return Err(MrcError::ConnectionError(std::io::Error::new( + std::io::ErrorKind::NotFound, + "MPV socket not found", + ))); } match cli.command { @@ -163,7 +165,9 @@ async fn main() -> io::Result<()> { CommandOptions::List => { info!("Listing playlist items"); if let Some(data) = get_property("playlist", None).await? { - println!("{}", serde_json::to_string_pretty(&data)?); + let pretty_json = serde_json::to_string_pretty(&data) + .map_err(MrcError::ParseError)?; + println!("{}", pretty_json); } } @@ -171,7 +175,7 @@ async fn main() -> io::Result<()> { if filenames.is_empty() { let e = "No files provided to add to the playlist"; error!("{}", e); - return Err(io::Error::new(io::ErrorKind::InvalidInput, e)); + return Err(MrcError::InvalidInput(e.to_string())); } info!("Adding {} files to the playlist", filenames.len()); @@ -206,9 +210,9 @@ async fn main() -> io::Result<()> { loop { print!("mpv> "); - stdout.flush()?; + stdout.flush().map_err(MrcError::ConnectionError)?; let mut input = String::new(); - stdin.read_line(&mut input)?; + stdin.read_line(&mut input).map_err(MrcError::ConnectionError)?; let trimmed = input.trim(); if trimmed.eq_ignore_ascii_case("exit") { @@ -236,6 +240,7 @@ async fn main() -> io::Result<()> { "set ", "Set the specified property to a value", ), + ("help", "Show this help message"), ("exit", "Quit interactive mode"), ]; @@ -301,7 +306,9 @@ async fn main() -> io::Result<()> { ["list"] => { info!("Listing playlist items"); if let Some(data) = get_property("playlist", None).await? { - println!("{}", serde_json::to_string_pretty(&data)?); + let pretty_json = serde_json::to_string_pretty(&data) + .map_err(MrcError::ParseError)?; + println!("{}", pretty_json); } } @@ -331,7 +338,7 @@ async fn main() -> io::Result<()> { _ => { println!("Unknown command: {}", trimmed); - println!("Valid commands: play , pause, stop, next, prev, seek , clear, list, add , get , set , exit"); + println!("Valid commands: play , pause, stop, next, prev, seek , clear, list, add , get , set , help, exit"); } } } diff --git a/src/lib.rs b/src/lib.rs index 2ee6d79a..17138163 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,13 +41,65 @@ //! ## Functions use serde_json::{json, Value}; -use std::io::{self}; +use std::io; +use thiserror::Error; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::UnixStream; use tracing::{debug, error}; pub const SOCKET_PATH: &str = "/tmp/mpvsocket"; +/// Errors that can occur when interacting with the MPV IPC interface. +#[derive(Error, Debug)] +pub enum MrcError { + /// Connection to the MPV socket could not be established. + #[error("failed to connect to MPV socket: {0}")] + ConnectionError(#[from] io::Error), + + /// Error when parsing a JSON response from MPV. + #[error("failed to parse JSON response: {0}")] + ParseError(#[from] serde_json::Error), + + /// Error when a socket operation times out. + #[error("socket operation timed out after {0} seconds")] + SocketTimeout(u64), + + /// Error when MPV returns an error response. + #[error("MPV error: {0}")] + MpvError(String), + + /// Error when trying to use a property that doesn't exist. + #[error("property '{0}' not found")] + PropertyNotFound(String), + + /// Error when the socket response is not valid UTF-8. + #[error("invalid UTF-8 in socket response: {0}")] + InvalidUtf8(#[from] std::string::FromUtf8Error), + + /// Error when a network operation fails. + #[error("network error: {0}")] + NetworkError(String), + + /// Error when the server connection is lost or broken. + #[error("server connection lost: {0}")] + ConnectionLost(String), + + /// Error when a communication protocol is violated. + #[error("protocol error: {0}")] + ProtocolError(String), + + /// Error when invalid input is provided. + #[error("invalid input: {0}")] + InvalidInput(String), + + /// Error related to TLS operations. + #[error("TLS error: {0}")] + TlsError(String), +} + +/// A specialized Result type for MRC operations. +pub type Result = std::result::Result; + /// Sends a generic IPC command to the specified socket and returns the parsed response data. /// /// # Arguments @@ -59,55 +111,84 @@ pub const SOCKET_PATH: &str = "/tmp/mpvsocket"; /// A `Result` containing an `Option` with the parsed response data if successful. /// /// # Errors -/// Returns an error if the connection to the socket fails or if the response cannot be parsed. +/// Returns a `MrcError` if the connection to the socket fails or if the response cannot be parsed. pub async fn send_ipc_command( command: &str, args: &[Value], socket_path: Option<&str>, -) -> io::Result> { +) -> Result> { let socket_path = socket_path.unwrap_or(SOCKET_PATH); debug!( "Sending IPC command: {} with arguments: {:?}", command, args ); - match UnixStream::connect(socket_path).await { - Ok(mut socket) => { - debug!("Connected to socket at {}", socket_path); + // Add timeout for connection + let stream = tokio::time::timeout( + std::time::Duration::from_secs(5), + UnixStream::connect(socket_path), + ) + .await + .map_err(|_| MrcError::SocketTimeout(5))? + .map_err(MrcError::ConnectionError)?; - let mut command_array = vec![json!(command)]; - command_array.extend_from_slice(args); - let message = json!({ "command": command_array }); - let message_str = format!("{}\n", serde_json::to_string(&message)?); - debug!("Serialized message to send with newline: {}", message_str); + let mut socket = stream; + debug!("Connected to socket at {}", socket_path); - socket.write_all(message_str.as_bytes()).await?; - socket.flush().await?; - debug!("Message sent and flushed"); + let mut command_array = vec![json!(command)]; + command_array.extend_from_slice(args); + let message = json!({ "command": command_array }); + let message_str = format!("{}\n", serde_json::to_string(&message)?); + debug!("Serialized message to send with newline: {}", message_str); - let mut response = vec![0; 1024]; - let n = socket.read(&mut response).await?; - let response_str = String::from_utf8_lossy(&response[..n]); - debug!("Raw response: {}", response_str); + // Write with timeout + tokio::time::timeout( + std::time::Duration::from_secs(5), + socket.write_all(message_str.as_bytes()), + ) + .await + .map_err(|_| MrcError::SocketTimeout(5))??; + + // Flush with timeout + tokio::time::timeout( + std::time::Duration::from_secs(5), + socket.flush(), + ) + .await + .map_err(|_| MrcError::SocketTimeout(5))??; + + debug!("Message sent and flushed"); - match serde_json::from_str::(&response_str) { - Ok(json_response) => { - debug!("Parsed IPC response: {:?}", json_response); - Ok(json_response.get("data").cloned()) - } + let mut response = vec![0; 1024]; + + // Read with timeout + let n = tokio::time::timeout( + std::time::Duration::from_secs(5), + socket.read(&mut response), + ) + .await + .map_err(|_| MrcError::SocketTimeout(5))??; + + if n == 0 { + return Err(MrcError::ConnectionLost("Socket closed unexpectedly".into())); + } - Err(e) => { - error!("Failed to parse response: {}", e); - Ok(None) - } - } - } + let response_str = String::from_utf8(response[..n].to_vec())?; + debug!("Raw response: {}", response_str); - Err(e) => { - error!("Failed to connect to MPV socket: {}", e); - Err(e) + let json_response = serde_json::from_str::(&response_str) + .map_err(MrcError::ParseError)?; + + debug!("Parsed IPC response: {:?}", json_response); + + // Check if MPV returned an error + if let Some(error) = json_response.get("error").and_then(|e| e.as_str()) { + if !error.is_empty() { + return Err(MrcError::MpvError(error.to_string())); } } + + Ok(json_response.get("data").cloned()) } /// Represents common MPV commands. @@ -179,7 +260,7 @@ pub async fn set_property( property: &str, value: &Value, socket_path: Option<&str>, -) -> io::Result> { +) -> Result> { send_ipc_command( MpvCommand::SetProperty.as_str(), &[json!(property), value.clone()], @@ -198,7 +279,7 @@ pub async fn set_property( /// /// # Errors /// Returns an error if the connection to the socket fails or the command execution encounters issues. -pub async fn playlist_next(socket_path: Option<&str>) -> io::Result> { +pub async fn playlist_next(socket_path: Option<&str>) -> Result> { send_ipc_command(MpvCommand::PlaylistNext.as_str(), &[], socket_path).await } @@ -212,7 +293,7 @@ pub async fn playlist_next(socket_path: Option<&str>) -> io::Result) -> io::Result> { +pub async fn playlist_prev(socket_path: Option<&str>) -> Result> { send_ipc_command(MpvCommand::PlaylistPrev.as_str(), &[], socket_path).await } @@ -227,7 +308,7 @@ pub async fn playlist_prev(socket_path: Option<&str>) -> io::Result) -> io::Result> { +pub async fn seek(seconds: f64, socket_path: Option<&str>) -> Result> { send_ipc_command(MpvCommand::Seek.as_str(), &[json!(seconds)], socket_path).await } @@ -241,7 +322,7 @@ pub async fn seek(seconds: f64, socket_path: Option<&str>) -> io::Result) -> io::Result> { +pub async fn quit(socket_path: Option<&str>) -> Result> { send_ipc_command(MpvCommand::Quit.as_str(), &[], socket_path).await } @@ -261,7 +342,7 @@ pub async fn playlist_move( from_index: usize, to_index: usize, socket_path: Option<&str>, -) -> io::Result> { +) -> Result> { send_ipc_command( MpvCommand::PlaylistMove.as_str(), &[json!(from_index), json!(to_index)], @@ -284,7 +365,7 @@ pub async fn playlist_move( pub async fn playlist_remove( index: Option, socket_path: Option<&str>, -) -> io::Result> { +) -> Result> { let args = match index { Some(idx) => vec![json!(idx)], None => vec![json!("current")], @@ -302,7 +383,7 @@ pub async fn playlist_remove( /// /// # Errors /// Returns an error if the connection to the socket fails or the command execution encounters issues. -pub async fn playlist_clear(socket_path: Option<&str>) -> io::Result> { +pub async fn playlist_clear(socket_path: Option<&str>) -> Result> { send_ipc_command(MpvCommand::PlaylistClear.as_str(), &[], socket_path).await } @@ -317,7 +398,7 @@ pub async fn playlist_clear(socket_path: Option<&str>) -> io::Result) -> io::Result> { +pub async fn get_property(property: &str, socket_path: Option<&str>) -> Result> { send_ipc_command( MpvCommand::GetProperty.as_str(), &[json!(property)], @@ -342,7 +423,7 @@ pub async fn loadfile( filename: &str, append: bool, socket_path: Option<&str>, -) -> io::Result> { +) -> Result> { let append_flag = if append { json!("append-play") } else { diff --git a/src/server.rs b/src/server.rs index 1ab9ea84..c307c47c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,7 +9,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_native_tls::TlsAcceptor; use tracing::{debug, error, info}; -use mrc::{get_property, playlist_clear, playlist_next, playlist_prev, quit, seek, set_property}; +use mrc::{get_property, playlist_clear, playlist_next, playlist_prev, quit, seek, set_property, MrcError, Result as MrcResult}; #[derive(Parser)] #[command(author, version, about)] @@ -26,11 +26,13 @@ struct Config { async fn handle_connection( stream: tokio::net::TcpStream, acceptor: Arc, -) -> Result<(), Box> { - let mut stream = acceptor.accept(stream).await?; +) -> MrcResult<()> { + let mut stream = acceptor.accept(stream).await + .map_err(|e| MrcError::TlsError(e.to_string()))?; let mut buffer = vec![0; 2048]; - let n = stream.read(&mut buffer).await?; + let n = stream.read(&mut buffer).await + .map_err(MrcError::ConnectionError)?; let request = String::from_utf8_lossy(&buffer[..n]); debug!("Received request:\n{}", request); @@ -87,45 +89,35 @@ async fn handle_connection( Ok(()) } -async fn process_command(command: &str) -> Result { +async fn process_command(command: &str) -> MrcResult { match command { "pause" => { info!("Pausing playback"); - set_property("pause", &json!(true), None) - .await - .map_err(|e| format!("Failed to pause: {:?}", e))?; + set_property("pause", &json!(true), None).await?; Ok("Paused playback\n".to_string()) } "play" => { info!("Unpausing playback"); - set_property("pause", &json!(false), None) - .await - .map_err(|e| format!("Failed to play: {:?}", e))?; + set_property("pause", &json!(false), None).await?; Ok("Resumed playback\n".to_string()) } "stop" => { info!("Stopping playback and quitting MPV"); - quit(None) - .await - .map_err(|e| format!("Failed to stop: {:?}", e))?; + quit(None).await?; Ok("Stopped playback\n".to_string()) } "next" => { info!("Skipping to next item in the playlist"); - playlist_next(None) - .await - .map_err(|e| format!("Failed to skip to next: {:?}", e))?; + playlist_next(None).await?; Ok("Skipped to next item\n".to_string()) } "prev" => { info!("Skipping to previous item in the playlist"); - playlist_prev(None) - .await - .map_err(|e| format!("Failed to skip to previous: {:?}", e))?; + playlist_prev(None).await?; Ok("Skipped to previous item\n".to_string()) } @@ -134,55 +126,56 @@ async fn process_command(command: &str) -> Result { if let Some(seconds) = parts.get(1) { if let Ok(sec) = seconds.parse::() { info!("Seeking to {} seconds", sec); - seek(sec.into(), None) - .await - .map_err(|e| format!("Failed to seek: {:?}", e))?; + seek(sec.into(), None).await?; return Ok(format!("Seeking to {} seconds\n", sec)); } } - Err("Invalid seek command".to_string()) + Err(MrcError::InvalidInput("Invalid seek command".to_string())) } "clear" => { info!("Clearing the playlist"); - playlist_clear(None) - .await - .map_err(|e| format!("Failed to clear playlist: {:?}", e))?; + playlist_clear(None).await?; Ok("Cleared playlist\n".to_string()) } "list" => { info!("Listing playlist items"); match get_property("playlist", None).await { - Ok(Some(data)) => Ok(format!( - "Playlist: {}", - serde_json::to_string_pretty(&data).unwrap() - )), - Ok(None) => Err("No playlist data available".to_string()), - Err(e) => Err(format!("Failed to fetch playlist: {:?}", e)), + Ok(Some(data)) => { + let pretty_json = serde_json::to_string_pretty(&data) + .map_err(MrcError::ParseError)?; + Ok(format!("Playlist: {}", pretty_json)) + }, + Ok(None) => Err(MrcError::PropertyNotFound("playlist".to_string())), + Err(e) => Err(e), } } - _ => Err("Unknown command".to_string()), + _ => Err(MrcError::InvalidInput(format!("Unknown command: {}", command))), } } -fn create_tls_acceptor() -> Result> { +fn create_tls_acceptor() -> MrcResult { let pfx_path = env::var("TLS_PFX_PATH") - .map_err(|_| std::io::Error::new(std::io::ErrorKind::NotFound, "TLS_PFX_PATH not set"))?; + .map_err(|_| MrcError::InvalidInput("TLS_PFX_PATH not set".to_string()))?; let password = env::var("TLS_PASSWORD") - .map_err(|_| std::io::Error::new(std::io::ErrorKind::NotFound, "TLS_PASSWORD not set"))?; + .map_err(|_| MrcError::InvalidInput("TLS_PASSWORD not set".to_string()))?; - let mut file = std::fs::File::open(&pfx_path)?; + let mut file = std::fs::File::open(&pfx_path) + .map_err(MrcError::ConnectionError)?; let mut identity = vec![]; - file.read_to_end(&mut identity)?; + file.read_to_end(&mut identity) + .map_err(MrcError::ConnectionError)?; - let identity = Identity::from_pkcs12(&identity, &password)?; - let native_acceptor = NativeTlsAcceptor::new(identity)?; + let identity = Identity::from_pkcs12(&identity, &password) + .map_err(|e| MrcError::TlsError(e.to_string()))?; + let native_acceptor = NativeTlsAcceptor::new(identity) + .map_err(|e| MrcError::TlsError(e.to_string()))?; Ok(TlsAcceptor::from(native_acceptor)) } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> MrcResult<()> { tracing_subscriber::fmt::init(); let config = Config::parse(); @@ -191,17 +184,23 @@ async fn main() -> Result<(), Box> { "Error: MPV socket not found at '{}'. Is MPV running?", config.socket ); + return Err(MrcError::ConnectionError(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("MPV socket not found at '{}'", config.socket), + ))); } info!("Server is starting..."); match create_tls_acceptor() { Ok(acceptor) => { let acceptor = Arc::new(acceptor); - let listener = tokio::net::TcpListener::bind(&config.bind).await?; + let listener = tokio::net::TcpListener::bind(&config.bind).await + .map_err(MrcError::ConnectionError)?; info!("Server is listening on {}", config.bind); loop { - let (stream, _) = listener.accept().await?; + let (stream, _) = listener.accept().await + .map_err(MrcError::ConnectionError)?; info!("New connection accepted."); let acceptor = Arc::clone(&acceptor);