diff --git a/.envrc b/.envrc index 3550a30f..23f9a4ef 100644 --- a/.envrc +++ b/.envrc @@ -1 +1 @@ -use flake +use flake . --substituters "https://cache.nixos.org" diff --git a/src/cli.rs b/src/cli.rs index b27ddb0b..2634a6b9 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,12 +1,14 @@ use clap::{Parser, Subcommand}; -use mrc::set_property; use mrc::SOCKET_PATH; use mrc::{ - get_property, loadfile, playlist_clear, playlist_move, playlist_next, playlist_prev, - playlist_remove, quit, seek, MrcError, Result, + MrcError, Result, get_property, loadfile, playlist_clear, playlist_move, playlist_next, + playlist_prev, playlist_remove, quit, seek, set_property, }; use serde_json::json; -use std::{io::{self, Write}, path::PathBuf}; +use std::{ + io::{self, Write}, + path::PathBuf, +}; use tracing::{debug, error, info}; #[derive(Parser)] @@ -165,8 +167,8 @@ async fn main() -> Result<()> { CommandOptions::List => { info!("Listing playlist items"); if let Some(data) = get_property("playlist", None).await? { - let pretty_json = serde_json::to_string_pretty(&data) - .map_err(MrcError::ParseError)?; + let pretty_json = + serde_json::to_string_pretty(&data).map_err(MrcError::ParseError)?; println!("{}", pretty_json); } } @@ -212,7 +214,9 @@ async fn main() -> Result<()> { print!("mpv> "); stdout.flush().map_err(MrcError::ConnectionError)?; let mut input = String::new(); - stdin.read_line(&mut input).map_err(MrcError::ConnectionError)?; + stdin + .read_line(&mut input) + .map_err(MrcError::ConnectionError)?; let trimmed = input.trim(); if trimmed.eq_ignore_ascii_case("exit") { @@ -338,7 +342,9 @@ async fn main() -> Result<()> { _ => { println!("Unknown command: {}", trimmed); - println!("Valid commands: play , pause, stop, next, prev, seek , clear, list, add , get , set , help, exit"); + println!( + "Valid commands: play , pause, stop, next, prev, seek , clear, list, add , get , set , help, exit" + ); } } } diff --git a/src/lib.rs b/src/lib.rs index 17138163..64d6eed5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,7 +40,7 @@ //! //! ## Functions -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::io; use thiserror::Error; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -55,43 +55,43 @@ pub enum MrcError { /// Connection to the MPV socket could not be established. #[error("failed to connect to MPV socket: {0}")] ConnectionError(#[from] io::Error), - + /// Error when parsing a JSON response from MPV. #[error("failed to parse JSON response: {0}")] ParseError(#[from] serde_json::Error), - + /// Error when a socket operation times out. #[error("socket operation timed out after {0} seconds")] SocketTimeout(u64), - + /// Error when MPV returns an error response. #[error("MPV error: {0}")] MpvError(String), - + /// Error when trying to use a property that doesn't exist. #[error("property '{0}' not found")] PropertyNotFound(String), - + /// Error when the socket response is not valid UTF-8. #[error("invalid UTF-8 in socket response: {0}")] InvalidUtf8(#[from] std::string::FromUtf8Error), - + /// Error when a network operation fails. #[error("network error: {0}")] NetworkError(String), - + /// Error when the server connection is lost or broken. #[error("server connection lost: {0}")] ConnectionLost(String), - + /// Error when a communication protocol is violated. #[error("protocol error: {0}")] ProtocolError(String), - + /// Error when invalid input is provided. #[error("invalid input: {0}")] InvalidInput(String), - + /// Error related to TLS operations. #[error("TLS error: {0}")] TlsError(String), @@ -100,6 +100,81 @@ pub enum MrcError { /// A specialized Result type for MRC operations. pub type Result = std::result::Result; +/// Connects to the MPV IPC socket with timeout. +async fn connect_to_socket(socket_path: &str) -> Result { + debug!("Connecting to socket at {}", socket_path); + + tokio::time::timeout( + std::time::Duration::from_secs(5), + UnixStream::connect(socket_path), + ) + .await + .map_err(|_| MrcError::SocketTimeout(5))? + .map_err(MrcError::ConnectionError) +} + +/// Sends a command message to the socket with timeout. +async fn send_message(socket: &mut UnixStream, command: &str, args: &[Value]) -> Result<()> { + let mut command_array = vec![json!(command)]; + command_array.extend_from_slice(args); + let message = json!({ "command": command_array }); + let message_str = format!("{}\n", serde_json::to_string(&message)?); + + debug!("Serialized message to send with newline: {}", message_str); + + // Write with timeout + tokio::time::timeout( + std::time::Duration::from_secs(5), + socket.write_all(message_str.as_bytes()), + ) + .await + .map_err(|_| MrcError::SocketTimeout(5))??; + + // Flush with timeout + tokio::time::timeout(std::time::Duration::from_secs(5), socket.flush()) + .await + .map_err(|_| MrcError::SocketTimeout(5))??; + + debug!("Message sent and flushed"); + Ok(()) +} + +/// Reads and parses the response from the socket. +async fn read_response(socket: &mut UnixStream) -> Result { + let mut response = vec![0; 1024]; + + // Read with timeout + let n = tokio::time::timeout( + std::time::Duration::from_secs(5), + socket.read(&mut response), + ) + .await + .map_err(|_| MrcError::SocketTimeout(5))??; + + if n == 0 { + return Err(MrcError::ConnectionLost( + "Socket closed unexpectedly".into(), + )); + } + + let response_str = String::from_utf8(response[..n].to_vec())?; + debug!("Raw response: {}", response_str); + + let json_response = + serde_json::from_str::(&response_str).map_err(MrcError::ParseError)?; + + debug!("Parsed IPC response: {:?}", json_response); + + // Check if MPV returned an error + if let Some(error) = json_response.get("error").and_then(|e| e.as_str()) { + if !error.is_empty() { + return Err(MrcError::MpvError(error.to_string())); + } + } + + Ok(json_response) +} + /// Sends a generic IPC command to the specified socket and returns the parsed response data. /// /// # Arguments @@ -123,71 +198,10 @@ pub async fn send_ipc_command( command, args ); - // Add timeout for connection - let stream = tokio::time::timeout( - std::time::Duration::from_secs(5), - UnixStream::connect(socket_path), - ) - .await - .map_err(|_| MrcError::SocketTimeout(5))? - .map_err(MrcError::ConnectionError)?; + let mut socket = connect_to_socket(socket_path).await?; + send_message(&mut socket, command, args).await?; + let json_response = read_response(&mut socket).await?; - let mut socket = stream; - debug!("Connected to socket at {}", socket_path); - - let mut command_array = vec![json!(command)]; - command_array.extend_from_slice(args); - let message = json!({ "command": command_array }); - let message_str = format!("{}\n", serde_json::to_string(&message)?); - debug!("Serialized message to send with newline: {}", message_str); - - // Write with timeout - tokio::time::timeout( - std::time::Duration::from_secs(5), - socket.write_all(message_str.as_bytes()), - ) - .await - .map_err(|_| MrcError::SocketTimeout(5))??; - - // Flush with timeout - tokio::time::timeout( - std::time::Duration::from_secs(5), - socket.flush(), - ) - .await - .map_err(|_| MrcError::SocketTimeout(5))??; - - debug!("Message sent and flushed"); - - let mut response = vec![0; 1024]; - - // Read with timeout - let n = tokio::time::timeout( - std::time::Duration::from_secs(5), - socket.read(&mut response), - ) - .await - .map_err(|_| MrcError::SocketTimeout(5))??; - - if n == 0 { - return Err(MrcError::ConnectionLost("Socket closed unexpectedly".into())); - } - - let response_str = String::from_utf8(response[..n].to_vec())?; - debug!("Raw response: {}", response_str); - - let json_response = serde_json::from_str::(&response_str) - .map_err(MrcError::ParseError)?; - - debug!("Parsed IPC response: {:?}", json_response); - - // Check if MPV returned an error - if let Some(error) = json_response.get("error").and_then(|e| e.as_str()) { - if !error.is_empty() { - return Err(MrcError::MpvError(error.to_string())); - } - } - Ok(json_response.get("data").cloned()) } diff --git a/src/server.rs b/src/server.rs index c307c47c..3c0dd01a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,7 +9,10 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_native_tls::TlsAcceptor; use tracing::{debug, error, info}; -use mrc::{get_property, playlist_clear, playlist_next, playlist_prev, quit, seek, set_property, MrcError, Result as MrcResult}; +use mrc::{ + MrcError, Result as MrcResult, get_property, playlist_clear, playlist_next, playlist_prev, + quit, seek, set_property, +}; #[derive(Parser)] #[command(author, version, about)] @@ -27,11 +30,15 @@ async fn handle_connection( stream: tokio::net::TcpStream, acceptor: Arc, ) -> MrcResult<()> { - let mut stream = acceptor.accept(stream).await + let mut stream = acceptor + .accept(stream) + .await .map_err(|e| MrcError::TlsError(e.to_string()))?; let mut buffer = vec![0; 2048]; - let n = stream.read(&mut buffer).await + let n = stream + .read(&mut buffer) + .await .map_err(MrcError::ConnectionError)?; let request = String::from_utf8_lossy(&buffer[..n]); @@ -143,15 +150,18 @@ async fn process_command(command: &str) -> MrcResult { info!("Listing playlist items"); match get_property("playlist", None).await { Ok(Some(data)) => { - let pretty_json = serde_json::to_string_pretty(&data) - .map_err(MrcError::ParseError)?; + let pretty_json = + serde_json::to_string_pretty(&data).map_err(MrcError::ParseError)?; Ok(format!("Playlist: {}", pretty_json)) - }, + } Ok(None) => Err(MrcError::PropertyNotFound("playlist".to_string())), Err(e) => Err(e), } } - _ => Err(MrcError::InvalidInput(format!("Unknown command: {}", command))), + _ => Err(MrcError::InvalidInput(format!( + "Unknown command: {}", + command + ))), } } @@ -161,16 +171,15 @@ fn create_tls_acceptor() -> MrcResult { let password = env::var("TLS_PASSWORD") .map_err(|_| MrcError::InvalidInput("TLS_PASSWORD not set".to_string()))?; - let mut file = std::fs::File::open(&pfx_path) - .map_err(MrcError::ConnectionError)?; + let mut file = std::fs::File::open(&pfx_path).map_err(MrcError::ConnectionError)?; let mut identity = vec![]; file.read_to_end(&mut identity) .map_err(MrcError::ConnectionError)?; let identity = Identity::from_pkcs12(&identity, &password) .map_err(|e| MrcError::TlsError(e.to_string()))?; - let native_acceptor = NativeTlsAcceptor::new(identity) - .map_err(|e| MrcError::TlsError(e.to_string()))?; + let native_acceptor = + NativeTlsAcceptor::new(identity).map_err(|e| MrcError::TlsError(e.to_string()))?; Ok(TlsAcceptor::from(native_acceptor)) } @@ -194,13 +203,13 @@ async fn main() -> MrcResult<()> { match create_tls_acceptor() { Ok(acceptor) => { let acceptor = Arc::new(acceptor); - let listener = tokio::net::TcpListener::bind(&config.bind).await + let listener = tokio::net::TcpListener::bind(&config.bind) + .await .map_err(MrcError::ConnectionError)?; info!("Server is listening on {}", config.bind); loop { - let (stream, _) = listener.accept().await - .map_err(MrcError::ConnectionError)?; + let (stream, _) = listener.accept().await.map_err(MrcError::ConnectionError)?; info!("New connection accepted."); let acceptor = Arc::clone(&acceptor);