diff --git a/src/api/feed/mod.rs b/src/api/feed/mod.rs new file mode 100644 index 0000000..abd11c2 --- /dev/null +++ b/src/api/feed/mod.rs @@ -0,0 +1,3 @@ +mod timeline; + +pub use timeline::get_timeline; diff --git a/src/api/feed/timeline.rs b/src/api/feed/timeline.rs new file mode 100644 index 0000000..94c361e --- /dev/null +++ b/src/api/feed/timeline.rs @@ -0,0 +1,237 @@ +// Yes, I know, this endpoint is an appview one, not for PDS. Who cares!! +// Yes, this only gets posts that our DB/instance knows about. Who cares!!! + +use crate::state::AppState; +use axum::{ + Json, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use jacquard_repo::storage::BlockStore; +use serde::Serialize; +use serde_json::{Value, json}; +use sqlx::Row; +use tracing::error; + +#[derive(Serialize)] +pub struct TimelineOutput { + pub feed: Vec, + pub cursor: Option, +} + +#[derive(Serialize)] +pub struct FeedViewPost { + pub post: PostView, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PostView { + pub uri: String, + pub cid: String, + pub author: AuthorView, + pub record: Value, + pub indexed_at: String, +} + +#[derive(Serialize)] +pub struct AuthorView { + pub did: String, + pub handle: String, +} + +pub async fn get_timeline( + State(state): State, + headers: axum::http::HeaderMap, +) -> Response { + let auth_header = headers.get("Authorization"); + if auth_header.is_none() { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + let token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); + + let session = sqlx::query( + "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" + ) + .bind(&token) + .fetch_optional(&state.db) + .await + .unwrap_or(None); + + let (did, key_bytes) = match session { + Some(row) => ( + row.get::("did"), + row.get::, _>("key_bytes"), + ), + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed"})), + ) + .into_response(); + } + }; + + if crate::auth::verify_token(&token, &key_bytes).is_err() { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), + ) + .into_response(); + } + + let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") + .bind(&did) + .fetch_optional(&state.db) + .await; + + let user_id: uuid::Uuid = match user_query { + Ok(Some(row)) => row.get("id"), + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "User not found"})), + ) + .into_response(); + } + }; + + let follows_query = sqlx::query( + "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow'" + ) + .bind(user_id) + .fetch_all(&state.db) + .await; + + let follow_cids: Vec = match follows_query { + Ok(rows) => rows.iter().map(|r| r.get("record_cid")).collect(), + Err(e) => { + error!("Failed to get follows: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let mut followed_dids: Vec = Vec::new(); + for cid_str in follow_cids { + let cid = match cid_str.parse::() { + Ok(c) => c, + Err(_) => continue, + }; + + let block_bytes = match state.block_store.get(&cid).await { + Ok(Some(b)) => b, + _ => continue, + }; + + let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) { + Ok(v) => v, + Err(_) => continue, + }; + + if let Some(subject) = record.get("subject").and_then(|s| s.as_str()) { + followed_dids.push(subject.to_string()); + } + } + + if followed_dids.is_empty() { + return ( + StatusCode::OK, + Json(TimelineOutput { + feed: vec![], + cursor: None, + }), + ) + .into_response(); + } + + let placeholders: Vec = followed_dids + .iter() + .enumerate() + .map(|(i, _)| format!("${}", i + 1)) + .collect(); + + let posts_query = format!( + "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle + FROM records r + JOIN repos rp ON r.repo_id = rp.user_id + JOIN users u ON rp.user_id = u.id + WHERE u.did IN ({}) AND r.collection = 'app.bsky.feed.post' + ORDER BY r.created_at DESC + LIMIT 50", + placeholders.join(", ") + ); + + let mut query = sqlx::query(&posts_query); + for did in &followed_dids { + query = query.bind(did); + } + + let posts_result = query.fetch_all(&state.db).await; + + let posts = match posts_result { + Ok(rows) => rows, + Err(e) => { + error!("Failed to get posts: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let mut feed: Vec = Vec::new(); + + for row in posts { + let record_cid: String = row.get("record_cid"); + let rkey: String = row.get("rkey"); + let created_at: chrono::DateTime = row.get("created_at"); + let author_did: String = row.get("did"); + let author_handle: String = row.get("handle"); + + let cid = match record_cid.parse::() { + Ok(c) => c, + Err(_) => continue, + }; + + let block_bytes = match state.block_store.get(&cid).await { + Ok(Some(b)) => b, + _ => continue, + }; + + let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) { + Ok(v) => v, + Err(_) => continue, + }; + + let uri = format!("at://{}/app.bsky.feed.post/{}", author_did, rkey); + + feed.push(FeedViewPost { + post: PostView { + uri, + cid: record_cid, + author: AuthorView { + did: author_did, + handle: author_handle, + }, + record, + indexed_at: created_at.to_rfc3339(), + }, + }); + } + + (StatusCode::OK, Json(TimelineOutput { feed, cursor: None })).into_response() +} diff --git a/src/api/identity/account.rs b/src/api/identity/account.rs index ba9de90..07f9b48 100644 --- a/src/api/identity/account.rs +++ b/src/api/identity/account.rs @@ -206,10 +206,10 @@ pub async fn create_account( } let mst = Mst::new(Arc::new(state.block_store.clone())); - let mst_root = match mst.root().await { + let mst_root = match mst.persist().await { Ok(c) => c, Err(e) => { - error!("Error creating MST root: {:?}", e); + error!("Error persisting MST: {:?}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"})), diff --git a/src/api/mod.rs b/src/api/mod.rs index f380623..c481a9a 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,4 @@ +pub mod feed; pub mod identity; pub mod proxy; pub mod repo; diff --git a/src/api/repo/record/delete.rs b/src/api/repo/record/delete.rs index ffe555e..9f650b4 100644 --- a/src/api/repo/record/delete.rs +++ b/src/api/repo/record/delete.rs @@ -151,17 +151,21 @@ pub async fn delete_record( // TODO: Check swapRecord if provided? Skipping for brevity/robustness - if let Err(e) = mst.delete(&key).await { - error!("Failed to delete from MST: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response(); - } + let new_mst = match mst.delete(&key).await { + Ok(m) => m, + Err(e) => { + error!("Failed to delete from MST: {:?}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response(); + } + }; - let new_mst_root = match mst.root().await { + let new_mst_root = match new_mst.persist().await { Ok(c) => c, - Err(_e) => { + Err(e) => { + error!("Failed to persist MST: {:?}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})), + Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), ) .into_response(); } diff --git a/src/api/repo/record/write.rs b/src/api/repo/record/write.rs index 7c5117e..4e85968 100644 --- a/src/api/repo/record/write.rs +++ b/src/api/repo/record/write.rs @@ -205,19 +205,22 @@ pub async fn create_record( }; let key = format!("{}/{}", collection_nsid, rkey); - if let Err(e) = mst.update(&key, record_cid).await { - error!("Failed to update MST: {:?}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError"})), - ) - .into_response(); - } + let new_mst = match mst.add(&key, record_cid).await { + Ok(m) => m, + Err(e) => { + error!("Failed to add to MST: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; - let new_mst_root = match mst.root().await { + let new_mst_root = match new_mst.persist().await { Ok(c) => c, Err(e) => { - error!("Failed to get new MST root: {:?}", e); + error!("Failed to persist MST: {:?}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"})), @@ -317,6 +320,8 @@ pub struct PutRecordInput { pub record: serde_json::Value, #[serde(rename = "swapCommit")] pub swap_commit: Option, + #[serde(rename = "swapRecord")] + pub swap_record: Option, } #[derive(Serialize)] @@ -490,18 +495,78 @@ pub async fn put_record( }; let key = format!("{}/{}", collection_nsid, rkey); - if let Err(e) = mst.update(&key, record_cid).await { - error!("Failed to update MST: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response(); - } - let new_mst_root = match mst.root().await { - Ok(c) => c, + let existing = match mst.get(&key).await { + Ok(v) => v, Err(e) => { - error!("Failed to get new MST root: {:?}", e); + error!("Failed to check MST key: {:?}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})), + Json( + json!({"error": "InternalError", "message": "Failed to check existing record"}), + ), + ) + .into_response(); + } + }; + + if let Some(swap_record_str) = &input.swap_record { + let swap_record_cid = match Cid::from_str(swap_record_str) { + Ok(c) => c, + Err(_) => { + return ( + StatusCode::BAD_REQUEST, + Json( + json!({"error": "InvalidSwapRecord", "message": "Invalid swapRecord CID"}), + ), + ) + .into_response(); + } + }; + match &existing { + Some(current_cid) if *current_cid != swap_record_cid => { + return ( + StatusCode::CONFLICT, + Json(json!({"error": "InvalidSwap", "message": "Record has been modified"})), + ) + .into_response(); + } + None => { + return ( + StatusCode::CONFLICT, + Json(json!({"error": "InvalidSwap", "message": "Record does not exist"})), + ) + .into_response(); + } + _ => {} + } + } + + let new_mst = if existing.is_some() { + match mst.update(&key, record_cid).await { + Ok(m) => m, + Err(e) => { + error!("Failed to update MST: {:?}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response(); + } + } + } else { + match mst.add(&key, record_cid).await { + Ok(m) => m, + Err(e) => { + error!("Failed to add to MST: {:?}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to add to MST: {:?}", e)}))).into_response(); + } + } + }; + + let new_mst_root = match new_mst.persist().await { + Ok(c) => c, + Err(e) => { + error!("Failed to persist MST: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), ) .into_response(); } diff --git a/src/lib.rs b/src/lib.rs index 8a3f3ee..c290936 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,6 +65,10 @@ pub fn app(state: AppState) -> Router { "/xrpc/com.atproto.repo.uploadBlob", post(api::repo::upload_blob), ) + .route( + "/xrpc/app.bsky.feed.getTimeline", + get(api::feed::get_timeline), + ) .route("/.well-known/did.json", get(api::identity::well_known_did)) .route("/u/{handle}/did.json", get(api::identity::user_did_doc)) .route("/xrpc/{*method}", any(api::proxy::proxy_handler)) diff --git a/tests/lifecycle.rs b/tests/lifecycle.rs index f53cb44..ec34ad1 100644 --- a/tests/lifecycle.rs +++ b/tests/lifecycle.rs @@ -53,7 +53,6 @@ async fn setup_new_user(handle_prefix: &str) -> (String, String) { } #[tokio::test] -#[ignore] async fn test_post_crud_lifecycle() { let client = client(); let (did, jwt) = setup_new_user("lifecycle-crud").await; @@ -221,7 +220,6 @@ async fn test_post_crud_lifecycle() { } #[tokio::test] -#[ignore] async fn test_record_update_conflict_lifecycle() { let client = client(); let (user_did, user_jwt) = setup_new_user("user-conflict").await; @@ -277,7 +275,7 @@ async fn test_record_update_conflict_lifecycle() { "$type": "app.bsky.actor.profile", "displayName": "Updated Name (v2)" }, - "swapCommit": cid_v1 // <-- Correctly point to v1 + "swapRecord": cid_v1 }); let update_res_v2 = client .post(format!( @@ -308,7 +306,7 @@ async fn test_record_update_conflict_lifecycle() { "$type": "app.bsky.actor.profile", "displayName": "Stale Update (v3)" }, - "swapCommit": cid_v1 + "swapRecord": cid_v1 }); let update_res_v3_stale = client .post(format!( @@ -335,7 +333,7 @@ async fn test_record_update_conflict_lifecycle() { "$type": "app.bsky.actor.profile", "displayName": "Good Update (v3)" }, - "swapCommit": cid_v2 // <-- Correct + "swapRecord": cid_v2 }); let update_res_v3_good = client .post(format!( @@ -448,7 +446,6 @@ async fn create_follow( } #[tokio::test] -#[ignore] async fn test_social_flow_lifecycle() { let client = client();