mirror of
https://github.com/NotAShelf/stash.git
synced 2026-04-12 22:17:41 +00:00
stash: async db operations; make hashes deterministic
Signed-off-by: NotAShelf <raf@notashelf.dev> Change-Id: Iccc9980fa13a752e0e6c9fb630c28ba96a6a6964
This commit is contained in:
parent
7184c8b682
commit
95bf1766ce
6 changed files with 815 additions and 362 deletions
541
Cargo.lock
generated
541
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -15,6 +15,7 @@ path = "src/main.rs"
|
|||
|
||||
[dependencies]
|
||||
base64 = "0.22.1"
|
||||
blocking = "1.6.2"
|
||||
clap = { version = "4.5.60", features = [ "derive", "env" ] }
|
||||
clap-verbosity-flag = "3.0.4"
|
||||
color-eyre = "0.6.5"
|
||||
|
|
@ -43,6 +44,7 @@ wayland-protocols-wlr = { version = "0.3.10", default-features = false, optional
|
|||
wl-clipboard-rs = "0.9.3"
|
||||
|
||||
[dev-dependencies]
|
||||
futures = "0.3.32"
|
||||
tempfile = "3.26.0"
|
||||
|
||||
[features]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,32 @@
|
|||
use std::{
|
||||
collections::{BinaryHeap, hash_map::DefaultHasher},
|
||||
hash::{Hash, Hasher},
|
||||
io::Read,
|
||||
time::Duration,
|
||||
};
|
||||
use std::{collections::BinaryHeap, io::Read, time::Duration};
|
||||
|
||||
/// FNV-1a hasher for deterministic hashing across process runs.
|
||||
/// Unlike DefaultHasher (SipHash), this produces stable hashes.
|
||||
struct Fnv1aHasher {
|
||||
state: u64,
|
||||
}
|
||||
|
||||
impl Fnv1aHasher {
|
||||
const FNV_OFFSET: u64 = 0xCBF29CE484222325;
|
||||
const FNV_PRIME: u64 = 0x100000001B3;
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
state: Self::FNV_OFFSET,
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&mut self, bytes: &[u8]) {
|
||||
for byte in bytes {
|
||||
self.state ^= *byte as u64;
|
||||
self.state = self.state.wrapping_mul(Self::FNV_PRIME);
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&self) -> u64 {
|
||||
self.state
|
||||
}
|
||||
}
|
||||
|
||||
use smol::Timer;
|
||||
use wl_clipboard_rs::{
|
||||
|
|
@ -17,7 +40,7 @@ use wl_clipboard_rs::{
|
|||
},
|
||||
};
|
||||
|
||||
use crate::db::{ClipboardDb, SqliteClipboardDb};
|
||||
use crate::db::{SqliteClipboardDb, nonblocking::AsyncClipboardDb};
|
||||
|
||||
/// Wrapper to provide [`Ord`] implementation for `f64` by negating values.
|
||||
/// This allows [`BinaryHeap`], which is a max-heap, to function as a min-heap.
|
||||
|
|
@ -97,6 +120,16 @@ impl ExpirationQueue {
|
|||
}
|
||||
expired
|
||||
}
|
||||
|
||||
/// Check if the queue is empty
|
||||
fn is_empty(&self) -> bool {
|
||||
self.heap.is_empty()
|
||||
}
|
||||
|
||||
/// Get the number of entries in the queue
|
||||
fn len(&self) -> usize {
|
||||
self.heap.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get clipboard contents using the source application's preferred MIME type.
|
||||
|
|
@ -177,7 +210,7 @@ fn negotiate_mime_type(
|
|||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait WatchCommand {
|
||||
fn watch(
|
||||
async fn watch(
|
||||
&self,
|
||||
max_dedupe_search: u64,
|
||||
max_items: u64,
|
||||
|
|
@ -190,7 +223,7 @@ pub trait WatchCommand {
|
|||
}
|
||||
|
||||
impl WatchCommand for SqliteClipboardDb {
|
||||
fn watch(
|
||||
async fn watch(
|
||||
&self,
|
||||
max_dedupe_search: u64,
|
||||
max_items: u64,
|
||||
|
|
@ -200,207 +233,203 @@ impl WatchCommand for SqliteClipboardDb {
|
|||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
) {
|
||||
smol::block_on(async {
|
||||
log::info!(
|
||||
"Starting clipboard watch daemon with MIME type preference: \
|
||||
{mime_type_preference}"
|
||||
);
|
||||
let async_db = AsyncClipboardDb::new(self.db_path.clone());
|
||||
log::info!(
|
||||
"Starting clipboard watch daemon with MIME type preference: \
|
||||
{mime_type_preference}"
|
||||
);
|
||||
|
||||
// Build expiration queue from existing entries
|
||||
let mut exp_queue = ExpirationQueue::new();
|
||||
if let Ok(Some((expires_at, id))) = self.get_next_expiration() {
|
||||
exp_queue.push(expires_at, id);
|
||||
// Load remaining expirations (exclude already-marked expired entries)
|
||||
let mut stmt = self
|
||||
.conn
|
||||
.prepare(
|
||||
"SELECT expires_at, id FROM clipboard WHERE expires_at IS NOT \
|
||||
NULL AND (is_expired IS NULL OR is_expired = 0) ORDER BY \
|
||||
expires_at ASC",
|
||||
)
|
||||
.ok();
|
||||
if let Some(ref mut stmt) = stmt {
|
||||
let mut rows = stmt.query([]).ok();
|
||||
if let Some(ref mut rows) = rows {
|
||||
while let Ok(Some(row)) = rows.next() {
|
||||
if let (Ok(exp), Ok(row_id)) =
|
||||
(row.get::<_, f64>(0), row.get::<_, i64>(1))
|
||||
{
|
||||
// Skip first entry which is already added
|
||||
if exp_queue
|
||||
.heap
|
||||
.iter()
|
||||
.any(|(_, existing_id)| *existing_id == row_id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
exp_queue.push(exp, row_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Build expiration queue from existing entries
|
||||
let mut exp_queue = ExpirationQueue::new();
|
||||
|
||||
// Load all expirations from database asynchronously
|
||||
match async_db.load_all_expirations().await {
|
||||
Ok(expirations) => {
|
||||
for (expires_at, id) in expirations {
|
||||
exp_queue.push(expires_at, id);
|
||||
}
|
||||
}
|
||||
|
||||
// We use hashes for comparison instead of storing full contents
|
||||
let mut last_hash: Option<u64> = None;
|
||||
let mut buf = Vec::with_capacity(4096);
|
||||
|
||||
// Helper to hash clipboard contents
|
||||
let hash_contents = |data: &[u8]| -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
data.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
};
|
||||
|
||||
// Initialize with current clipboard using smart MIME negotiation
|
||||
if let Ok((mut reader, _)) = negotiate_mime_type(mime_type_preference) {
|
||||
buf.clear();
|
||||
if reader.read_to_end(&mut buf).is_ok() && !buf.is_empty() {
|
||||
last_hash = Some(hash_contents(&buf));
|
||||
if !exp_queue.is_empty() {
|
||||
log::info!("Loaded {} expirations from database", exp_queue.len());
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
log::warn!("Failed to load expirations: {e}");
|
||||
},
|
||||
}
|
||||
|
||||
// We use hashes for comparison instead of storing full contents
|
||||
let mut last_hash: Option<u64> = None;
|
||||
let mut buf = Vec::with_capacity(4096);
|
||||
|
||||
// Helper to hash clipboard contents using FNV-1a (deterministic across
|
||||
// runs)
|
||||
let hash_contents = |data: &[u8]| -> u64 {
|
||||
let mut hasher = Fnv1aHasher::new();
|
||||
hasher.write(data);
|
||||
hasher.finish()
|
||||
};
|
||||
|
||||
// Initialize with current clipboard using smart MIME negotiation
|
||||
if let Ok((mut reader, _)) = negotiate_mime_type(mime_type_preference) {
|
||||
buf.clear();
|
||||
if reader.read_to_end(&mut buf).is_ok() && !buf.is_empty() {
|
||||
last_hash = Some(hash_contents(&buf));
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
// Process any pending expirations
|
||||
if let Some(next_exp) = exp_queue.peek_next() {
|
||||
let now = SqliteClipboardDb::now();
|
||||
if next_exp <= now {
|
||||
// Expired entries to process
|
||||
let expired_ids = exp_queue.pop_expired(now);
|
||||
for id in expired_ids {
|
||||
// Verify entry still exists and get its content_hash
|
||||
let expired_hash: Option<i64> = self
|
||||
.conn
|
||||
.query_row(
|
||||
"SELECT content_hash FROM clipboard WHERE id = ?1",
|
||||
[id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.ok();
|
||||
let poll_interval = Duration::from_millis(500);
|
||||
|
||||
if let Some(stored_hash) = expired_hash {
|
||||
// Mark as expired
|
||||
self
|
||||
.conn
|
||||
.execute(
|
||||
"UPDATE clipboard SET is_expired = 1 WHERE id = ?1",
|
||||
[id],
|
||||
)
|
||||
.ok();
|
||||
loop {
|
||||
// Process any pending expirations that are due now
|
||||
if let Some(next_exp) = exp_queue.peek_next() {
|
||||
let now = SqliteClipboardDb::now();
|
||||
if next_exp <= now {
|
||||
// Expired entries to process
|
||||
let expired_ids = exp_queue.pop_expired(now);
|
||||
for id in expired_ids {
|
||||
// Verify entry still exists and get its content_hash
|
||||
let expired_hash: Option<i64> =
|
||||
match async_db.get_content_hash(id).await {
|
||||
Ok(hash) => hash,
|
||||
Err(e) => {
|
||||
log::warn!("Failed to get content hash for entry {id}: {e}");
|
||||
None
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(stored_hash) = expired_hash {
|
||||
// Mark as expired
|
||||
if let Err(e) = async_db.mark_expired(id).await {
|
||||
log::warn!("Failed to mark entry {id} as expired: {e}");
|
||||
} else {
|
||||
log::info!("Entry {id} marked as expired");
|
||||
}
|
||||
|
||||
// Check if this expired entry is currently in the clipboard
|
||||
if let Ok((mut reader, _)) =
|
||||
negotiate_mime_type(mime_type_preference)
|
||||
// Check if this expired entry is currently in the clipboard
|
||||
if let Ok((mut reader, _)) =
|
||||
negotiate_mime_type(mime_type_preference)
|
||||
{
|
||||
let mut current_buf = Vec::new();
|
||||
if reader.read_to_end(&mut current_buf).is_ok()
|
||||
&& !current_buf.is_empty()
|
||||
{
|
||||
let mut current_buf = Vec::new();
|
||||
if reader.read_to_end(&mut current_buf).is_ok()
|
||||
&& !current_buf.is_empty()
|
||||
{
|
||||
let current_hash = hash_contents(¤t_buf);
|
||||
// Compare as i64 (database stores as i64)
|
||||
if current_hash as i64 == stored_hash {
|
||||
// Clear the clipboard since expired content is still
|
||||
// there
|
||||
let mut opts = Options::new();
|
||||
opts.clipboard(
|
||||
wl_clipboard_rs::copy::ClipboardType::Regular,
|
||||
let current_hash = hash_contents(¤t_buf);
|
||||
// Convert stored i64 to u64 for comparison (preserves bit
|
||||
// pattern)
|
||||
if current_hash == stored_hash as u64 {
|
||||
// Clear the clipboard since expired content is still
|
||||
// there
|
||||
let mut opts = Options::new();
|
||||
opts
|
||||
.clipboard(wl_clipboard_rs::copy::ClipboardType::Regular);
|
||||
if opts
|
||||
.copy(
|
||||
Source::Bytes(Vec::new().into()),
|
||||
CopyMimeType::Autodetect,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
log::info!(
|
||||
"Cleared clipboard containing expired entry {id}"
|
||||
);
|
||||
last_hash = None; // reset tracked hash
|
||||
} else {
|
||||
log::warn!(
|
||||
"Failed to clear clipboard for expired entry {id}"
|
||||
);
|
||||
if opts
|
||||
.copy(
|
||||
Source::Bytes(Vec::new().into()),
|
||||
CopyMimeType::Autodetect,
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
log::info!(
|
||||
"Cleared clipboard containing expired entry {id}"
|
||||
);
|
||||
last_hash = None; // reset tracked hash
|
||||
} else {
|
||||
log::warn!(
|
||||
"Failed to clear clipboard for expired entry {id}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sleep *precisely* until next expiration
|
||||
let sleep_duration = next_exp - now;
|
||||
Timer::after(Duration::from_secs_f64(sleep_duration)).await;
|
||||
continue; // skip normal poll, process expirations first
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normal clipboard polling
|
||||
match negotiate_mime_type(mime_type_preference) {
|
||||
Ok((mut reader, _mime_type)) => {
|
||||
buf.clear();
|
||||
if let Err(e) = reader.read_to_end(&mut buf) {
|
||||
log::error!("Failed to read clipboard contents: {e}");
|
||||
Timer::after(Duration::from_millis(500)).await;
|
||||
continue;
|
||||
}
|
||||
// Normal clipboard polling (always run, even when expirations are
|
||||
// pending)
|
||||
match negotiate_mime_type(mime_type_preference) {
|
||||
Ok((mut reader, _mime_type)) => {
|
||||
buf.clear();
|
||||
if let Err(e) = reader.read_to_end(&mut buf) {
|
||||
log::error!("Failed to read clipboard contents: {e}");
|
||||
Timer::after(Duration::from_millis(500)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only store if changed and not empty
|
||||
if !buf.is_empty() {
|
||||
let current_hash = hash_contents(&buf);
|
||||
if last_hash != Some(current_hash) {
|
||||
match self.store_entry(
|
||||
&buf[..],
|
||||
// Only store if changed and not empty
|
||||
if !buf.is_empty() {
|
||||
let current_hash = hash_contents(&buf);
|
||||
if last_hash != Some(current_hash) {
|
||||
// Clone buf for the async operation since it needs 'static
|
||||
let buf_clone = buf.clone();
|
||||
match async_db
|
||||
.store_entry(
|
||||
buf_clone,
|
||||
max_dedupe_search,
|
||||
max_items,
|
||||
Some(excluded_apps),
|
||||
Some(excluded_apps.to_vec()),
|
||||
min_size,
|
||||
max_size,
|
||||
) {
|
||||
Ok(id) => {
|
||||
log::info!("Stored new clipboard entry (id: {id})");
|
||||
last_hash = Some(current_hash);
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(id) => {
|
||||
log::info!("Stored new clipboard entry (id: {id})");
|
||||
last_hash = Some(current_hash);
|
||||
|
||||
// Set expiration if configured
|
||||
if let Some(duration) = expire_after {
|
||||
let expires_at =
|
||||
SqliteClipboardDb::now() + duration.as_secs_f64();
|
||||
self.set_expiration(id, expires_at).ok();
|
||||
// Set expiration if configured
|
||||
if let Some(duration) = expire_after {
|
||||
let expires_at =
|
||||
SqliteClipboardDb::now() + duration.as_secs_f64();
|
||||
if let Err(e) =
|
||||
async_db.set_expiration(id, expires_at).await
|
||||
{
|
||||
log::warn!(
|
||||
"Failed to set expiration for entry {id}: {e}"
|
||||
);
|
||||
} else {
|
||||
exp_queue.push(expires_at, id);
|
||||
}
|
||||
},
|
||||
Err(crate::db::StashError::ExcludedByApp(_)) => {
|
||||
log::info!("Clipboard entry excluded by app filter");
|
||||
last_hash = Some(current_hash);
|
||||
},
|
||||
Err(crate::db::StashError::Store(ref msg))
|
||||
if msg.contains("Excluded by app filter") =>
|
||||
{
|
||||
log::info!("Clipboard entry excluded by app filter");
|
||||
last_hash = Some(current_hash);
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to store clipboard entry: {e}");
|
||||
last_hash = Some(current_hash);
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(crate::db::StashError::ExcludedByApp(_)) => {
|
||||
log::info!("Clipboard entry excluded by app filter");
|
||||
last_hash = Some(current_hash);
|
||||
},
|
||||
Err(crate::db::StashError::Store(ref msg))
|
||||
if msg.contains("Excluded by app filter") =>
|
||||
{
|
||||
log::info!("Clipboard entry excluded by app filter");
|
||||
last_hash = Some(current_hash);
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to store clipboard entry: {e}");
|
||||
last_hash = Some(current_hash);
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
let error_msg = e.to_string();
|
||||
if !error_msg.contains("empty") {
|
||||
log::error!("Failed to get clipboard contents: {e}");
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Normal poll interval (only if no expirations pending)
|
||||
if exp_queue.peek_next().is_none() {
|
||||
Timer::after(Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
let error_msg = e.to_string();
|
||||
if !error_msg.contains("empty") {
|
||||
log::error!("Failed to get clipboard contents: {e}");
|
||||
}
|
||||
},
|
||||
}
|
||||
});
|
||||
|
||||
// Calculate sleep time: min of poll interval and time until next
|
||||
// expiration
|
||||
let sleep_duration = if let Some(next_exp) = exp_queue.peek_next() {
|
||||
let now = SqliteClipboardDb::now();
|
||||
let time_to_exp = (next_exp - now).max(0.0);
|
||||
poll_interval.min(Duration::from_secs_f64(time_to_exp))
|
||||
} else {
|
||||
poll_interval
|
||||
};
|
||||
Timer::after(sleep_duration).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,44 @@
|
|||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
env,
|
||||
fmt,
|
||||
fs,
|
||||
hash::{Hash, Hasher},
|
||||
io::{BufRead, BufReader, Read, Write},
|
||||
path::PathBuf,
|
||||
str,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
pub mod nonblocking;
|
||||
|
||||
/// FNV-1a hasher for deterministic hashing across process runs.
|
||||
/// Unlike DefaultHasher (SipHash with random seed), this produces stable
|
||||
/// hashes.
|
||||
pub struct Fnv1aHasher {
|
||||
state: u64,
|
||||
}
|
||||
|
||||
impl Fnv1aHasher {
|
||||
const FNV_OFFSET: u64 = 0xCBF29CE484222325;
|
||||
const FNV_PRIME: u64 = 0x100000001B3;
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Self::FNV_OFFSET,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&mut self, bytes: &[u8]) {
|
||||
for byte in bytes {
|
||||
self.state ^= *byte as u64;
|
||||
self.state = self.state.wrapping_mul(Self::FNV_PRIME);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn finish(&self) -> u64 {
|
||||
self.state
|
||||
}
|
||||
}
|
||||
|
||||
use base64::prelude::*;
|
||||
use log::{debug, error, info, warn};
|
||||
use mime_sniffer::MimeTypeSniffer;
|
||||
|
|
@ -210,11 +240,15 @@ impl fmt::Display for Entry {
|
|||
}
|
||||
|
||||
pub struct SqliteClipboardDb {
|
||||
pub conn: Connection,
|
||||
pub conn: Connection,
|
||||
pub db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl SqliteClipboardDb {
|
||||
pub fn new(mut conn: Connection) -> Result<Self, StashError> {
|
||||
pub fn new(
|
||||
mut conn: Connection,
|
||||
db_path: PathBuf,
|
||||
) -> Result<Self, StashError> {
|
||||
conn
|
||||
.pragma_update(None, "synchronous", "OFF")
|
||||
.map_err(|e| {
|
||||
|
|
@ -449,7 +483,7 @@ impl SqliteClipboardDb {
|
|||
// focused window state.
|
||||
#[cfg(feature = "use-toplevel")]
|
||||
crate::wayland::init_wayland_state();
|
||||
Ok(Self { conn })
|
||||
Ok(Self { conn, db_path })
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -535,8 +569,8 @@ impl ClipboardDb for SqliteClipboardDb {
|
|||
}
|
||||
|
||||
// Calculate content hash for deduplication
|
||||
let mut hasher = DefaultHasher::new();
|
||||
buf.hash(&mut hasher);
|
||||
let mut hasher = Fnv1aHasher::new();
|
||||
hasher.write(&buf);
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
let content_hash = hasher.finish() as i64;
|
||||
|
||||
|
|
@ -940,20 +974,6 @@ impl SqliteClipboardDb {
|
|||
.map_err(|e| StashError::Trim(e.to_string().into()))
|
||||
}
|
||||
|
||||
/// Get the earliest expiration (timestamp, id) for heap initialization
|
||||
pub fn get_next_expiration(&self) -> Result<Option<(f64, i64)>, StashError> {
|
||||
match self.conn.query_row(
|
||||
"SELECT expires_at, id FROM clipboard WHERE expires_at IS NOT NULL \
|
||||
ORDER BY expires_at ASC LIMIT 1",
|
||||
[],
|
||||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||||
) {
|
||||
Ok(result) => Ok(Some(result)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(StashError::Store(e.to_string().into())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set expiration timestamp for an entry
|
||||
pub fn set_expiration(
|
||||
&self,
|
||||
|
|
@ -1338,7 +1358,8 @@ mod tests {
|
|||
fn test_db() -> SqliteClipboardDb {
|
||||
let conn =
|
||||
Connection::open_in_memory().expect("Failed to open in-memory db");
|
||||
SqliteClipboardDb::new(conn).expect("Failed to create test database")
|
||||
SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create test database")
|
||||
}
|
||||
|
||||
fn get_schema_version(conn: &Connection) -> rusqlite::Result<i64> {
|
||||
|
|
@ -1369,7 +1390,8 @@ mod tests {
|
|||
let db_path = temp_dir.path().join("test_fresh.db");
|
||||
let conn = Connection::open(&db_path).expect("Failed to open database");
|
||||
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
|
||||
assert_eq!(
|
||||
get_schema_version(&db.conn).expect("Failed to get schema version"),
|
||||
|
|
@ -1419,7 +1441,8 @@ mod tests {
|
|||
|
||||
assert_eq!(get_schema_version(&conn).expect("Failed to get version"), 0);
|
||||
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
|
||||
assert_eq!(
|
||||
get_schema_version(&db.conn)
|
||||
|
|
@ -1461,7 +1484,8 @@ mod tests {
|
|||
)
|
||||
.expect("Failed to insert data");
|
||||
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
|
||||
assert_eq!(
|
||||
get_schema_version(&db.conn)
|
||||
|
|
@ -1504,7 +1528,8 @@ mod tests {
|
|||
)
|
||||
.expect("Failed to insert data");
|
||||
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
|
||||
assert_eq!(
|
||||
get_schema_version(&db.conn)
|
||||
|
|
@ -1535,12 +1560,13 @@ mod tests {
|
|||
)
|
||||
.expect("Failed to create table");
|
||||
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
let version_after_first =
|
||||
get_schema_version(&db.conn).expect("Failed to get version");
|
||||
|
||||
let db2 =
|
||||
SqliteClipboardDb::new(db.conn).expect("Failed to create database again");
|
||||
let db2 = SqliteClipboardDb::new(db.conn, db.db_path)
|
||||
.expect("Failed to create database again");
|
||||
let version_after_second =
|
||||
get_schema_version(&db2.conn).expect("Failed to get version");
|
||||
|
||||
|
|
@ -1553,7 +1579,8 @@ mod tests {
|
|||
let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
|
||||
let db_path = temp_dir.path().join("test_store.db");
|
||||
let conn = Connection::open(&db_path).expect("Failed to open database");
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
|
||||
let test_data = b"Hello, World!";
|
||||
let cursor = std::io::Cursor::new(test_data.to_vec());
|
||||
|
|
@ -1589,7 +1616,8 @@ mod tests {
|
|||
let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
|
||||
let db_path = temp_dir.path().join("test_copy.db");
|
||||
let conn = Connection::open(&db_path).expect("Failed to open database");
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
|
||||
let test_data = b"Test content for copy";
|
||||
let cursor = std::io::Cursor::new(test_data.to_vec());
|
||||
|
|
@ -1608,8 +1636,8 @@ mod tests {
|
|||
|
||||
std::thread::sleep(std::time::Duration::from_millis(1100));
|
||||
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
test_data.hash(&mut hasher);
|
||||
let mut hasher = Fnv1aHasher::new();
|
||||
hasher.write(test_data);
|
||||
let content_hash = hasher.finish() as i64;
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
|
|
@ -1670,7 +1698,8 @@ mod tests {
|
|||
)
|
||||
.expect("Failed to insert data");
|
||||
|
||||
let db = SqliteClipboardDb::new(conn).expect("Failed to create database");
|
||||
let db = SqliteClipboardDb::new(conn, PathBuf::from(":memory:"))
|
||||
.expect("Failed to create database");
|
||||
|
||||
assert_eq!(
|
||||
get_schema_version(&db.conn).expect("Failed to get version"),
|
||||
|
|
|
|||
141
src/db/nonblocking.rs
Normal file
141
src/db/nonblocking.rs
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use rusqlite::OptionalExtension;
|
||||
|
||||
use crate::db::{ClipboardDb, SqliteClipboardDb, StashError};
|
||||
|
||||
/// Async wrapper for database operations that runs blocking operations
|
||||
/// on a thread pool to avoid blocking the async runtime.
|
||||
///
|
||||
/// Since rusqlite::Connection is not Send, we store the database path
|
||||
/// and open a new connection for each operation.
|
||||
pub struct AsyncClipboardDb {
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl AsyncClipboardDb {
|
||||
pub fn new(db_path: PathBuf) -> Self {
|
||||
Self { db_path }
|
||||
}
|
||||
|
||||
pub async fn store_entry(
|
||||
&self,
|
||||
data: Vec<u8>,
|
||||
max_dedupe_search: u64,
|
||||
max_items: u64,
|
||||
excluded_apps: Option<Vec<String>>,
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
) -> Result<i64, StashError> {
|
||||
let path = self.db_path.clone();
|
||||
blocking::unblock(move || {
|
||||
let db = Self::open_db_internal(&path)?;
|
||||
db.store_entry(
|
||||
std::io::Cursor::new(data),
|
||||
max_dedupe_search,
|
||||
max_items,
|
||||
excluded_apps.as_deref(),
|
||||
min_size,
|
||||
max_size,
|
||||
)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn set_expiration(
|
||||
&self,
|
||||
id: i64,
|
||||
expires_at: f64,
|
||||
) -> Result<(), StashError> {
|
||||
let path = self.db_path.clone();
|
||||
blocking::unblock(move || {
|
||||
let db = Self::open_db_internal(&path)?;
|
||||
db.set_expiration(id, expires_at)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn load_all_expirations(
|
||||
&self,
|
||||
) -> Result<Vec<(f64, i64)>, StashError> {
|
||||
let path = self.db_path.clone();
|
||||
blocking::unblock(move || {
|
||||
let db = Self::open_db_internal(&path)?;
|
||||
let mut stmt = db
|
||||
.conn
|
||||
.prepare(
|
||||
"SELECT expires_at, id FROM clipboard WHERE expires_at IS NOT NULL \
|
||||
AND (is_expired IS NULL OR is_expired = 0) ORDER BY expires_at ASC",
|
||||
)
|
||||
.map_err(|e| StashError::ListDecode(e.to_string().into()))?;
|
||||
|
||||
let mut rows = stmt
|
||||
.query([])
|
||||
.map_err(|e| StashError::ListDecode(e.to_string().into()))?;
|
||||
let mut expirations = Vec::new();
|
||||
|
||||
while let Some(row) = rows
|
||||
.next()
|
||||
.map_err(|e| StashError::ListDecode(e.to_string().into()))?
|
||||
{
|
||||
let exp = row
|
||||
.get::<_, f64>(0)
|
||||
.map_err(|e| StashError::ListDecode(e.to_string().into()))?;
|
||||
let id = row
|
||||
.get::<_, i64>(1)
|
||||
.map_err(|e| StashError::ListDecode(e.to_string().into()))?;
|
||||
expirations.push((exp, id));
|
||||
}
|
||||
Ok(expirations)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_content_hash(
|
||||
&self,
|
||||
id: i64,
|
||||
) -> Result<Option<i64>, StashError> {
|
||||
let path = self.db_path.clone();
|
||||
blocking::unblock(move || {
|
||||
let db = Self::open_db_internal(&path)?;
|
||||
let result: Option<i64> = db
|
||||
.conn
|
||||
.query_row(
|
||||
"SELECT content_hash FROM clipboard WHERE id = ?1",
|
||||
[id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| StashError::ListDecode(e.to_string().into()))?;
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn mark_expired(&self, id: i64) -> Result<(), StashError> {
|
||||
let path = self.db_path.clone();
|
||||
blocking::unblock(move || {
|
||||
let db = Self::open_db_internal(&path)?;
|
||||
db.conn
|
||||
.execute("UPDATE clipboard SET is_expired = 1 WHERE id = ?1", [id])
|
||||
.map_err(|e| StashError::Store(e.to_string().into()))?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
fn open_db_internal(path: &PathBuf) -> Result<SqliteClipboardDb, StashError> {
|
||||
let conn = rusqlite::Connection::open(path).map_err(|e| {
|
||||
StashError::Store(format!("Failed to open database: {e}").into())
|
||||
})?;
|
||||
SqliteClipboardDb::new(conn, path.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AsyncClipboardDb {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
db_path: self.db_path.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -228,7 +228,7 @@ fn main() -> color_eyre::eyre::Result<()> {
|
|||
}
|
||||
|
||||
let conn = rusqlite::Connection::open(&db_path)?;
|
||||
let db = db::SqliteClipboardDb::new(conn)?;
|
||||
let db = db::SqliteClipboardDb::new(conn, db_path)?;
|
||||
|
||||
match cli.command {
|
||||
Some(Command::Store) => {
|
||||
|
|
@ -476,7 +476,8 @@ fn main() -> color_eyre::eyre::Result<()> {
|
|||
&mime_type,
|
||||
cli.min_size,
|
||||
cli.max_size,
|
||||
);
|
||||
)
|
||||
.await;
|
||||
},
|
||||
|
||||
None => {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue