diff --git a/.sqlx/query-7bb1388dec372fe749462cd9b604e5802b770aeb110462208988141d31c86c92.json b/.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json similarity index 66% rename from .sqlx/query-7bb1388dec372fe749462cd9b604e5802b770aeb110462208988141d31c86c92.json rename to .sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json index e5184b9..1f4d972 100644 --- a/.sqlx/query-7bb1388dec372fe749462cd9b604e5802b770aeb110462208988141d31c86c92.json +++ b/.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT k.key_bytes, k.encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", + "query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", "describe": { "columns": [ { @@ -24,5 +24,5 @@ true ] }, - "hash": "7bb1388dec372fe749462cd9b604e5802b770aeb110462208988141d31c86c92" + "hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b" } diff --git a/.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json b/.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json new file mode 100644 index 0000000..4096224 --- /dev/null +++ b/.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "val", + "type_info": "Int4" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + }, + "nullable": [ + null + ] + }, + "hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288" +} diff --git a/.sqlx/query-c61fc3b2fbdf6891269908ef21f13dcabdc3b032e9f767becae34ca176df18b6.json b/.sqlx/query-8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14.json similarity index 50% rename from .sqlx/query-c61fc3b2fbdf6891269908ef21f13dcabdc3b032e9f767becae34ca176df18b6.json rename to .sqlx/query-8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14.json index a892f59..8d08950 100644 --- a/.sqlx/query-c61fc3b2fbdf6891269908ef21f13dcabdc3b032e9f767becae34ca176df18b6.json +++ b/.sqlx/query-8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()", + "query": "INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()", "describe": { "columns": [], "parameters": { @@ -8,10 +8,11 @@ "Uuid", "Text", "Text", + "Text", "Text" ] }, "nullable": [] }, - "hash": "c61fc3b2fbdf6891269908ef21f13dcabdc3b032e9f767becae34ca176df18b6" + "hash": "8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14" } diff --git a/.sqlx/query-bf60faafb5c79a149ba237a984f78d068b5d691f6762641412a5aa1517605c04.json b/.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json similarity index 72% rename from .sqlx/query-bf60faafb5c79a149ba237a984f78d068b5d691f6762641412a5aa1517605c04.json rename to .sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json index 9d483f0..ae031f0 100644 --- a/.sqlx/query-bf60faafb5c79a149ba237a984f78d068b5d691f6762641412a5aa1517605c04.json +++ b/.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow'", + "query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000", "describe": { "columns": [ { @@ -18,5 +18,5 @@ false ] }, - "hash": "bf60faafb5c79a149ba237a984f78d068b5d691f6762641412a5aa1517605c04" + "hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc" } diff --git a/.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json b/.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json new file mode 100644 index 0000000..e63f03c --- /dev/null +++ b/.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json @@ -0,0 +1,47 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "record_cid", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "collection", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "rkey", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 4, + "name": "repo_rev", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + true + ] + }, + "hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e" +} diff --git a/TODO.md b/TODO.md index 130235e..da4cad6 100644 --- a/TODO.md +++ b/TODO.md @@ -168,15 +168,15 @@ These endpoints need to be implemented at the PDS level (not just proxied to app - [x] Implement `app.bsky.actor.getProfiles` (PDS-level with proxy fallback). ### Feed (`app.bsky.feed`) -These are implemented at PDS level to enable local-first reads: -- [ ] Implement `app.bsky.feed.getTimeline` (PDS-level with proxy). -- [ ] Implement `app.bsky.feed.getAuthorFeed` (PDS-level with proxy). -- [ ] Implement `app.bsky.feed.getActorLikes` (PDS-level with proxy). -- [ ] Implement `app.bsky.feed.getPostThread` (PDS-level with proxy). -- [ ] Implement `app.bsky.feed.getFeed` (PDS-level with proxy). +These are implemented at PDS level to enable local-first reads (read-after-write pattern): +- [x] Implement `app.bsky.feed.getTimeline` (PDS-level with proxy + RAW). +- [x] Implement `app.bsky.feed.getAuthorFeed` (PDS-level with proxy + RAW). +- [x] Implement `app.bsky.feed.getActorLikes` (PDS-level with proxy + RAW). +- [x] Implement `app.bsky.feed.getPostThread` (PDS-level with proxy + RAW + NotFound handling). +- [x] Implement `app.bsky.feed.getFeed` (proxy to feed generator). ### Notification (`app.bsky.notification`) -- [ ] Implement `app.bsky.notification.registerPush` (push notification registration). +- [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied). ## Deprecated Sync Endpoints (for compatibility) - [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility). diff --git a/migrations/202512211600_add_repo_rev.sql b/migrations/202512211600_add_repo_rev.sql new file mode 100644 index 0000000..6e81b79 --- /dev/null +++ b/migrations/202512211600_add_repo_rev.sql @@ -0,0 +1,2 @@ +ALTER TABLE records ADD COLUMN repo_rev TEXT; +CREATE INDEX idx_records_repo_rev ON records(repo_rev); diff --git a/src/api/actor/profile.rs b/src/api/actor/profile.rs index e195052..247f275 100644 --- a/src/api/actor/profile.rs +++ b/src/api/actor/profile.rs @@ -125,10 +125,18 @@ pub async fn get_profile( ) -> Response { let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); - let auth_did = auth_header.and_then(|h| { - let token = crate::auth::extract_bearer_token_from_header(Some(h))?; - crate::auth::get_did_from_token(&token).ok() - }); + let auth_did = if let Some(h) = auth_header { + if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { + match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => Some(user.did), + Err(_) => None, + } + } else { + None + } + } else { + None + }; let mut query_params = HashMap::new(); query_params.insert("actor".to_string(), params.actor.clone()); @@ -167,10 +175,18 @@ pub async fn get_profiles( ) -> Response { let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); - let auth_did = auth_header.and_then(|h| { - let token = crate::auth::extract_bearer_token_from_header(Some(h))?; - crate::auth::get_did_from_token(&token).ok() - }); + let auth_did = if let Some(h) = auth_header { + if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { + match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => Some(user.did), + Err(_) => None, + } + } else { + None + } + } else { + None + }; let mut query_params = HashMap::new(); query_params.insert("actors".to_string(), params.actors.clone()); diff --git a/src/api/error.rs b/src/api/error.rs index f30f2bd..2ab5425 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -44,13 +44,19 @@ pub enum ApiError { InvitesDisabled, DatabaseError, UpstreamFailure, + UpstreamTimeout, + UpstreamUnavailable(String), + UpstreamError { status: u16, error: Option, message: Option }, } impl ApiError { fn status_code(&self) -> StatusCode { match self { - Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => { - StatusCode::INTERNAL_SERVER_ERROR + Self::InternalError | Self::DatabaseError => StatusCode::INTERNAL_SERVER_ERROR, + Self::UpstreamFailure | Self::UpstreamUnavailable(_) => StatusCode::BAD_GATEWAY, + Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT, + Self::UpstreamError { status, .. } => { + StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY) } Self::AuthenticationRequired | Self::AuthenticationFailed @@ -83,7 +89,15 @@ impl ApiError { fn error_name(&self) -> &'static str { match self { - Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => "InternalError", + Self::InternalError | Self::DatabaseError => "InternalError", + Self::UpstreamFailure | Self::UpstreamUnavailable(_) => "UpstreamFailure", + Self::UpstreamTimeout => "UpstreamTimeout", + Self::UpstreamError { error, .. } => { + if let Some(e) = error { + return Box::leak(e.clone().into_boxed_str()); + } + "UpstreamError" + } Self::AuthenticationRequired => "AuthenticationRequired", Self::AuthenticationFailed | Self::AuthenticationFailedMsg(_) => "AuthenticationFailed", Self::InvalidToken => "InvalidToken", @@ -116,10 +130,25 @@ impl ApiError { Self::AuthenticationFailedMsg(msg) | Self::ExpiredTokenMsg(msg) | Self::InvalidRequest(msg) - | Self::RepoNotFoundMsg(msg) => Some(msg.clone()), + | Self::RepoNotFoundMsg(msg) + | Self::UpstreamUnavailable(msg) => Some(msg.clone()), + Self::UpstreamError { message, .. } => message.clone(), + Self::UpstreamTimeout => Some("Upstream service timed out".to_string()), _ => None, } } + + pub fn from_upstream_response( + status: u16, + body: &[u8], + ) -> Self { + if let Ok(parsed) = serde_json::from_slice::(body) { + let error = parsed.get("error").and_then(|v| v.as_str()).map(String::from); + let message = parsed.get("message").and_then(|v| v.as_str()).map(String::from); + return Self::UpstreamError { status, error, message }; + } + Self::UpstreamError { status, error: None, message: None } + } } impl IntoResponse for ApiError { diff --git a/src/api/feed/actor_likes.rs b/src/api/feed/actor_likes.rs new file mode 100644 index 0000000..1f0120f --- /dev/null +++ b/src/api/feed/actor_likes.rs @@ -0,0 +1,158 @@ +use crate::api::read_after_write::{ + extract_repo_rev, format_munged_response, get_local_lag, get_records_since_rev, + proxy_to_appview, FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, +}; +use crate::state::AppState; +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::Deserialize; +use serde_json::Value; +use std::collections::HashMap; +use tracing::warn; + +#[derive(Deserialize)] +pub struct GetActorLikesParams { + pub actor: String, + pub limit: Option, + pub cursor: Option, +} + +fn insert_likes_into_feed(feed: &mut Vec, likes: &[RecordDescript]) { + for like in likes { + let like_time = &like.indexed_at.to_rfc3339(); + let idx = feed + .iter() + .position(|fi| &fi.post.indexed_at < like_time) + .unwrap_or(feed.len()); + + let placeholder_post = PostView { + uri: like.record.subject.uri.clone(), + cid: like.record.subject.cid.clone(), + author: crate::api::read_after_write::AuthorView { + did: String::new(), + handle: String::new(), + display_name: None, + avatar: None, + extra: HashMap::new(), + }, + record: Value::Null, + indexed_at: like.indexed_at.to_rfc3339(), + embed: None, + reply_count: 0, + repost_count: 0, + like_count: 0, + quote_count: 0, + extra: HashMap::new(), + }; + + feed.insert( + idx, + FeedViewPost { + post: placeholder_post, + reply: None, + reason: None, + feed_context: None, + extra: HashMap::new(), + }, + ); + } +} + +pub async fn get_actor_likes( + State(state): State, + headers: axum::http::HeaderMap, + Query(params): Query, +) -> Response { + let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); + + let auth_did = if let Some(h) = auth_header { + if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { + match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => Some(user.did), + Err(_) => None, + } + } else { + None + } + } else { + None + }; + + let mut query_params = HashMap::new(); + query_params.insert("actor".to_string(), params.actor.clone()); + if let Some(limit) = params.limit { + query_params.insert("limit".to_string(), limit.to_string()); + } + if let Some(cursor) = ¶ms.cursor { + query_params.insert("cursor".to_string(), cursor.clone()); + } + + let proxy_result = + match proxy_to_appview("app.bsky.feed.getActorLikes", &query_params, auth_header).await { + Ok(r) => r, + Err(e) => return e, + }; + + if !proxy_result.status.is_success() { + return (proxy_result.status, proxy_result.body).into_response(); + } + + let rev = match extract_repo_rev(&proxy_result.headers) { + Some(r) => r, + None => return (proxy_result.status, proxy_result.body).into_response(), + }; + + let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { + Ok(f) => f, + Err(e) => { + warn!("Failed to parse actor likes response: {:?}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + let requester_did = match auth_did { + Some(d) => d, + None => return (StatusCode::OK, Json(feed_output)).into_response(), + }; + + let actor_did = if params.actor.starts_with("did:") { + params.actor.clone() + } else { + match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", params.actor) + .fetch_optional(&state.db) + .await + { + Ok(Some(did)) => did, + Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(), + Err(e) => { + warn!("Database error resolving actor handle: {:?}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + } + }; + + if actor_did != requester_did { + return (StatusCode::OK, Json(feed_output)).into_response(); + } + + let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { + Ok(r) => r, + Err(e) => { + warn!("Failed to get local records: {}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + if local_records.likes.is_empty() { + return (StatusCode::OK, Json(feed_output)).into_response(); + } + + insert_likes_into_feed(&mut feed_output.feed, &local_records.likes); + + let lag = get_local_lag(&local_records); + format_munged_response(feed_output, lag) +} diff --git a/src/api/feed/author_feed.rs b/src/api/feed/author_feed.rs new file mode 100644 index 0000000..829c92a --- /dev/null +++ b/src/api/feed/author_feed.rs @@ -0,0 +1,169 @@ +use crate::api::read_after_write::{ + extract_repo_rev, format_local_post, format_munged_response, get_local_lag, + get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost, + ProfileRecord, RecordDescript, +}; +use crate::state::AppState; +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::Deserialize; +use std::collections::HashMap; +use tracing::warn; + +#[derive(Deserialize)] +pub struct GetAuthorFeedParams { + pub actor: String, + pub limit: Option, + pub cursor: Option, + pub filter: Option, + #[serde(rename = "includePins")] + pub include_pins: Option, +} + +fn update_author_profile_in_feed( + feed: &mut [FeedViewPost], + author_did: &str, + local_profile: &RecordDescript, +) { + for item in feed.iter_mut() { + if item.post.author.did == author_did { + if let Some(ref display_name) = local_profile.record.display_name { + item.post.author.display_name = Some(display_name.clone()); + } + } + } +} + +pub async fn get_author_feed( + State(state): State, + headers: axum::http::HeaderMap, + Query(params): Query, +) -> Response { + let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); + + let auth_did = if let Some(h) = auth_header { + if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { + match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => Some(user.did), + Err(_) => None, + } + } else { + None + } + } else { + None + }; + + let mut query_params = HashMap::new(); + query_params.insert("actor".to_string(), params.actor.clone()); + if let Some(limit) = params.limit { + query_params.insert("limit".to_string(), limit.to_string()); + } + if let Some(cursor) = ¶ms.cursor { + query_params.insert("cursor".to_string(), cursor.clone()); + } + if let Some(filter) = ¶ms.filter { + query_params.insert("filter".to_string(), filter.clone()); + } + if let Some(include_pins) = params.include_pins { + query_params.insert("includePins".to_string(), include_pins.to_string()); + } + + let proxy_result = + match proxy_to_appview("app.bsky.feed.getAuthorFeed", &query_params, auth_header).await { + Ok(r) => r, + Err(e) => return e, + }; + + if !proxy_result.status.is_success() { + return (proxy_result.status, proxy_result.body).into_response(); + } + + let rev = match extract_repo_rev(&proxy_result.headers) { + Some(r) => r, + None => return (proxy_result.status, proxy_result.body).into_response(), + }; + + let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { + Ok(f) => f, + Err(e) => { + warn!("Failed to parse author feed response: {:?}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + let requester_did = match auth_did { + Some(d) => d, + None => return (StatusCode::OK, Json(feed_output)).into_response(), + }; + + let actor_did = if params.actor.starts_with("did:") { + params.actor.clone() + } else { + match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", params.actor) + .fetch_optional(&state.db) + .await + { + Ok(Some(did)) => did, + Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(), + Err(e) => { + warn!("Database error resolving actor handle: {:?}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + } + }; + + if actor_did != requester_did { + return (StatusCode::OK, Json(feed_output)).into_response(); + } + + let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { + Ok(r) => r, + Err(e) => { + warn!("Failed to get local records: {}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + if local_records.count == 0 { + return (StatusCode::OK, Json(feed_output)).into_response(); + } + + let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) + .fetch_optional(&state.db) + .await + { + Ok(Some(h)) => h, + Ok(None) => requester_did.clone(), + Err(e) => { + warn!("Database error fetching handle: {:?}", e); + requester_did.clone() + } + }; + + if let Some(ref local_profile) = local_records.profile { + update_author_profile_in_feed(&mut feed_output.feed, &requester_did, local_profile); + } + + let local_posts: Vec<_> = local_records + .posts + .iter() + .map(|p| { + format_local_post( + p, + &requester_did, + &handle, + local_records.profile.as_ref(), + ) + }) + .collect(); + + insert_posts_into_feed(&mut feed_output.feed, local_posts); + + let lag = get_local_lag(&local_records); + format_munged_response(feed_output, lag) +} diff --git a/src/api/feed/custom_feed.rs b/src/api/feed/custom_feed.rs new file mode 100644 index 0000000..e1ad855 --- /dev/null +++ b/src/api/feed/custom_feed.rs @@ -0,0 +1,131 @@ +use crate::api::proxy_client::{ + is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, MAX_RESPONSE_SIZE, +}; +use crate::api::ApiError; +use crate::state::AppState; +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use serde::Deserialize; +use std::collections::HashMap; +use tracing::{error, info}; + +#[derive(Deserialize)] +pub struct GetFeedParams { + pub feed: String, + pub limit: Option, + pub cursor: Option, +} + +pub async fn get_feed( + State(state): State, + headers: axum::http::HeaderMap, + Query(params): Query, +) -> Response { + let token = match crate::auth::extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + Some(t) => t, + None => return ApiError::AuthenticationRequired.into_response(), + }; + + if let Err(e) = crate::auth::validate_bearer_token(&state.db, &token).await { + return ApiError::from(e).into_response(); + }; + + if let Err(e) = validate_at_uri(¶ms.feed) { + return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response(); + } + + let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); + + let appview_url = match std::env::var("APPVIEW_URL") { + Ok(url) => url, + Err(_) => { + return ApiError::UpstreamUnavailable("No upstream AppView configured".to_string()) + .into_response(); + } + }; + + if let Err(e) = is_ssrf_safe(&appview_url) { + error!("SSRF check failed for appview URL: {}", e); + return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) + .into_response(); + } + + let limit = validate_limit(params.limit, 50, 100); + let mut query_params = HashMap::new(); + query_params.insert("feed".to_string(), params.feed.clone()); + query_params.insert("limit".to_string(), limit.to_string()); + if let Some(cursor) = ¶ms.cursor { + query_params.insert("cursor".to_string(), cursor.clone()); + } + + let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", appview_url); + info!(target = %target_url, feed = %params.feed, "Proxying getFeed request"); + + let client = proxy_client(); + let mut request_builder = client.get(&target_url).query(&query_params); + + if let Some(auth) = auth_header { + request_builder = request_builder.header("Authorization", auth); + } + + match request_builder.send().await { + Ok(resp) => { + let status = + StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); + + let content_length = resp.content_length().unwrap_or(0); + if content_length > MAX_RESPONSE_SIZE { + error!( + content_length, + max = MAX_RESPONSE_SIZE, + "getFeed response too large" + ); + return ApiError::UpstreamFailure.into_response(); + } + + let resp_headers = resp.headers().clone(); + let body = match resp.bytes().await { + Ok(b) => { + if b.len() as u64 > MAX_RESPONSE_SIZE { + error!(len = b.len(), "getFeed response body exceeded limit"); + return ApiError::UpstreamFailure.into_response(); + } + b + } + Err(e) => { + error!(error = ?e, "Error reading getFeed response"); + return ApiError::UpstreamFailure.into_response(); + } + }; + + let mut response_builder = axum::response::Response::builder().status(status); + if let Some(ct) = resp_headers.get("content-type") { + response_builder = response_builder.header("content-type", ct); + } + + match response_builder.body(axum::body::Body::from(body)) { + Ok(r) => r, + Err(e) => { + error!(error = ?e, "Error building getFeed response"); + ApiError::UpstreamFailure.into_response() + } + } + } + Err(e) => { + error!(error = ?e, "Error proxying getFeed"); + if e.is_timeout() { + ApiError::UpstreamTimeout.into_response() + } else if e.is_connect() { + ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) + .into_response() + } else { + ApiError::UpstreamFailure.into_response() + } + } + } +} diff --git a/src/api/feed/mod.rs b/src/api/feed/mod.rs index abd11c2..ea4c0a2 100644 --- a/src/api/feed/mod.rs +++ b/src/api/feed/mod.rs @@ -1,3 +1,11 @@ +mod actor_likes; +mod author_feed; +mod custom_feed; +mod post_thread; mod timeline; +pub use actor_likes::get_actor_likes; +pub use author_feed::get_author_feed; +pub use custom_feed::get_feed; +pub use post_thread::get_post_thread; pub use timeline::get_timeline; diff --git a/src/api/feed/post_thread.rs b/src/api/feed/post_thread.rs new file mode 100644 index 0000000..1a938f5 --- /dev/null +++ b/src/api/feed/post_thread.rs @@ -0,0 +1,322 @@ +use crate::api::read_after_write::{ + extract_repo_rev, format_local_post, format_munged_response, get_local_lag, + get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript, +}; +use crate::state::AppState; +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use tracing::warn; + +#[derive(Deserialize)] +pub struct GetPostThreadParams { + pub uri: String, + pub depth: Option, + #[serde(rename = "parentHeight")] + pub parent_height: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ThreadViewPost { + #[serde(rename = "$type")] + pub thread_type: Option, + pub post: PostView, + #[serde(skip_serializing_if = "Option::is_none")] + pub parent: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub replies: Option>, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ThreadNode { + Post(ThreadViewPost), + NotFound(ThreadNotFound), + Blocked(ThreadBlocked), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ThreadNotFound { + #[serde(rename = "$type")] + pub thread_type: String, + pub uri: String, + pub not_found: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ThreadBlocked { + #[serde(rename = "$type")] + pub thread_type: String, + pub uri: String, + pub blocked: bool, + pub author: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PostThreadOutput { + pub thread: ThreadNode, + #[serde(skip_serializing_if = "Option::is_none")] + pub threadgate: Option, +} + +const MAX_THREAD_DEPTH: usize = 10; + +fn add_replies_to_thread( + thread: &mut ThreadViewPost, + local_posts: &[RecordDescript], + author_did: &str, + author_handle: &str, + depth: usize, +) { + if depth >= MAX_THREAD_DEPTH { + return; + } + + let thread_uri = &thread.post.uri; + + let replies: Vec<_> = local_posts + .iter() + .filter(|p| { + p.record + .reply + .as_ref() + .and_then(|r| r.get("parent")) + .and_then(|parent| parent.get("uri")) + .and_then(|u| u.as_str()) + == Some(thread_uri) + }) + .map(|p| { + let post_view = format_local_post(p, author_did, author_handle, None); + ThreadNode::Post(ThreadViewPost { + thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), + post: post_view, + parent: None, + replies: None, + extra: HashMap::new(), + }) + }) + .collect(); + + if !replies.is_empty() { + match &mut thread.replies { + Some(existing) => existing.extend(replies), + None => thread.replies = Some(replies), + } + } + + if let Some(ref mut existing_replies) = thread.replies { + for reply in existing_replies.iter_mut() { + if let ThreadNode::Post(reply_thread) = reply { + add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1); + } + } + } +} + +pub async fn get_post_thread( + State(state): State, + headers: axum::http::HeaderMap, + Query(params): Query, +) -> Response { + let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); + + let auth_did = if let Some(h) = auth_header { + if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { + match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => Some(user.did), + Err(_) => None, + } + } else { + None + } + } else { + None + }; + + let mut query_params = HashMap::new(); + query_params.insert("uri".to_string(), params.uri.clone()); + if let Some(depth) = params.depth { + query_params.insert("depth".to_string(), depth.to_string()); + } + if let Some(parent_height) = params.parent_height { + query_params.insert("parentHeight".to_string(), parent_height.to_string()); + } + + let proxy_result = + match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await { + Ok(r) => r, + Err(e) => return e, + }; + + if proxy_result.status == StatusCode::NOT_FOUND { + return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await; + } + + if !proxy_result.status.is_success() { + return (proxy_result.status, proxy_result.body).into_response(); + } + + let rev = match extract_repo_rev(&proxy_result.headers) { + Some(r) => r, + None => return (proxy_result.status, proxy_result.body).into_response(), + }; + + let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) { + Ok(t) => t, + Err(e) => { + warn!("Failed to parse post thread response: {:?}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + let requester_did = match auth_did { + Some(d) => d, + None => return (StatusCode::OK, Json(thread_output)).into_response(), + }; + + let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { + Ok(r) => r, + Err(e) => { + warn!("Failed to get local records: {}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + if local_records.posts.is_empty() { + return (StatusCode::OK, Json(thread_output)).into_response(); + } + + let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) + .fetch_optional(&state.db) + .await + { + Ok(Some(h)) => h, + Ok(None) => requester_did.clone(), + Err(e) => { + warn!("Database error fetching handle: {:?}", e); + requester_did.clone() + } + }; + + if let ThreadNode::Post(ref mut thread_post) = thread_output.thread { + add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0); + } + + let lag = get_local_lag(&local_records); + format_munged_response(thread_output, lag) +} + +async fn handle_not_found( + state: &AppState, + uri: &str, + auth_did: Option, + headers: &axum::http::HeaderMap, +) -> Response { + let rev = match extract_repo_rev(headers) { + Some(r) => r, + None => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Post not found"})), + ) + .into_response() + } + }; + + let requester_did = match auth_did { + Some(d) => d, + None => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Post not found"})), + ) + .into_response() + } + }; + + let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect(); + if uri_parts.len() != 3 { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Post not found"})), + ) + .into_response(); + } + + let post_did = uri_parts[0]; + if post_did != requester_did { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Post not found"})), + ) + .into_response(); + } + + let local_records = match get_records_since_rev(state, &requester_did, &rev).await { + Ok(r) => r, + Err(_) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Post not found"})), + ) + .into_response() + } + }; + + let local_post = local_records.posts.iter().find(|p| p.uri == uri); + + let local_post = match local_post { + Some(p) => p, + None => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Post not found"})), + ) + .into_response() + } + }; + + let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) + .fetch_optional(&state.db) + .await + { + Ok(Some(h)) => h, + Ok(None) => requester_did.clone(), + Err(e) => { + warn!("Database error fetching handle: {:?}", e); + requester_did.clone() + } + }; + + let post_view = format_local_post( + local_post, + &requester_did, + &handle, + local_records.profile.as_ref(), + ); + + let thread = PostThreadOutput { + thread: ThreadNode::Post(ThreadViewPost { + thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), + post: post_view, + parent: None, + replies: None, + extra: HashMap::new(), + }), + threadgate: None, + }; + + let lag = get_local_lag(&local_records); + format_munged_response(thread, lag) +} diff --git a/src/api/feed/timeline.rs b/src/api/feed/timeline.rs index 078aed3..e8ed2a5 100644 --- a/src/api/feed/timeline.rs +++ b/src/api/feed/timeline.rs @@ -1,51 +1,35 @@ -// Yes, I know, this endpoint is an appview one, not for PDS. Who cares!! -// Yes, this only gets posts that our DB/instance knows about. Who cares!!! - +use crate::api::read_after_write::{ + extract_repo_rev, format_local_post, format_munged_response, get_local_lag, + get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost, + PostView, +}; use crate::state::AppState; use axum::{ - Json, - extract::State, + extract::{Query, State}, http::StatusCode, response::{IntoResponse, Response}, + Json, }; use jacquard_repo::storage::BlockStore; -use serde::Serialize; -use serde_json::{Value, json}; -use tracing::error; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::collections::HashMap; +use tracing::warn; -#[derive(Serialize)] -pub struct TimelineOutput { - pub feed: Vec, +#[derive(Deserialize)] +pub struct GetTimelineParams { + pub algorithm: Option, + pub limit: Option, pub cursor: Option, } -#[derive(Serialize)] -pub struct FeedViewPost { - pub post: PostView, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PostView { - pub uri: String, - pub cid: String, - pub author: AuthorView, - pub record: Value, - pub indexed_at: String, -} - -#[derive(Serialize)] -pub struct AuthorView { - pub did: String, - pub handle: String, -} - pub async fn get_timeline( State(state): State, headers: axum::http::HeaderMap, + Query(params): Query, ) -> Response { let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()) + headers.get("Authorization").and_then(|h| h.to_str().ok()), ) { Some(t) => t, None => { @@ -68,32 +52,131 @@ pub async fn get_timeline( } }; - let user_query = sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did) - .fetch_optional(&state.db) - .await; + match std::env::var("APPVIEW_URL") { + Ok(url) if !url.starts_with("http://127.0.0.1") => { + return get_timeline_with_appview(&state, &headers, ¶ms, &auth_user.did).await; + } + _ => {} + } - let user_id = match user_query { - Ok(Some(row)) => row.id, - _ => { + get_timeline_local_only(&state, &auth_user.did).await +} + +async fn get_timeline_with_appview( + state: &AppState, + headers: &axum::http::HeaderMap, + params: &GetTimelineParams, + auth_did: &str, +) -> Response { + let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); + + let mut query_params = HashMap::new(); + if let Some(algo) = ¶ms.algorithm { + query_params.insert("algorithm".to_string(), algo.clone()); + } + if let Some(limit) = params.limit { + query_params.insert("limit".to_string(), limit.to_string()); + } + if let Some(cursor) = ¶ms.cursor { + query_params.insert("cursor".to_string(), cursor.clone()); + } + + let proxy_result = + match proxy_to_appview("app.bsky.feed.getTimeline", &query_params, auth_header).await { + Ok(r) => r, + Err(e) => return e, + }; + + if !proxy_result.status.is_success() { + return (proxy_result.status, proxy_result.body).into_response(); + } + + let rev = extract_repo_rev(&proxy_result.headers); + if rev.is_none() { + return (proxy_result.status, proxy_result.body).into_response(); + } + let rev = rev.unwrap(); + + let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { + Ok(f) => f, + Err(e) => { + warn!("Failed to parse timeline response: {:?}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + let local_records = match get_records_since_rev(state, auth_did, &rev).await { + Ok(r) => r, + Err(e) => { + warn!("Failed to get local records: {}", e); + return (proxy_result.status, proxy_result.body).into_response(); + } + }; + + if local_records.count == 0 { + return (proxy_result.status, proxy_result.body).into_response(); + } + + let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did) + .fetch_optional(&state.db) + .await + { + Ok(Some(h)) => h, + Ok(None) => auth_did.to_string(), + Err(e) => { + warn!("Database error fetching handle: {:?}", e); + auth_did.to_string() + } + }; + + let local_posts: Vec<_> = local_records + .posts + .iter() + .map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref())) + .collect(); + + insert_posts_into_feed(&mut feed_output.feed, local_posts); + + let lag = get_local_lag(&local_records); + format_munged_response(feed_output, lag) +} + +async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response { + let user_id: uuid::Uuid = match sqlx::query_scalar!( + "SELECT id FROM users WHERE did = $1", + auth_did + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(id)) => id, + Ok(None) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"})), ) .into_response(); } + Err(e) => { + warn!("Database error fetching user: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Database error"})), + ) + .into_response(); + } }; let follows_query = sqlx::query!( - "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow'", + "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000", user_id ) - .fetch_all(&state.db) - .await; + .fetch_all(&state.db) + .await; let follow_cids: Vec = match follows_query { Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(), - Err(e) => { - error!("Failed to get follows: {:?}", e); + Err(_) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"})), @@ -127,7 +210,7 @@ pub async fn get_timeline( if followed_dids.is_empty() { return ( StatusCode::OK, - Json(TimelineOutput { + Json(FeedOutput { feed: vec![], cursor: None, }), @@ -150,8 +233,7 @@ pub async fn get_timeline( let posts = match posts_result { Ok(rows) => rows, - Err(e) => { - error!("Failed to get posts: {:?}", e); + Err(_) => { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"})), @@ -190,15 +272,28 @@ pub async fn get_timeline( post: PostView { uri, cid: record_cid, - author: AuthorView { + author: crate::api::read_after_write::AuthorView { did: author_did, handle: author_handle, + display_name: None, + avatar: None, + extra: HashMap::new(), }, record, indexed_at: created_at.to_rfc3339(), + embed: None, + reply_count: 0, + repost_count: 0, + like_count: 0, + quote_count: 0, + extra: HashMap::new(), }, + reply: None, + reason: None, + feed_context: None, + extra: HashMap::new(), }); } - (StatusCode::OK, Json(TimelineOutput { feed, cursor: None })).into_response() + (StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response() } diff --git a/src/api/mod.rs b/src/api/mod.rs index ebd7535..61bba8e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -4,9 +4,13 @@ pub mod error; pub mod feed; pub mod identity; pub mod moderation; +pub mod notification; pub mod proxy; +pub mod proxy_client; +pub mod read_after_write; pub mod repo; pub mod server; pub mod validation; pub use error::ApiError; +pub use proxy_client::{proxy_client, validate_at_uri, validate_did, validate_limit, AtUriParts}; diff --git a/src/api/notification/mod.rs b/src/api/notification/mod.rs new file mode 100644 index 0000000..0d60b0a --- /dev/null +++ b/src/api/notification/mod.rs @@ -0,0 +1,3 @@ +mod register_push; + +pub use register_push::register_push; diff --git a/src/api/notification/register_push.rs b/src/api/notification/register_push.rs new file mode 100644 index 0000000..7da7444 --- /dev/null +++ b/src/api/notification/register_push.rs @@ -0,0 +1,166 @@ +use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did}; +use crate::api::ApiError; +use crate::state::AppState; +use axum::{ + extract::State, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, + Json, +}; +use serde::Deserialize; +use serde_json::json; +use tracing::{error, info}; + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RegisterPushInput { + pub service_did: String, + pub token: String, + pub platform: String, + pub app_id: String, +} + +const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"]; + +pub async fn register_push( + State(state): State, + headers: HeaderMap, + Json(input): Json, +) -> Response { + let token = match crate::auth::extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + Some(t) => t, + None => return ApiError::AuthenticationRequired.into_response(), + }; + + let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => user, + Err(e) => return ApiError::from(e).into_response(), + }; + + if let Err(e) = validate_did(&input.service_did) { + return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response(); + } + + if input.token.is_empty() || input.token.len() > 4096 { + return ApiError::InvalidRequest("Invalid push token".to_string()).into_response(); + } + + if !VALID_PLATFORMS.contains(&input.platform.as_str()) { + return ApiError::InvalidRequest(format!( + "Invalid platform. Must be one of: {}", + VALID_PLATFORMS.join(", ") + )) + .into_response(); + } + + if input.app_id.is_empty() || input.app_id.len() > 256 { + return ApiError::InvalidRequest("Invalid appId".to_string()).into_response(); + } + + let appview_url = match std::env::var("APPVIEW_URL") { + Ok(url) => url, + Err(_) => { + return ApiError::UpstreamUnavailable("No upstream AppView configured".to_string()) + .into_response(); + } + }; + + if let Err(e) = is_ssrf_safe(&appview_url) { + error!("SSRF check failed for appview URL: {}", e); + return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) + .into_response(); + } + + let key_row = match sqlx::query!( + "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", + auth_user.did + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + Ok(None) => { + error!(did = %auth_user.did, "No signing key found for user"); + return ApiError::InternalError.into_response(); + } + Err(e) => { + error!(error = ?e, "Database error fetching signing key"); + return ApiError::DatabaseError.into_response(); + } + }; + + let decrypted_key = + match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) { + Ok(k) => k, + Err(e) => { + error!(error = ?e, "Failed to decrypt signing key"); + return ApiError::InternalError.into_response(); + } + }; + + let service_token = match crate::auth::create_service_token( + &auth_user.did, + &input.service_did, + "app.bsky.notification.registerPush", + &decrypted_key, + ) { + Ok(t) => t, + Err(e) => { + error!(error = ?e, "Failed to create service token"); + return ApiError::InternalError.into_response(); + } + }; + + let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", appview_url); + info!( + target = %target_url, + service_did = %input.service_did, + platform = %input.platform, + "Proxying registerPush request" + ); + + let client = proxy_client(); + let request_body = json!({ + "serviceDid": input.service_did, + "token": input.token, + "platform": input.platform, + "appId": input.app_id + }); + + match client + .post(&target_url) + .header("Authorization", format!("Bearer {}", service_token)) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await + { + Ok(resp) => { + let status = + StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); + if status.is_success() { + StatusCode::OK.into_response() + } else { + let body = resp.bytes().await.unwrap_or_default(); + error!( + status = %status, + "registerPush upstream error" + ); + ApiError::from_upstream_response(status.as_u16(), &body).into_response() + } + } + Err(e) => { + error!(error = ?e, "Error proxying registerPush"); + if e.is_timeout() { + ApiError::UpstreamTimeout.into_response() + } else if e.is_connect() { + ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) + .into_response() + } else { + ApiError::UpstreamFailure.into_response() + } + } + } +} diff --git a/src/api/proxy.rs b/src/api/proxy.rs index 397871f..89e8c2f 100644 --- a/src/api/proxy.rs +++ b/src/api/proxy.rs @@ -46,21 +46,15 @@ pub async fn proxy_handler( if let Some(token) = crate::auth::extract_bearer_token_from_header( headers.get("Authorization").and_then(|h| h.to_str().ok()) ) { - if let Ok(did) = crate::auth::get_did_from_token(&token) { - let key_row = sqlx::query!("SELECT k.key_bytes, k.encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", did) - .fetch_optional(&state.db) - .await; - - if let Ok(Some(row)) = key_row { - if let Ok(decrypted_key) = crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { - if let Ok(new_token) = - crate::auth::create_service_token(&did, aud, &method, &decrypted_key) + if let Ok(auth_user) = crate::auth::validate_bearer_token(&state.db, &token).await { + if let Some(key_bytes) = auth_user.key_bytes { + if let Ok(new_token) = + crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes) + { + if let Ok(val) = + axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) { - if let Ok(val) = - axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) - { - auth_header_val = Some(val); - } + auth_header_val = Some(val); } } } diff --git a/src/api/proxy_client.rs b/src/api/proxy_client.rs new file mode 100644 index 0000000..47ab3a8 --- /dev/null +++ b/src/api/proxy_client.rs @@ -0,0 +1,252 @@ +use reqwest::{Client, ClientBuilder, Url}; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::sync::OnceLock; +use std::time::Duration; +use tracing::warn; + +pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); +pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); +pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); +pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; + +static PROXY_CLIENT: OnceLock = OnceLock::new(); + +pub fn proxy_client() -> &'static Client { + PROXY_CLIENT.get_or_init(|| { + ClientBuilder::new() + .timeout(DEFAULT_BODY_TIMEOUT) + .connect_timeout(DEFAULT_CONNECT_TIMEOUT) + .pool_max_idle_per_host(10) + .pool_idle_timeout(Duration::from_secs(90)) + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Failed to build HTTP client - this indicates a TLS or system configuration issue") + }) +} + +pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { + let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?; + + let scheme = parsed.scheme(); + if scheme != "https" { + let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok() + || url.starts_with("http://127.0.0.1") + || url.starts_with("http://localhost"); + + if !allow_http { + return Err(SsrfError::InsecureProtocol(scheme.to_string())); + } + } + + let host = parsed.host_str().ok_or(SsrfError::NoHost)?; + + if host == "localhost" { + return Ok(()); + } + + if let Ok(ip) = host.parse::() { + if ip.is_loopback() { + return Ok(()); + } + if !is_unicast_ip(&ip) { + return Err(SsrfError::NonUnicastIp(ip.to_string())); + } + return Ok(()); + } + + let port = parsed.port().unwrap_or(if scheme == "https" { 443 } else { 80 }); + let socket_addrs: Vec = match (host, port).to_socket_addrs() { + Ok(addrs) => addrs.collect(), + Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())), + }; + + for addr in &socket_addrs { + if !is_unicast_ip(&addr.ip()) { + warn!( + "DNS resolution for {} returned non-unicast IP: {}", + host, + addr.ip() + ); + return Err(SsrfError::NonUnicastIp(addr.ip().to_string())); + } + } + + Ok(()) +} + +fn is_unicast_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => { + !v4.is_loopback() + && !v4.is_broadcast() + && !v4.is_multicast() + && !v4.is_unspecified() + && !v4.is_link_local() + && !is_private_v4(v4) + } + IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(), + } +} + +fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { + let octets = ip.octets(); + octets[0] == 10 + || (octets[0] == 172 && (16..=31).contains(&octets[1])) + || (octets[0] == 192 && octets[1] == 168) + || (octets[0] == 169 && octets[1] == 254) +} + +#[derive(Debug, Clone)] +pub enum SsrfError { + InvalidUrl, + InsecureProtocol(String), + NoHost, + NonUnicastIp(String), + DnsResolutionFailed(String), +} + +impl std::fmt::Display for SsrfError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SsrfError::InvalidUrl => write!(f, "Invalid URL"), + SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p), + SsrfError::NoHost => write!(f, "No host in URL"), + SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip), + SsrfError::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for: {}", host), + } + } +} + +impl std::error::Error for SsrfError {} + +pub const HEADERS_TO_FORWARD: &[&str] = &[ + "accept-language", + "atproto-accept-labelers", + "x-bsky-topics", +]; + +pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[ + "atproto-repo-rev", + "atproto-content-labelers", + "retry-after", + "content-type", +]; + +pub fn validate_at_uri(uri: &str) -> Result { + if !uri.starts_with("at://") { + return Err("URI must start with at://"); + } + + let path = uri.trim_start_matches("at://"); + let parts: Vec<&str> = path.split('/').collect(); + + if parts.is_empty() { + return Err("URI missing DID"); + } + + let did = parts[0]; + if !did.starts_with("did:") { + return Err("Invalid DID in URI"); + } + + if parts.len() > 1 { + let collection = parts[1]; + if collection.is_empty() || !collection.contains('.') { + return Err("Invalid collection NSID"); + } + } + + Ok(AtUriParts { + did: did.to_string(), + collection: parts.get(1).map(|s| s.to_string()), + rkey: parts.get(2).map(|s| s.to_string()), + }) +} + +#[derive(Debug, Clone)] +pub struct AtUriParts { + pub did: String, + pub collection: Option, + pub rkey: Option, +} + +pub fn validate_limit(limit: Option, default: u32, max: u32) -> u32 { + match limit { + Some(l) if l == 0 => default, + Some(l) if l > max => max, + Some(l) => l, + None => default, + } +} + +pub fn validate_did(did: &str) -> Result<(), &'static str> { + if !did.starts_with("did:") { + return Err("Invalid DID format"); + } + + let parts: Vec<&str> = did.split(':').collect(); + if parts.len() < 3 { + return Err("DID must have at least method and identifier"); + } + + let method = parts[1]; + if method != "plc" && method != "web" { + return Err("Unsupported DID method"); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ssrf_safe_https() { + assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok()); + } + + #[test] + fn test_ssrf_blocks_http_by_default() { + let result = is_ssrf_safe("http://external.example.com/xrpc/test"); + assert!(matches!(result, Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_)))); + } + + #[test] + fn test_ssrf_allows_localhost_http() { + assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok()); + assert!(is_ssrf_safe("http://localhost:8080/test").is_ok()); + } + + #[test] + fn test_validate_at_uri() { + let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123"); + assert!(result.is_ok()); + let parts = result.unwrap(); + assert_eq!(parts.did, "did:plc:test"); + assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string())); + assert_eq!(parts.rkey, Some("abc123".to_string())); + } + + #[test] + fn test_validate_at_uri_invalid() { + assert!(validate_at_uri("https://example.com").is_err()); + assert!(validate_at_uri("at://notadid/collection/rkey").is_err()); + } + + #[test] + fn test_validate_limit() { + assert_eq!(validate_limit(None, 50, 100), 50); + assert_eq!(validate_limit(Some(0), 50, 100), 50); + assert_eq!(validate_limit(Some(200), 50, 100), 100); + assert_eq!(validate_limit(Some(75), 50, 100), 75); + } + + #[test] + fn test_validate_did() { + assert!(validate_did("did:plc:abc123").is_ok()); + assert!(validate_did("did:web:example.com").is_ok()); + assert!(validate_did("notadid").is_err()); + assert!(validate_did("did:unknown:test").is_err()); + } +} diff --git a/src/api/read_after_write.rs b/src/api/read_after_write.rs new file mode 100644 index 0000000..939d115 --- /dev/null +++ b/src/api/read_after_write.rs @@ -0,0 +1,433 @@ +use crate::api::proxy_client::{ + is_ssrf_safe, proxy_client, MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, +}; +use crate::api::ApiError; +use crate::state::AppState; +use axum::{ + http::{HeaderMap, HeaderValue, StatusCode}, + response::{IntoResponse, Response}, + Json, +}; +use chrono::{DateTime, Utc}; +use jacquard_repo::storage::BlockStore; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use tracing::{error, info, warn}; +use uuid::Uuid; + +pub const REPO_REV_HEADER: &str = "atproto-repo-rev"; +pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostRecord { + #[serde(rename = "$type")] + pub record_type: Option, + pub text: String, + pub created_at: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub reply: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub embed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub langs: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub labels: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tags: Option>, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProfileRecord { + #[serde(rename = "$type")] + pub record_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub avatar: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub banner: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone)] +pub struct RecordDescript { + pub uri: String, + pub cid: String, + pub indexed_at: DateTime, + pub record: T, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LikeRecord { + #[serde(rename = "$type")] + pub record_type: Option, + pub subject: LikeSubject, + pub created_at: String, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LikeSubject { + pub uri: String, + pub cid: String, +} + +#[derive(Debug, Default)] +pub struct LocalRecords { + pub count: usize, + pub profile: Option>, + pub posts: Vec>, + pub likes: Vec>, +} + +pub async fn get_records_since_rev( + state: &AppState, + did: &str, + rev: &str, +) -> Result { + let mut result = LocalRecords::default(); + + let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) + .fetch_optional(&state.db) + .await + .map_err(|e| format!("DB error: {}", e))? + .ok_or_else(|| "User not found".to_string())?; + + let rows = sqlx::query!( + r#" + SELECT record_cid, collection, rkey, created_at, repo_rev + FROM records + WHERE repo_id = $1 AND repo_rev > $2 + ORDER BY repo_rev ASC + LIMIT 10 + "#, + user_id, + rev + ) + .fetch_all(&state.db) + .await + .map_err(|e| format!("DB error fetching records: {}", e))?; + + if rows.is_empty() { + return Ok(result); + } + + let sanity_check = sqlx::query_scalar!( + "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1", + user_id, + rev + ) + .fetch_optional(&state.db) + .await + .map_err(|e| format!("DB error sanity check: {}", e))?; + + if sanity_check.is_none() { + warn!("Sanity check failed: no records found before rev {}", rev); + return Ok(result); + } + + for row in rows { + result.count += 1; + + let cid: cid::Cid = match row.record_cid.parse() { + Ok(c) => c, + Err(_) => continue, + }; + + let block_bytes = match state.block_store.get(&cid).await { + Ok(Some(b)) => b, + _ => continue, + }; + + let uri = format!("at://{}/{}/{}", did, row.collection, row.rkey); + let indexed_at = row.created_at; + + if row.collection == "app.bsky.actor.profile" && row.rkey == "self" { + if let Ok(record) = serde_ipld_dagcbor::from_slice::(&block_bytes) { + result.profile = Some(RecordDescript { + uri, + cid: row.record_cid, + indexed_at, + record, + }); + } + } else if row.collection == "app.bsky.feed.post" { + if let Ok(record) = serde_ipld_dagcbor::from_slice::(&block_bytes) { + result.posts.push(RecordDescript { + uri, + cid: row.record_cid, + indexed_at, + record, + }); + } + } else if row.collection == "app.bsky.feed.like" { + if let Ok(record) = serde_ipld_dagcbor::from_slice::(&block_bytes) { + result.likes.push(RecordDescript { + uri, + cid: row.record_cid, + indexed_at, + record, + }); + } + } + } + + Ok(result) +} + +pub fn get_local_lag(local: &LocalRecords) -> Option { + let mut oldest: Option> = local.profile.as_ref().map(|p| p.indexed_at); + + for post in &local.posts { + match oldest { + None => oldest = Some(post.indexed_at), + Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at), + _ => {} + } + } + + for like in &local.likes { + match oldest { + None => oldest = Some(like.indexed_at), + Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at), + _ => {} + } + } + + oldest.map(|o| (Utc::now() - o).num_milliseconds()) +} + +pub fn extract_repo_rev(headers: &HeaderMap) -> Option { + headers + .get(REPO_REV_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()) +} + +#[derive(Debug)] +pub struct ProxyResponse { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: bytes::Bytes, +} + +pub async fn proxy_to_appview( + method: &str, + params: &HashMap, + auth_header: Option<&str>, +) -> Result { + let appview_url = std::env::var("APPVIEW_URL").map_err(|_| { + ApiError::UpstreamUnavailable("No upstream AppView configured".to_string()).into_response() + })?; + + if let Err(e) = is_ssrf_safe(&appview_url) { + error!("SSRF check failed for appview URL: {}", e); + return Err(ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) + .into_response()); + } + + let target_url = format!("{}/xrpc/{}", appview_url, method); + info!(target = %target_url, "Proxying request to appview"); + + let client = proxy_client(); + let mut request_builder = client.get(&target_url).query(params); + + if let Some(auth) = auth_header { + request_builder = request_builder.header("Authorization", auth); + } + + match request_builder.send().await { + Ok(resp) => { + let status = + StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); + + let headers: HeaderMap = resp + .headers() + .iter() + .filter(|(k, _)| { + RESPONSE_HEADERS_TO_FORWARD + .iter() + .any(|h| k.as_str().eq_ignore_ascii_case(h)) + }) + .filter_map(|(k, v)| { + let name = axum::http::HeaderName::try_from(k.as_str()).ok()?; + let value = HeaderValue::from_bytes(v.as_bytes()).ok()?; + Some((name, value)) + }) + .collect(); + + let content_length = resp + .content_length() + .unwrap_or(0); + if content_length > MAX_RESPONSE_SIZE { + error!( + content_length, + max = MAX_RESPONSE_SIZE, + "Upstream response too large" + ); + return Err(ApiError::UpstreamFailure.into_response()); + } + + let body = resp.bytes().await.map_err(|e| { + error!(error = ?e, "Error reading proxy response body"); + ApiError::UpstreamFailure.into_response() + })?; + + if body.len() as u64 > MAX_RESPONSE_SIZE { + error!( + len = body.len(), + max = MAX_RESPONSE_SIZE, + "Upstream response body exceeded size limit" + ); + return Err(ApiError::UpstreamFailure.into_response()); + } + + Ok(ProxyResponse { + status, + headers, + body, + }) + } + Err(e) => { + error!(error = ?e, "Error sending proxy request"); + if e.is_timeout() { + Err(ApiError::UpstreamTimeout.into_response()) + } else if e.is_connect() { + Err(ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) + .into_response()) + } else { + Err(ApiError::UpstreamFailure.into_response()) + } + } + } +} + +pub fn format_munged_response(data: T, lag: Option) -> Response { + let mut response = (StatusCode::OK, Json(data)).into_response(); + + if let Some(lag_ms) = lag { + if let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) { + response + .headers_mut() + .insert(UPSTREAM_LAG_HEADER, header_val); + } + } + + response +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AuthorView { + pub did: String, + pub handle: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub avatar: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostView { + pub uri: String, + pub cid: String, + pub author: AuthorView, + pub record: Value, + pub indexed_at: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub embed: Option, + #[serde(default)] + pub reply_count: i64, + #[serde(default)] + pub repost_count: i64, + #[serde(default)] + pub like_count: i64, + #[serde(default)] + pub quote_count: i64, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FeedViewPost { + pub post: PostView, + #[serde(skip_serializing_if = "Option::is_none")] + pub reply: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub feed_context: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FeedOutput { + pub feed: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor: Option, +} + +pub fn format_local_post( + descript: &RecordDescript, + author_did: &str, + author_handle: &str, + profile: Option<&RecordDescript>, +) -> PostView { + let display_name = profile.and_then(|p| p.record.display_name.clone()); + + PostView { + uri: descript.uri.clone(), + cid: descript.cid.clone(), + author: AuthorView { + did: author_did.to_string(), + handle: author_handle.to_string(), + display_name, + avatar: None, + extra: HashMap::new(), + }, + record: serde_json::to_value(&descript.record).unwrap_or(Value::Null), + indexed_at: descript.indexed_at.to_rfc3339(), + embed: descript.record.embed.clone(), + reply_count: 0, + repost_count: 0, + like_count: 0, + quote_count: 0, + extra: HashMap::new(), + } +} + +pub fn insert_posts_into_feed(feed: &mut Vec, posts: Vec) { + if posts.is_empty() { + return; + } + + let new_items: Vec = posts + .into_iter() + .map(|post| FeedViewPost { + post, + reply: None, + reason: None, + feed_context: None, + extra: HashMap::new(), + }) + .collect(); + + feed.extend(new_items); + feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at)); +} diff --git a/src/api/repo/record/utils.rs b/src/api/repo/record/utils.rs index b5fae51..5dd8cb5 100644 --- a/src/api/repo/record/utils.rs +++ b/src/api/repo/record/utils.rs @@ -25,7 +25,7 @@ pub async fn commit_and_log( current_root_cid: Option, new_mst_root: Cid, ops: Vec, - blocks_cids: &Vec, + blocks_cids: &[String], ) -> Result { let key_row = sqlx::query!( "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", @@ -63,16 +63,18 @@ pub async fn commit_and_log( .await .map_err(|e| format!("DB Error (repos): {}", e))?; + let rev_str = rev.to_string(); for op in &ops { match op { RecordOp::Create { collection, rkey, cid } | RecordOp::Update { collection, rkey, cid } => { sqlx::query!( - "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4) - ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()", + "INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()", user_id, collection, rkey, - cid.to_string() + cid.to_string(), + rev_str ) .execute(&mut *tx) .await diff --git a/src/lib.rs b/src/lib.rs index b1ffdb7..e84a740 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -296,11 +296,30 @@ pub fn app(state: AppState) -> Router { "/xrpc/app.bsky.actor.getProfiles", get(api::actor::get_profiles), ) - // I know I know, I'm not supposed to implement appview endpoints. Leave me be .route( "/xrpc/app.bsky.feed.getTimeline", get(api::feed::get_timeline), ) + .route( + "/xrpc/app.bsky.feed.getAuthorFeed", + get(api::feed::get_author_feed), + ) + .route( + "/xrpc/app.bsky.feed.getActorLikes", + get(api::feed::get_actor_likes), + ) + .route( + "/xrpc/app.bsky.feed.getPostThread", + get(api::feed::get_post_thread), + ) + .route( + "/xrpc/app.bsky.feed.getFeed", + get(api::feed::get_feed), + ) + .route( + "/xrpc/app.bsky.notification.registerPush", + post(api::notification::register_push), + ) .route("/.well-known/did.json", get(api::identity::well_known_did)) .route("/u/{handle}/did.json", get(api::identity::user_did_doc)) // OAuth 2.1 endpoints diff --git a/tests/appview_integration.rs b/tests/appview_integration.rs new file mode 100644 index 0000000..49527ca --- /dev/null +++ b/tests/appview_integration.rs @@ -0,0 +1,149 @@ +mod common; + +use common::{base_url, client, create_account_and_login}; +use reqwest::StatusCode; +use serde_json::{json, Value}; + +#[tokio::test] +async fn test_get_author_feed_returns_appview_data() { + let client = client(); + let base = base_url().await; + let (jwt, did) = create_account_and_login(&client).await; + + let res = client + .get(format!( + "{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}", + base, did + )) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body: Value = res.json().await.unwrap(); + assert!(body["feed"].is_array(), "Response should have feed array"); + let feed = body["feed"].as_array().unwrap(); + assert_eq!(feed.len(), 1, "Feed should have 1 post from appview"); + assert_eq!( + feed[0]["post"]["record"]["text"].as_str(), + Some("Author feed post from appview"), + "Post text should match appview response" + ); +} + +#[tokio::test] +async fn test_get_actor_likes_returns_appview_data() { + let client = client(); + let base = base_url().await; + let (jwt, did) = create_account_and_login(&client).await; + + let res = client + .get(format!( + "{}/xrpc/app.bsky.feed.getActorLikes?actor={}", + base, did + )) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body: Value = res.json().await.unwrap(); + assert!(body["feed"].is_array(), "Response should have feed array"); + let feed = body["feed"].as_array().unwrap(); + assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview"); + assert_eq!( + feed[0]["post"]["record"]["text"].as_str(), + Some("Liked post from appview"), + "Post text should match appview response" + ); +} + +#[tokio::test] +async fn test_get_post_thread_returns_appview_data() { + let client = client(); + let base = base_url().await; + let (jwt, did) = create_account_and_login(&client).await; + + let res = client + .get(format!( + "{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123", + base, did + )) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body: Value = res.json().await.unwrap(); + assert!(body["thread"].is_object(), "Response should have thread object"); + assert_eq!( + body["thread"]["$type"].as_str(), + Some("app.bsky.feed.defs#threadViewPost"), + "Thread should be a threadViewPost" + ); + assert_eq!( + body["thread"]["post"]["record"]["text"].as_str(), + Some("Thread post from appview"), + "Post text should match appview response" + ); +} + +#[tokio::test] +async fn test_get_feed_returns_appview_data() { + let client = client(); + let base = base_url().await; + let (jwt, _did) = create_account_and_login(&client).await; + + let res = client + .get(format!( + "{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test", + base + )) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body: Value = res.json().await.unwrap(); + assert!(body["feed"].is_array(), "Response should have feed array"); + let feed = body["feed"].as_array().unwrap(); + assert_eq!(feed.len(), 1, "Feed should have 1 post from appview"); + assert_eq!( + feed[0]["post"]["record"]["text"].as_str(), + Some("Custom feed post from appview"), + "Post text should match appview response" + ); +} + +#[tokio::test] +async fn test_register_push_proxies_to_appview() { + let client = client(); + let base = base_url().await; + let (jwt, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/app.bsky.notification.registerPush", + base + )) + .header("Authorization", format!("Bearer {}", jwt)) + .json(&json!({ + "serviceDid": "did:web:example.com", + "token": "test-push-token", + "platform": "ios", + "appId": "xyz.bsky.app" + })) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 56e0d02..7d03bc5 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -233,6 +233,122 @@ async fn setup_mock_appview(mock_server: &MockServer) { }))) .mount(mock_server) .await; + + Mock::given(method("GET")) + .and(path("/xrpc/app.bsky.feed.getTimeline")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("atproto-repo-rev", "0") + .set_body_json(json!({ + "feed": [], + "cursor": null + })) + ) + .mount(mock_server) + .await; + + Mock::given(method("GET")) + .and(path("/xrpc/app.bsky.feed.getAuthorFeed")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("atproto-repo-rev", "0") + .set_body_json(json!({ + "feed": [{ + "post": { + "uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author", + "cid": "bafyappview123", + "author": {"did": "did:plc:mock-author", "handle": "mock.author"}, + "record": { + "$type": "app.bsky.feed.post", + "text": "Author feed post from appview", + "createdAt": "2025-01-01T00:00:00Z" + }, + "indexedAt": "2025-01-01T00:00:00Z" + } + }], + "cursor": "author-cursor" + })), + ) + .mount(mock_server) + .await; + + Mock::given(method("GET")) + .and(path("/xrpc/app.bsky.feed.getActorLikes")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("atproto-repo-rev", "0") + .set_body_json(json!({ + "feed": [{ + "post": { + "uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post", + "cid": "bafyliked123", + "author": {"did": "did:plc:mock-likes", "handle": "mock.likes"}, + "record": { + "$type": "app.bsky.feed.post", + "text": "Liked post from appview", + "createdAt": "2025-01-01T00:00:00Z" + }, + "indexedAt": "2025-01-01T00:00:00Z" + } + }], + "cursor": null + })), + ) + .mount(mock_server) + .await; + + Mock::given(method("GET")) + .and(path("/xrpc/app.bsky.feed.getPostThread")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("atproto-repo-rev", "0") + .set_body_json(json!({ + "thread": { + "$type": "app.bsky.feed.defs#threadViewPost", + "post": { + "uri": "at://did:plc:mock/app.bsky.feed.post/thread-post", + "cid": "bafythread123", + "author": {"did": "did:plc:mock", "handle": "mock.handle"}, + "record": { + "$type": "app.bsky.feed.post", + "text": "Thread post from appview", + "createdAt": "2025-01-01T00:00:00Z" + }, + "indexedAt": "2025-01-01T00:00:00Z" + }, + "replies": [] + } + })), + ) + .mount(mock_server) + .await; + + Mock::given(method("GET")) + .and(path("/xrpc/app.bsky.feed.getFeed")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "feed": [{ + "post": { + "uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post", + "cid": "bafyfeed123", + "author": {"did": "did:plc:mock-feed", "handle": "mock.feed"}, + "record": { + "$type": "app.bsky.feed.post", + "text": "Custom feed post from appview", + "createdAt": "2025-01-01T00:00:00Z" + }, + "indexedAt": "2025-01-01T00:00:00Z" + } + }], + "cursor": null + }))) + .mount(mock_server) + .await; + + Mock::given(method("POST")) + .and(path("/xrpc/app.bsky.notification.registerPush")) + .respond_with(ResponseTemplate::new(200)) + .mount(mock_server) + .await; } async fn spawn_app(database_url: String) -> String { diff --git a/tests/feed.rs b/tests/feed.rs new file mode 100644 index 0000000..33491f2 --- /dev/null +++ b/tests/feed.rs @@ -0,0 +1,122 @@ +mod common; + +use common::{base_url, client, create_account_and_login}; +use serde_json::json; + +#[tokio::test] +async fn test_get_timeline_requires_auth() { + let client = client(); + let base = base_url().await; + + let res = client + .get(format!("{}/xrpc/app.bsky.feed.getTimeline", base)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 401); +} + +#[tokio::test] +async fn test_get_author_feed_requires_actor() { + let client = client(); + let base = base_url().await; + let (jwt, _did) = create_account_and_login(&client).await; + + let res = client + .get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base)) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 400); +} + +#[tokio::test] +async fn test_get_actor_likes_requires_actor() { + let client = client(); + let base = base_url().await; + let (jwt, _did) = create_account_and_login(&client).await; + + let res = client + .get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base)) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 400); +} + +#[tokio::test] +async fn test_get_post_thread_requires_uri() { + let client = client(); + let base = base_url().await; + let (jwt, _did) = create_account_and_login(&client).await; + + let res = client + .get(format!("{}/xrpc/app.bsky.feed.getPostThread", base)) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 400); +} + +#[tokio::test] +async fn test_get_feed_requires_auth() { + let client = client(); + let base = base_url().await; + + let res = client + .get(format!( + "{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test", + base + )) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 401); +} + +#[tokio::test] +async fn test_get_feed_requires_feed_param() { + let client = client(); + let base = base_url().await; + let (jwt, _did) = create_account_and_login(&client).await; + + let res = client + .get(format!("{}/xrpc/app.bsky.feed.getFeed", base)) + .header("Authorization", format!("Bearer {}", jwt)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 400); +} + +#[tokio::test] +async fn test_register_push_requires_auth() { + let client = client(); + let base = base_url().await; + + let res = client + .post(format!( + "{}/xrpc/app.bsky.notification.registerPush", + base + )) + .json(&json!({ + "serviceDid": "did:web:example.com", + "token": "test-token", + "platform": "ios", + "appId": "xyz.bsky.app" + })) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 401); +} diff --git a/tests/import_verification.rs b/tests/import_verification.rs index 189ecb4..a8d91ed 100644 --- a/tests/import_verification.rs +++ b/tests/import_verification.rs @@ -23,7 +23,7 @@ fn write_varint(buf: &mut Vec, mut value: u64) { async fn test_import_rejects_car_for_different_user() { let client = client(); - let (token_a, did_a) = create_account_and_login(&client).await; + let (token_a, _did_a) = create_account_and_login(&client).await; let (_token_b, did_b) = create_account_and_login(&client).await; let export_res = client