mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-08 21:30:08 +00:00
fix: concurrent perf improvement
This commit is contained in:
@@ -190,6 +190,7 @@ pub async fn import_repo(
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(DEFAULT_MAX_BLOCKS);
|
||||
let _write_lock = state.repo_write_locks.lock(user_id).await;
|
||||
match apply_import(&state.repo_repo, user_id, root, blocks.clone(), max_blocks).await {
|
||||
Ok(import_result) => {
|
||||
info!(
|
||||
|
||||
@@ -326,6 +326,9 @@ pub async fn apply_writes(
|
||||
.ok()
|
||||
.flatten()
|
||||
.ok_or_else(|| ApiError::InternalError(Some("User not found".into())))?;
|
||||
|
||||
let _write_lock = state.repo_write_locks.lock(user_id).await;
|
||||
|
||||
let root_cid_str = state
|
||||
.repo_repo
|
||||
.get_repo_root_cid_by_user_id(user_id)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::api::error::ApiError;
|
||||
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
|
||||
use crate::api::repo::record::utils::{
|
||||
CommitParams, RecordOp, commit_and_log, get_current_root_cid,
|
||||
};
|
||||
use crate::api::repo::record::write::{CommitInfo, prepare_repo_write};
|
||||
use crate::auth::{Active, Auth, VerifyScope};
|
||||
use crate::cid_types::CommitCid;
|
||||
@@ -56,9 +58,11 @@ pub async fn delete_record(
|
||||
|
||||
let did = repo_auth.did;
|
||||
let user_id = repo_auth.user_id;
|
||||
let current_root_cid = repo_auth.current_root_cid;
|
||||
let controller_did = repo_auth.controller_did;
|
||||
|
||||
let _write_lock = state.repo_write_locks.lock(user_id).await;
|
||||
let current_root_cid = get_current_root_cid(&state, user_id).await?;
|
||||
|
||||
if let Some(swap_commit) = &input.swap_commit
|
||||
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid)
|
||||
{
|
||||
@@ -238,6 +242,8 @@ pub async fn delete_record_internal(
|
||||
collection: &Nsid,
|
||||
rkey: &Rkey,
|
||||
) -> Result<(), String> {
|
||||
let _write_lock = state.repo_write_locks.lock(user_id).await;
|
||||
|
||||
let root_cid_str = state
|
||||
.repo_repo
|
||||
.get_repo_root_cid_by_user_id(user_id)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::api::error::ApiError;
|
||||
use crate::cid_types::CommitCid;
|
||||
use crate::state::AppState;
|
||||
use crate::types::{Did, Handle, Nsid, Rkey};
|
||||
use bytes::Bytes;
|
||||
@@ -8,9 +10,24 @@ use jacquard_repo::storage::BlockStore;
|
||||
use k256::ecdsa::SigningKey;
|
||||
use serde_json::{Value, json};
|
||||
use std::str::FromStr;
|
||||
use tracing::error;
|
||||
use tranquil_db_traits::SequenceNumber;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn get_current_root_cid(state: &AppState, user_id: Uuid) -> Result<CommitCid, ApiError> {
|
||||
let root_cid_str = state
|
||||
.repo_repo
|
||||
.get_repo_root_cid_by_user_id(user_id)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("DB error fetching repo root: {}", e);
|
||||
ApiError::InternalError(None)
|
||||
})?
|
||||
.ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?;
|
||||
CommitCid::from_str(&root_cid_str)
|
||||
.map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))
|
||||
}
|
||||
|
||||
pub fn extract_blob_cids(record: &Value) -> Vec<String> {
|
||||
let mut blobs = Vec::new();
|
||||
extract_blob_cids_recursive(record, &mut blobs);
|
||||
@@ -328,6 +345,9 @@ pub async fn create_record_internal(
|
||||
.await
|
||||
.map_err(|e| format!("DB error: {}", e))?
|
||||
.ok_or_else(|| "User not found".to_string())?;
|
||||
|
||||
let _write_lock = state.repo_write_locks.lock(user_id).await;
|
||||
|
||||
let root_cid_link = state
|
||||
.repo_repo
|
||||
.get_repo_root_cid_by_user_id(user_id)
|
||||
|
||||
@@ -3,6 +3,7 @@ use super::validation_mode::{ValidationMode, deserialize_validation_mode};
|
||||
use crate::api::error::ApiError;
|
||||
use crate::api::repo::record::utils::{
|
||||
CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids,
|
||||
get_current_root_cid,
|
||||
};
|
||||
use crate::auth::{
|
||||
Active, Auth, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated,
|
||||
@@ -31,7 +32,6 @@ use uuid::Uuid;
|
||||
pub struct RepoWriteAuth {
|
||||
pub did: Did,
|
||||
pub user_id: Uuid,
|
||||
pub current_root_cid: CommitCid,
|
||||
pub is_oauth: bool,
|
||||
pub scope: Option<String>,
|
||||
pub controller_did: Option<Did>,
|
||||
@@ -62,24 +62,10 @@ pub async fn prepare_repo_write<A: RepoScopeAction>(
|
||||
ApiError::InternalError(None).into_response()
|
||||
})?
|
||||
.ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?;
|
||||
let root_cid_str = state
|
||||
.repo_repo
|
||||
.get_repo_root_cid_by_user_id(user_id)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("DB error fetching repo root: {}", e);
|
||||
ApiError::InternalError(None).into_response()
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
ApiError::InternalError(Some("Repo root not found".into())).into_response()
|
||||
})?;
|
||||
let current_root_cid = CommitCid::from_str(&root_cid_str).map_err(|_| {
|
||||
ApiError::InternalError(Some("Invalid repo root CID".into())).into_response()
|
||||
})?;
|
||||
|
||||
Ok(RepoWriteAuth {
|
||||
did: principal_did.into_did(),
|
||||
user_id,
|
||||
current_root_cid,
|
||||
is_oauth: user.is_oauth(),
|
||||
scope: user.scope.clone(),
|
||||
controller_did: scope_proof.controller_did().map(|c| c.into_did()),
|
||||
@@ -130,9 +116,11 @@ pub async fn create_record(
|
||||
|
||||
let did = repo_auth.did;
|
||||
let user_id = repo_auth.user_id;
|
||||
let current_root_cid = repo_auth.current_root_cid;
|
||||
let controller_did = repo_auth.controller_did;
|
||||
|
||||
let _write_lock = state.repo_write_locks.lock(user_id).await;
|
||||
let current_root_cid = get_current_root_cid(&state, user_id).await?;
|
||||
|
||||
if let Some(swap_commit) = &input.swap_commit
|
||||
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid)
|
||||
{
|
||||
@@ -433,9 +421,11 @@ pub async fn put_record(
|
||||
|
||||
let did = repo_auth.did;
|
||||
let user_id = repo_auth.user_id;
|
||||
let current_root_cid = repo_auth.current_root_cid;
|
||||
let controller_did = repo_auth.controller_did;
|
||||
|
||||
let _write_lock = state.repo_write_locks.lock(user_id).await;
|
||||
let current_root_cid = get_current_root_cid(&state, user_id).await?;
|
||||
|
||||
if let Some(swap_commit) = &input.swap_commit
|
||||
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid)
|
||||
{
|
||||
|
||||
@@ -16,6 +16,7 @@ pub mod oauth;
|
||||
pub mod plc;
|
||||
pub mod rate_limit;
|
||||
pub mod repo;
|
||||
pub mod repo_write_lock;
|
||||
pub mod scheduled;
|
||||
pub mod sso;
|
||||
pub mod state;
|
||||
|
||||
180
crates/tranquil-pds/src/repo_write_lock.rs
Normal file
180
crates/tranquil-pds/src/repo_write_lock.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};
|
||||
use uuid::Uuid;
|
||||
|
||||
const SWEEP_INTERVAL: Duration = Duration::from_secs(300);
|
||||
|
||||
pub struct RepoWriteLocks {
|
||||
locks: Arc<RwLock<HashMap<Uuid, Arc<Mutex<()>>>>>,
|
||||
}
|
||||
|
||||
impl Default for RepoWriteLocks {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RepoWriteLocks {
|
||||
pub fn new() -> Self {
|
||||
let locks = Arc::new(RwLock::new(HashMap::new()));
|
||||
let sweep_locks = Arc::clone(&locks);
|
||||
tokio::spawn(async move {
|
||||
sweep_loop(sweep_locks).await;
|
||||
});
|
||||
Self { locks }
|
||||
}
|
||||
|
||||
pub async fn lock(&self, user_id: Uuid) -> OwnedMutexGuard<()> {
|
||||
let mutex = {
|
||||
let read_guard = self.locks.read().await;
|
||||
read_guard.get(&user_id).cloned()
|
||||
};
|
||||
|
||||
match mutex {
|
||||
Some(m) => m.lock_owned().await,
|
||||
None => {
|
||||
let mut write_guard = self.locks.write().await;
|
||||
let mutex = write_guard
|
||||
.entry(user_id)
|
||||
.or_insert_with(|| Arc::new(Mutex::new(())))
|
||||
.clone();
|
||||
drop(write_guard);
|
||||
mutex.lock_owned().await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn sweep_loop(locks: Arc<RwLock<HashMap<Uuid, Arc<Mutex<()>>>>>) {
|
||||
tokio::time::sleep(SWEEP_INTERVAL).await;
|
||||
let mut write_guard = locks.write().await;
|
||||
let before = write_guard.len();
|
||||
write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1);
|
||||
let evicted = before - write_guard.len();
|
||||
if evicted > 0 {
|
||||
tracing::debug!(
|
||||
evicted,
|
||||
remaining = write_guard.len(),
|
||||
"repo write lock sweep"
|
||||
);
|
||||
}
|
||||
drop(write_guard);
|
||||
Box::pin(sweep_loop(locks)).await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_locks_serialize_same_user() {
|
||||
let locks = Arc::new(RepoWriteLocks::new());
|
||||
let user_id = Uuid::new_v4();
|
||||
let counter = Arc::new(AtomicU32::new(0));
|
||||
let max_concurrent = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let handles: Vec<_> = (0..10)
|
||||
.map(|_| {
|
||||
let locks = locks.clone();
|
||||
let counter = counter.clone();
|
||||
let max_concurrent = max_concurrent.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _guard = locks.lock(user_id).await;
|
||||
let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
max_concurrent.fetch_max(current, Ordering::SeqCst);
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
counter.fetch_sub(1, Ordering::SeqCst);
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
futures::future::join_all(handles).await;
|
||||
|
||||
assert_eq!(
|
||||
max_concurrent.load(Ordering::SeqCst),
|
||||
1,
|
||||
"Only one task should hold the lock at a time for same user"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_different_users_can_run_concurrently() {
|
||||
let locks = Arc::new(RepoWriteLocks::new());
|
||||
let user1 = Uuid::new_v4();
|
||||
let user2 = Uuid::new_v4();
|
||||
let concurrent_count = Arc::new(AtomicU32::new(0));
|
||||
let max_concurrent = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let locks1 = locks.clone();
|
||||
let count1 = concurrent_count.clone();
|
||||
let max1 = max_concurrent.clone();
|
||||
let handle1 = tokio::spawn(async move {
|
||||
let _guard = locks1.lock(user1).await;
|
||||
let current = count1.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
max1.fetch_max(current, Ordering::SeqCst);
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
count1.fetch_sub(1, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
|
||||
let locks2 = locks.clone();
|
||||
let count2 = concurrent_count.clone();
|
||||
let max2 = max_concurrent.clone();
|
||||
let handle2 = tokio::spawn(async move {
|
||||
let _guard = locks2.lock(user2).await;
|
||||
let current = count2.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
max2.fetch_max(current, Ordering::SeqCst);
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
count2.fetch_sub(1, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
handle1.await.unwrap();
|
||||
handle2.await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
max_concurrent.load(Ordering::SeqCst),
|
||||
2,
|
||||
"Different users should be able to run concurrently"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sweep_evicts_idle_entries() {
|
||||
let locks = Arc::new(RwLock::new(HashMap::new()));
|
||||
let user_id = Uuid::new_v4();
|
||||
|
||||
{
|
||||
let mut write_guard = locks.write().await;
|
||||
write_guard.insert(user_id, Arc::new(Mutex::new(())));
|
||||
}
|
||||
|
||||
assert_eq!(locks.read().await.len(), 1);
|
||||
|
||||
let mut write_guard = locks.write().await;
|
||||
write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1);
|
||||
assert_eq!(write_guard.len(), 0, "Idle entry should be evicted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sweep_preserves_active_entries() {
|
||||
let locks = Arc::new(RwLock::new(HashMap::new()));
|
||||
let user_id = Uuid::new_v4();
|
||||
let active_mutex = Arc::new(Mutex::new(()));
|
||||
let _held_ref = active_mutex.clone();
|
||||
|
||||
{
|
||||
let mut write_guard = locks.write().await;
|
||||
write_guard.insert(user_id, active_mutex);
|
||||
}
|
||||
|
||||
let mut write_guard = locks.write().await;
|
||||
write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1);
|
||||
assert_eq!(write_guard.len(), 1, "Active entry should be preserved");
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ use crate::circuit_breaker::CircuitBreakers;
|
||||
use crate::config::AuthConfig;
|
||||
use crate::rate_limit::RateLimiters;
|
||||
use crate::repo::PostgresBlockStore;
|
||||
use crate::repo_write_lock::RepoWriteLocks;
|
||||
use crate::sso::{SsoConfig, SsoManager};
|
||||
use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage};
|
||||
use crate::sync::firehose::SequencedEvent;
|
||||
@@ -38,6 +39,7 @@ pub struct AppState {
|
||||
pub backup_storage: Option<Arc<dyn BackupStorage>>,
|
||||
pub firehose_tx: broadcast::Sender<SequencedEvent>,
|
||||
pub rate_limiters: Arc<RateLimiters>,
|
||||
pub repo_write_locks: Arc<RepoWriteLocks>,
|
||||
pub circuit_breakers: Arc<CircuitBreakers>,
|
||||
pub cache: Arc<dyn Cache>,
|
||||
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
|
||||
@@ -181,6 +183,7 @@ impl AppState {
|
||||
|
||||
let (firehose_tx, _) = broadcast::channel(firehose_buffer_size);
|
||||
let rate_limiters = Arc::new(RateLimiters::new());
|
||||
let repo_write_locks = Arc::new(RepoWriteLocks::new());
|
||||
let circuit_breakers = Arc::new(CircuitBreakers::new());
|
||||
let (cache, distributed_rate_limiter) = create_cache().await;
|
||||
let did_resolver = Arc::new(DidResolver::new());
|
||||
@@ -209,6 +212,7 @@ impl AppState {
|
||||
backup_storage,
|
||||
firehose_tx,
|
||||
rate_limiters,
|
||||
repo_write_locks,
|
||||
circuit_breakers,
|
||||
cache,
|
||||
distributed_rate_limiter,
|
||||
|
||||
Reference in New Issue
Block a user