diff --git a/crates/tranquil-pds/src/api/repo/import.rs b/crates/tranquil-pds/src/api/repo/import.rs index 6881360..498d9a5 100644 --- a/crates/tranquil-pds/src/api/repo/import.rs +++ b/crates/tranquil-pds/src/api/repo/import.rs @@ -190,6 +190,7 @@ pub async fn import_repo( .ok() .and_then(|s| s.parse().ok()) .unwrap_or(DEFAULT_MAX_BLOCKS); + let _write_lock = state.repo_write_locks.lock(user_id).await; match apply_import(&state.repo_repo, user_id, root, blocks.clone(), max_blocks).await { Ok(import_result) => { info!( diff --git a/crates/tranquil-pds/src/api/repo/record/batch.rs b/crates/tranquil-pds/src/api/repo/record/batch.rs index da71208..a2878b4 100644 --- a/crates/tranquil-pds/src/api/repo/record/batch.rs +++ b/crates/tranquil-pds/src/api/repo/record/batch.rs @@ -326,6 +326,9 @@ pub async fn apply_writes( .ok() .flatten() .ok_or_else(|| ApiError::InternalError(Some("User not found".into())))?; + + let _write_lock = state.repo_write_locks.lock(user_id).await; + let root_cid_str = state .repo_repo .get_repo_root_cid_by_user_id(user_id) diff --git a/crates/tranquil-pds/src/api/repo/record/delete.rs b/crates/tranquil-pds/src/api/repo/record/delete.rs index 032f537..c741975 100644 --- a/crates/tranquil-pds/src/api/repo/record/delete.rs +++ b/crates/tranquil-pds/src/api/repo/record/delete.rs @@ -1,5 +1,7 @@ use crate::api::error::ApiError; -use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; +use crate::api::repo::record::utils::{ + CommitParams, RecordOp, commit_and_log, get_current_root_cid, +}; use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; use crate::auth::{Active, Auth, VerifyScope}; use crate::cid_types::CommitCid; @@ -56,9 +58,11 @@ pub async fn delete_record( let did = repo_auth.did; let user_id = repo_auth.user_id; - let current_root_cid = repo_auth.current_root_cid; let controller_did = repo_auth.controller_did; + let _write_lock = state.repo_write_locks.lock(user_id).await; + let current_root_cid = get_current_root_cid(&state, user_id).await?; + if let Some(swap_commit) = &input.swap_commit && CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid) { @@ -238,6 +242,8 @@ pub async fn delete_record_internal( collection: &Nsid, rkey: &Rkey, ) -> Result<(), String> { + let _write_lock = state.repo_write_locks.lock(user_id).await; + let root_cid_str = state .repo_repo .get_repo_root_cid_by_user_id(user_id) diff --git a/crates/tranquil-pds/src/api/repo/record/utils.rs b/crates/tranquil-pds/src/api/repo/record/utils.rs index 87a6d72..2215c0f 100644 --- a/crates/tranquil-pds/src/api/repo/record/utils.rs +++ b/crates/tranquil-pds/src/api/repo/record/utils.rs @@ -1,3 +1,5 @@ +use crate::api::error::ApiError; +use crate::cid_types::CommitCid; use crate::state::AppState; use crate::types::{Did, Handle, Nsid, Rkey}; use bytes::Bytes; @@ -8,9 +10,24 @@ use jacquard_repo::storage::BlockStore; use k256::ecdsa::SigningKey; use serde_json::{Value, json}; use std::str::FromStr; +use tracing::error; use tranquil_db_traits::SequenceNumber; use uuid::Uuid; +pub async fn get_current_root_cid(state: &AppState, user_id: Uuid) -> Result { + let root_cid_str = state + .repo_repo + .get_repo_root_cid_by_user_id(user_id) + .await + .map_err(|e| { + error!("DB error fetching repo root: {}", e); + ApiError::InternalError(None) + })? + .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?; + CommitCid::from_str(&root_cid_str) + .map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into()))) +} + pub fn extract_blob_cids(record: &Value) -> Vec { let mut blobs = Vec::new(); extract_blob_cids_recursive(record, &mut blobs); @@ -328,6 +345,9 @@ pub async fn create_record_internal( .await .map_err(|e| format!("DB error: {}", e))? .ok_or_else(|| "User not found".to_string())?; + + let _write_lock = state.repo_write_locks.lock(user_id).await; + let root_cid_link = state .repo_repo .get_repo_root_cid_by_user_id(user_id) diff --git a/crates/tranquil-pds/src/api/repo/record/write.rs b/crates/tranquil-pds/src/api/repo/record/write.rs index f340ffe..d5e13c6 100644 --- a/crates/tranquil-pds/src/api/repo/record/write.rs +++ b/crates/tranquil-pds/src/api/repo/record/write.rs @@ -3,6 +3,7 @@ use super::validation_mode::{ValidationMode, deserialize_validation_mode}; use crate::api::error::ApiError; use crate::api::repo::record::utils::{ CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, + get_current_root_cid, }; use crate::auth::{ Active, Auth, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated, @@ -31,7 +32,6 @@ use uuid::Uuid; pub struct RepoWriteAuth { pub did: Did, pub user_id: Uuid, - pub current_root_cid: CommitCid, pub is_oauth: bool, pub scope: Option, pub controller_did: Option, @@ -62,24 +62,10 @@ pub async fn prepare_repo_write( ApiError::InternalError(None).into_response() })? .ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?; - let root_cid_str = state - .repo_repo - .get_repo_root_cid_by_user_id(user_id) - .await - .map_err(|e| { - error!("DB error fetching repo root: {}", e); - ApiError::InternalError(None).into_response() - })? - .ok_or_else(|| { - ApiError::InternalError(Some("Repo root not found".into())).into_response() - })?; - let current_root_cid = CommitCid::from_str(&root_cid_str).map_err(|_| { - ApiError::InternalError(Some("Invalid repo root CID".into())).into_response() - })?; + Ok(RepoWriteAuth { did: principal_did.into_did(), user_id, - current_root_cid, is_oauth: user.is_oauth(), scope: user.scope.clone(), controller_did: scope_proof.controller_did().map(|c| c.into_did()), @@ -130,9 +116,11 @@ pub async fn create_record( let did = repo_auth.did; let user_id = repo_auth.user_id; - let current_root_cid = repo_auth.current_root_cid; let controller_did = repo_auth.controller_did; + let _write_lock = state.repo_write_locks.lock(user_id).await; + let current_root_cid = get_current_root_cid(&state, user_id).await?; + if let Some(swap_commit) = &input.swap_commit && CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid) { @@ -433,9 +421,11 @@ pub async fn put_record( let did = repo_auth.did; let user_id = repo_auth.user_id; - let current_root_cid = repo_auth.current_root_cid; let controller_did = repo_auth.controller_did; + let _write_lock = state.repo_write_locks.lock(user_id).await; + let current_root_cid = get_current_root_cid(&state, user_id).await?; + if let Some(swap_commit) = &input.swap_commit && CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid) { diff --git a/crates/tranquil-pds/src/lib.rs b/crates/tranquil-pds/src/lib.rs index 78ef4f3..fcb9f1e 100644 --- a/crates/tranquil-pds/src/lib.rs +++ b/crates/tranquil-pds/src/lib.rs @@ -16,6 +16,7 @@ pub mod oauth; pub mod plc; pub mod rate_limit; pub mod repo; +pub mod repo_write_lock; pub mod scheduled; pub mod sso; pub mod state; diff --git a/crates/tranquil-pds/src/repo_write_lock.rs b/crates/tranquil-pds/src/repo_write_lock.rs new file mode 100644 index 0000000..7c77fa5 --- /dev/null +++ b/crates/tranquil-pds/src/repo_write_lock.rs @@ -0,0 +1,180 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{Mutex, OwnedMutexGuard, RwLock}; +use uuid::Uuid; + +const SWEEP_INTERVAL: Duration = Duration::from_secs(300); + +pub struct RepoWriteLocks { + locks: Arc>>>>, +} + +impl Default for RepoWriteLocks { + fn default() -> Self { + Self::new() + } +} + +impl RepoWriteLocks { + pub fn new() -> Self { + let locks = Arc::new(RwLock::new(HashMap::new())); + let sweep_locks = Arc::clone(&locks); + tokio::spawn(async move { + sweep_loop(sweep_locks).await; + }); + Self { locks } + } + + pub async fn lock(&self, user_id: Uuid) -> OwnedMutexGuard<()> { + let mutex = { + let read_guard = self.locks.read().await; + read_guard.get(&user_id).cloned() + }; + + match mutex { + Some(m) => m.lock_owned().await, + None => { + let mut write_guard = self.locks.write().await; + let mutex = write_guard + .entry(user_id) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + drop(write_guard); + mutex.lock_owned().await + } + } + } +} + +async fn sweep_loop(locks: Arc>>>>) { + tokio::time::sleep(SWEEP_INTERVAL).await; + let mut write_guard = locks.write().await; + let before = write_guard.len(); + write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1); + let evicted = before - write_guard.len(); + if evicted > 0 { + tracing::debug!( + evicted, + remaining = write_guard.len(), + "repo write lock sweep" + ); + } + drop(write_guard); + Box::pin(sweep_loop(locks)).await; +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::time::Duration; + + #[tokio::test] + async fn test_locks_serialize_same_user() { + let locks = Arc::new(RepoWriteLocks::new()); + let user_id = Uuid::new_v4(); + let counter = Arc::new(AtomicU32::new(0)); + let max_concurrent = Arc::new(AtomicU32::new(0)); + + let handles: Vec<_> = (0..10) + .map(|_| { + let locks = locks.clone(); + let counter = counter.clone(); + let max_concurrent = max_concurrent.clone(); + + tokio::spawn(async move { + let _guard = locks.lock(user_id).await; + let current = counter.fetch_add(1, Ordering::SeqCst) + 1; + max_concurrent.fetch_max(current, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(1)).await; + counter.fetch_sub(1, Ordering::SeqCst); + }) + }) + .collect(); + + futures::future::join_all(handles).await; + + assert_eq!( + max_concurrent.load(Ordering::SeqCst), + 1, + "Only one task should hold the lock at a time for same user" + ); + } + + #[tokio::test] + async fn test_different_users_can_run_concurrently() { + let locks = Arc::new(RepoWriteLocks::new()); + let user1 = Uuid::new_v4(); + let user2 = Uuid::new_v4(); + let concurrent_count = Arc::new(AtomicU32::new(0)); + let max_concurrent = Arc::new(AtomicU32::new(0)); + + let locks1 = locks.clone(); + let count1 = concurrent_count.clone(); + let max1 = max_concurrent.clone(); + let handle1 = tokio::spawn(async move { + let _guard = locks1.lock(user1).await; + let current = count1.fetch_add(1, Ordering::SeqCst) + 1; + max1.fetch_max(current, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(50)).await; + count1.fetch_sub(1, Ordering::SeqCst); + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + + let locks2 = locks.clone(); + let count2 = concurrent_count.clone(); + let max2 = max_concurrent.clone(); + let handle2 = tokio::spawn(async move { + let _guard = locks2.lock(user2).await; + let current = count2.fetch_add(1, Ordering::SeqCst) + 1; + max2.fetch_max(current, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(50)).await; + count2.fetch_sub(1, Ordering::SeqCst); + }); + + handle1.await.unwrap(); + handle2.await.unwrap(); + + assert_eq!( + max_concurrent.load(Ordering::SeqCst), + 2, + "Different users should be able to run concurrently" + ); + } + + #[tokio::test] + async fn test_sweep_evicts_idle_entries() { + let locks = Arc::new(RwLock::new(HashMap::new())); + let user_id = Uuid::new_v4(); + + { + let mut write_guard = locks.write().await; + write_guard.insert(user_id, Arc::new(Mutex::new(()))); + } + + assert_eq!(locks.read().await.len(), 1); + + let mut write_guard = locks.write().await; + write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1); + assert_eq!(write_guard.len(), 0, "Idle entry should be evicted"); + } + + #[tokio::test] + async fn test_sweep_preserves_active_entries() { + let locks = Arc::new(RwLock::new(HashMap::new())); + let user_id = Uuid::new_v4(); + let active_mutex = Arc::new(Mutex::new(())); + let _held_ref = active_mutex.clone(); + + { + let mut write_guard = locks.write().await; + write_guard.insert(user_id, active_mutex); + } + + let mut write_guard = locks.write().await; + write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1); + assert_eq!(write_guard.len(), 1, "Active entry should be preserved"); + } +} diff --git a/crates/tranquil-pds/src/state.rs b/crates/tranquil-pds/src/state.rs index 36fc5b9..af7bab0 100644 --- a/crates/tranquil-pds/src/state.rs +++ b/crates/tranquil-pds/src/state.rs @@ -5,6 +5,7 @@ use crate::circuit_breaker::CircuitBreakers; use crate::config::AuthConfig; use crate::rate_limit::RateLimiters; use crate::repo::PostgresBlockStore; +use crate::repo_write_lock::RepoWriteLocks; use crate::sso::{SsoConfig, SsoManager}; use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage}; use crate::sync::firehose::SequencedEvent; @@ -38,6 +39,7 @@ pub struct AppState { pub backup_storage: Option>, pub firehose_tx: broadcast::Sender, pub rate_limiters: Arc, + pub repo_write_locks: Arc, pub circuit_breakers: Arc, pub cache: Arc, pub distributed_rate_limiter: Arc, @@ -181,6 +183,7 @@ impl AppState { let (firehose_tx, _) = broadcast::channel(firehose_buffer_size); let rate_limiters = Arc::new(RateLimiters::new()); + let repo_write_locks = Arc::new(RepoWriteLocks::new()); let circuit_breakers = Arc::new(CircuitBreakers::new()); let (cache, distributed_rate_limiter) = create_cache().await; let did_resolver = Arc::new(DidResolver::new()); @@ -209,6 +212,7 @@ impl AppState { backup_storage, firehose_tx, rate_limiters, + repo_write_locks, circuit_breakers, cache, distributed_rate_limiter,