diff --git a/crates/tranquil-pds/src/sync/import.rs b/crates/tranquil-pds/src/sync/import.rs index e8fe823..d486aff 100644 --- a/crates/tranquil-pds/src/sync/import.rs +++ b/crates/tranquil-pds/src/sync/import.rs @@ -192,77 +192,55 @@ fn walk_mst_node( prev_key: &[u8], records: &mut Vec, ) -> Result<(), ImportError> { + use super::mst::{entries, left_child, parse_mst_entry, reconstruct_key}; + let block = blocks .get(cid) .ok_or_else(|| ImportError::BlockNotFound(cid.to_string()))?; - let value: Ipld = serde_ipld_dagcbor::from_slice(block) + let node: Ipld = serde_ipld_dagcbor::from_slice(block) .map_err(|e| ImportError::InvalidCbor(e.to_string()))?; - if let Ipld::Map(ref obj) = value { - if let Some(Ipld::Link(left_cid)) = obj.get("l") { - walk_mst_node(blocks, left_cid, prev_key, records)?; - } + if let Some(left_cid) = left_child(&node) { + walk_mst_node(blocks, &left_cid, prev_key, records)?; + } - let mut current_key = prev_key.to_vec(); + let mut current_key = prev_key.to_vec(); - if let Some(Ipld::List(entries)) = obj.get("e") { - for entry in entries { - if let Ipld::Map(entry_obj) = entry { - let prefix_len = entry_obj - .get("p") - .and_then(|p| match p { - Ipld::Integer(n) => usize::try_from(*n).ok(), - _ => None, - }) - .unwrap_or(0); + if let Some(entry_list) = entries(&node) { + entry_list + .iter() + .filter_map(parse_mst_entry) + .try_for_each(|entry| { + if let Some(ref suffix) = entry.key_suffix { + reconstruct_key(&mut current_key, entry.prefix_len, suffix); + } - let key_suffix = entry_obj.get("k").and_then(|k| { - if let Ipld::Bytes(b) = k { - Some(b.clone()) - } else { - None - } - }); + if let Some(tree_cid) = entry.subtree { + walk_mst_node(blocks, &tree_cid, ¤t_key, records)?; + } - if let Some(suffix) = key_suffix { - current_key.truncate(prefix_len); - current_key.extend_from_slice(&suffix); - } - - if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { - walk_mst_node(blocks, tree_cid, ¤t_key, records)?; - } - - let record_cid = entry_obj.get("v").and_then(|v| { - if let Ipld::Link(cid) = v { - Some(*cid) - } else { - None - } - }); - - if let Some(record_cid) = record_cid - && let Ok(full_key) = String::from_utf8(current_key.clone()) - && let Some(record_block) = blocks.get(&record_cid) - && let Ok(record_value) = - serde_ipld_dagcbor::from_slice::(record_block) - { - let blob_refs = find_blob_refs_ipld(&record_value, 0); - let parts: Vec<&str> = full_key.split('/').collect(); - if parts.len() >= 2 { - let collection = parts[..parts.len() - 1].join("/"); - let rkey = parts[parts.len() - 1].to_string(); - records.push(ImportedRecord { - collection, - rkey, - cid: record_cid, - blob_refs, - }); - } + if let Some(record_cid) = entry.value + && let Ok(full_key) = String::from_utf8(current_key.clone()) + && let Some(record_block) = blocks.get(&record_cid) + && let Ok(record_value) = + serde_ipld_dagcbor::from_slice::(record_block) + { + let blob_refs = find_blob_refs_ipld(&record_value, 0); + let parts: Vec<&str> = full_key.split('/').collect(); + if parts.len() >= 2 { + let collection = parts[..parts.len() - 1].join("/"); + let rkey = parts[parts.len() - 1].to_string(); + records.push(ImportedRecord { + collection, + rkey, + cid: record_cid, + blob_refs, + }); } } - } - } + + Ok::<_, ImportError>(()) + })?; } Ok(()) } diff --git a/crates/tranquil-pds/src/sync/mod.rs b/crates/tranquil-pds/src/sync/mod.rs index 077c5f2..3dfbdbe 100644 --- a/crates/tranquil-pds/src/sync/mod.rs +++ b/crates/tranquil-pds/src/sync/mod.rs @@ -2,6 +2,7 @@ pub mod car; pub mod firehose; pub mod frame; pub mod import; +pub mod mst; pub mod util; pub mod verify; diff --git a/crates/tranquil-pds/src/sync/mst.rs b/crates/tranquil-pds/src/sync/mst.rs new file mode 100644 index 0000000..6a5c76d --- /dev/null +++ b/crates/tranquil-pds/src/sync/mst.rs @@ -0,0 +1,67 @@ +use cid::Cid; +use ipld_core::ipld::Ipld; + +pub struct MstEntry { + pub prefix_len: usize, + pub key_suffix: Option>, + pub subtree: Option, + pub value: Option, +} + +pub fn parse_mst_entry(entry: &Ipld) -> Option { + let obj = match entry { + Ipld::Map(m) => m, + _ => return None, + }; + let prefix_len = obj + .get("p") + .and_then(|p| match p { + Ipld::Integer(n) => usize::try_from(*n).ok(), + _ => None, + }) + .unwrap_or(0); + let key_suffix = obj.get("k").and_then(|k| match k { + Ipld::Bytes(b) => Some(b.clone()), + Ipld::String(s) => Some(s.as_bytes().to_vec()), + _ => None, + }); + let subtree = obj.get("t").and_then(|t| match t { + Ipld::Link(cid) => Some(*cid), + _ => None, + }); + let value = obj.get("v").and_then(|v| match v { + Ipld::Link(cid) => Some(*cid), + _ => None, + }); + Some(MstEntry { + prefix_len, + key_suffix, + subtree, + value, + }) +} + +pub fn left_child(node: &Ipld) -> Option { + match node { + Ipld::Map(obj) => match obj.get("l") { + Some(Ipld::Link(cid)) => Some(*cid), + _ => None, + }, + _ => None, + } +} + +pub fn entries(node: &Ipld) -> Option<&Vec> { + match node { + Ipld::Map(obj) => match obj.get("e") { + Some(Ipld::List(entries)) => Some(entries), + _ => None, + }, + _ => None, + } +} + +pub fn reconstruct_key(prev_key: &mut Vec, prefix_len: usize, suffix: &[u8]) { + prev_key.truncate(prefix_len); + prev_key.extend_from_slice(suffix); +} diff --git a/crates/tranquil-pds/src/sync/util.rs b/crates/tranquil-pds/src/sync/util.rs index b4962ad..d6aa9c2 100644 --- a/crates/tranquil-pds/src/sync/util.rs +++ b/crates/tranquil-pds/src/sync/util.rs @@ -205,23 +205,38 @@ fn format_atproto_time(dt: chrono::DateTime) -> String { dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() } -fn format_identity_event(event: &SequencedEvent) -> Result, SyncFrameError> { - let frame = IdentityFrame { - did: event.did.clone(), - handle: event.handle.as_ref().map(|h| h.to_string()), - seq: event.seq.as_i64(), - time: format_atproto_time(event.created_at), - }; - let header = FrameHeader { - op: 1, - t: FrameType::Identity, - }; - let mut bytes = Vec::with_capacity(256); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; +fn serialize_cbor_pair( + header: &H, + payload: &P, + capacity: usize, +) -> Result, SyncFrameError> { + let mut bytes = Vec::with_capacity(capacity); + serde_ipld_dagcbor::to_writer(&mut bytes, header)?; + serde_ipld_dagcbor::to_writer(&mut bytes, payload)?; Ok(bytes) } +fn serialize_event_frame( + frame_type: FrameType, + payload: &P, + capacity: usize, +) -> Result, SyncFrameError> { + serialize_cbor_pair(&FrameHeader { op: 1, t: frame_type }, payload, capacity) +} + +fn format_identity_event(event: &SequencedEvent) -> Result, SyncFrameError> { + serialize_event_frame( + FrameType::Identity, + &IdentityFrame { + did: event.did.clone(), + handle: event.handle.as_ref().map(|h| h.to_string()), + seq: event.seq.as_i64(), + time: format_atproto_time(event.created_at), + }, + 256, + ) +} + fn format_account_event(event: &SequencedEvent) -> Result, SyncFrameError> { let frame = AccountFrame { did: event.did.clone(), @@ -230,13 +245,7 @@ fn format_account_event(event: &SequencedEvent) -> Result, SyncFrameErro seq: event.seq.as_i64(), time: format_atproto_time(event.created_at), }; - let header = FrameHeader { - op: 1, - t: FrameType::Account, - }; - let mut bytes = Vec::with_capacity(256); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; + let bytes = serialize_event_frame(FrameType::Account, &frame, 256)?; let hex_str: String = bytes.iter().map(|b| format!("{:02x}", b)).collect(); tracing::info!( did = %frame.did, @@ -269,33 +278,27 @@ async fn format_sync_event( extract_rev_from_commit_bytes(&commit_bytes).ok_or(SyncFrameError::RevExtraction)? }; let car_bytes = write_car_blocks(commit_cid, Some(commit_bytes), BTreeMap::new()).await?; - let frame = SyncFrame { - did: event.did.clone(), - rev, - blocks: car_bytes, - seq: event.seq.as_i64(), - time: format_atproto_time(event.created_at), - }; - let header = FrameHeader { - op: 1, - t: FrameType::Sync, - }; - let mut bytes = Vec::with_capacity(512); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; - Ok(bytes) + serialize_event_frame( + FrameType::Sync, + &SyncFrame { + did: event.did.clone(), + rev, + blocks: car_bytes, + seq: event.seq.as_i64(), + time: format_atproto_time(event.created_at), + }, + 512, + ) } -pub async fn format_event_for_sending( - state: &AppState, - event: SequencedEvent, -) -> Result, SyncFrameError> { - match event.event_type { - RepoEventType::Identity => return format_identity_event(&event), - RepoEventType::Account => return format_account_event(&event), - RepoEventType::Sync => return format_sync_event(state, &event).await, - RepoEventType::Commit => {} - } +struct CommitEventContext { + frame: CommitFrame, + commit_cid: Cid, + prev_cid: Option, + block_cids: Vec, +} + +fn prepare_commit_event(event: SequencedEvent) -> Result { let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); let prev_cid_link = event.prev_cid.clone(); let prev_data_cid_link = event.prev_data_cid.clone(); @@ -314,47 +317,81 @@ pub async fn format_event_for_sending( let prev_cid = prev_cid_link .as_ref() .and_then(|c| Cid::from_str(c.as_str()).ok()); - let mut all_cids: Vec = block_cids_str + let mut block_cids: Vec = block_cids_str .iter() .filter_map(|s| Cid::from_str(s).ok()) .filter(|c| Some(*c) != prev_cid) .collect(); - if !all_cids.contains(&commit_cid) { - all_cids.push(commit_cid); + if !block_cids.contains(&commit_cid) { + block_cids.push(commit_cid); } - if let Some(ref pc) = prev_cid + Ok(CommitEventContext { + frame, + commit_cid, + prev_cid, + block_cids, + }) +} + +fn partition_blocks( + block_cids: impl IntoIterator, + commit_cid: Cid, +) -> (Option, BTreeMap) { + let (commit_data, other_blocks): (Vec<_>, Vec<_>) = block_cids + .into_iter() + .partition(|(cid, _)| *cid == commit_cid); + let commit_bytes = commit_data.into_iter().next().map(|(_, data)| data); + let other = other_blocks.into_iter().collect(); + (commit_bytes, other) +} + +async fn finalize_commit_frame( + mut frame: CommitFrame, + commit_cid: Cid, + commit_bytes: Option, + other_blocks: BTreeMap, +) -> Result, SyncFrameError> { + if let Some(ref cb) = commit_bytes + && let Some(rev) = extract_rev_from_commit_bytes(cb) + { + frame.rev = rev; + } + frame.blocks = write_car_blocks(commit_cid, commit_bytes, other_blocks).await?; + let capacity = frame.blocks.len() + 512; + serialize_event_frame(FrameType::Commit, &frame, capacity) +} + +pub async fn format_event_for_sending( + state: &AppState, + event: SequencedEvent, +) -> Result, SyncFrameError> { + match event.event_type { + RepoEventType::Identity => return format_identity_event(&event), + RepoEventType::Account => return format_account_event(&event), + RepoEventType::Sync => return format_sync_event(state, &event).await, + RepoEventType::Commit => {} + } + let ctx = prepare_commit_event(event)?; + let mut frame = ctx.frame; + if let Some(ref pc) = ctx.prev_cid && let Ok(Some(prev_bytes)) = state.block_store.get(pc).await && let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) { frame.since = Some(rev); } - let car_bytes = if !all_cids.is_empty() { - let fetched = state.block_store.get_many(&all_cids).await?; - let (commit_data, other_blocks): (Vec<_>, Vec<_>) = all_cids - .iter() - .zip(fetched.iter()) - .filter_map(|(cid, data_opt)| data_opt.as_ref().map(|data| (*cid, data.clone()))) - .partition(|(cid, _)| *cid == commit_cid); - let commit_bytes = commit_data.into_iter().next().map(|(_, data)| data); - if let Some(ref cb) = commit_bytes - && let Some(rev) = extract_rev_from_commit_bytes(cb) - { - frame.rev = rev; - } - let blocks: std::collections::BTreeMap = other_blocks.into_iter().collect(); - write_car_blocks(commit_cid, commit_bytes, blocks).await? - } else { - Vec::new() - }; - frame.blocks = car_bytes; - let header = FrameHeader { - op: 1, - t: FrameType::Commit, - }; - let mut bytes = Vec::with_capacity(frame.blocks.len() + 512); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; - Ok(bytes) + if ctx.block_cids.is_empty() { + frame.blocks = Vec::new(); + let capacity = frame.blocks.len() + 512; + return serialize_event_frame(FrameType::Commit, &frame, capacity); + } + let fetched = state.block_store.get_many(&ctx.block_cids).await?; + let resolved = ctx + .block_cids + .iter() + .zip(fetched.iter()) + .filter_map(|(cid, data_opt)| data_opt.as_ref().map(|data| (*cid, data.clone()))); + let (commit_bytes, other_blocks) = partition_blocks(resolved, ctx.commit_cid); + finalize_commit_frame(frame, ctx.commit_cid, commit_bytes, other_blocks).await } pub async fn prefetch_blocks_for_events( @@ -413,21 +450,17 @@ fn format_sync_event_with_prefetched( Some(commit_bytes.clone()), BTreeMap::new(), ))?; - let frame = SyncFrame { - did: event.did.clone(), - rev, - blocks: car_bytes, - seq: event.seq.as_i64(), - time: format_atproto_time(event.created_at), - }; - let header = FrameHeader { - op: 1, - t: FrameType::Sync, - }; - let mut bytes = Vec::new(); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; - Ok(bytes) + serialize_event_frame( + FrameType::Sync, + &SyncFrame { + did: event.did.clone(), + rev, + blocks: car_bytes, + seq: event.seq.as_i64(), + time: format_atproto_time(event.created_at), + }, + 512, + ) } pub async fn format_event_with_prefetched_blocks( @@ -440,94 +473,51 @@ pub async fn format_event_with_prefetched_blocks( RepoEventType::Sync => return format_sync_event_with_prefetched(&event, prefetched), RepoEventType::Commit => {} } - let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); - let prev_cid_link = event.prev_cid.clone(); - let prev_data_cid_link = event.prev_data_cid.clone(); - let mut frame: CommitFrame = - event - .try_into() - .map_err(|e: crate::sync::frame::CommitFrameError| { - SyncFrameError::InvalidEvent(e.to_string()) - })?; - if let Some(ref pdc) = prev_data_cid_link - && let Ok(cid) = Cid::from_str(pdc.as_str()) - { - frame.prev_data = Some(cid); - } - let commit_cid = frame.commit; - let prev_cid = prev_cid_link - .as_ref() - .and_then(|c| Cid::from_str(c.as_str()).ok()); - let mut all_cids: Vec = block_cids_str - .iter() - .filter_map(|s| Cid::from_str(s).ok()) - .filter(|c| Some(*c) != prev_cid) - .collect(); - if !all_cids.contains(&commit_cid) { - all_cids.push(commit_cid); - } - if let Some(commit_bytes) = prefetched.get(&commit_cid) - && let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) - { - frame.rev = rev; - } - if let Some(ref pc) = prev_cid + let ctx = prepare_commit_event(event)?; + let mut frame = ctx.frame; + if let Some(ref pc) = ctx.prev_cid && let Some(prev_bytes) = prefetched.get(pc) && let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) { frame.since = Some(rev); } - let car_bytes = if !all_cids.is_empty() { - let (commit_data, other_blocks): (Vec<_>, Vec<_>) = all_cids - .into_iter() - .filter_map(|cid| prefetched.get(&cid).map(|data| (cid, data.clone()))) - .partition(|(cid, _)| *cid == commit_cid); - let commit_bytes_for_car = commit_data.into_iter().next().map(|(_, data)| data); - let blocks: BTreeMap = other_blocks.into_iter().collect(); - write_car_blocks(commit_cid, commit_bytes_for_car, blocks).await? - } else { - Vec::new() - }; - frame.blocks = car_bytes; - let header = FrameHeader { - op: 1, - t: FrameType::Commit, - }; - let mut bytes = Vec::with_capacity(frame.blocks.len() + 512); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; - Ok(bytes) + if ctx.block_cids.is_empty() { + frame.blocks = Vec::new(); + let capacity = frame.blocks.len() + 512; + return serialize_event_frame(FrameType::Commit, &frame, capacity); + } + let resolved = ctx + .block_cids + .into_iter() + .filter_map(|cid| prefetched.get(&cid).map(|data| (cid, data.clone()))); + let (commit_bytes, other_blocks) = partition_blocks(resolved, ctx.commit_cid); + finalize_commit_frame(frame, ctx.commit_cid, commit_bytes, other_blocks).await } pub fn format_info_frame( name: InfoFrameName, message: Option<&str>, ) -> Result, SyncFrameError> { - let header = FrameHeader { - op: 1, - t: FrameType::Info, - }; - let frame = InfoFrame { - name, - message: message.map(String::from), - }; - let mut bytes = Vec::with_capacity(128); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; - Ok(bytes) + serialize_event_frame( + FrameType::Info, + &InfoFrame { + name, + message: message.map(String::from), + }, + 128, + ) } pub fn format_error_frame( error: ErrorFrameName, message: Option<&str>, ) -> Result, SyncFrameError> { - let header = ErrorFrameHeader { op: -1 }; - let frame = ErrorFrameBody { - error, - message: message.map(String::from), - }; - let mut bytes = Vec::with_capacity(128); - serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; - serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; - Ok(bytes) + serialize_cbor_pair( + &ErrorFrameHeader { op: -1 }, + &ErrorFrameBody { + error, + message: message.map(String::from), + }, + 128, + ) } diff --git a/crates/tranquil-pds/src/sync/verify.rs b/crates/tranquil-pds/src/sync/verify.rs index 44f59d6..43571f0 100644 --- a/crates/tranquil-pds/src/sync/verify.rs +++ b/crates/tranquil-pds/src/sync/verify.rs @@ -205,6 +205,7 @@ impl CarVerifier { data_cid: &Cid, blocks: &HashMap, ) -> Result<(), VerifyError> { + use super::mst::{entries, left_child, parse_mst_entry}; use ipld_core::ipld::Ipld; let mut stack = vec![*data_cid]; @@ -227,65 +228,56 @@ impl CarVerifier { .ok_or_else(|| VerifyError::BlockNotFound(cid.to_string()))?; let node: Ipld = serde_ipld_dagcbor::from_slice(block) .map_err(|e| VerifyError::InvalidCbor(e.to_string()))?; - if let Ipld::Map(ref obj) = node { - if let Some(Ipld::Link(left_cid)) = obj.get("l") { - if !blocks.contains_key(left_cid) { - return Err(VerifyError::BlockNotFound(format!( - "MST left pointer {} not in CAR", - left_cid - ))); - } - stack.push(*left_cid); + if let Some(left_cid) = left_child(&node) { + if !blocks.contains_key(&left_cid) { + return Err(VerifyError::BlockNotFound(format!( + "MST left pointer {} not in CAR", + left_cid + ))); } - if let Some(Ipld::List(entries)) = obj.get("e") { - let mut last_full_key: Vec = Vec::new(); - for entry in entries { - if let Ipld::Map(entry_obj) = entry { - let prefix_len = entry_obj - .get("p") - .and_then(|p| match p { - Ipld::Integer(i) => usize::try_from(*i).ok(), - _ => None, - }) - .unwrap_or(0); - let key_suffix = entry_obj.get("k").and_then(|k| match k { - Ipld::Bytes(b) => Some(b.clone()), - Ipld::String(s) => Some(s.as_bytes().to_vec()), - _ => None, - }); - if let Some(suffix) = key_suffix { - let mut full_key = Vec::new(); - if prefix_len > 0 && prefix_len <= last_full_key.len() { - full_key.extend_from_slice(&last_full_key[..prefix_len]); - } - full_key.extend_from_slice(&suffix); - if !last_full_key.is_empty() && full_key <= last_full_key { - return Err(VerifyError::MstValidationFailed( - "MST keys not in sorted order".to_string(), - )); - } - last_full_key = full_key; - } - if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { - if !blocks.contains_key(tree_cid) { - return Err(VerifyError::BlockNotFound(format!( - "MST subtree {} not in CAR", - tree_cid - ))); - } - stack.push(*tree_cid); - } - if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") - && !blocks.contains_key(value_cid) + stack.push(left_cid); + } + if let Some(entry_list) = entries(&node) { + let mut last_full_key: Vec = Vec::new(); + entry_list + .iter() + .filter_map(parse_mst_entry) + .try_for_each(|entry| { + if let Some(ref suffix) = entry.key_suffix { + let mut full_key = Vec::new(); + if entry.prefix_len > 0 + && entry.prefix_len <= last_full_key.len() { - warn!( - "Record block {} referenced in MST not in CAR (may be expected for partial export)", - value_cid - ); + full_key + .extend_from_slice(&last_full_key[..entry.prefix_len]); } + full_key.extend_from_slice(suffix); + if !last_full_key.is_empty() && full_key <= last_full_key { + return Err(VerifyError::MstValidationFailed( + "MST keys not in sorted order".to_string(), + )); + } + last_full_key = full_key; } - } - } + if let Some(tree_cid) = entry.subtree { + if !blocks.contains_key(&tree_cid) { + return Err(VerifyError::BlockNotFound(format!( + "MST subtree {} not in CAR", + tree_cid + ))); + } + stack.push(tree_cid); + } + if let Some(value_cid) = entry.value + && !blocks.contains_key(&value_cid) + { + warn!( + "Record block {} referenced in MST not in CAR (may be expected for partial export)", + value_cid + ); + } + Ok(()) + })?; } } debug!(