Half-ass attempt at the local-first appview endpoints like ref impl

This commit is contained in:
lewis
2025-12-12 00:28:57 +02:00
parent e90488fbad
commit 2ededf32a6
27 changed files with 2355 additions and 94 deletions

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "SELECT k.key_bytes, k.encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
"query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
"describe": {
"columns": [
{
@@ -24,5 +24,5 @@
true
]
},
"hash": "7bb1388dec372fe749462cd9b604e5802b770aeb110462208988141d31c86c92"
"hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b"
}

View File

@@ -0,0 +1,23 @@
{
"db_name": "PostgreSQL",
"query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "val",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": [
null
]
},
"hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
"query": "INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()",
"describe": {
"columns": [],
"parameters": {
@@ -8,10 +8,11 @@
"Uuid",
"Text",
"Text",
"Text",
"Text"
]
},
"nullable": []
},
"hash": "c61fc3b2fbdf6891269908ef21f13dcabdc3b032e9f767becae34ca176df18b6"
"hash": "8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow'",
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
"describe": {
"columns": [
{
@@ -18,5 +18,5 @@
false
]
},
"hash": "bf60faafb5c79a149ba237a984f78d068b5d691f6762641412a5aa1517605c04"
"hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc"
}

View File

@@ -0,0 +1,47 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "record_cid",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "collection",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "rkey",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 4,
"name": "repo_rev",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": [
false,
false,
false,
false,
true
]
},
"hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e"
}

14
TODO.md
View File

@@ -168,15 +168,15 @@ These endpoints need to be implemented at the PDS level (not just proxied to app
- [x] Implement `app.bsky.actor.getProfiles` (PDS-level with proxy fallback).
### Feed (`app.bsky.feed`)
These are implemented at PDS level to enable local-first reads:
- [ ] Implement `app.bsky.feed.getTimeline` (PDS-level with proxy).
- [ ] Implement `app.bsky.feed.getAuthorFeed` (PDS-level with proxy).
- [ ] Implement `app.bsky.feed.getActorLikes` (PDS-level with proxy).
- [ ] Implement `app.bsky.feed.getPostThread` (PDS-level with proxy).
- [ ] Implement `app.bsky.feed.getFeed` (PDS-level with proxy).
These are implemented at PDS level to enable local-first reads (read-after-write pattern):
- [x] Implement `app.bsky.feed.getTimeline` (PDS-level with proxy + RAW).
- [x] Implement `app.bsky.feed.getAuthorFeed` (PDS-level with proxy + RAW).
- [x] Implement `app.bsky.feed.getActorLikes` (PDS-level with proxy + RAW).
- [x] Implement `app.bsky.feed.getPostThread` (PDS-level with proxy + RAW + NotFound handling).
- [x] Implement `app.bsky.feed.getFeed` (proxy to feed generator).
### Notification (`app.bsky.notification`)
- [ ] Implement `app.bsky.notification.registerPush` (push notification registration).
- [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied).
## Deprecated Sync Endpoints (for compatibility)
- [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility).

View File

@@ -0,0 +1,2 @@
ALTER TABLE records ADD COLUMN repo_rev TEXT;
CREATE INDEX idx_records_repo_rev ON records(repo_rev);

View File

@@ -125,10 +125,18 @@ pub async fn get_profile(
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_did = auth_header.and_then(|h| {
let token = crate::auth::extract_bearer_token_from_header(Some(h))?;
crate::auth::get_did_from_token(&token).ok()
});
let auth_did = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => Some(user.did),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let mut query_params = HashMap::new();
query_params.insert("actor".to_string(), params.actor.clone());
@@ -167,10 +175,18 @@ pub async fn get_profiles(
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_did = auth_header.and_then(|h| {
let token = crate::auth::extract_bearer_token_from_header(Some(h))?;
crate::auth::get_did_from_token(&token).ok()
});
let auth_did = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => Some(user.did),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let mut query_params = HashMap::new();
query_params.insert("actors".to_string(), params.actors.clone());

View File

@@ -44,13 +44,19 @@ pub enum ApiError {
InvitesDisabled,
DatabaseError,
UpstreamFailure,
UpstreamTimeout,
UpstreamUnavailable(String),
UpstreamError { status: u16, error: Option<String>, message: Option<String> },
}
impl ApiError {
fn status_code(&self) -> StatusCode {
match self {
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => {
StatusCode::INTERNAL_SERVER_ERROR
Self::InternalError | Self::DatabaseError => StatusCode::INTERNAL_SERVER_ERROR,
Self::UpstreamFailure | Self::UpstreamUnavailable(_) => StatusCode::BAD_GATEWAY,
Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT,
Self::UpstreamError { status, .. } => {
StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY)
}
Self::AuthenticationRequired
| Self::AuthenticationFailed
@@ -83,7 +89,15 @@ impl ApiError {
fn error_name(&self) -> &'static str {
match self {
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => "InternalError",
Self::InternalError | Self::DatabaseError => "InternalError",
Self::UpstreamFailure | Self::UpstreamUnavailable(_) => "UpstreamFailure",
Self::UpstreamTimeout => "UpstreamTimeout",
Self::UpstreamError { error, .. } => {
if let Some(e) = error {
return Box::leak(e.clone().into_boxed_str());
}
"UpstreamError"
}
Self::AuthenticationRequired => "AuthenticationRequired",
Self::AuthenticationFailed | Self::AuthenticationFailedMsg(_) => "AuthenticationFailed",
Self::InvalidToken => "InvalidToken",
@@ -116,10 +130,25 @@ impl ApiError {
Self::AuthenticationFailedMsg(msg)
| Self::ExpiredTokenMsg(msg)
| Self::InvalidRequest(msg)
| Self::RepoNotFoundMsg(msg) => Some(msg.clone()),
| Self::RepoNotFoundMsg(msg)
| Self::UpstreamUnavailable(msg) => Some(msg.clone()),
Self::UpstreamError { message, .. } => message.clone(),
Self::UpstreamTimeout => Some("Upstream service timed out".to_string()),
_ => None,
}
}
pub fn from_upstream_response(
status: u16,
body: &[u8],
) -> Self {
if let Ok(parsed) = serde_json::from_slice::<serde_json::Value>(body) {
let error = parsed.get("error").and_then(|v| v.as_str()).map(String::from);
let message = parsed.get("message").and_then(|v| v.as_str()).map(String::from);
return Self::UpstreamError { status, error, message };
}
Self::UpstreamError { status, error: None, message: None }
}
}
impl IntoResponse for ApiError {

158
src/api/feed/actor_likes.rs Normal file
View File

@@ -0,0 +1,158 @@
use crate::api::read_after_write::{
extract_repo_rev, format_munged_response, get_local_lag, get_records_since_rev,
proxy_to_appview, FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript,
};
use crate::state::AppState;
use axum::{
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
use tracing::warn;
#[derive(Deserialize)]
pub struct GetActorLikesParams {
pub actor: String,
pub limit: Option<u32>,
pub cursor: Option<String>,
}
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
for like in likes {
let like_time = &like.indexed_at.to_rfc3339();
let idx = feed
.iter()
.position(|fi| &fi.post.indexed_at < like_time)
.unwrap_or(feed.len());
let placeholder_post = PostView {
uri: like.record.subject.uri.clone(),
cid: like.record.subject.cid.clone(),
author: crate::api::read_after_write::AuthorView {
did: String::new(),
handle: String::new(),
display_name: None,
avatar: None,
extra: HashMap::new(),
},
record: Value::Null,
indexed_at: like.indexed_at.to_rfc3339(),
embed: None,
reply_count: 0,
repost_count: 0,
like_count: 0,
quote_count: 0,
extra: HashMap::new(),
};
feed.insert(
idx,
FeedViewPost {
post: placeholder_post,
reply: None,
reason: None,
feed_context: None,
extra: HashMap::new(),
},
);
}
}
pub async fn get_actor_likes(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetActorLikesParams>,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_did = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => Some(user.did),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let mut query_params = HashMap::new();
query_params.insert("actor".to_string(), params.actor.clone());
if let Some(limit) = params.limit {
query_params.insert("limit".to_string(), limit.to_string());
}
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
let proxy_result =
match proxy_to_appview("app.bsky.feed.getActorLikes", &query_params, auth_header).await {
Ok(r) => r,
Err(e) => return e,
};
if !proxy_result.status.is_success() {
return (proxy_result.status, proxy_result.body).into_response();
}
let rev = match extract_repo_rev(&proxy_result.headers) {
Some(r) => r,
None => return (proxy_result.status, proxy_result.body).into_response(),
};
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(f) => f,
Err(e) => {
warn!("Failed to parse actor likes response: {:?}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
let requester_did = match auth_did {
Some(d) => d,
None => return (StatusCode::OK, Json(feed_output)).into_response(),
};
let actor_did = if params.actor.starts_with("did:") {
params.actor.clone()
} else {
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", params.actor)
.fetch_optional(&state.db)
.await
{
Ok(Some(did)) => did,
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
Err(e) => {
warn!("Database error resolving actor handle: {:?}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
}
};
if actor_did != requester_did {
return (StatusCode::OK, Json(feed_output)).into_response();
}
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
if local_records.likes.is_empty() {
return (StatusCode::OK, Json(feed_output)).into_response();
}
insert_likes_into_feed(&mut feed_output.feed, &local_records.likes);
let lag = get_local_lag(&local_records);
format_munged_response(feed_output, lag)
}

169
src/api/feed/author_feed.rs Normal file
View File

@@ -0,0 +1,169 @@
use crate::api::read_after_write::{
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost,
ProfileRecord, RecordDescript,
};
use crate::state::AppState;
use axum::{
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Deserialize;
use std::collections::HashMap;
use tracing::warn;
#[derive(Deserialize)]
pub struct GetAuthorFeedParams {
pub actor: String,
pub limit: Option<u32>,
pub cursor: Option<String>,
pub filter: Option<String>,
#[serde(rename = "includePins")]
pub include_pins: Option<bool>,
}
fn update_author_profile_in_feed(
feed: &mut [FeedViewPost],
author_did: &str,
local_profile: &RecordDescript<ProfileRecord>,
) {
for item in feed.iter_mut() {
if item.post.author.did == author_did {
if let Some(ref display_name) = local_profile.record.display_name {
item.post.author.display_name = Some(display_name.clone());
}
}
}
}
pub async fn get_author_feed(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetAuthorFeedParams>,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_did = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => Some(user.did),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let mut query_params = HashMap::new();
query_params.insert("actor".to_string(), params.actor.clone());
if let Some(limit) = params.limit {
query_params.insert("limit".to_string(), limit.to_string());
}
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
if let Some(filter) = &params.filter {
query_params.insert("filter".to_string(), filter.clone());
}
if let Some(include_pins) = params.include_pins {
query_params.insert("includePins".to_string(), include_pins.to_string());
}
let proxy_result =
match proxy_to_appview("app.bsky.feed.getAuthorFeed", &query_params, auth_header).await {
Ok(r) => r,
Err(e) => return e,
};
if !proxy_result.status.is_success() {
return (proxy_result.status, proxy_result.body).into_response();
}
let rev = match extract_repo_rev(&proxy_result.headers) {
Some(r) => r,
None => return (proxy_result.status, proxy_result.body).into_response(),
};
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(f) => f,
Err(e) => {
warn!("Failed to parse author feed response: {:?}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
let requester_did = match auth_did {
Some(d) => d,
None => return (StatusCode::OK, Json(feed_output)).into_response(),
};
let actor_did = if params.actor.starts_with("did:") {
params.actor.clone()
} else {
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", params.actor)
.fetch_optional(&state.db)
.await
{
Ok(Some(did)) => did,
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
Err(e) => {
warn!("Database error resolving actor handle: {:?}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
}
};
if actor_did != requester_did {
return (StatusCode::OK, Json(feed_output)).into_response();
}
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
if local_records.count == 0 {
return (StatusCode::OK, Json(feed_output)).into_response();
}
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => requester_did.clone(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
requester_did.clone()
}
};
if let Some(ref local_profile) = local_records.profile {
update_author_profile_in_feed(&mut feed_output.feed, &requester_did, local_profile);
}
let local_posts: Vec<_> = local_records
.posts
.iter()
.map(|p| {
format_local_post(
p,
&requester_did,
&handle,
local_records.profile.as_ref(),
)
})
.collect();
insert_posts_into_feed(&mut feed_output.feed, local_posts);
let lag = get_local_lag(&local_records);
format_munged_response(feed_output, lag)
}

131
src/api/feed/custom_feed.rs Normal file
View File

@@ -0,0 +1,131 @@
use crate::api::proxy_client::{
is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, MAX_RESPONSE_SIZE,
};
use crate::api::ApiError;
use crate::state::AppState;
use axum::{
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Deserialize;
use std::collections::HashMap;
use tracing::{error, info};
#[derive(Deserialize)]
pub struct GetFeedParams {
pub feed: String,
pub limit: Option<u32>,
pub cursor: Option<String>,
}
pub async fn get_feed(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetFeedParams>,
) -> Response {
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
if let Err(e) = crate::auth::validate_bearer_token(&state.db, &token).await {
return ApiError::from(e).into_response();
};
if let Err(e) = validate_at_uri(&params.feed) {
return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response();
}
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let appview_url = match std::env::var("APPVIEW_URL") {
Ok(url) => url,
Err(_) => {
return ApiError::UpstreamUnavailable("No upstream AppView configured".to_string())
.into_response();
}
};
if let Err(e) = is_ssrf_safe(&appview_url) {
error!("SSRF check failed for appview URL: {}", e);
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
.into_response();
}
let limit = validate_limit(params.limit, 50, 100);
let mut query_params = HashMap::new();
query_params.insert("feed".to_string(), params.feed.clone());
query_params.insert("limit".to_string(), limit.to_string());
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", appview_url);
info!(target = %target_url, feed = %params.feed, "Proxying getFeed request");
let client = proxy_client();
let mut request_builder = client.get(&target_url).query(&query_params);
if let Some(auth) = auth_header {
request_builder = request_builder.header("Authorization", auth);
}
match request_builder.send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let content_length = resp.content_length().unwrap_or(0);
if content_length > MAX_RESPONSE_SIZE {
error!(
content_length,
max = MAX_RESPONSE_SIZE,
"getFeed response too large"
);
return ApiError::UpstreamFailure.into_response();
}
let resp_headers = resp.headers().clone();
let body = match resp.bytes().await {
Ok(b) => {
if b.len() as u64 > MAX_RESPONSE_SIZE {
error!(len = b.len(), "getFeed response body exceeded limit");
return ApiError::UpstreamFailure.into_response();
}
b
}
Err(e) => {
error!(error = ?e, "Error reading getFeed response");
return ApiError::UpstreamFailure.into_response();
}
};
let mut response_builder = axum::response::Response::builder().status(status);
if let Some(ct) = resp_headers.get("content-type") {
response_builder = response_builder.header("content-type", ct);
}
match response_builder.body(axum::body::Body::from(body)) {
Ok(r) => r,
Err(e) => {
error!(error = ?e, "Error building getFeed response");
ApiError::UpstreamFailure.into_response()
}
}
}
Err(e) => {
error!(error = ?e, "Error proxying getFeed");
if e.is_timeout() {
ApiError::UpstreamTimeout.into_response()
} else if e.is_connect() {
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
.into_response()
} else {
ApiError::UpstreamFailure.into_response()
}
}
}
}

View File

@@ -1,3 +1,11 @@
mod actor_likes;
mod author_feed;
mod custom_feed;
mod post_thread;
mod timeline;
pub use actor_likes::get_actor_likes;
pub use author_feed::get_author_feed;
pub use custom_feed::get_feed;
pub use post_thread::get_post_thread;
pub use timeline::get_timeline;

322
src/api/feed/post_thread.rs Normal file
View File

@@ -0,0 +1,322 @@
use crate::api::read_after_write::{
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript,
};
use crate::state::AppState;
use axum::{
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use tracing::warn;
#[derive(Deserialize)]
pub struct GetPostThreadParams {
pub uri: String,
pub depth: Option<u32>,
#[serde(rename = "parentHeight")]
pub parent_height: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThreadViewPost {
#[serde(rename = "$type")]
pub thread_type: Option<String>,
pub post: PostView,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent: Option<Box<ThreadNode>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub replies: Option<Vec<ThreadNode>>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ThreadNode {
Post(ThreadViewPost),
NotFound(ThreadNotFound),
Blocked(ThreadBlocked),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThreadNotFound {
#[serde(rename = "$type")]
pub thread_type: String,
pub uri: String,
pub not_found: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThreadBlocked {
#[serde(rename = "$type")]
pub thread_type: String,
pub uri: String,
pub blocked: bool,
pub author: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostThreadOutput {
pub thread: ThreadNode,
#[serde(skip_serializing_if = "Option::is_none")]
pub threadgate: Option<Value>,
}
const MAX_THREAD_DEPTH: usize = 10;
fn add_replies_to_thread(
thread: &mut ThreadViewPost,
local_posts: &[RecordDescript<PostRecord>],
author_did: &str,
author_handle: &str,
depth: usize,
) {
if depth >= MAX_THREAD_DEPTH {
return;
}
let thread_uri = &thread.post.uri;
let replies: Vec<_> = local_posts
.iter()
.filter(|p| {
p.record
.reply
.as_ref()
.and_then(|r| r.get("parent"))
.and_then(|parent| parent.get("uri"))
.and_then(|u| u.as_str())
== Some(thread_uri)
})
.map(|p| {
let post_view = format_local_post(p, author_did, author_handle, None);
ThreadNode::Post(ThreadViewPost {
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
post: post_view,
parent: None,
replies: None,
extra: HashMap::new(),
})
})
.collect();
if !replies.is_empty() {
match &mut thread.replies {
Some(existing) => existing.extend(replies),
None => thread.replies = Some(replies),
}
}
if let Some(ref mut existing_replies) = thread.replies {
for reply in existing_replies.iter_mut() {
if let ThreadNode::Post(reply_thread) = reply {
add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1);
}
}
}
}
pub async fn get_post_thread(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetPostThreadParams>,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_did = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => Some(user.did),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let mut query_params = HashMap::new();
query_params.insert("uri".to_string(), params.uri.clone());
if let Some(depth) = params.depth {
query_params.insert("depth".to_string(), depth.to_string());
}
if let Some(parent_height) = params.parent_height {
query_params.insert("parentHeight".to_string(), parent_height.to_string());
}
let proxy_result =
match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await {
Ok(r) => r,
Err(e) => return e,
};
if proxy_result.status == StatusCode::NOT_FOUND {
return handle_not_found(&state, &params.uri, auth_did, &proxy_result.headers).await;
}
if !proxy_result.status.is_success() {
return (proxy_result.status, proxy_result.body).into_response();
}
let rev = match extract_repo_rev(&proxy_result.headers) {
Some(r) => r,
None => return (proxy_result.status, proxy_result.body).into_response(),
};
let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(t) => t,
Err(e) => {
warn!("Failed to parse post thread response: {:?}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
let requester_did = match auth_did {
Some(d) => d,
None => return (StatusCode::OK, Json(thread_output)).into_response(),
};
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
if local_records.posts.is_empty() {
return (StatusCode::OK, Json(thread_output)).into_response();
}
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => requester_did.clone(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
requester_did.clone()
}
};
if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0);
}
let lag = get_local_lag(&local_records);
format_munged_response(thread_output, lag)
}
async fn handle_not_found(
state: &AppState,
uri: &str,
auth_did: Option<String>,
headers: &axum::http::HeaderMap,
) -> Response {
let rev = match extract_repo_rev(headers) {
Some(r) => r,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response()
}
};
let requester_did = match auth_did {
Some(d) => d,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response()
}
};
let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
if uri_parts.len() != 3 {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
let post_did = uri_parts[0];
if post_did != requester_did {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
Ok(r) => r,
Err(_) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response()
}
};
let local_post = local_records.posts.iter().find(|p| p.uri == uri);
let local_post = match local_post {
Some(p) => p,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response()
}
};
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => requester_did.clone(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
requester_did.clone()
}
};
let post_view = format_local_post(
local_post,
&requester_did,
&handle,
local_records.profile.as_ref(),
);
let thread = PostThreadOutput {
thread: ThreadNode::Post(ThreadViewPost {
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
post: post_view,
parent: None,
replies: None,
extra: HashMap::new(),
}),
threadgate: None,
};
let lag = get_local_lag(&local_records);
format_munged_response(thread, lag)
}

View File

@@ -1,51 +1,35 @@
// Yes, I know, this endpoint is an appview one, not for PDS. Who cares!!
// Yes, this only gets posts that our DB/instance knows about. Who cares!!!
use crate::api::read_after_write::{
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost,
PostView,
};
use crate::state::AppState;
use axum::{
Json,
extract::State,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use jacquard_repo::storage::BlockStore;
use serde::Serialize;
use serde_json::{Value, json};
use tracing::error;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use tracing::warn;
#[derive(Serialize)]
pub struct TimelineOutput {
pub feed: Vec<FeedViewPost>,
#[derive(Deserialize)]
pub struct GetTimelineParams {
pub algorithm: Option<String>,
pub limit: Option<u32>,
pub cursor: Option<String>,
}
#[derive(Serialize)]
pub struct FeedViewPost {
pub post: PostView,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PostView {
pub uri: String,
pub cid: String,
pub author: AuthorView,
pub record: Value,
pub indexed_at: String,
}
#[derive(Serialize)]
pub struct AuthorView {
pub did: String,
pub handle: String,
}
pub async fn get_timeline(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetTimelineParams>,
) -> Response {
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok())
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => {
@@ -68,32 +52,131 @@ pub async fn get_timeline(
}
};
let user_query = sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did)
.fetch_optional(&state.db)
.await;
match std::env::var("APPVIEW_URL") {
Ok(url) if !url.starts_with("http://127.0.0.1") => {
return get_timeline_with_appview(&state, &headers, &params, &auth_user.did).await;
}
_ => {}
}
let user_id = match user_query {
Ok(Some(row)) => row.id,
_ => {
get_timeline_local_only(&state, &auth_user.did).await
}
async fn get_timeline_with_appview(
state: &AppState,
headers: &axum::http::HeaderMap,
params: &GetTimelineParams,
auth_did: &str,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let mut query_params = HashMap::new();
if let Some(algo) = &params.algorithm {
query_params.insert("algorithm".to_string(), algo.clone());
}
if let Some(limit) = params.limit {
query_params.insert("limit".to_string(), limit.to_string());
}
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
let proxy_result =
match proxy_to_appview("app.bsky.feed.getTimeline", &query_params, auth_header).await {
Ok(r) => r,
Err(e) => return e,
};
if !proxy_result.status.is_success() {
return (proxy_result.status, proxy_result.body).into_response();
}
let rev = extract_repo_rev(&proxy_result.headers);
if rev.is_none() {
return (proxy_result.status, proxy_result.body).into_response();
}
let rev = rev.unwrap();
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(f) => f,
Err(e) => {
warn!("Failed to parse timeline response: {:?}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
let local_records = match get_records_since_rev(state, auth_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return (proxy_result.status, proxy_result.body).into_response();
}
};
if local_records.count == 0 {
return (proxy_result.status, proxy_result.body).into_response();
}
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => auth_did.to_string(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
auth_did.to_string()
}
};
let local_posts: Vec<_> = local_records
.posts
.iter()
.map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref()))
.collect();
insert_posts_into_feed(&mut feed_output.feed, local_posts);
let lag = get_local_lag(&local_records);
format_munged_response(feed_output, lag)
}
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
let user_id: uuid::Uuid = match sqlx::query_scalar!(
"SELECT id FROM users WHERE did = $1",
auth_did
)
.fetch_optional(&state.db)
.await
{
Ok(Some(id)) => id,
Ok(None) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "User not found"})),
)
.into_response();
}
Err(e) => {
warn!("Database error fetching user: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Database error"})),
)
.into_response();
}
};
let follows_query = sqlx::query!(
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow'",
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
user_id
)
.fetch_all(&state.db)
.await;
.fetch_all(&state.db)
.await;
let follow_cids: Vec<String> = match follows_query {
Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(),
Err(e) => {
error!("Failed to get follows: {:?}", e);
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
@@ -127,7 +210,7 @@ pub async fn get_timeline(
if followed_dids.is_empty() {
return (
StatusCode::OK,
Json(TimelineOutput {
Json(FeedOutput {
feed: vec![],
cursor: None,
}),
@@ -150,8 +233,7 @@ pub async fn get_timeline(
let posts = match posts_result {
Ok(rows) => rows,
Err(e) => {
error!("Failed to get posts: {:?}", e);
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
@@ -190,15 +272,28 @@ pub async fn get_timeline(
post: PostView {
uri,
cid: record_cid,
author: AuthorView {
author: crate::api::read_after_write::AuthorView {
did: author_did,
handle: author_handle,
display_name: None,
avatar: None,
extra: HashMap::new(),
},
record,
indexed_at: created_at.to_rfc3339(),
embed: None,
reply_count: 0,
repost_count: 0,
like_count: 0,
quote_count: 0,
extra: HashMap::new(),
},
reply: None,
reason: None,
feed_context: None,
extra: HashMap::new(),
});
}
(StatusCode::OK, Json(TimelineOutput { feed, cursor: None })).into_response()
(StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response()
}

View File

@@ -4,9 +4,13 @@ pub mod error;
pub mod feed;
pub mod identity;
pub mod moderation;
pub mod notification;
pub mod proxy;
pub mod proxy_client;
pub mod read_after_write;
pub mod repo;
pub mod server;
pub mod validation;
pub use error::ApiError;
pub use proxy_client::{proxy_client, validate_at_uri, validate_did, validate_limit, AtUriParts};

View File

@@ -0,0 +1,3 @@
mod register_push;
pub use register_push::register_push;

View File

@@ -0,0 +1,166 @@
use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did};
use crate::api::ApiError;
use crate::state::AppState;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
Json,
};
use serde::Deserialize;
use serde_json::json;
use tracing::{error, info};
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RegisterPushInput {
pub service_did: String,
pub token: String,
pub platform: String,
pub app_id: String,
}
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
pub async fn register_push(
State(state): State<AppState>,
headers: HeaderMap,
Json(input): Json<RegisterPushInput>,
) -> Response {
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};
if let Err(e) = validate_did(&input.service_did) {
return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response();
}
if input.token.is_empty() || input.token.len() > 4096 {
return ApiError::InvalidRequest("Invalid push token".to_string()).into_response();
}
if !VALID_PLATFORMS.contains(&input.platform.as_str()) {
return ApiError::InvalidRequest(format!(
"Invalid platform. Must be one of: {}",
VALID_PLATFORMS.join(", ")
))
.into_response();
}
if input.app_id.is_empty() || input.app_id.len() > 256 {
return ApiError::InvalidRequest("Invalid appId".to_string()).into_response();
}
let appview_url = match std::env::var("APPVIEW_URL") {
Ok(url) => url,
Err(_) => {
return ApiError::UpstreamUnavailable("No upstream AppView configured".to_string())
.into_response();
}
};
if let Err(e) = is_ssrf_safe(&appview_url) {
error!("SSRF check failed for appview URL: {}", e);
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
.into_response();
}
let key_row = match sqlx::query!(
"SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
auth_user.did
)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
Ok(None) => {
error!(did = %auth_user.did, "No signing key found for user");
return ApiError::InternalError.into_response();
}
Err(e) => {
error!(error = ?e, "Database error fetching signing key");
return ApiError::DatabaseError.into_response();
}
};
let decrypted_key =
match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) {
Ok(k) => k,
Err(e) => {
error!(error = ?e, "Failed to decrypt signing key");
return ApiError::InternalError.into_response();
}
};
let service_token = match crate::auth::create_service_token(
&auth_user.did,
&input.service_did,
"app.bsky.notification.registerPush",
&decrypted_key,
) {
Ok(t) => t,
Err(e) => {
error!(error = ?e, "Failed to create service token");
return ApiError::InternalError.into_response();
}
};
let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", appview_url);
info!(
target = %target_url,
service_did = %input.service_did,
platform = %input.platform,
"Proxying registerPush request"
);
let client = proxy_client();
let request_body = json!({
"serviceDid": input.service_did,
"token": input.token,
"platform": input.platform,
"appId": input.app_id
});
match client
.post(&target_url)
.header("Authorization", format!("Bearer {}", service_token))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
{
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
if status.is_success() {
StatusCode::OK.into_response()
} else {
let body = resp.bytes().await.unwrap_or_default();
error!(
status = %status,
"registerPush upstream error"
);
ApiError::from_upstream_response(status.as_u16(), &body).into_response()
}
}
Err(e) => {
error!(error = ?e, "Error proxying registerPush");
if e.is_timeout() {
ApiError::UpstreamTimeout.into_response()
} else if e.is_connect() {
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
.into_response()
} else {
ApiError::UpstreamFailure.into_response()
}
}
}
}

View File

@@ -46,21 +46,15 @@ pub async fn proxy_handler(
if let Some(token) = crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok())
) {
if let Ok(did) = crate::auth::get_did_from_token(&token) {
let key_row = sqlx::query!("SELECT k.key_bytes, k.encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", did)
.fetch_optional(&state.db)
.await;
if let Ok(Some(row)) = key_row {
if let Ok(decrypted_key) = crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
if let Ok(new_token) =
crate::auth::create_service_token(&did, aud, &method, &decrypted_key)
if let Ok(auth_user) = crate::auth::validate_bearer_token(&state.db, &token).await {
if let Some(key_bytes) = auth_user.key_bytes {
if let Ok(new_token) =
crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes)
{
if let Ok(val) =
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
{
if let Ok(val) =
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
{
auth_header_val = Some(val);
}
auth_header_val = Some(val);
}
}
}

252
src/api/proxy_client.rs Normal file
View File

@@ -0,0 +1,252 @@
use reqwest::{Client, ClientBuilder, Url};
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::sync::OnceLock;
use std::time::Duration;
use tracing::warn;
pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
pub fn proxy_client() -> &'static Client {
PROXY_CLIENT.get_or_init(|| {
ClientBuilder::new()
.timeout(DEFAULT_BODY_TIMEOUT)
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.pool_max_idle_per_host(10)
.pool_idle_timeout(Duration::from_secs(90))
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Failed to build HTTP client - this indicates a TLS or system configuration issue")
})
}
pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
let scheme = parsed.scheme();
if scheme != "https" {
let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok()
|| url.starts_with("http://127.0.0.1")
|| url.starts_with("http://localhost");
if !allow_http {
return Err(SsrfError::InsecureProtocol(scheme.to_string()));
}
}
let host = parsed.host_str().ok_or(SsrfError::NoHost)?;
if host == "localhost" {
return Ok(());
}
if let Ok(ip) = host.parse::<IpAddr>() {
if ip.is_loopback() {
return Ok(());
}
if !is_unicast_ip(&ip) {
return Err(SsrfError::NonUnicastIp(ip.to_string()));
}
return Ok(());
}
let port = parsed.port().unwrap_or(if scheme == "https" { 443 } else { 80 });
let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() {
Ok(addrs) => addrs.collect(),
Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
};
for addr in &socket_addrs {
if !is_unicast_ip(&addr.ip()) {
warn!(
"DNS resolution for {} returned non-unicast IP: {}",
host,
addr.ip()
);
return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
}
}
Ok(())
}
fn is_unicast_ip(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
!v4.is_loopback()
&& !v4.is_broadcast()
&& !v4.is_multicast()
&& !v4.is_unspecified()
&& !v4.is_link_local()
&& !is_private_v4(v4)
}
IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
}
}
fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 10
|| (octets[0] == 172 && (16..=31).contains(&octets[1]))
|| (octets[0] == 192 && octets[1] == 168)
|| (octets[0] == 169 && octets[1] == 254)
}
#[derive(Debug, Clone)]
pub enum SsrfError {
InvalidUrl,
InsecureProtocol(String),
NoHost,
NonUnicastIp(String),
DnsResolutionFailed(String),
}
impl std::fmt::Display for SsrfError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SsrfError::InvalidUrl => write!(f, "Invalid URL"),
SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p),
SsrfError::NoHost => write!(f, "No host in URL"),
SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip),
SsrfError::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for: {}", host),
}
}
}
impl std::error::Error for SsrfError {}
pub const HEADERS_TO_FORWARD: &[&str] = &[
"accept-language",
"atproto-accept-labelers",
"x-bsky-topics",
];
pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[
"atproto-repo-rev",
"atproto-content-labelers",
"retry-after",
"content-type",
];
pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
if !uri.starts_with("at://") {
return Err("URI must start with at://");
}
let path = uri.trim_start_matches("at://");
let parts: Vec<&str> = path.split('/').collect();
if parts.is_empty() {
return Err("URI missing DID");
}
let did = parts[0];
if !did.starts_with("did:") {
return Err("Invalid DID in URI");
}
if parts.len() > 1 {
let collection = parts[1];
if collection.is_empty() || !collection.contains('.') {
return Err("Invalid collection NSID");
}
}
Ok(AtUriParts {
did: did.to_string(),
collection: parts.get(1).map(|s| s.to_string()),
rkey: parts.get(2).map(|s| s.to_string()),
})
}
#[derive(Debug, Clone)]
pub struct AtUriParts {
pub did: String,
pub collection: Option<String>,
pub rkey: Option<String>,
}
pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
match limit {
Some(l) if l == 0 => default,
Some(l) if l > max => max,
Some(l) => l,
None => default,
}
}
pub fn validate_did(did: &str) -> Result<(), &'static str> {
if !did.starts_with("did:") {
return Err("Invalid DID format");
}
let parts: Vec<&str> = did.split(':').collect();
if parts.len() < 3 {
return Err("DID must have at least method and identifier");
}
let method = parts[1];
if method != "plc" && method != "web" {
return Err("Unsupported DID method");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ssrf_safe_https() {
assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok());
}
#[test]
fn test_ssrf_blocks_http_by_default() {
let result = is_ssrf_safe("http://external.example.com/xrpc/test");
assert!(matches!(result, Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_))));
}
#[test]
fn test_ssrf_allows_localhost_http() {
assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok());
assert!(is_ssrf_safe("http://localhost:8080/test").is_ok());
}
#[test]
fn test_validate_at_uri() {
let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123");
assert!(result.is_ok());
let parts = result.unwrap();
assert_eq!(parts.did, "did:plc:test");
assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string()));
assert_eq!(parts.rkey, Some("abc123".to_string()));
}
#[test]
fn test_validate_at_uri_invalid() {
assert!(validate_at_uri("https://example.com").is_err());
assert!(validate_at_uri("at://notadid/collection/rkey").is_err());
}
#[test]
fn test_validate_limit() {
assert_eq!(validate_limit(None, 50, 100), 50);
assert_eq!(validate_limit(Some(0), 50, 100), 50);
assert_eq!(validate_limit(Some(200), 50, 100), 100);
assert_eq!(validate_limit(Some(75), 50, 100), 75);
}
#[test]
fn test_validate_did() {
assert!(validate_did("did:plc:abc123").is_ok());
assert!(validate_did("did:web:example.com").is_ok());
assert!(validate_did("notadid").is_err());
assert!(validate_did("did:unknown:test").is_err());
}
}

433
src/api/read_after_write.rs Normal file
View File

@@ -0,0 +1,433 @@
use crate::api::proxy_client::{
is_ssrf_safe, proxy_client, MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD,
};
use crate::api::ApiError;
use crate::state::AppState;
use axum::{
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
use chrono::{DateTime, Utc};
use jacquard_repo::storage::BlockStore;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use tracing::{error, info, warn};
use uuid::Uuid;
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostRecord {
#[serde(rename = "$type")]
pub record_type: Option<String>,
pub text: String,
pub created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reply: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub embed: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub langs: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub labels: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProfileRecord {
#[serde(rename = "$type")]
pub record_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avatar: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub banner: Option<Value>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone)]
pub struct RecordDescript<T> {
pub uri: String,
pub cid: String,
pub indexed_at: DateTime<Utc>,
pub record: T,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LikeRecord {
#[serde(rename = "$type")]
pub record_type: Option<String>,
pub subject: LikeSubject,
pub created_at: String,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LikeSubject {
pub uri: String,
pub cid: String,
}
#[derive(Debug, Default)]
pub struct LocalRecords {
pub count: usize,
pub profile: Option<RecordDescript<ProfileRecord>>,
pub posts: Vec<RecordDescript<PostRecord>>,
pub likes: Vec<RecordDescript<LikeRecord>>,
}
pub async fn get_records_since_rev(
state: &AppState,
did: &str,
rev: &str,
) -> Result<LocalRecords, String> {
let mut result = LocalRecords::default();
let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
.fetch_optional(&state.db)
.await
.map_err(|e| format!("DB error: {}", e))?
.ok_or_else(|| "User not found".to_string())?;
let rows = sqlx::query!(
r#"
SELECT record_cid, collection, rkey, created_at, repo_rev
FROM records
WHERE repo_id = $1 AND repo_rev > $2
ORDER BY repo_rev ASC
LIMIT 10
"#,
user_id,
rev
)
.fetch_all(&state.db)
.await
.map_err(|e| format!("DB error fetching records: {}", e))?;
if rows.is_empty() {
return Ok(result);
}
let sanity_check = sqlx::query_scalar!(
"SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
user_id,
rev
)
.fetch_optional(&state.db)
.await
.map_err(|e| format!("DB error sanity check: {}", e))?;
if sanity_check.is_none() {
warn!("Sanity check failed: no records found before rev {}", rev);
return Ok(result);
}
for row in rows {
result.count += 1;
let cid: cid::Cid = match row.record_cid.parse() {
Ok(c) => c,
Err(_) => continue,
};
let block_bytes = match state.block_store.get(&cid).await {
Ok(Some(b)) => b,
_ => continue,
};
let uri = format!("at://{}/{}/{}", did, row.collection, row.rkey);
let indexed_at = row.created_at;
if row.collection == "app.bsky.actor.profile" && row.rkey == "self" {
if let Ok(record) = serde_ipld_dagcbor::from_slice::<ProfileRecord>(&block_bytes) {
result.profile = Some(RecordDescript {
uri,
cid: row.record_cid,
indexed_at,
record,
});
}
} else if row.collection == "app.bsky.feed.post" {
if let Ok(record) = serde_ipld_dagcbor::from_slice::<PostRecord>(&block_bytes) {
result.posts.push(RecordDescript {
uri,
cid: row.record_cid,
indexed_at,
record,
});
}
} else if row.collection == "app.bsky.feed.like" {
if let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) {
result.likes.push(RecordDescript {
uri,
cid: row.record_cid,
indexed_at,
record,
});
}
}
}
Ok(result)
}
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
for post in &local.posts {
match oldest {
None => oldest = Some(post.indexed_at),
Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at),
_ => {}
}
}
for like in &local.likes {
match oldest {
None => oldest = Some(like.indexed_at),
Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at),
_ => {}
}
}
oldest.map(|o| (Utc::now() - o).num_milliseconds())
}
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
headers
.get(REPO_REV_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
}
#[derive(Debug)]
pub struct ProxyResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: bytes::Bytes,
}
pub async fn proxy_to_appview(
method: &str,
params: &HashMap<String, String>,
auth_header: Option<&str>,
) -> Result<ProxyResponse, Response> {
let appview_url = std::env::var("APPVIEW_URL").map_err(|_| {
ApiError::UpstreamUnavailable("No upstream AppView configured".to_string()).into_response()
})?;
if let Err(e) = is_ssrf_safe(&appview_url) {
error!("SSRF check failed for appview URL: {}", e);
return Err(ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
.into_response());
}
let target_url = format!("{}/xrpc/{}", appview_url, method);
info!(target = %target_url, "Proxying request to appview");
let client = proxy_client();
let mut request_builder = client.get(&target_url).query(params);
if let Some(auth) = auth_header {
request_builder = request_builder.header("Authorization", auth);
}
match request_builder.send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let headers: HeaderMap = resp
.headers()
.iter()
.filter(|(k, _)| {
RESPONSE_HEADERS_TO_FORWARD
.iter()
.any(|h| k.as_str().eq_ignore_ascii_case(h))
})
.filter_map(|(k, v)| {
let name = axum::http::HeaderName::try_from(k.as_str()).ok()?;
let value = HeaderValue::from_bytes(v.as_bytes()).ok()?;
Some((name, value))
})
.collect();
let content_length = resp
.content_length()
.unwrap_or(0);
if content_length > MAX_RESPONSE_SIZE {
error!(
content_length,
max = MAX_RESPONSE_SIZE,
"Upstream response too large"
);
return Err(ApiError::UpstreamFailure.into_response());
}
let body = resp.bytes().await.map_err(|e| {
error!(error = ?e, "Error reading proxy response body");
ApiError::UpstreamFailure.into_response()
})?;
if body.len() as u64 > MAX_RESPONSE_SIZE {
error!(
len = body.len(),
max = MAX_RESPONSE_SIZE,
"Upstream response body exceeded size limit"
);
return Err(ApiError::UpstreamFailure.into_response());
}
Ok(ProxyResponse {
status,
headers,
body,
})
}
Err(e) => {
error!(error = ?e, "Error sending proxy request");
if e.is_timeout() {
Err(ApiError::UpstreamTimeout.into_response())
} else if e.is_connect() {
Err(ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
.into_response())
} else {
Err(ApiError::UpstreamFailure.into_response())
}
}
}
}
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
let mut response = (StatusCode::OK, Json(data)).into_response();
if let Some(lag_ms) = lag {
if let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) {
response
.headers_mut()
.insert(UPSTREAM_LAG_HEADER, header_val);
}
}
response
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthorView {
pub did: String,
pub handle: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avatar: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostView {
pub uri: String,
pub cid: String,
pub author: AuthorView,
pub record: Value,
pub indexed_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub embed: Option<Value>,
#[serde(default)]
pub reply_count: i64,
#[serde(default)]
pub repost_count: i64,
#[serde(default)]
pub like_count: i64,
#[serde(default)]
pub quote_count: i64,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FeedViewPost {
pub post: PostView,
#[serde(skip_serializing_if = "Option::is_none")]
pub reply: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub feed_context: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedOutput {
pub feed: Vec<FeedViewPost>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cursor: Option<String>,
}
pub fn format_local_post(
descript: &RecordDescript<PostRecord>,
author_did: &str,
author_handle: &str,
profile: Option<&RecordDescript<ProfileRecord>>,
) -> PostView {
let display_name = profile.and_then(|p| p.record.display_name.clone());
PostView {
uri: descript.uri.clone(),
cid: descript.cid.clone(),
author: AuthorView {
did: author_did.to_string(),
handle: author_handle.to_string(),
display_name,
avatar: None,
extra: HashMap::new(),
},
record: serde_json::to_value(&descript.record).unwrap_or(Value::Null),
indexed_at: descript.indexed_at.to_rfc3339(),
embed: descript.record.embed.clone(),
reply_count: 0,
repost_count: 0,
like_count: 0,
quote_count: 0,
extra: HashMap::new(),
}
}
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
if posts.is_empty() {
return;
}
let new_items: Vec<FeedViewPost> = posts
.into_iter()
.map(|post| FeedViewPost {
post,
reply: None,
reason: None,
feed_context: None,
extra: HashMap::new(),
})
.collect();
feed.extend(new_items);
feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at));
}

View File

@@ -25,7 +25,7 @@ pub async fn commit_and_log(
current_root_cid: Option<Cid>,
new_mst_root: Cid,
ops: Vec<RecordOp>,
blocks_cids: &Vec<String>,
blocks_cids: &[String],
) -> Result<CommitResult, String> {
let key_row = sqlx::query!(
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
@@ -63,16 +63,18 @@ pub async fn commit_and_log(
.await
.map_err(|e| format!("DB Error (repos): {}", e))?;
let rev_str = rev.to_string();
for op in &ops {
match op {
RecordOp::Create { collection, rkey, cid } | RecordOp::Update { collection, rkey, cid } => {
sqlx::query!(
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
"INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()",
user_id,
collection,
rkey,
cid.to_string()
cid.to_string(),
rev_str
)
.execute(&mut *tx)
.await

View File

@@ -296,11 +296,30 @@ pub fn app(state: AppState) -> Router {
"/xrpc/app.bsky.actor.getProfiles",
get(api::actor::get_profiles),
)
// I know I know, I'm not supposed to implement appview endpoints. Leave me be
.route(
"/xrpc/app.bsky.feed.getTimeline",
get(api::feed::get_timeline),
)
.route(
"/xrpc/app.bsky.feed.getAuthorFeed",
get(api::feed::get_author_feed),
)
.route(
"/xrpc/app.bsky.feed.getActorLikes",
get(api::feed::get_actor_likes),
)
.route(
"/xrpc/app.bsky.feed.getPostThread",
get(api::feed::get_post_thread),
)
.route(
"/xrpc/app.bsky.feed.getFeed",
get(api::feed::get_feed),
)
.route(
"/xrpc/app.bsky.notification.registerPush",
post(api::notification::register_push),
)
.route("/.well-known/did.json", get(api::identity::well_known_did))
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
// OAuth 2.1 endpoints

View File

@@ -0,0 +1,149 @@
mod common;
use common::{base_url, client, create_account_and_login};
use reqwest::StatusCode;
use serde_json::{json, Value};
#[tokio::test]
async fn test_get_author_feed_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}",
base, did
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(body["feed"].is_array(), "Response should have feed array");
let feed = body["feed"].as_array().unwrap();
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
assert_eq!(
feed[0]["post"]["record"]["text"].as_str(),
Some("Author feed post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_get_actor_likes_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getActorLikes?actor={}",
base, did
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(body["feed"].is_array(), "Response should have feed array");
let feed = body["feed"].as_array().unwrap();
assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview");
assert_eq!(
feed[0]["post"]["record"]["text"].as_str(),
Some("Liked post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_get_post_thread_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123",
base, did
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(body["thread"].is_object(), "Response should have thread object");
assert_eq!(
body["thread"]["$type"].as_str(),
Some("app.bsky.feed.defs#threadViewPost"),
"Thread should be a threadViewPost"
);
assert_eq!(
body["thread"]["post"]["record"]["text"].as_str(),
Some("Thread post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_get_feed_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
base
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(body["feed"].is_array(), "Response should have feed array");
let feed = body["feed"].as_array().unwrap();
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
assert_eq!(
feed[0]["post"]["record"]["text"].as_str(),
Some("Custom feed post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_register_push_proxies_to_appview() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/app.bsky.notification.registerPush",
base
))
.header("Authorization", format!("Bearer {}", jwt))
.json(&json!({
"serviceDid": "did:web:example.com",
"token": "test-push-token",
"platform": "ios",
"appId": "xyz.bsky.app"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}

View File

@@ -233,6 +233,122 @@ async fn setup_mock_appview(mock_server: &MockServer) {
})))
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getTimeline"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"feed": [],
"cursor": null
}))
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getAuthorFeed"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"feed": [{
"post": {
"uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author",
"cid": "bafyappview123",
"author": {"did": "did:plc:mock-author", "handle": "mock.author"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Author feed post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
}
}],
"cursor": "author-cursor"
})),
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getActorLikes"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"feed": [{
"post": {
"uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post",
"cid": "bafyliked123",
"author": {"did": "did:plc:mock-likes", "handle": "mock.likes"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Liked post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
}
}],
"cursor": null
})),
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getPostThread"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"thread": {
"$type": "app.bsky.feed.defs#threadViewPost",
"post": {
"uri": "at://did:plc:mock/app.bsky.feed.post/thread-post",
"cid": "bafythread123",
"author": {"did": "did:plc:mock", "handle": "mock.handle"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Thread post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
},
"replies": []
}
})),
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getFeed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"feed": [{
"post": {
"uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post",
"cid": "bafyfeed123",
"author": {"did": "did:plc:mock-feed", "handle": "mock.feed"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Custom feed post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
}
}],
"cursor": null
})))
.mount(mock_server)
.await;
Mock::given(method("POST"))
.and(path("/xrpc/app.bsky.notification.registerPush"))
.respond_with(ResponseTemplate::new(200))
.mount(mock_server)
.await;
}
async fn spawn_app(database_url: String) -> String {

122
tests/feed.rs Normal file
View File

@@ -0,0 +1,122 @@
mod common;
use common::{base_url, client, create_account_and_login};
use serde_json::json;
#[tokio::test]
async fn test_get_timeline_requires_auth() {
let client = client();
let base = base_url().await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getTimeline", base))
.send()
.await
.unwrap();
assert_eq!(res.status(), 401);
}
#[tokio::test]
async fn test_get_author_feed_requires_actor() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_get_actor_likes_requires_actor() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_get_post_thread_requires_uri() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getPostThread", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_get_feed_requires_auth() {
let client = client();
let base = base_url().await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
base
))
.send()
.await
.unwrap();
assert_eq!(res.status(), 401);
}
#[tokio::test]
async fn test_get_feed_requires_feed_param() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getFeed", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_register_push_requires_auth() {
let client = client();
let base = base_url().await;
let res = client
.post(format!(
"{}/xrpc/app.bsky.notification.registerPush",
base
))
.json(&json!({
"serviceDid": "did:web:example.com",
"token": "test-token",
"platform": "ios",
"appId": "xyz.bsky.app"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), 401);
}

View File

@@ -23,7 +23,7 @@ fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
async fn test_import_rejects_car_for_different_user() {
let client = client();
let (token_a, did_a) = create_account_and_login(&client).await;
let (token_a, _did_a) = create_account_and_login(&client).await;
let (_token_b, did_b) = create_account_and_login(&client).await;
let export_res = client