refactor(api): extract repo write lifecycle to repo_ops

This commit is contained in:
Lewis
2026-03-19 19:51:07 +02:00
committed by Tangled
parent 4d86f026df
commit 7bc90d5e23
5 changed files with 318 additions and 486 deletions

View File

@@ -2,7 +2,6 @@ pub mod batch;
pub mod delete;
pub mod pagination;
pub mod read;
pub mod utils;
pub mod validation;
pub mod validation_mode;
pub mod write;
@@ -13,7 +12,7 @@ pub use validation_mode::ValidationMode;
pub use batch::apply_writes;
pub use delete::{DeleteRecordInput, delete_record, delete_record_internal};
pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
pub use utils::*;
pub use tranquil_pds::repo_ops::*;
pub use write::{
CreateRecordInput, CreateRecordOutput, PutRecordInput, PutRecordOutput, create_record,
prepare_repo_write, put_record,

View File

@@ -1,4 +1,5 @@
use super::pagination::{PaginationDirection, deserialize_pagination_direction};
use crate::common;
use axum::{
Json,
extract::{Query, State},
@@ -59,40 +60,9 @@ pub async fn get_record(
_headers: HeaderMap,
Query(input): Query<GetRecordInput>,
) -> Response {
let hostname_for_handles = tranquil_config::get().server.hostname_without_port();
let user_id_opt = if input.repo.is_did() {
let did: tranquil_pds::types::Did = match input.repo.as_str().parse() {
Ok(d) => d,
Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(),
};
state.user_repo.get_id_by_did(&did).await.map_err(|_| ())
} else {
let repo_str = input.repo.as_str();
let handle_str = if !repo_str.contains('.') {
format!("{}.{}", repo_str, hostname_for_handles)
} else {
repo_str.to_string()
};
let handle: tranquil_pds::types::Handle = match handle_str.parse() {
Ok(h) => h,
Err(_) => {
return ApiError::InvalidRequest("Invalid handle format".into()).into_response();
}
};
state
.user_repo
.get_id_by_handle(&handle)
.await
.map_err(|_| ())
};
let user_id: uuid::Uuid = match user_id_opt {
Ok(Some(id)) => id,
Ok(None) => {
return ApiError::RepoNotFound(Some("Repo not found".into())).into_response();
}
Err(_) => {
return ApiError::InternalError(None).into_response();
}
let user_id = match common::resolve_repo_user_id(state.user_repo.as_ref(), &input.repo).await {
Ok(id) => id,
Err(e) => return e.into_response(),
};
let record_row = state
.repo_repo
@@ -158,40 +128,9 @@ pub async fn list_records(
State(state): State<AppState>,
Query(input): Query<ListRecordsInput>,
) -> Response {
let hostname_for_handles = tranquil_config::get().server.hostname_without_port();
let user_id_opt = if input.repo.is_did() {
let did: tranquil_pds::types::Did = match input.repo.as_str().parse() {
Ok(d) => d,
Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(),
};
state.user_repo.get_id_by_did(&did).await.map_err(|_| ())
} else {
let repo_str = input.repo.as_str();
let handle_str = if !repo_str.contains('.') {
format!("{}.{}", repo_str, hostname_for_handles)
} else {
repo_str.to_string()
};
let handle: tranquil_pds::types::Handle = match handle_str.parse() {
Ok(h) => h,
Err(_) => {
return ApiError::InvalidRequest("Invalid handle format".into()).into_response();
}
};
state
.user_repo
.get_id_by_handle(&handle)
.await
.map_err(|_| ())
};
let user_id: uuid::Uuid = match user_id_opt {
Ok(Some(id)) => id,
Ok(None) => {
return ApiError::RepoNotFound(Some("Repo not found".into())).into_response();
}
Err(_) => {
return ApiError::InternalError(None).into_response();
}
let user_id = match common::resolve_repo_user_id(state.user_repo.as_ref(), &input.repo).await {
Ok(id) => id,
Err(e) => return e.into_response(),
};
let limit = input.limit.unwrap_or(50).clamp(1, 100);
let limit_i64 = i64::from(limit);

View File

@@ -1,4 +1,3 @@
use axum::response::Response;
use tranquil_pds::api::error::ApiError;
use tranquil_pds::types::{Nsid, Rkey};
use tranquil_pds::validation::{RecordValidator, ValidationError, ValidationStatus};
@@ -8,21 +7,19 @@ pub async fn validate_record_with_status(
collection: &Nsid,
rkey: Option<&Rkey>,
require_lexicon: bool,
) -> Result<ValidationStatus, Box<Response>> {
) -> Result<ValidationStatus, ApiError> {
let registry = tranquil_lexicon::LexiconRegistry::global();
if !registry.has_schema(collection.as_str()) {
let _ = registry.resolve_dynamic(collection.as_str()).await;
}
let validator = RecordValidator::new().require_lexicon(require_lexicon);
match validator.validate_with_rkey(record, collection.as_str(), rkey.map(|r| r.as_str())) {
Ok(status) => Ok(status),
Err(e) => Err(validation_error_to_box_response(e)),
}
validator
.validate_with_rkey(record, collection.as_str(), rkey.map(|v| v.as_str()))
.map_err(validation_error_to_api_error)
}
fn validation_error_to_box_response(e: ValidationError) -> Box<Response> {
use axum::response::IntoResponse;
fn validation_error_to_api_error(e: ValidationError) -> ApiError {
let msg = match e {
ValidationError::MissingType => "Record must have a $type field".to_string(),
ValidationError::TypeMismatch { expected, actual } => {
@@ -44,5 +41,5 @@ fn validation_error_to_box_response(e: ValidationError) -> Box<Response> {
ValidationError::UnknownType(type_name) => format!("Lexicon not found: lex:{}", type_name),
e => e.to_string(),
};
Box::new(ApiError::InvalidRecord(msg).into_response())
ApiError::InvalidRecord(msg)
}

View File

@@ -1,30 +1,21 @@
use super::validation::validate_record_with_status;
use super::validation_mode::{ValidationMode, deserialize_validation_mode};
use crate::repo::record::utils::{
CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids,
get_current_root_cid,
};
use axum::{
Json,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use axum::{Json, extract::State};
use cid::Cid;
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
use jacquard_repo::storage::BlockStore;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::str::FromStr;
use std::sync::Arc;
use tracing::error;
use tranquil_pds::api::error::ApiError;
use tranquil_pds::api::error::{ApiError, DbResultExt};
use tranquil_pds::auth::{
Active, Auth, AuthSource, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated,
require_verified_or_delegated,
};
use tranquil_pds::cid_types::CommitCid;
use tranquil_pds::delegation::DelegationActionType;
use tranquil_pds::repo::tracking::TrackingBlockStore;
use tranquil_pds::repo_ops::{
FinalizeParams, RecordOp, begin_repo_write, extract_backlinks, extract_blob_cids,
finalize_repo_write,
};
use tranquil_pds::state::AppState;
use tranquil_pds::types::{AtIdentifier, AtUri, Did, Nsid, Rkey};
use tranquil_pds::validation::ValidationStatus;
@@ -42,13 +33,13 @@ pub async fn prepare_repo_write<A: RepoScopeAction>(
state: &AppState,
scope_proof: &ScopeVerified<'_, A>,
repo: &AtIdentifier,
) -> Result<RepoWriteAuth, Response> {
) -> Result<RepoWriteAuth, ApiError> {
let user = scope_proof.user();
let principal_did = scope_proof.principal_did();
if repo.as_str() != principal_did.as_str() {
return Err(
ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(),
);
return Err(ApiError::InvalidRepo(
"Repo does not match authenticated user".into(),
));
}
require_not_migrated(state, principal_did.as_did()).await?;
@@ -58,11 +49,8 @@ pub async fn prepare_repo_write<A: RepoScopeAction>(
.user_repo
.get_id_by_did(principal_did.as_did())
.await
.map_err(|e| {
error!("DB error fetching user: {}", e);
ApiError::InternalError(None).into_response()
})?
.ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?;
.log_db_err("fetching user for repo write")?
.ok_or(ApiError::InternalError(Some("User not found".into())))?;
Ok(RepoWriteAuth {
did: principal_did.into_did(),
@@ -72,6 +60,7 @@ pub async fn prepare_repo_write<A: RepoScopeAction>(
controller_did: scope_proof.controller_did().map(|c| c.into_did()),
})
}
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct CreateRecordInput {
@@ -84,6 +73,7 @@ pub struct CreateRecordInput {
#[serde(rename = "swapCommit")]
pub swap_commit: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CommitInfo {
@@ -100,117 +90,64 @@ pub struct CreateRecordOutput {
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_status: Option<ValidationStatus>,
}
pub async fn create_record(
State(state): State<AppState>,
auth: Auth<Active>,
Json(input): Json<CreateRecordInput>,
) -> Result<Response, tranquil_pds::api::error::ApiError> {
let scope_proof = match auth.verify_repo_create(&input.collection) {
Ok(proof) => proof,
Err(e) => return Ok(e.into_response()),
};
let repo_auth = match prepare_repo_write(&state, &scope_proof, &input.repo).await {
Ok(res) => res,
Err(err_res) => return Ok(err_res),
};
) -> Result<Json<CreateRecordOutput>, ApiError> {
let scope_proof = auth.verify_repo_create(&input.collection)?;
let repo_auth = prepare_repo_write(&state, &scope_proof, &input.repo).await?;
let did = repo_auth.did;
let user_id = repo_auth.user_id;
let controller_did = repo_auth.controller_did;
let _write_lock = state.repo_write_locks.lock(user_id).await;
let current_root_cid = get_current_root_cid(&state, user_id).await?;
if let Some(swap_commit) = &input.swap_commit
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid)
{
return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response());
}
let (ctx, mut mst) = begin_repo_write(&state, user_id, input.swap_commit.as_deref()).await?;
let validation_status = if input.validate.should_skip() {
None
} else {
match validate_record_with_status(
&input.record,
&input.collection,
input.rkey.as_ref(),
input.validate.requires_lexicon(),
Some(
validate_record_with_status(
&input.record,
&input.collection,
input.rkey.as_ref(),
input.validate.requires_lexicon(),
)
.await?,
)
.await
{
Ok(status) => Some(status),
Err(err_response) => return Ok(*err_response),
}
};
let rkey = input.rkey.unwrap_or_else(Rkey::generate);
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await {
Ok(Some(b)) => b,
_ => {
return Ok(
ApiError::InternalError(Some("Commit block not found".into())).into_response(),
);
}
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
_ => {
return Ok(
ApiError::InternalError(Some("Failed to parse commit".into())).into_response(),
);
}
};
let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
let initial_mst_root = commit.data;
let mut ops: Vec<RecordOp> = Vec::new();
let mut conflict_uris_to_cleanup: Vec<AtUri> = Vec::new();
let mut all_old_mst_blocks = std::collections::BTreeMap::new();
if !input.validate.should_skip() {
let record_uri = AtUri::from_parts(&did, &input.collection, &rkey);
let backlinks = extract_backlinks(&record_uri, &input.record);
if !backlinks.is_empty() {
let conflicts = match state
let conflicts = state
.backlink_repo
.get_backlink_conflicts(user_id, &input.collection, &backlinks)
.await
{
Ok(c) => c,
Err(e) => {
error!("Failed to check backlink conflicts: {}", e);
return Ok(ApiError::InternalError(None).into_response());
}
};
.log_db_err("checking backlink conflicts")?;
for conflict_uri in conflicts {
let conflict_rkey = match conflict_uri.rkey() {
Some(r) => Rkey::from(r.to_string()),
None => continue,
};
let conflict_collection = match conflict_uri.collection() {
Some(c) => Nsid::from(c.to_string()),
None => continue,
let (Some(conflict_rkey_str), Some(conflict_col_str)) =
(conflict_uri.rkey(), conflict_uri.collection())
else {
continue;
};
let conflict_rkey = Rkey::from(conflict_rkey_str.to_string());
let conflict_collection = Nsid::from(conflict_col_str.to_string());
let conflict_key = format!("{}/{}", conflict_collection, conflict_rkey);
let prev_cid = match mst.get(&conflict_key).await {
Ok(Some(cid)) => cid,
Ok(None) => continue,
Err(_) => continue,
_ => continue,
};
if mst
.blocks_for_path(&conflict_key, &mut all_old_mst_blocks)
.await
.is_err()
{
error!("Failed to get old MST blocks for conflict {}", conflict_uri);
}
mst = match mst.delete(&conflict_key).await {
Ok(m) => m,
Err(e) => {
@@ -233,42 +170,19 @@ pub async fn create_record(
}
let record_ipld = tranquil_pds::util::json_to_ipld(&input.record);
let mut record_bytes = Vec::new();
if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() {
return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response());
}
let record_cid = match tracking_store.put(&record_bytes).await {
Ok(c) => c,
_ => {
return Ok(
ApiError::InternalError(Some("Failed to save record block".into())).into_response(),
);
}
};
let key = format!("{}/{}", input.collection, rkey);
if mst
.blocks_for_path(&key, &mut all_old_mst_blocks)
let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld)
.map_err(|_| ApiError::InvalidRecord("Failed to serialize record".into()))?;
let record_cid = ctx
.tracking_store
.put(&record_bytes)
.await
.is_err()
{
error!("Failed to get old MST blocks for new record path");
}
.map_err(|_| ApiError::InternalError(Some("Failed to save record block".into())))?;
let new_mst = match mst.add(&key, record_cid).await {
Ok(m) => m,
_ => {
return Ok(ApiError::InternalError(Some("Failed to add to MST".into())).into_response());
}
};
let new_mst_root = match new_mst.persist().await {
Ok(c) => c,
_ => {
return Ok(
ApiError::InternalError(Some("Failed to persist MST".into())).into_response(),
);
}
};
let key = format!("{}/{}", input.collection, rkey);
mst = mst
.add(&key, record_cid)
.await
.map_err(|_| ApiError::InternalError(Some("Failed to add to MST".into())))?;
ops.push(RecordOp::Create {
collection: input.collection.clone(),
@@ -276,86 +190,55 @@ pub async fn create_record(
cid: record_cid,
});
let mut new_mst_blocks = std::collections::BTreeMap::new();
if new_mst
.blocks_for_path(&key, &mut new_mst_blocks)
.await
.is_err()
{
return Ok(
ApiError::InternalError(Some("Failed to get new MST blocks for path".into()))
.into_response(),
);
}
let mut relevant_blocks = new_mst_blocks.clone();
relevant_blocks.extend(all_old_mst_blocks.iter().map(|(k, v)| (*k, v.clone())));
relevant_blocks.insert(record_cid, bytes::Bytes::new());
let written_cids: Vec<Cid> = tracking_store
.get_all_relevant_cids()
.into_iter()
.chain(relevant_blocks.keys().copied())
.collect::<std::collections::HashSet<_>>()
.into_iter()
let modified_keys: Vec<String> = ops
.iter()
.map(|op| match op {
RecordOp::Create {
collection, rkey, ..
}
| RecordOp::Update {
collection, rkey, ..
}
| RecordOp::Delete {
collection, rkey, ..
} => format!("{}/{}", collection, rkey),
})
.collect();
let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
let blob_cids = extract_blob_cids(&input.record);
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid())
.chain(
all_old_mst_blocks
.keys()
.filter(|cid| !new_mst_blocks.contains_key(*cid))
.copied(),
)
.collect();
let commit_result = match commit_and_log(
let commit_result = finalize_repo_write(
&state,
CommitParams {
ctx,
mst,
FinalizeParams {
did: &did,
user_id,
current_root_cid: Some(current_root_cid.into_cid()),
prev_data_cid: Some(initial_mst_root),
new_mst_root,
ops,
blocks_cids: &written_cids_str,
blobs: &blob_cids,
obsolete_cids,
},
)
.await
{
Ok(res) => res,
Err(e) => return Ok(ApiError::from(e).into_response()),
};
for conflict_uri in conflict_uris_to_cleanup {
if let Err(e) = state
.backlink_repo
.remove_backlinks_by_uri(&conflict_uri)
.await
{
error!("Failed to remove backlinks for {}: {}", conflict_uri, e);
}
}
if let Some(ref controller) = controller_did {
let _ = state
.delegation_repo
.log_delegation_action(
&did,
controller,
Some(controller),
DelegationActionType::RepoWrite,
Some(json!({
controller_did: controller_did.as_ref(),
delegation_detail: controller_did.as_ref().map(|_| {
json!({
"action": "create",
"collection": input.collection,
"rkey": rkey
})),
None,
None,
)
.await;
})
}),
ops,
modified_keys: &modified_keys,
blob_cids: &blob_cids,
},
)
.await?;
{
let backlink_repo = state.backlink_repo.clone();
futures::future::join_all(conflict_uris_to_cleanup.iter().map(|uri| {
let backlink_repo = backlink_repo.clone();
async move {
if let Err(e) = backlink_repo.remove_backlinks_by_uri(uri).await {
error!("Failed to remove backlinks for {}: {}", uri, e);
}
}
}))
.await;
}
let created_uri = AtUri::from_parts(&did, &input.collection, &rkey);
@@ -366,20 +249,17 @@ pub async fn create_record(
error!("Failed to add backlinks for {}: {}", created_uri, e);
}
Ok((
StatusCode::OK,
Json(CreateRecordOutput {
uri: created_uri,
cid: record_cid.to_string(),
commit: CommitInfo {
cid: commit_result.commit_cid.to_string(),
rev: commit_result.rev,
},
validation_status,
}),
)
.into_response())
Ok(Json(CreateRecordOutput {
uri: created_uri,
cid: record_cid.to_string(),
commit: CommitInfo {
cid: commit_result.commit_cid.to_string(),
rev: commit_result.rev,
},
validation_status,
}))
}
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct PutRecordInput {
@@ -394,6 +274,7 @@ pub struct PutRecordInput {
#[serde(rename = "swapRecord")]
pub swap_record: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PutRecordOutput {
@@ -404,130 +285,77 @@ pub struct PutRecordOutput {
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_status: Option<ValidationStatus>,
}
pub async fn put_record(
State(state): State<AppState>,
auth: Auth<Active>,
Json(input): Json<PutRecordInput>,
) -> Result<Response, tranquil_pds::api::error::ApiError> {
let upsert_proof = match auth.verify_repo_upsert(&input.collection) {
Ok(proof) => proof,
Err(e) => return Ok(e.into_response()),
};
let repo_auth = match prepare_repo_write(&state, &upsert_proof, &input.repo).await {
Ok(res) => res,
Err(err_res) => return Ok(err_res),
};
) -> Result<Json<PutRecordOutput>, ApiError> {
let upsert_proof = auth.verify_repo_upsert(&input.collection)?;
let repo_auth = prepare_repo_write(&state, &upsert_proof, &input.repo).await?;
let did = repo_auth.did;
let user_id = repo_auth.user_id;
let controller_did = repo_auth.controller_did;
let _write_lock = state.repo_write_locks.lock(user_id).await;
let current_root_cid = get_current_root_cid(&state, user_id).await?;
let (ctx, mst) = begin_repo_write(&state, user_id, input.swap_commit.as_deref()).await?;
if let Some(swap_commit) = &input.swap_commit
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid)
{
return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response());
}
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await {
Ok(Some(b)) => b,
_ => {
return Ok(
ApiError::InternalError(Some("Commit block not found".into())).into_response(),
);
}
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
_ => {
return Ok(
ApiError::InternalError(Some("Failed to parse commit".into())).into_response(),
);
}
};
let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
let key = format!("{}/{}", input.collection, input.rkey);
let validation_status = if input.validate.should_skip() {
None
} else {
match validate_record_with_status(
&input.record,
&input.collection,
Some(&input.rkey),
input.validate.requires_lexicon(),
Some(
validate_record_with_status(
&input.record,
&input.collection,
Some(&input.rkey),
input.validate.requires_lexicon(),
)
.await?,
)
.await
{
Ok(status) => Some(status),
Err(err_response) => return Ok(*err_response),
}
};
let key = format!("{}/{}", input.collection, input.rkey);
if let Some(swap_record_str) = &input.swap_record {
let expected_cid = Cid::from_str(swap_record_str).ok();
let actual_cid = mst.get(&key).await.ok().flatten();
if expected_cid != actual_cid {
return Ok(ApiError::InvalidSwap(Some(
return Err(ApiError::InvalidSwap(Some(
"Record has been modified or does not exist".into(),
))
.into_response());
)));
}
}
let existing_cid = mst.get(&key).await.ok().flatten();
let record_ipld = tranquil_pds::util::json_to_ipld(&input.record);
let mut record_bytes = Vec::new();
if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() {
return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response());
}
let record_cid = match tracking_store.put(&record_bytes).await {
Ok(c) => c,
_ => {
return Ok(
ApiError::InternalError(Some("Failed to save record block".into())).into_response(),
);
}
};
let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld)
.map_err(|_| ApiError::InvalidRecord("Failed to serialize record".into()))?;
let record_cid = ctx
.tracking_store
.put(&record_bytes)
.await
.map_err(|_| ApiError::InternalError(Some("Failed to save record block".into())))?;
if existing_cid == Some(record_cid) {
return Ok((
StatusCode::OK,
Json(PutRecordOutput {
uri: AtUri::from_parts(&did, &input.collection, &input.rkey),
cid: record_cid.to_string(),
commit: None,
validation_status,
}),
)
.into_response());
return Ok(Json(PutRecordOutput {
uri: AtUri::from_parts(&did, &input.collection, &input.rkey),
cid: record_cid.to_string(),
commit: None,
validation_status,
}));
}
let new_mst =
if existing_cid.is_some() {
match mst.update(&key, record_cid).await {
Ok(m) => m,
Err(_) => {
return Ok(ApiError::InternalError(Some("Failed to update MST".into()))
.into_response());
}
}
} else {
match mst.add(&key, record_cid).await {
Ok(m) => m,
Err(_) => {
return Ok(ApiError::InternalError(Some("Failed to add to MST".into()))
.into_response());
}
}
};
let new_mst_root = match new_mst.persist().await {
Ok(c) => c,
Err(_) => {
return Ok(
ApiError::InternalError(Some("Failed to persist MST".into())).into_response(),
);
}
let is_update = existing_cid.is_some();
let new_mst = if is_update {
mst.update(&key, record_cid)
.await
.map_err(|_| ApiError::InternalError(Some("Failed to update MST".into())))?
} else {
mst.add(&key, record_cid)
.await
.map_err(|_| ApiError::InternalError(Some("Failed to add to MST".into())))?
};
let op = if existing_cid.is_some() {
let op = if is_update {
RecordOp::Update {
collection: input.collection.clone(),
rkey: input.rkey.clone(),
@@ -541,100 +369,39 @@ pub async fn put_record(
cid: record_cid,
}
};
let mut new_mst_blocks = std::collections::BTreeMap::new();
let mut old_mst_blocks = std::collections::BTreeMap::new();
if new_mst
.blocks_for_path(&key, &mut new_mst_blocks)
.await
.is_err()
{
return Ok(
ApiError::InternalError(Some("Failed to get new MST blocks for path".into()))
.into_response(),
);
}
if mst
.blocks_for_path(&key, &mut old_mst_blocks)
.await
.is_err()
{
return Ok(
ApiError::InternalError(Some("Failed to get old MST blocks for path".into()))
.into_response(),
);
}
let mut relevant_blocks = new_mst_blocks.clone();
relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone())));
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
let written_cids: Vec<Cid> = tracking_store
.get_all_relevant_cids()
.into_iter()
.chain(relevant_blocks.keys().copied())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
let is_update = existing_cid.is_some();
let modified_keys = [key];
let blob_cids = extract_blob_cids(&input.record);
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid())
.chain(
old_mst_blocks
.keys()
.filter(|cid| !new_mst_blocks.contains_key(*cid))
.copied(),
)
.chain(existing_cid)
.collect();
let commit_result = match commit_and_log(
let commit_result = finalize_repo_write(
&state,
CommitParams {
ctx,
new_mst,
FinalizeParams {
did: &did,
user_id,
current_root_cid: Some(current_root_cid.into_cid()),
prev_data_cid: Some(commit.data),
new_mst_root,
ops: vec![op],
blocks_cids: &written_cids_str,
blobs: &blob_cids,
obsolete_cids,
},
)
.await
{
Ok(res) => res,
Err(e) => return Ok(ApiError::from(e).into_response()),
};
if let Some(ref controller) = controller_did {
let _ = state
.delegation_repo
.log_delegation_action(
&did,
controller,
Some(controller),
DelegationActionType::RepoWrite,
Some(json!({
controller_did: controller_did.as_ref(),
delegation_detail: controller_did.as_ref().map(|_| {
json!({
"action": if is_update { "update" } else { "create" },
"collection": input.collection,
"rkey": input.rkey
})),
None,
None,
)
.await;
}
Ok((
StatusCode::OK,
Json(PutRecordOutput {
uri: AtUri::from_parts(&did, &input.collection, &input.rkey),
cid: record_cid.to_string(),
commit: Some(CommitInfo {
cid: commit_result.commit_cid.to_string(),
rev: commit_result.rev,
})
}),
validation_status,
}),
ops: vec![op],
modified_keys: &modified_keys,
blob_cids: &blob_cids,
},
)
.into_response())
.await?;
Ok(Json(PutRecordOutput {
uri: AtUri::from_parts(&did, &input.collection, &input.rkey),
cid: record_cid.to_string(),
commit: Some(CommitInfo {
cid: commit_result.commit_cid.to_string(),
rev: commit_result.rev,
}),
validation_status,
}))
}

View File

@@ -1,15 +1,19 @@
use crate::api::error::ApiError;
use crate::cid_types::CommitCid;
use crate::repo::tracking::TrackingBlockStore;
use crate::state::AppState;
use crate::types::{Did, Handle, Nsid, Rkey};
use bytes::Bytes;
use cid::Cid;
use jacquard_common::types::{integer::LimitedU32, string::Tid};
use jacquard_repo::commit::Commit;
use jacquard_repo::mst::Mst;
use jacquard_repo::storage::BlockStore;
use k256::ecdsa::SigningKey;
use serde_json::{Value, json};
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::OwnedMutexGuard;
use tracing::error;
use tranquil_db_traits::SequenceNumber;
use uuid::Uuid;
@@ -158,6 +162,135 @@ pub fn extract_backlinks(uri: &AtUri, record: &Value) -> Vec<Backlink> {
}
}
pub struct RepoWriteContext {
pub tracking_store: TrackingBlockStore,
pub current_root_cid: Cid,
pub prev_data_cid: Cid,
pub write_lock: OwnedMutexGuard<()>,
}
pub struct FinalizeParams<'a> {
pub did: &'a Did,
pub user_id: Uuid,
pub controller_did: Option<&'a Did>,
pub delegation_detail: Option<serde_json::Value>,
pub ops: Vec<RecordOp>,
pub modified_keys: &'a [String],
pub blob_cids: &'a [String],
}
pub async fn begin_repo_write(
state: &AppState,
user_id: Uuid,
swap_commit: Option<&str>,
) -> Result<(RepoWriteContext, Mst<TrackingBlockStore>), ApiError> {
let write_lock = state.repo_write_locks.lock(user_id).await;
let root_cid_str = state
.repo_repo
.get_repo_root_cid_by_user_id(user_id)
.await
.map_err(|e| {
error!("DB error fetching repo root: {}", e);
ApiError::InternalError(None)
})?
.ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?;
let current_root_cid = Cid::from_str(root_cid_str.as_str())
.map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))?;
if let Some(expected) = swap_commit {
let expected_cid = Cid::from_str(expected)
.map_err(|_| ApiError::InvalidSwap(Some("Invalid swap commit CID".into())))?;
if expected_cid != current_root_cid {
return Err(ApiError::InvalidSwap(Some("Repo has been modified".into())));
}
}
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
let commit_bytes = tracking_store
.get(&current_root_cid)
.await
.map_err(|e| {
error!("Failed to load commit block: {:?}", e);
ApiError::InternalError(None)
})?
.ok_or_else(|| ApiError::InternalError(Some("Commit block not found".into())))?;
let commit = Commit::from_cbor(&commit_bytes).map_err(|e| {
error!("Failed to parse commit: {:?}", e);
ApiError::InternalError(None)
})?;
let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
let ctx = RepoWriteContext {
tracking_store,
current_root_cid,
prev_data_cid: commit.data,
write_lock,
};
Ok((ctx, mst))
}
pub async fn finalize_repo_write(
state: &AppState,
ctx: RepoWriteContext,
mst: Mst<TrackingBlockStore>,
params: FinalizeParams<'_>,
) -> Result<CommitResult, ApiError> {
let new_mst_root = mst.persist().await.map_err(|e| {
error!("MST persist failed: {:?}", e);
ApiError::InternalError(None)
})?;
let written_cids: Vec<Cid> = ctx
.tracking_store
.get_all_relevant_cids()
.into_iter()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let written_cids_str: Vec<String> = written_cids.iter().map(ToString::to_string).collect();
let result = commit_and_log(
state,
CommitParams {
did: params.did,
user_id: params.user_id,
current_root_cid: Some(ctx.current_root_cid),
prev_data_cid: Some(ctx.prev_data_cid),
new_mst_root,
ops: params.ops,
blocks_cids: &written_cids_str,
blobs: params.blob_cids,
obsolete_cids: vec![ctx.current_root_cid],
},
)
.await?;
if let Some(controller_did) = params.controller_did
&& let Some(detail) = params.delegation_detail
&& let Err(e) = state
.delegation_repo
.log_delegation_action(
params.did,
controller_did,
Some(controller_did),
tranquil_db_traits::DelegationActionType::RepoWrite,
Some(detail),
None,
None,
)
.await
{
tracing::warn!("Failed to log delegation audit: {:?}", e);
}
Ok(result)
}
pub fn create_signed_commit(
did: &Did,
data: Cid,
@@ -392,9 +525,6 @@ pub async fn create_record_internal(
rkey: &Rkey,
record: &serde_json::Value,
) -> Result<(String, Cid), CommitError> {
use crate::repo::tracking::TrackingBlockStore;
use jacquard_repo::mst::Mst;
use std::sync::Arc;
let user_id: Uuid = state
.user_repo
.get_id_by_did(did)