diff --git a/crates/tranquil-api/src/repo/record/mod.rs b/crates/tranquil-api/src/repo/record/mod.rs index 50e17cd..230548c 100644 --- a/crates/tranquil-api/src/repo/record/mod.rs +++ b/crates/tranquil-api/src/repo/record/mod.rs @@ -2,7 +2,6 @@ pub mod batch; pub mod delete; pub mod pagination; pub mod read; -pub mod utils; pub mod validation; pub mod validation_mode; pub mod write; @@ -13,7 +12,7 @@ pub use validation_mode::ValidationMode; pub use batch::apply_writes; pub use delete::{DeleteRecordInput, delete_record, delete_record_internal}; pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records}; -pub use utils::*; +pub use tranquil_pds::repo_ops::*; pub use write::{ CreateRecordInput, CreateRecordOutput, PutRecordInput, PutRecordOutput, create_record, prepare_repo_write, put_record, diff --git a/crates/tranquil-api/src/repo/record/read.rs b/crates/tranquil-api/src/repo/record/read.rs index 23abe3f..8ba42ea 100644 --- a/crates/tranquil-api/src/repo/record/read.rs +++ b/crates/tranquil-api/src/repo/record/read.rs @@ -1,4 +1,5 @@ use super::pagination::{PaginationDirection, deserialize_pagination_direction}; +use crate::common; use axum::{ Json, extract::{Query, State}, @@ -59,40 +60,9 @@ pub async fn get_record( _headers: HeaderMap, Query(input): Query, ) -> Response { - let hostname_for_handles = tranquil_config::get().server.hostname_without_port(); - let user_id_opt = if input.repo.is_did() { - let did: tranquil_pds::types::Did = match input.repo.as_str().parse() { - Ok(d) => d, - Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(), - }; - state.user_repo.get_id_by_did(&did).await.map_err(|_| ()) - } else { - let repo_str = input.repo.as_str(); - let handle_str = if !repo_str.contains('.') { - format!("{}.{}", repo_str, hostname_for_handles) - } else { - repo_str.to_string() - }; - let handle: tranquil_pds::types::Handle = match handle_str.parse() { - Ok(h) => h, - Err(_) => { - return ApiError::InvalidRequest("Invalid handle format".into()).into_response(); - } - }; - state - .user_repo - .get_id_by_handle(&handle) - .await - .map_err(|_| ()) - }; - let user_id: uuid::Uuid = match user_id_opt { - Ok(Some(id)) => id, - Ok(None) => { - return ApiError::RepoNotFound(Some("Repo not found".into())).into_response(); - } - Err(_) => { - return ApiError::InternalError(None).into_response(); - } + let user_id = match common::resolve_repo_user_id(state.user_repo.as_ref(), &input.repo).await { + Ok(id) => id, + Err(e) => return e.into_response(), }; let record_row = state .repo_repo @@ -158,40 +128,9 @@ pub async fn list_records( State(state): State, Query(input): Query, ) -> Response { - let hostname_for_handles = tranquil_config::get().server.hostname_without_port(); - let user_id_opt = if input.repo.is_did() { - let did: tranquil_pds::types::Did = match input.repo.as_str().parse() { - Ok(d) => d, - Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(), - }; - state.user_repo.get_id_by_did(&did).await.map_err(|_| ()) - } else { - let repo_str = input.repo.as_str(); - let handle_str = if !repo_str.contains('.') { - format!("{}.{}", repo_str, hostname_for_handles) - } else { - repo_str.to_string() - }; - let handle: tranquil_pds::types::Handle = match handle_str.parse() { - Ok(h) => h, - Err(_) => { - return ApiError::InvalidRequest("Invalid handle format".into()).into_response(); - } - }; - state - .user_repo - .get_id_by_handle(&handle) - .await - .map_err(|_| ()) - }; - let user_id: uuid::Uuid = match user_id_opt { - Ok(Some(id)) => id, - Ok(None) => { - return ApiError::RepoNotFound(Some("Repo not found".into())).into_response(); - } - Err(_) => { - return ApiError::InternalError(None).into_response(); - } + let user_id = match common::resolve_repo_user_id(state.user_repo.as_ref(), &input.repo).await { + Ok(id) => id, + Err(e) => return e.into_response(), }; let limit = input.limit.unwrap_or(50).clamp(1, 100); let limit_i64 = i64::from(limit); diff --git a/crates/tranquil-api/src/repo/record/validation.rs b/crates/tranquil-api/src/repo/record/validation.rs index d7d7a1d..7c2150c 100644 --- a/crates/tranquil-api/src/repo/record/validation.rs +++ b/crates/tranquil-api/src/repo/record/validation.rs @@ -1,4 +1,3 @@ -use axum::response::Response; use tranquil_pds::api::error::ApiError; use tranquil_pds::types::{Nsid, Rkey}; use tranquil_pds::validation::{RecordValidator, ValidationError, ValidationStatus}; @@ -8,21 +7,19 @@ pub async fn validate_record_with_status( collection: &Nsid, rkey: Option<&Rkey>, require_lexicon: bool, -) -> Result> { +) -> Result { let registry = tranquil_lexicon::LexiconRegistry::global(); if !registry.has_schema(collection.as_str()) { let _ = registry.resolve_dynamic(collection.as_str()).await; } let validator = RecordValidator::new().require_lexicon(require_lexicon); - match validator.validate_with_rkey(record, collection.as_str(), rkey.map(|r| r.as_str())) { - Ok(status) => Ok(status), - Err(e) => Err(validation_error_to_box_response(e)), - } + validator + .validate_with_rkey(record, collection.as_str(), rkey.map(|v| v.as_str())) + .map_err(validation_error_to_api_error) } -fn validation_error_to_box_response(e: ValidationError) -> Box { - use axum::response::IntoResponse; +fn validation_error_to_api_error(e: ValidationError) -> ApiError { let msg = match e { ValidationError::MissingType => "Record must have a $type field".to_string(), ValidationError::TypeMismatch { expected, actual } => { @@ -44,5 +41,5 @@ fn validation_error_to_box_response(e: ValidationError) -> Box { ValidationError::UnknownType(type_name) => format!("Lexicon not found: lex:{}", type_name), e => e.to_string(), }; - Box::new(ApiError::InvalidRecord(msg).into_response()) + ApiError::InvalidRecord(msg) } diff --git a/crates/tranquil-api/src/repo/record/write.rs b/crates/tranquil-api/src/repo/record/write.rs index 5394a68..e1b65b2 100644 --- a/crates/tranquil-api/src/repo/record/write.rs +++ b/crates/tranquil-api/src/repo/record/write.rs @@ -1,30 +1,21 @@ use super::validation::validate_record_with_status; use super::validation_mode::{ValidationMode, deserialize_validation_mode}; -use crate::repo::record::utils::{ - CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, - get_current_root_cid, -}; -use axum::{ - Json, - extract::State, - http::StatusCode, - response::{IntoResponse, Response}, -}; +use axum::{Json, extract::State}; use cid::Cid; -use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; +use jacquard_repo::storage::BlockStore; use serde::{Deserialize, Serialize}; use serde_json::json; use std::str::FromStr; -use std::sync::Arc; use tracing::error; -use tranquil_pds::api::error::ApiError; +use tranquil_pds::api::error::{ApiError, DbResultExt}; use tranquil_pds::auth::{ Active, Auth, AuthSource, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated, require_verified_or_delegated, }; -use tranquil_pds::cid_types::CommitCid; -use tranquil_pds::delegation::DelegationActionType; -use tranquil_pds::repo::tracking::TrackingBlockStore; +use tranquil_pds::repo_ops::{ + FinalizeParams, RecordOp, begin_repo_write, extract_backlinks, extract_blob_cids, + finalize_repo_write, +}; use tranquil_pds::state::AppState; use tranquil_pds::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; use tranquil_pds::validation::ValidationStatus; @@ -42,13 +33,13 @@ pub async fn prepare_repo_write( state: &AppState, scope_proof: &ScopeVerified<'_, A>, repo: &AtIdentifier, -) -> Result { +) -> Result { let user = scope_proof.user(); let principal_did = scope_proof.principal_did(); if repo.as_str() != principal_did.as_str() { - return Err( - ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(), - ); + return Err(ApiError::InvalidRepo( + "Repo does not match authenticated user".into(), + )); } require_not_migrated(state, principal_did.as_did()).await?; @@ -58,11 +49,8 @@ pub async fn prepare_repo_write( .user_repo .get_id_by_did(principal_did.as_did()) .await - .map_err(|e| { - error!("DB error fetching user: {}", e); - ApiError::InternalError(None).into_response() - })? - .ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?; + .log_db_err("fetching user for repo write")? + .ok_or(ApiError::InternalError(Some("User not found".into())))?; Ok(RepoWriteAuth { did: principal_did.into_did(), @@ -72,6 +60,7 @@ pub async fn prepare_repo_write( controller_did: scope_proof.controller_did().map(|c| c.into_did()), }) } + #[derive(Deserialize)] #[allow(dead_code)] pub struct CreateRecordInput { @@ -84,6 +73,7 @@ pub struct CreateRecordInput { #[serde(rename = "swapCommit")] pub swap_commit: Option, } + #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct CommitInfo { @@ -100,117 +90,64 @@ pub struct CreateRecordOutput { #[serde(skip_serializing_if = "Option::is_none")] pub validation_status: Option, } + pub async fn create_record( State(state): State, auth: Auth, Json(input): Json, -) -> Result { - let scope_proof = match auth.verify_repo_create(&input.collection) { - Ok(proof) => proof, - Err(e) => return Ok(e.into_response()), - }; - - let repo_auth = match prepare_repo_write(&state, &scope_proof, &input.repo).await { - Ok(res) => res, - Err(err_res) => return Ok(err_res), - }; - +) -> Result, ApiError> { + let scope_proof = auth.verify_repo_create(&input.collection)?; + let repo_auth = prepare_repo_write(&state, &scope_proof, &input.repo).await?; let did = repo_auth.did; let user_id = repo_auth.user_id; let controller_did = repo_auth.controller_did; - let _write_lock = state.repo_write_locks.lock(user_id).await; - let current_root_cid = get_current_root_cid(&state, user_id).await?; - - if let Some(swap_commit) = &input.swap_commit - && CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid) - { - return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); - } + let (ctx, mut mst) = begin_repo_write(&state, user_id, input.swap_commit.as_deref()).await?; let validation_status = if input.validate.should_skip() { None } else { - match validate_record_with_status( - &input.record, - &input.collection, - input.rkey.as_ref(), - input.validate.requires_lexicon(), + Some( + validate_record_with_status( + &input.record, + &input.collection, + input.rkey.as_ref(), + input.validate.requires_lexicon(), + ) + .await?, ) - .await - { - Ok(status) => Some(status), - Err(err_response) => return Ok(*err_response), - } }; + let rkey = input.rkey.unwrap_or_else(Rkey::generate); - - let tracking_store = TrackingBlockStore::new(state.block_store.clone()); - let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await { - Ok(Some(b)) => b, - _ => { - return Ok( - ApiError::InternalError(Some("Commit block not found".into())).into_response(), - ); - } - }; - let commit = match Commit::from_cbor(&commit_bytes) { - Ok(c) => c, - _ => { - return Ok( - ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), - ); - } - }; - let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); - let initial_mst_root = commit.data; - let mut ops: Vec = Vec::new(); let mut conflict_uris_to_cleanup: Vec = Vec::new(); - let mut all_old_mst_blocks = std::collections::BTreeMap::new(); if !input.validate.should_skip() { let record_uri = AtUri::from_parts(&did, &input.collection, &rkey); let backlinks = extract_backlinks(&record_uri, &input.record); if !backlinks.is_empty() { - let conflicts = match state + let conflicts = state .backlink_repo .get_backlink_conflicts(user_id, &input.collection, &backlinks) .await - { - Ok(c) => c, - Err(e) => { - error!("Failed to check backlink conflicts: {}", e); - return Ok(ApiError::InternalError(None).into_response()); - } - }; + .log_db_err("checking backlink conflicts")?; for conflict_uri in conflicts { - let conflict_rkey = match conflict_uri.rkey() { - Some(r) => Rkey::from(r.to_string()), - None => continue, - }; - let conflict_collection = match conflict_uri.collection() { - Some(c) => Nsid::from(c.to_string()), - None => continue, + let (Some(conflict_rkey_str), Some(conflict_col_str)) = + (conflict_uri.rkey(), conflict_uri.collection()) + else { + continue; }; + let conflict_rkey = Rkey::from(conflict_rkey_str.to_string()); + let conflict_collection = Nsid::from(conflict_col_str.to_string()); let conflict_key = format!("{}/{}", conflict_collection, conflict_rkey); let prev_cid = match mst.get(&conflict_key).await { Ok(Some(cid)) => cid, - Ok(None) => continue, - Err(_) => continue, + _ => continue, }; - if mst - .blocks_for_path(&conflict_key, &mut all_old_mst_blocks) - .await - .is_err() - { - error!("Failed to get old MST blocks for conflict {}", conflict_uri); - } - mst = match mst.delete(&conflict_key).await { Ok(m) => m, Err(e) => { @@ -233,42 +170,19 @@ pub async fn create_record( } let record_ipld = tranquil_pds::util::json_to_ipld(&input.record); - let mut record_bytes = Vec::new(); - if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { - return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); - } - let record_cid = match tracking_store.put(&record_bytes).await { - Ok(c) => c, - _ => { - return Ok( - ApiError::InternalError(Some("Failed to save record block".into())).into_response(), - ); - } - }; - let key = format!("{}/{}", input.collection, rkey); - - if mst - .blocks_for_path(&key, &mut all_old_mst_blocks) + let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld) + .map_err(|_| ApiError::InvalidRecord("Failed to serialize record".into()))?; + let record_cid = ctx + .tracking_store + .put(&record_bytes) .await - .is_err() - { - error!("Failed to get old MST blocks for new record path"); - } + .map_err(|_| ApiError::InternalError(Some("Failed to save record block".into())))?; - let new_mst = match mst.add(&key, record_cid).await { - Ok(m) => m, - _ => { - return Ok(ApiError::InternalError(Some("Failed to add to MST".into())).into_response()); - } - }; - let new_mst_root = match new_mst.persist().await { - Ok(c) => c, - _ => { - return Ok( - ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), - ); - } - }; + let key = format!("{}/{}", input.collection, rkey); + mst = mst + .add(&key, record_cid) + .await + .map_err(|_| ApiError::InternalError(Some("Failed to add to MST".into())))?; ops.push(RecordOp::Create { collection: input.collection.clone(), @@ -276,86 +190,55 @@ pub async fn create_record( cid: record_cid, }); - let mut new_mst_blocks = std::collections::BTreeMap::new(); - if new_mst - .blocks_for_path(&key, &mut new_mst_blocks) - .await - .is_err() - { - return Ok( - ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) - .into_response(), - ); - } - - let mut relevant_blocks = new_mst_blocks.clone(); - relevant_blocks.extend(all_old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); - relevant_blocks.insert(record_cid, bytes::Bytes::new()); - let written_cids: Vec = tracking_store - .get_all_relevant_cids() - .into_iter() - .chain(relevant_blocks.keys().copied()) - .collect::>() - .into_iter() + let modified_keys: Vec = ops + .iter() + .map(|op| match op { + RecordOp::Create { + collection, rkey, .. + } + | RecordOp::Update { + collection, rkey, .. + } + | RecordOp::Delete { + collection, rkey, .. + } => format!("{}/{}", collection, rkey), + }) .collect(); - let written_cids_str: Vec = written_cids.iter().map(|c| c.to_string()).collect(); let blob_cids = extract_blob_cids(&input.record); - let obsolete_cids: Vec = std::iter::once(current_root_cid.into_cid()) - .chain( - all_old_mst_blocks - .keys() - .filter(|cid| !new_mst_blocks.contains_key(*cid)) - .copied(), - ) - .collect(); - let commit_result = match commit_and_log( + let commit_result = finalize_repo_write( &state, - CommitParams { + ctx, + mst, + FinalizeParams { did: &did, user_id, - current_root_cid: Some(current_root_cid.into_cid()), - prev_data_cid: Some(initial_mst_root), - new_mst_root, - ops, - blocks_cids: &written_cids_str, - blobs: &blob_cids, - obsolete_cids, - }, - ) - .await - { - Ok(res) => res, - Err(e) => return Ok(ApiError::from(e).into_response()), - }; - - for conflict_uri in conflict_uris_to_cleanup { - if let Err(e) = state - .backlink_repo - .remove_backlinks_by_uri(&conflict_uri) - .await - { - error!("Failed to remove backlinks for {}: {}", conflict_uri, e); - } - } - - if let Some(ref controller) = controller_did { - let _ = state - .delegation_repo - .log_delegation_action( - &did, - controller, - Some(controller), - DelegationActionType::RepoWrite, - Some(json!({ + controller_did: controller_did.as_ref(), + delegation_detail: controller_did.as_ref().map(|_| { + json!({ "action": "create", "collection": input.collection, "rkey": rkey - })), - None, - None, - ) - .await; + }) + }), + ops, + modified_keys: &modified_keys, + blob_cids: &blob_cids, + }, + ) + .await?; + + { + let backlink_repo = state.backlink_repo.clone(); + futures::future::join_all(conflict_uris_to_cleanup.iter().map(|uri| { + let backlink_repo = backlink_repo.clone(); + async move { + if let Err(e) = backlink_repo.remove_backlinks_by_uri(uri).await { + error!("Failed to remove backlinks for {}: {}", uri, e); + } + } + })) + .await; } let created_uri = AtUri::from_parts(&did, &input.collection, &rkey); @@ -366,20 +249,17 @@ pub async fn create_record( error!("Failed to add backlinks for {}: {}", created_uri, e); } - Ok(( - StatusCode::OK, - Json(CreateRecordOutput { - uri: created_uri, - cid: record_cid.to_string(), - commit: CommitInfo { - cid: commit_result.commit_cid.to_string(), - rev: commit_result.rev, - }, - validation_status, - }), - ) - .into_response()) + Ok(Json(CreateRecordOutput { + uri: created_uri, + cid: record_cid.to_string(), + commit: CommitInfo { + cid: commit_result.commit_cid.to_string(), + rev: commit_result.rev, + }, + validation_status, + })) } + #[derive(Deserialize)] #[allow(dead_code)] pub struct PutRecordInput { @@ -394,6 +274,7 @@ pub struct PutRecordInput { #[serde(rename = "swapRecord")] pub swap_record: Option, } + #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct PutRecordOutput { @@ -404,130 +285,77 @@ pub struct PutRecordOutput { #[serde(skip_serializing_if = "Option::is_none")] pub validation_status: Option, } + pub async fn put_record( State(state): State, auth: Auth, Json(input): Json, -) -> Result { - let upsert_proof = match auth.verify_repo_upsert(&input.collection) { - Ok(proof) => proof, - Err(e) => return Ok(e.into_response()), - }; - - let repo_auth = match prepare_repo_write(&state, &upsert_proof, &input.repo).await { - Ok(res) => res, - Err(err_res) => return Ok(err_res), - }; - +) -> Result, ApiError> { + let upsert_proof = auth.verify_repo_upsert(&input.collection)?; + let repo_auth = prepare_repo_write(&state, &upsert_proof, &input.repo).await?; let did = repo_auth.did; let user_id = repo_auth.user_id; let controller_did = repo_auth.controller_did; - let _write_lock = state.repo_write_locks.lock(user_id).await; - let current_root_cid = get_current_root_cid(&state, user_id).await?; + let (ctx, mst) = begin_repo_write(&state, user_id, input.swap_commit.as_deref()).await?; - if let Some(swap_commit) = &input.swap_commit - && CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid) - { - return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); - } - let tracking_store = TrackingBlockStore::new(state.block_store.clone()); - let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await { - Ok(Some(b)) => b, - _ => { - return Ok( - ApiError::InternalError(Some("Commit block not found".into())).into_response(), - ); - } - }; - let commit = match Commit::from_cbor(&commit_bytes) { - Ok(c) => c, - _ => { - return Ok( - ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), - ); - } - }; - let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); - let key = format!("{}/{}", input.collection, input.rkey); let validation_status = if input.validate.should_skip() { None } else { - match validate_record_with_status( - &input.record, - &input.collection, - Some(&input.rkey), - input.validate.requires_lexicon(), + Some( + validate_record_with_status( + &input.record, + &input.collection, + Some(&input.rkey), + input.validate.requires_lexicon(), + ) + .await?, ) - .await - { - Ok(status) => Some(status), - Err(err_response) => return Ok(*err_response), - } }; + + let key = format!("{}/{}", input.collection, input.rkey); + if let Some(swap_record_str) = &input.swap_record { let expected_cid = Cid::from_str(swap_record_str).ok(); let actual_cid = mst.get(&key).await.ok().flatten(); if expected_cid != actual_cid { - return Ok(ApiError::InvalidSwap(Some( + return Err(ApiError::InvalidSwap(Some( "Record has been modified or does not exist".into(), - )) - .into_response()); + ))); } } + let existing_cid = mst.get(&key).await.ok().flatten(); let record_ipld = tranquil_pds::util::json_to_ipld(&input.record); - let mut record_bytes = Vec::new(); - if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { - return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); - } - let record_cid = match tracking_store.put(&record_bytes).await { - Ok(c) => c, - _ => { - return Ok( - ApiError::InternalError(Some("Failed to save record block".into())).into_response(), - ); - } - }; + let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld) + .map_err(|_| ApiError::InvalidRecord("Failed to serialize record".into()))?; + let record_cid = ctx + .tracking_store + .put(&record_bytes) + .await + .map_err(|_| ApiError::InternalError(Some("Failed to save record block".into())))?; + if existing_cid == Some(record_cid) { - return Ok(( - StatusCode::OK, - Json(PutRecordOutput { - uri: AtUri::from_parts(&did, &input.collection, &input.rkey), - cid: record_cid.to_string(), - commit: None, - validation_status, - }), - ) - .into_response()); + return Ok(Json(PutRecordOutput { + uri: AtUri::from_parts(&did, &input.collection, &input.rkey), + cid: record_cid.to_string(), + commit: None, + validation_status, + })); } - let new_mst = - if existing_cid.is_some() { - match mst.update(&key, record_cid).await { - Ok(m) => m, - Err(_) => { - return Ok(ApiError::InternalError(Some("Failed to update MST".into())) - .into_response()); - } - } - } else { - match mst.add(&key, record_cid).await { - Ok(m) => m, - Err(_) => { - return Ok(ApiError::InternalError(Some("Failed to add to MST".into())) - .into_response()); - } - } - }; - let new_mst_root = match new_mst.persist().await { - Ok(c) => c, - Err(_) => { - return Ok( - ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), - ); - } + + let is_update = existing_cid.is_some(); + let new_mst = if is_update { + mst.update(&key, record_cid) + .await + .map_err(|_| ApiError::InternalError(Some("Failed to update MST".into())))? + } else { + mst.add(&key, record_cid) + .await + .map_err(|_| ApiError::InternalError(Some("Failed to add to MST".into())))? }; - let op = if existing_cid.is_some() { + + let op = if is_update { RecordOp::Update { collection: input.collection.clone(), rkey: input.rkey.clone(), @@ -541,100 +369,39 @@ pub async fn put_record( cid: record_cid, } }; - let mut new_mst_blocks = std::collections::BTreeMap::new(); - let mut old_mst_blocks = std::collections::BTreeMap::new(); - if new_mst - .blocks_for_path(&key, &mut new_mst_blocks) - .await - .is_err() - { - return Ok( - ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) - .into_response(), - ); - } - if mst - .blocks_for_path(&key, &mut old_mst_blocks) - .await - .is_err() - { - return Ok( - ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) - .into_response(), - ); - } - let mut relevant_blocks = new_mst_blocks.clone(); - relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); - relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); - let written_cids: Vec = tracking_store - .get_all_relevant_cids() - .into_iter() - .chain(relevant_blocks.keys().copied()) - .collect::>() - .into_iter() - .collect(); - let written_cids_str: Vec = written_cids.iter().map(|c| c.to_string()).collect(); - let is_update = existing_cid.is_some(); + + let modified_keys = [key]; let blob_cids = extract_blob_cids(&input.record); - let obsolete_cids: Vec = std::iter::once(current_root_cid.into_cid()) - .chain( - old_mst_blocks - .keys() - .filter(|cid| !new_mst_blocks.contains_key(*cid)) - .copied(), - ) - .chain(existing_cid) - .collect(); - let commit_result = match commit_and_log( + + let commit_result = finalize_repo_write( &state, - CommitParams { + ctx, + new_mst, + FinalizeParams { did: &did, user_id, - current_root_cid: Some(current_root_cid.into_cid()), - prev_data_cid: Some(commit.data), - new_mst_root, - ops: vec![op], - blocks_cids: &written_cids_str, - blobs: &blob_cids, - obsolete_cids, - }, - ) - .await - { - Ok(res) => res, - Err(e) => return Ok(ApiError::from(e).into_response()), - }; - - if let Some(ref controller) = controller_did { - let _ = state - .delegation_repo - .log_delegation_action( - &did, - controller, - Some(controller), - DelegationActionType::RepoWrite, - Some(json!({ + controller_did: controller_did.as_ref(), + delegation_detail: controller_did.as_ref().map(|_| { + json!({ "action": if is_update { "update" } else { "create" }, "collection": input.collection, "rkey": input.rkey - })), - None, - None, - ) - .await; - } - - Ok(( - StatusCode::OK, - Json(PutRecordOutput { - uri: AtUri::from_parts(&did, &input.collection, &input.rkey), - cid: record_cid.to_string(), - commit: Some(CommitInfo { - cid: commit_result.commit_cid.to_string(), - rev: commit_result.rev, + }) }), - validation_status, - }), + ops: vec![op], + modified_keys: &modified_keys, + blob_cids: &blob_cids, + }, ) - .into_response()) + .await?; + + Ok(Json(PutRecordOutput { + uri: AtUri::from_parts(&did, &input.collection, &input.rkey), + cid: record_cid.to_string(), + commit: Some(CommitInfo { + cid: commit_result.commit_cid.to_string(), + rev: commit_result.rev, + }), + validation_status, + })) } diff --git a/crates/tranquil-pds/src/repo_ops.rs b/crates/tranquil-pds/src/repo_ops.rs index f72ffa6..83b5cfe 100644 --- a/crates/tranquil-pds/src/repo_ops.rs +++ b/crates/tranquil-pds/src/repo_ops.rs @@ -1,15 +1,19 @@ use crate::api::error::ApiError; use crate::cid_types::CommitCid; +use crate::repo::tracking::TrackingBlockStore; use crate::state::AppState; use crate::types::{Did, Handle, Nsid, Rkey}; use bytes::Bytes; use cid::Cid; use jacquard_common::types::{integer::LimitedU32, string::Tid}; use jacquard_repo::commit::Commit; +use jacquard_repo::mst::Mst; use jacquard_repo::storage::BlockStore; use k256::ecdsa::SigningKey; use serde_json::{Value, json}; use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::OwnedMutexGuard; use tracing::error; use tranquil_db_traits::SequenceNumber; use uuid::Uuid; @@ -158,6 +162,135 @@ pub fn extract_backlinks(uri: &AtUri, record: &Value) -> Vec { } } +pub struct RepoWriteContext { + pub tracking_store: TrackingBlockStore, + pub current_root_cid: Cid, + pub prev_data_cid: Cid, + pub write_lock: OwnedMutexGuard<()>, +} + +pub struct FinalizeParams<'a> { + pub did: &'a Did, + pub user_id: Uuid, + pub controller_did: Option<&'a Did>, + pub delegation_detail: Option, + pub ops: Vec, + pub modified_keys: &'a [String], + pub blob_cids: &'a [String], +} + +pub async fn begin_repo_write( + state: &AppState, + user_id: Uuid, + swap_commit: Option<&str>, +) -> Result<(RepoWriteContext, Mst), ApiError> { + let write_lock = state.repo_write_locks.lock(user_id).await; + + let root_cid_str = state + .repo_repo + .get_repo_root_cid_by_user_id(user_id) + .await + .map_err(|e| { + error!("DB error fetching repo root: {}", e); + ApiError::InternalError(None) + })? + .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?; + + let current_root_cid = Cid::from_str(root_cid_str.as_str()) + .map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))?; + + if let Some(expected) = swap_commit { + let expected_cid = Cid::from_str(expected) + .map_err(|_| ApiError::InvalidSwap(Some("Invalid swap commit CID".into())))?; + if expected_cid != current_root_cid { + return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); + } + } + + let tracking_store = TrackingBlockStore::new(state.block_store.clone()); + let commit_bytes = tracking_store + .get(¤t_root_cid) + .await + .map_err(|e| { + error!("Failed to load commit block: {:?}", e); + ApiError::InternalError(None) + })? + .ok_or_else(|| ApiError::InternalError(Some("Commit block not found".into())))?; + + let commit = Commit::from_cbor(&commit_bytes).map_err(|e| { + error!("Failed to parse commit: {:?}", e); + ApiError::InternalError(None) + })?; + + let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); + + let ctx = RepoWriteContext { + tracking_store, + current_root_cid, + prev_data_cid: commit.data, + write_lock, + }; + + Ok((ctx, mst)) +} + +pub async fn finalize_repo_write( + state: &AppState, + ctx: RepoWriteContext, + mst: Mst, + params: FinalizeParams<'_>, +) -> Result { + let new_mst_root = mst.persist().await.map_err(|e| { + error!("MST persist failed: {:?}", e); + ApiError::InternalError(None) + })?; + + let written_cids: Vec = ctx + .tracking_store + .get_all_relevant_cids() + .into_iter() + .collect::>() + .into_iter() + .collect(); + let written_cids_str: Vec = written_cids.iter().map(ToString::to_string).collect(); + + let result = commit_and_log( + state, + CommitParams { + did: params.did, + user_id: params.user_id, + current_root_cid: Some(ctx.current_root_cid), + prev_data_cid: Some(ctx.prev_data_cid), + new_mst_root, + ops: params.ops, + blocks_cids: &written_cids_str, + blobs: params.blob_cids, + obsolete_cids: vec![ctx.current_root_cid], + }, + ) + .await?; + + if let Some(controller_did) = params.controller_did + && let Some(detail) = params.delegation_detail + && let Err(e) = state + .delegation_repo + .log_delegation_action( + params.did, + controller_did, + Some(controller_did), + tranquil_db_traits::DelegationActionType::RepoWrite, + Some(detail), + None, + None, + ) + .await + { + tracing::warn!("Failed to log delegation audit: {:?}", e); + } + + Ok(result) +} + pub fn create_signed_commit( did: &Did, data: Cid, @@ -392,9 +525,6 @@ pub async fn create_record_internal( rkey: &Rkey, record: &serde_json::Value, ) -> Result<(String, Cid), CommitError> { - use crate::repo::tracking::TrackingBlockStore; - use jacquard_repo::mst::Mst; - use std::sync::Arc; let user_id: Uuid = state .user_repo .get_id_by_did(did)