fix: concurrent perf improvement

This commit is contained in:
lewis
2026-02-03 18:33:26 +01:00
committed by Tangled
parent 442ca1434f
commit dbc81a6416
8 changed files with 225 additions and 20 deletions

View File

@@ -190,6 +190,7 @@ pub async fn import_repo(
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_MAX_BLOCKS);
let _write_lock = state.repo_write_locks.lock(user_id).await;
match apply_import(&state.repo_repo, user_id, root, blocks.clone(), max_blocks).await {
Ok(import_result) => {
info!(

View File

@@ -326,6 +326,9 @@ pub async fn apply_writes(
.ok()
.flatten()
.ok_or_else(|| ApiError::InternalError(Some("User not found".into())))?;
let _write_lock = state.repo_write_locks.lock(user_id).await;
let root_cid_str = state
.repo_repo
.get_repo_root_cid_by_user_id(user_id)

View File

@@ -1,5 +1,7 @@
use crate::api::error::ApiError;
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
use crate::api::repo::record::utils::{
CommitParams, RecordOp, commit_and_log, get_current_root_cid,
};
use crate::api::repo::record::write::{CommitInfo, prepare_repo_write};
use crate::auth::{Active, Auth, VerifyScope};
use crate::cid_types::CommitCid;
@@ -56,9 +58,11 @@ pub async fn delete_record(
let did = repo_auth.did;
let user_id = repo_auth.user_id;
let current_root_cid = repo_auth.current_root_cid;
let controller_did = repo_auth.controller_did;
let _write_lock = state.repo_write_locks.lock(user_id).await;
let current_root_cid = get_current_root_cid(&state, user_id).await?;
if let Some(swap_commit) = &input.swap_commit
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid)
{
@@ -238,6 +242,8 @@ pub async fn delete_record_internal(
collection: &Nsid,
rkey: &Rkey,
) -> Result<(), String> {
let _write_lock = state.repo_write_locks.lock(user_id).await;
let root_cid_str = state
.repo_repo
.get_repo_root_cid_by_user_id(user_id)

View File

@@ -1,3 +1,5 @@
use crate::api::error::ApiError;
use crate::cid_types::CommitCid;
use crate::state::AppState;
use crate::types::{Did, Handle, Nsid, Rkey};
use bytes::Bytes;
@@ -8,9 +10,24 @@ use jacquard_repo::storage::BlockStore;
use k256::ecdsa::SigningKey;
use serde_json::{Value, json};
use std::str::FromStr;
use tracing::error;
use tranquil_db_traits::SequenceNumber;
use uuid::Uuid;
pub async fn get_current_root_cid(state: &AppState, user_id: Uuid) -> Result<CommitCid, ApiError> {
let root_cid_str = state
.repo_repo
.get_repo_root_cid_by_user_id(user_id)
.await
.map_err(|e| {
error!("DB error fetching repo root: {}", e);
ApiError::InternalError(None)
})?
.ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?;
CommitCid::from_str(&root_cid_str)
.map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))
}
pub fn extract_blob_cids(record: &Value) -> Vec<String> {
let mut blobs = Vec::new();
extract_blob_cids_recursive(record, &mut blobs);
@@ -328,6 +345,9 @@ pub async fn create_record_internal(
.await
.map_err(|e| format!("DB error: {}", e))?
.ok_or_else(|| "User not found".to_string())?;
let _write_lock = state.repo_write_locks.lock(user_id).await;
let root_cid_link = state
.repo_repo
.get_repo_root_cid_by_user_id(user_id)

View File

@@ -3,6 +3,7 @@ use super::validation_mode::{ValidationMode, deserialize_validation_mode};
use crate::api::error::ApiError;
use crate::api::repo::record::utils::{
CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids,
get_current_root_cid,
};
use crate::auth::{
Active, Auth, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated,
@@ -31,7 +32,6 @@ use uuid::Uuid;
pub struct RepoWriteAuth {
pub did: Did,
pub user_id: Uuid,
pub current_root_cid: CommitCid,
pub is_oauth: bool,
pub scope: Option<String>,
pub controller_did: Option<Did>,
@@ -62,24 +62,10 @@ pub async fn prepare_repo_write<A: RepoScopeAction>(
ApiError::InternalError(None).into_response()
})?
.ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?;
let root_cid_str = state
.repo_repo
.get_repo_root_cid_by_user_id(user_id)
.await
.map_err(|e| {
error!("DB error fetching repo root: {}", e);
ApiError::InternalError(None).into_response()
})?
.ok_or_else(|| {
ApiError::InternalError(Some("Repo root not found".into())).into_response()
})?;
let current_root_cid = CommitCid::from_str(&root_cid_str).map_err(|_| {
ApiError::InternalError(Some("Invalid repo root CID".into())).into_response()
})?;
Ok(RepoWriteAuth {
did: principal_did.into_did(),
user_id,
current_root_cid,
is_oauth: user.is_oauth(),
scope: user.scope.clone(),
controller_did: scope_proof.controller_did().map(|c| c.into_did()),
@@ -130,9 +116,11 @@ pub async fn create_record(
let did = repo_auth.did;
let user_id = repo_auth.user_id;
let current_root_cid = repo_auth.current_root_cid;
let controller_did = repo_auth.controller_did;
let _write_lock = state.repo_write_locks.lock(user_id).await;
let current_root_cid = get_current_root_cid(&state, user_id).await?;
if let Some(swap_commit) = &input.swap_commit
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid)
{
@@ -433,9 +421,11 @@ pub async fn put_record(
let did = repo_auth.did;
let user_id = repo_auth.user_id;
let current_root_cid = repo_auth.current_root_cid;
let controller_did = repo_auth.controller_did;
let _write_lock = state.repo_write_locks.lock(user_id).await;
let current_root_cid = get_current_root_cid(&state, user_id).await?;
if let Some(swap_commit) = &input.swap_commit
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid)
{

View File

@@ -16,6 +16,7 @@ pub mod oauth;
pub mod plc;
pub mod rate_limit;
pub mod repo;
pub mod repo_write_lock;
pub mod scheduled;
pub mod sso;
pub mod state;

View File

@@ -0,0 +1,180 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};
use uuid::Uuid;
const SWEEP_INTERVAL: Duration = Duration::from_secs(300);
pub struct RepoWriteLocks {
locks: Arc<RwLock<HashMap<Uuid, Arc<Mutex<()>>>>>,
}
impl Default for RepoWriteLocks {
fn default() -> Self {
Self::new()
}
}
impl RepoWriteLocks {
pub fn new() -> Self {
let locks = Arc::new(RwLock::new(HashMap::new()));
let sweep_locks = Arc::clone(&locks);
tokio::spawn(async move {
sweep_loop(sweep_locks).await;
});
Self { locks }
}
pub async fn lock(&self, user_id: Uuid) -> OwnedMutexGuard<()> {
let mutex = {
let read_guard = self.locks.read().await;
read_guard.get(&user_id).cloned()
};
match mutex {
Some(m) => m.lock_owned().await,
None => {
let mut write_guard = self.locks.write().await;
let mutex = write_guard
.entry(user_id)
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone();
drop(write_guard);
mutex.lock_owned().await
}
}
}
}
async fn sweep_loop(locks: Arc<RwLock<HashMap<Uuid, Arc<Mutex<()>>>>>) {
tokio::time::sleep(SWEEP_INTERVAL).await;
let mut write_guard = locks.write().await;
let before = write_guard.len();
write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1);
let evicted = before - write_guard.len();
if evicted > 0 {
tracing::debug!(
evicted,
remaining = write_guard.len(),
"repo write lock sweep"
);
}
drop(write_guard);
Box::pin(sweep_loop(locks)).await;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
#[tokio::test]
async fn test_locks_serialize_same_user() {
let locks = Arc::new(RepoWriteLocks::new());
let user_id = Uuid::new_v4();
let counter = Arc::new(AtomicU32::new(0));
let max_concurrent = Arc::new(AtomicU32::new(0));
let handles: Vec<_> = (0..10)
.map(|_| {
let locks = locks.clone();
let counter = counter.clone();
let max_concurrent = max_concurrent.clone();
tokio::spawn(async move {
let _guard = locks.lock(user_id).await;
let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(1)).await;
counter.fetch_sub(1, Ordering::SeqCst);
})
})
.collect();
futures::future::join_all(handles).await;
assert_eq!(
max_concurrent.load(Ordering::SeqCst),
1,
"Only one task should hold the lock at a time for same user"
);
}
#[tokio::test]
async fn test_different_users_can_run_concurrently() {
let locks = Arc::new(RepoWriteLocks::new());
let user1 = Uuid::new_v4();
let user2 = Uuid::new_v4();
let concurrent_count = Arc::new(AtomicU32::new(0));
let max_concurrent = Arc::new(AtomicU32::new(0));
let locks1 = locks.clone();
let count1 = concurrent_count.clone();
let max1 = max_concurrent.clone();
let handle1 = tokio::spawn(async move {
let _guard = locks1.lock(user1).await;
let current = count1.fetch_add(1, Ordering::SeqCst) + 1;
max1.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
count1.fetch_sub(1, Ordering::SeqCst);
});
tokio::time::sleep(Duration::from_millis(10)).await;
let locks2 = locks.clone();
let count2 = concurrent_count.clone();
let max2 = max_concurrent.clone();
let handle2 = tokio::spawn(async move {
let _guard = locks2.lock(user2).await;
let current = count2.fetch_add(1, Ordering::SeqCst) + 1;
max2.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
count2.fetch_sub(1, Ordering::SeqCst);
});
handle1.await.unwrap();
handle2.await.unwrap();
assert_eq!(
max_concurrent.load(Ordering::SeqCst),
2,
"Different users should be able to run concurrently"
);
}
#[tokio::test]
async fn test_sweep_evicts_idle_entries() {
let locks = Arc::new(RwLock::new(HashMap::new()));
let user_id = Uuid::new_v4();
{
let mut write_guard = locks.write().await;
write_guard.insert(user_id, Arc::new(Mutex::new(())));
}
assert_eq!(locks.read().await.len(), 1);
let mut write_guard = locks.write().await;
write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1);
assert_eq!(write_guard.len(), 0, "Idle entry should be evicted");
}
#[tokio::test]
async fn test_sweep_preserves_active_entries() {
let locks = Arc::new(RwLock::new(HashMap::new()));
let user_id = Uuid::new_v4();
let active_mutex = Arc::new(Mutex::new(()));
let _held_ref = active_mutex.clone();
{
let mut write_guard = locks.write().await;
write_guard.insert(user_id, active_mutex);
}
let mut write_guard = locks.write().await;
write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1);
assert_eq!(write_guard.len(), 1, "Active entry should be preserved");
}
}

View File

@@ -5,6 +5,7 @@ use crate::circuit_breaker::CircuitBreakers;
use crate::config::AuthConfig;
use crate::rate_limit::RateLimiters;
use crate::repo::PostgresBlockStore;
use crate::repo_write_lock::RepoWriteLocks;
use crate::sso::{SsoConfig, SsoManager};
use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage};
use crate::sync::firehose::SequencedEvent;
@@ -38,6 +39,7 @@ pub struct AppState {
pub backup_storage: Option<Arc<dyn BackupStorage>>,
pub firehose_tx: broadcast::Sender<SequencedEvent>,
pub rate_limiters: Arc<RateLimiters>,
pub repo_write_locks: Arc<RepoWriteLocks>,
pub circuit_breakers: Arc<CircuitBreakers>,
pub cache: Arc<dyn Cache>,
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
@@ -181,6 +183,7 @@ impl AppState {
let (firehose_tx, _) = broadcast::channel(firehose_buffer_size);
let rate_limiters = Arc::new(RateLimiters::new());
let repo_write_locks = Arc::new(RepoWriteLocks::new());
let circuit_breakers = Arc::new(CircuitBreakers::new());
let (cache, distributed_rate_limiter) = create_cache().await;
let did_resolver = Arc::new(DidResolver::new());
@@ -209,6 +212,7 @@ impl AppState {
backup_storage,
firehose_tx,
rate_limiters,
repo_write_locks,
circuit_breakers,
cache,
distributed_rate_limiter,