mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-09 05:40:09 +00:00
Half-ass attempt at the local-first appview endpoints like ref impl
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT k.key_bytes, k.encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
|
||||
"query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -24,5 +24,5 @@
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "7bb1388dec372fe749462cd9b604e5802b770aeb110462208988141d31c86c92"
|
||||
"hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b"
|
||||
}
|
||||
23
.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json
generated
Normal file
23
.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json
generated
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "val",
|
||||
"type_info": "Int4"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
|
||||
"query": "INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
@@ -8,10 +8,11 @@
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "c61fc3b2fbdf6891269908ef21f13dcabdc3b032e9f767becae34ca176df18b6"
|
||||
"hash": "8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow'",
|
||||
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -18,5 +18,5 @@
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "bf60faafb5c79a149ba237a984f78d068b5d691f6762641412a5aa1517605c04"
|
||||
"hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc"
|
||||
}
|
||||
47
.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json
generated
Normal file
47
.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json
generated
Normal file
@@ -0,0 +1,47 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "collection",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "rkey",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "repo_rev",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e"
|
||||
}
|
||||
14
TODO.md
14
TODO.md
@@ -168,15 +168,15 @@ These endpoints need to be implemented at the PDS level (not just proxied to app
|
||||
- [x] Implement `app.bsky.actor.getProfiles` (PDS-level with proxy fallback).
|
||||
|
||||
### Feed (`app.bsky.feed`)
|
||||
These are implemented at PDS level to enable local-first reads:
|
||||
- [ ] Implement `app.bsky.feed.getTimeline` (PDS-level with proxy).
|
||||
- [ ] Implement `app.bsky.feed.getAuthorFeed` (PDS-level with proxy).
|
||||
- [ ] Implement `app.bsky.feed.getActorLikes` (PDS-level with proxy).
|
||||
- [ ] Implement `app.bsky.feed.getPostThread` (PDS-level with proxy).
|
||||
- [ ] Implement `app.bsky.feed.getFeed` (PDS-level with proxy).
|
||||
These are implemented at PDS level to enable local-first reads (read-after-write pattern):
|
||||
- [x] Implement `app.bsky.feed.getTimeline` (PDS-level with proxy + RAW).
|
||||
- [x] Implement `app.bsky.feed.getAuthorFeed` (PDS-level with proxy + RAW).
|
||||
- [x] Implement `app.bsky.feed.getActorLikes` (PDS-level with proxy + RAW).
|
||||
- [x] Implement `app.bsky.feed.getPostThread` (PDS-level with proxy + RAW + NotFound handling).
|
||||
- [x] Implement `app.bsky.feed.getFeed` (proxy to feed generator).
|
||||
|
||||
### Notification (`app.bsky.notification`)
|
||||
- [ ] Implement `app.bsky.notification.registerPush` (push notification registration).
|
||||
- [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied).
|
||||
|
||||
## Deprecated Sync Endpoints (for compatibility)
|
||||
- [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility).
|
||||
|
||||
2
migrations/202512211600_add_repo_rev.sql
Normal file
2
migrations/202512211600_add_repo_rev.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE records ADD COLUMN repo_rev TEXT;
|
||||
CREATE INDEX idx_records_repo_rev ON records(repo_rev);
|
||||
@@ -125,10 +125,18 @@ pub async fn get_profile(
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
|
||||
let auth_did = auth_header.and_then(|h| {
|
||||
let token = crate::auth::extract_bearer_token_from_header(Some(h))?;
|
||||
crate::auth::get_did_from_token(&token).ok()
|
||||
});
|
||||
let auth_did = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => Some(user.did),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("actor".to_string(), params.actor.clone());
|
||||
@@ -167,10 +175,18 @@ pub async fn get_profiles(
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
|
||||
let auth_did = auth_header.and_then(|h| {
|
||||
let token = crate::auth::extract_bearer_token_from_header(Some(h))?;
|
||||
crate::auth::get_did_from_token(&token).ok()
|
||||
});
|
||||
let auth_did = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => Some(user.did),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("actors".to_string(), params.actors.clone());
|
||||
|
||||
@@ -44,13 +44,19 @@ pub enum ApiError {
|
||||
InvitesDisabled,
|
||||
DatabaseError,
|
||||
UpstreamFailure,
|
||||
UpstreamTimeout,
|
||||
UpstreamUnavailable(String),
|
||||
UpstreamError { status: u16, error: Option<String>, message: Option<String> },
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
Self::InternalError | Self::DatabaseError => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::UpstreamFailure | Self::UpstreamUnavailable(_) => StatusCode::BAD_GATEWAY,
|
||||
Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT,
|
||||
Self::UpstreamError { status, .. } => {
|
||||
StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
Self::AuthenticationRequired
|
||||
| Self::AuthenticationFailed
|
||||
@@ -83,7 +89,15 @@ impl ApiError {
|
||||
|
||||
fn error_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => "InternalError",
|
||||
Self::InternalError | Self::DatabaseError => "InternalError",
|
||||
Self::UpstreamFailure | Self::UpstreamUnavailable(_) => "UpstreamFailure",
|
||||
Self::UpstreamTimeout => "UpstreamTimeout",
|
||||
Self::UpstreamError { error, .. } => {
|
||||
if let Some(e) = error {
|
||||
return Box::leak(e.clone().into_boxed_str());
|
||||
}
|
||||
"UpstreamError"
|
||||
}
|
||||
Self::AuthenticationRequired => "AuthenticationRequired",
|
||||
Self::AuthenticationFailed | Self::AuthenticationFailedMsg(_) => "AuthenticationFailed",
|
||||
Self::InvalidToken => "InvalidToken",
|
||||
@@ -116,10 +130,25 @@ impl ApiError {
|
||||
Self::AuthenticationFailedMsg(msg)
|
||||
| Self::ExpiredTokenMsg(msg)
|
||||
| Self::InvalidRequest(msg)
|
||||
| Self::RepoNotFoundMsg(msg) => Some(msg.clone()),
|
||||
| Self::RepoNotFoundMsg(msg)
|
||||
| Self::UpstreamUnavailable(msg) => Some(msg.clone()),
|
||||
Self::UpstreamError { message, .. } => message.clone(),
|
||||
Self::UpstreamTimeout => Some("Upstream service timed out".to_string()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_upstream_response(
|
||||
status: u16,
|
||||
body: &[u8],
|
||||
) -> Self {
|
||||
if let Ok(parsed) = serde_json::from_slice::<serde_json::Value>(body) {
|
||||
let error = parsed.get("error").and_then(|v| v.as_str()).map(String::from);
|
||||
let message = parsed.get("message").and_then(|v| v.as_str()).map(String::from);
|
||||
return Self::UpstreamError { status, error, message };
|
||||
}
|
||||
Self::UpstreamError { status, error: None, message: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
|
||||
158
src/api/feed/actor_likes.rs
Normal file
158
src/api/feed/actor_likes.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use crate::api::read_after_write::{
|
||||
extract_repo_rev, format_munged_response, get_local_lag, get_records_since_rev,
|
||||
proxy_to_appview, FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetActorLikesParams {
|
||||
pub actor: String,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
|
||||
for like in likes {
|
||||
let like_time = &like.indexed_at.to_rfc3339();
|
||||
let idx = feed
|
||||
.iter()
|
||||
.position(|fi| &fi.post.indexed_at < like_time)
|
||||
.unwrap_or(feed.len());
|
||||
|
||||
let placeholder_post = PostView {
|
||||
uri: like.record.subject.uri.clone(),
|
||||
cid: like.record.subject.cid.clone(),
|
||||
author: crate::api::read_after_write::AuthorView {
|
||||
did: String::new(),
|
||||
handle: String::new(),
|
||||
display_name: None,
|
||||
avatar: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
record: Value::Null,
|
||||
indexed_at: like.indexed_at.to_rfc3339(),
|
||||
embed: None,
|
||||
reply_count: 0,
|
||||
repost_count: 0,
|
||||
like_count: 0,
|
||||
quote_count: 0,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
feed.insert(
|
||||
idx,
|
||||
FeedViewPost {
|
||||
post: placeholder_post,
|
||||
reply: None,
|
||||
reason: None,
|
||||
feed_context: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_actor_likes(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetActorLikesParams>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
|
||||
let auth_did = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => Some(user.did),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("actor".to_string(), params.actor.clone());
|
||||
if let Some(limit) = params.limit {
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
}
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
|
||||
let proxy_result =
|
||||
match proxy_to_appview("app.bsky.feed.getActorLikes", &query_params, auth_header).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
if !proxy_result.status.is_success() {
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
|
||||
let rev = match extract_repo_rev(&proxy_result.headers) {
|
||||
Some(r) => r,
|
||||
None => return (proxy_result.status, proxy_result.body).into_response(),
|
||||
};
|
||||
|
||||
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse actor likes response: {:?}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let requester_did = match auth_did {
|
||||
Some(d) => d,
|
||||
None => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
};
|
||||
|
||||
let actor_did = if params.actor.starts_with("did:") {
|
||||
params.actor.clone()
|
||||
} else {
|
||||
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", params.actor)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(did)) => did,
|
||||
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
Err(e) => {
|
||||
warn!("Database error resolving actor handle: {:?}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if actor_did != requester_did {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
|
||||
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if local_records.likes.is_empty() {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
|
||||
insert_likes_into_feed(&mut feed_output.feed, &local_records.likes);
|
||||
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(feed_output, lag)
|
||||
}
|
||||
169
src/api/feed/author_feed.rs
Normal file
169
src/api/feed/author_feed.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use crate::api::read_after_write::{
|
||||
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
|
||||
get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost,
|
||||
ProfileRecord, RecordDescript,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetAuthorFeedParams {
|
||||
pub actor: String,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
pub filter: Option<String>,
|
||||
#[serde(rename = "includePins")]
|
||||
pub include_pins: Option<bool>,
|
||||
}
|
||||
|
||||
fn update_author_profile_in_feed(
|
||||
feed: &mut [FeedViewPost],
|
||||
author_did: &str,
|
||||
local_profile: &RecordDescript<ProfileRecord>,
|
||||
) {
|
||||
for item in feed.iter_mut() {
|
||||
if item.post.author.did == author_did {
|
||||
if let Some(ref display_name) = local_profile.record.display_name {
|
||||
item.post.author.display_name = Some(display_name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_author_feed(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetAuthorFeedParams>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
|
||||
let auth_did = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => Some(user.did),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("actor".to_string(), params.actor.clone());
|
||||
if let Some(limit) = params.limit {
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
}
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
if let Some(filter) = ¶ms.filter {
|
||||
query_params.insert("filter".to_string(), filter.clone());
|
||||
}
|
||||
if let Some(include_pins) = params.include_pins {
|
||||
query_params.insert("includePins".to_string(), include_pins.to_string());
|
||||
}
|
||||
|
||||
let proxy_result =
|
||||
match proxy_to_appview("app.bsky.feed.getAuthorFeed", &query_params, auth_header).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
if !proxy_result.status.is_success() {
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
|
||||
let rev = match extract_repo_rev(&proxy_result.headers) {
|
||||
Some(r) => r,
|
||||
None => return (proxy_result.status, proxy_result.body).into_response(),
|
||||
};
|
||||
|
||||
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse author feed response: {:?}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let requester_did = match auth_did {
|
||||
Some(d) => d,
|
||||
None => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
};
|
||||
|
||||
let actor_did = if params.actor.starts_with("did:") {
|
||||
params.actor.clone()
|
||||
} else {
|
||||
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", params.actor)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(did)) => did,
|
||||
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
Err(e) => {
|
||||
warn!("Database error resolving actor handle: {:?}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if actor_did != requester_did {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
|
||||
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if local_records.count == 0 {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => requester_did.clone(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
requester_did.clone()
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(ref local_profile) = local_records.profile {
|
||||
update_author_profile_in_feed(&mut feed_output.feed, &requester_did, local_profile);
|
||||
}
|
||||
|
||||
let local_posts: Vec<_> = local_records
|
||||
.posts
|
||||
.iter()
|
||||
.map(|p| {
|
||||
format_local_post(
|
||||
p,
|
||||
&requester_did,
|
||||
&handle,
|
||||
local_records.profile.as_ref(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
insert_posts_into_feed(&mut feed_output.feed, local_posts);
|
||||
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(feed_output, lag)
|
||||
}
|
||||
131
src/api/feed/custom_feed.rs
Normal file
131
src/api/feed/custom_feed.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use crate::api::proxy_client::{
|
||||
is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, MAX_RESPONSE_SIZE,
|
||||
};
|
||||
use crate::api::ApiError;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetFeedParams {
|
||||
pub feed: String,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_feed(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetFeedParams>,
|
||||
) -> Response {
|
||||
let token = match crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
Some(t) => t,
|
||||
None => return ApiError::AuthenticationRequired.into_response(),
|
||||
};
|
||||
|
||||
if let Err(e) = crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
return ApiError::from(e).into_response();
|
||||
};
|
||||
|
||||
if let Err(e) = validate_at_uri(¶ms.feed) {
|
||||
return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response();
|
||||
}
|
||||
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
|
||||
let appview_url = match std::env::var("APPVIEW_URL") {
|
||||
Ok(url) => url,
|
||||
Err(_) => {
|
||||
return ApiError::UpstreamUnavailable("No upstream AppView configured".to_string())
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = is_ssrf_safe(&appview_url) {
|
||||
error!("SSRF check failed for appview URL: {}", e);
|
||||
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let limit = validate_limit(params.limit, 50, 100);
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("feed".to_string(), params.feed.clone());
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
|
||||
let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", appview_url);
|
||||
info!(target = %target_url, feed = %params.feed, "Proxying getFeed request");
|
||||
|
||||
let client = proxy_client();
|
||||
let mut request_builder = client.get(&target_url).query(&query_params);
|
||||
|
||||
if let Some(auth) = auth_header {
|
||||
request_builder = request_builder.header("Authorization", auth);
|
||||
}
|
||||
|
||||
match request_builder.send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
|
||||
let content_length = resp.content_length().unwrap_or(0);
|
||||
if content_length > MAX_RESPONSE_SIZE {
|
||||
error!(
|
||||
content_length,
|
||||
max = MAX_RESPONSE_SIZE,
|
||||
"getFeed response too large"
|
||||
);
|
||||
return ApiError::UpstreamFailure.into_response();
|
||||
}
|
||||
|
||||
let resp_headers = resp.headers().clone();
|
||||
let body = match resp.bytes().await {
|
||||
Ok(b) => {
|
||||
if b.len() as u64 > MAX_RESPONSE_SIZE {
|
||||
error!(len = b.len(), "getFeed response body exceeded limit");
|
||||
return ApiError::UpstreamFailure.into_response();
|
||||
}
|
||||
b
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error reading getFeed response");
|
||||
return ApiError::UpstreamFailure.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut response_builder = axum::response::Response::builder().status(status);
|
||||
if let Some(ct) = resp_headers.get("content-type") {
|
||||
response_builder = response_builder.header("content-type", ct);
|
||||
}
|
||||
|
||||
match response_builder.body(axum::body::Body::from(body)) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error building getFeed response");
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error proxying getFeed");
|
||||
if e.is_timeout() {
|
||||
ApiError::UpstreamTimeout.into_response()
|
||||
} else if e.is_connect() {
|
||||
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
|
||||
.into_response()
|
||||
} else {
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,11 @@
|
||||
mod actor_likes;
|
||||
mod author_feed;
|
||||
mod custom_feed;
|
||||
mod post_thread;
|
||||
mod timeline;
|
||||
|
||||
pub use actor_likes::get_actor_likes;
|
||||
pub use author_feed::get_author_feed;
|
||||
pub use custom_feed::get_feed;
|
||||
pub use post_thread::get_post_thread;
|
||||
pub use timeline::get_timeline;
|
||||
|
||||
322
src/api/feed/post_thread.rs
Normal file
322
src/api/feed/post_thread.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
use crate::api::read_after_write::{
|
||||
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
|
||||
get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetPostThreadParams {
|
||||
pub uri: String,
|
||||
pub depth: Option<u32>,
|
||||
#[serde(rename = "parentHeight")]
|
||||
pub parent_height: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThreadViewPost {
|
||||
#[serde(rename = "$type")]
|
||||
pub thread_type: Option<String>,
|
||||
pub post: PostView,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parent: Option<Box<ThreadNode>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub replies: Option<Vec<ThreadNode>>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ThreadNode {
|
||||
Post(ThreadViewPost),
|
||||
NotFound(ThreadNotFound),
|
||||
Blocked(ThreadBlocked),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThreadNotFound {
|
||||
#[serde(rename = "$type")]
|
||||
pub thread_type: String,
|
||||
pub uri: String,
|
||||
pub not_found: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThreadBlocked {
|
||||
#[serde(rename = "$type")]
|
||||
pub thread_type: String,
|
||||
pub uri: String,
|
||||
pub blocked: bool,
|
||||
pub author: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PostThreadOutput {
|
||||
pub thread: ThreadNode,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub threadgate: Option<Value>,
|
||||
}
|
||||
|
||||
const MAX_THREAD_DEPTH: usize = 10;
|
||||
|
||||
fn add_replies_to_thread(
|
||||
thread: &mut ThreadViewPost,
|
||||
local_posts: &[RecordDescript<PostRecord>],
|
||||
author_did: &str,
|
||||
author_handle: &str,
|
||||
depth: usize,
|
||||
) {
|
||||
if depth >= MAX_THREAD_DEPTH {
|
||||
return;
|
||||
}
|
||||
|
||||
let thread_uri = &thread.post.uri;
|
||||
|
||||
let replies: Vec<_> = local_posts
|
||||
.iter()
|
||||
.filter(|p| {
|
||||
p.record
|
||||
.reply
|
||||
.as_ref()
|
||||
.and_then(|r| r.get("parent"))
|
||||
.and_then(|parent| parent.get("uri"))
|
||||
.and_then(|u| u.as_str())
|
||||
== Some(thread_uri)
|
||||
})
|
||||
.map(|p| {
|
||||
let post_view = format_local_post(p, author_did, author_handle, None);
|
||||
ThreadNode::Post(ThreadViewPost {
|
||||
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
|
||||
post: post_view,
|
||||
parent: None,
|
||||
replies: None,
|
||||
extra: HashMap::new(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !replies.is_empty() {
|
||||
match &mut thread.replies {
|
||||
Some(existing) => existing.extend(replies),
|
||||
None => thread.replies = Some(replies),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref mut existing_replies) = thread.replies {
|
||||
for reply in existing_replies.iter_mut() {
|
||||
if let ThreadNode::Post(reply_thread) = reply {
|
||||
add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_post_thread(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetPostThreadParams>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
|
||||
let auth_did = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => Some(user.did),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("uri".to_string(), params.uri.clone());
|
||||
if let Some(depth) = params.depth {
|
||||
query_params.insert("depth".to_string(), depth.to_string());
|
||||
}
|
||||
if let Some(parent_height) = params.parent_height {
|
||||
query_params.insert("parentHeight".to_string(), parent_height.to_string());
|
||||
}
|
||||
|
||||
let proxy_result =
|
||||
match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
if proxy_result.status == StatusCode::NOT_FOUND {
|
||||
return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
|
||||
}
|
||||
|
||||
if !proxy_result.status.is_success() {
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
|
||||
let rev = match extract_repo_rev(&proxy_result.headers) {
|
||||
Some(r) => r,
|
||||
None => return (proxy_result.status, proxy_result.body).into_response(),
|
||||
};
|
||||
|
||||
let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse post thread response: {:?}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let requester_did = match auth_did {
|
||||
Some(d) => d,
|
||||
None => return (StatusCode::OK, Json(thread_output)).into_response(),
|
||||
};
|
||||
|
||||
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if local_records.posts.is_empty() {
|
||||
return (StatusCode::OK, Json(thread_output)).into_response();
|
||||
}
|
||||
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => requester_did.clone(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
requester_did.clone()
|
||||
}
|
||||
};
|
||||
|
||||
if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
|
||||
add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0);
|
||||
}
|
||||
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(thread_output, lag)
|
||||
}
|
||||
|
||||
async fn handle_not_found(
|
||||
state: &AppState,
|
||||
uri: &str,
|
||||
auth_did: Option<String>,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Response {
|
||||
let rev = match extract_repo_rev(headers) {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let requester_did = match auth_did {
|
||||
Some(d) => d,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
|
||||
if uri_parts.len() != 3 {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let post_did = uri_parts[0];
|
||||
if post_did != requester_did {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let local_post = local_records.posts.iter().find(|p| p.uri == uri);
|
||||
|
||||
let local_post = match local_post {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => requester_did.clone(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
requester_did.clone()
|
||||
}
|
||||
};
|
||||
|
||||
let post_view = format_local_post(
|
||||
local_post,
|
||||
&requester_did,
|
||||
&handle,
|
||||
local_records.profile.as_ref(),
|
||||
);
|
||||
|
||||
let thread = PostThreadOutput {
|
||||
thread: ThreadNode::Post(ThreadViewPost {
|
||||
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
|
||||
post: post_view,
|
||||
parent: None,
|
||||
replies: None,
|
||||
extra: HashMap::new(),
|
||||
}),
|
||||
threadgate: None,
|
||||
};
|
||||
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(thread, lag)
|
||||
}
|
||||
@@ -1,51 +1,35 @@
|
||||
// Yes, I know, this endpoint is an appview one, not for PDS. Who cares!!
|
||||
// Yes, this only gets posts that our DB/instance knows about. Who cares!!!
|
||||
|
||||
use crate::api::read_after_write::{
|
||||
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
|
||||
get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost,
|
||||
PostView,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use serde::Serialize;
|
||||
use serde_json::{Value, json};
|
||||
use tracing::error;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TimelineOutput {
|
||||
pub feed: Vec<FeedViewPost>,
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetTimelineParams {
|
||||
pub algorithm: Option<String>,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct FeedViewPost {
|
||||
pub post: PostView,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostView {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
pub author: AuthorView,
|
||||
pub record: Value,
|
||||
pub indexed_at: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AuthorView {
|
||||
pub did: String,
|
||||
pub handle: String,
|
||||
}
|
||||
|
||||
pub async fn get_timeline(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetTimelineParams>,
|
||||
) -> Response {
|
||||
let token = match crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok())
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
@@ -68,32 +52,131 @@ pub async fn get_timeline(
|
||||
}
|
||||
};
|
||||
|
||||
let user_query = sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
match std::env::var("APPVIEW_URL") {
|
||||
Ok(url) if !url.starts_with("http://127.0.0.1") => {
|
||||
return get_timeline_with_appview(&state, &headers, ¶ms, &auth_user.did).await;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let user_id = match user_query {
|
||||
Ok(Some(row)) => row.id,
|
||||
_ => {
|
||||
get_timeline_local_only(&state, &auth_user.did).await
|
||||
}
|
||||
|
||||
async fn get_timeline_with_appview(
|
||||
state: &AppState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
params: &GetTimelineParams,
|
||||
auth_did: &str,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
|
||||
let mut query_params = HashMap::new();
|
||||
if let Some(algo) = ¶ms.algorithm {
|
||||
query_params.insert("algorithm".to_string(), algo.clone());
|
||||
}
|
||||
if let Some(limit) = params.limit {
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
}
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
|
||||
let proxy_result =
|
||||
match proxy_to_appview("app.bsky.feed.getTimeline", &query_params, auth_header).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
if !proxy_result.status.is_success() {
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
|
||||
let rev = extract_repo_rev(&proxy_result.headers);
|
||||
if rev.is_none() {
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
let rev = rev.unwrap();
|
||||
|
||||
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse timeline response: {:?}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let local_records = match get_records_since_rev(state, auth_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if local_records.count == 0 {
|
||||
return (proxy_result.status, proxy_result.body).into_response();
|
||||
}
|
||||
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => auth_did.to_string(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
auth_did.to_string()
|
||||
}
|
||||
};
|
||||
|
||||
let local_posts: Vec<_> = local_records
|
||||
.posts
|
||||
.iter()
|
||||
.map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref()))
|
||||
.collect();
|
||||
|
||||
insert_posts_into_feed(&mut feed_output.feed, local_posts);
|
||||
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(feed_output, lag)
|
||||
}
|
||||
|
||||
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
|
||||
let user_id: uuid::Uuid = match sqlx::query_scalar!(
|
||||
"SELECT id FROM users WHERE did = $1",
|
||||
auth_did
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(id)) => id,
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "User not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Database error fetching user: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Database error"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let follows_query = sqlx::query!(
|
||||
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow'",
|
||||
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
|
||||
user_id
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
let follow_cids: Vec<String> = match follows_query {
|
||||
Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(),
|
||||
Err(e) => {
|
||||
error!("Failed to get follows: {:?}", e);
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
@@ -127,7 +210,7 @@ pub async fn get_timeline(
|
||||
if followed_dids.is_empty() {
|
||||
return (
|
||||
StatusCode::OK,
|
||||
Json(TimelineOutput {
|
||||
Json(FeedOutput {
|
||||
feed: vec![],
|
||||
cursor: None,
|
||||
}),
|
||||
@@ -150,8 +233,7 @@ pub async fn get_timeline(
|
||||
|
||||
let posts = match posts_result {
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
error!("Failed to get posts: {:?}", e);
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
@@ -190,15 +272,28 @@ pub async fn get_timeline(
|
||||
post: PostView {
|
||||
uri,
|
||||
cid: record_cid,
|
||||
author: AuthorView {
|
||||
author: crate::api::read_after_write::AuthorView {
|
||||
did: author_did,
|
||||
handle: author_handle,
|
||||
display_name: None,
|
||||
avatar: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
record,
|
||||
indexed_at: created_at.to_rfc3339(),
|
||||
embed: None,
|
||||
reply_count: 0,
|
||||
repost_count: 0,
|
||||
like_count: 0,
|
||||
quote_count: 0,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
reply: None,
|
||||
reason: None,
|
||||
feed_context: None,
|
||||
extra: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(TimelineOutput { feed, cursor: None })).into_response()
|
||||
(StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response()
|
||||
}
|
||||
|
||||
@@ -4,9 +4,13 @@ pub mod error;
|
||||
pub mod feed;
|
||||
pub mod identity;
|
||||
pub mod moderation;
|
||||
pub mod notification;
|
||||
pub mod proxy;
|
||||
pub mod proxy_client;
|
||||
pub mod read_after_write;
|
||||
pub mod repo;
|
||||
pub mod server;
|
||||
pub mod validation;
|
||||
|
||||
pub use error::ApiError;
|
||||
pub use proxy_client::{proxy_client, validate_at_uri, validate_did, validate_limit, AtUriParts};
|
||||
|
||||
3
src/api/notification/mod.rs
Normal file
3
src/api/notification/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod register_push;
|
||||
|
||||
pub use register_push::register_push;
|
||||
166
src/api/notification/register_push.rs
Normal file
166
src/api/notification/register_push.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did};
|
||||
use crate::api::ApiError;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RegisterPushInput {
|
||||
pub service_did: String,
|
||||
pub token: String,
|
||||
pub platform: String,
|
||||
pub app_id: String,
|
||||
}
|
||||
|
||||
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
|
||||
|
||||
pub async fn register_push(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(input): Json<RegisterPushInput>,
|
||||
) -> Response {
|
||||
let token = match crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
Some(t) => t,
|
||||
None => return ApiError::AuthenticationRequired.into_response(),
|
||||
};
|
||||
|
||||
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => user,
|
||||
Err(e) => return ApiError::from(e).into_response(),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_did(&input.service_did) {
|
||||
return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response();
|
||||
}
|
||||
|
||||
if input.token.is_empty() || input.token.len() > 4096 {
|
||||
return ApiError::InvalidRequest("Invalid push token".to_string()).into_response();
|
||||
}
|
||||
|
||||
if !VALID_PLATFORMS.contains(&input.platform.as_str()) {
|
||||
return ApiError::InvalidRequest(format!(
|
||||
"Invalid platform. Must be one of: {}",
|
||||
VALID_PLATFORMS.join(", ")
|
||||
))
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if input.app_id.is_empty() || input.app_id.len() > 256 {
|
||||
return ApiError::InvalidRequest("Invalid appId".to_string()).into_response();
|
||||
}
|
||||
|
||||
let appview_url = match std::env::var("APPVIEW_URL") {
|
||||
Ok(url) => url,
|
||||
Err(_) => {
|
||||
return ApiError::UpstreamUnavailable("No upstream AppView configured".to_string())
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = is_ssrf_safe(&appview_url) {
|
||||
error!("SSRF check failed for appview URL: {}", e);
|
||||
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let key_row = match sqlx::query!(
|
||||
"SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
|
||||
auth_user.did
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => {
|
||||
error!(did = %auth_user.did, "No signing key found for user");
|
||||
return ApiError::InternalError.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Database error fetching signing key");
|
||||
return ApiError::DatabaseError.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let decrypted_key =
|
||||
match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Failed to decrypt signing key");
|
||||
return ApiError::InternalError.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let service_token = match crate::auth::create_service_token(
|
||||
&auth_user.did,
|
||||
&input.service_did,
|
||||
"app.bsky.notification.registerPush",
|
||||
&decrypted_key,
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Failed to create service token");
|
||||
return ApiError::InternalError.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", appview_url);
|
||||
info!(
|
||||
target = %target_url,
|
||||
service_did = %input.service_did,
|
||||
platform = %input.platform,
|
||||
"Proxying registerPush request"
|
||||
);
|
||||
|
||||
let client = proxy_client();
|
||||
let request_body = json!({
|
||||
"serviceDid": input.service_did,
|
||||
"token": input.token,
|
||||
"platform": input.platform,
|
||||
"appId": input.app_id
|
||||
});
|
||||
|
||||
match client
|
||||
.post(&target_url)
|
||||
.header("Authorization", format!("Bearer {}", service_token))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
if status.is_success() {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
let body = resp.bytes().await.unwrap_or_default();
|
||||
error!(
|
||||
status = %status,
|
||||
"registerPush upstream error"
|
||||
);
|
||||
ApiError::from_upstream_response(status.as_u16(), &body).into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error proxying registerPush");
|
||||
if e.is_timeout() {
|
||||
ApiError::UpstreamTimeout.into_response()
|
||||
} else if e.is_connect() {
|
||||
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
|
||||
.into_response()
|
||||
} else {
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -46,21 +46,15 @@ pub async fn proxy_handler(
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok())
|
||||
) {
|
||||
if let Ok(did) = crate::auth::get_did_from_token(&token) {
|
||||
let key_row = sqlx::query!("SELECT k.key_bytes, k.encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
if let Ok(Some(row)) = key_row {
|
||||
if let Ok(decrypted_key) = crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
|
||||
if let Ok(new_token) =
|
||||
crate::auth::create_service_token(&did, aud, &method, &decrypted_key)
|
||||
if let Ok(auth_user) = crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
if let Some(key_bytes) = auth_user.key_bytes {
|
||||
if let Ok(new_token) =
|
||||
crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes)
|
||||
{
|
||||
if let Ok(val) =
|
||||
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
|
||||
{
|
||||
if let Ok(val) =
|
||||
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
|
||||
{
|
||||
auth_header_val = Some(val);
|
||||
}
|
||||
auth_header_val = Some(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
252
src/api/proxy_client.rs
Normal file
252
src/api/proxy_client.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
use reqwest::{Client, ClientBuilder, Url};
|
||||
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tracing::warn;
|
||||
|
||||
pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
|
||||
|
||||
static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
|
||||
|
||||
pub fn proxy_client() -> &'static Client {
|
||||
PROXY_CLIENT.get_or_init(|| {
|
||||
ClientBuilder::new()
|
||||
.timeout(DEFAULT_BODY_TIMEOUT)
|
||||
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
|
||||
.pool_max_idle_per_host(10)
|
||||
.pool_idle_timeout(Duration::from_secs(90))
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.expect("Failed to build HTTP client - this indicates a TLS or system configuration issue")
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
|
||||
let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
|
||||
|
||||
let scheme = parsed.scheme();
|
||||
if scheme != "https" {
|
||||
let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok()
|
||||
|| url.starts_with("http://127.0.0.1")
|
||||
|| url.starts_with("http://localhost");
|
||||
|
||||
if !allow_http {
|
||||
return Err(SsrfError::InsecureProtocol(scheme.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let host = parsed.host_str().ok_or(SsrfError::NoHost)?;
|
||||
|
||||
if host == "localhost" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
if ip.is_loopback() {
|
||||
return Ok(());
|
||||
}
|
||||
if !is_unicast_ip(&ip) {
|
||||
return Err(SsrfError::NonUnicastIp(ip.to_string()));
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let port = parsed.port().unwrap_or(if scheme == "https" { 443 } else { 80 });
|
||||
let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() {
|
||||
Ok(addrs) => addrs.collect(),
|
||||
Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
|
||||
};
|
||||
|
||||
for addr in &socket_addrs {
|
||||
if !is_unicast_ip(&addr.ip()) {
|
||||
warn!(
|
||||
"DNS resolution for {} returned non-unicast IP: {}",
|
||||
host,
|
||||
addr.ip()
|
||||
);
|
||||
return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_unicast_ip(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
!v4.is_loopback()
|
||||
&& !v4.is_broadcast()
|
||||
&& !v4.is_multicast()
|
||||
&& !v4.is_unspecified()
|
||||
&& !v4.is_link_local()
|
||||
&& !is_private_v4(v4)
|
||||
}
|
||||
IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
|
||||
let octets = ip.octets();
|
||||
octets[0] == 10
|
||||
|| (octets[0] == 172 && (16..=31).contains(&octets[1]))
|
||||
|| (octets[0] == 192 && octets[1] == 168)
|
||||
|| (octets[0] == 169 && octets[1] == 254)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SsrfError {
|
||||
InvalidUrl,
|
||||
InsecureProtocol(String),
|
||||
NoHost,
|
||||
NonUnicastIp(String),
|
||||
DnsResolutionFailed(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SsrfError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SsrfError::InvalidUrl => write!(f, "Invalid URL"),
|
||||
SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p),
|
||||
SsrfError::NoHost => write!(f, "No host in URL"),
|
||||
SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip),
|
||||
SsrfError::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for: {}", host),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SsrfError {}
|
||||
|
||||
pub const HEADERS_TO_FORWARD: &[&str] = &[
|
||||
"accept-language",
|
||||
"atproto-accept-labelers",
|
||||
"x-bsky-topics",
|
||||
];
|
||||
|
||||
pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[
|
||||
"atproto-repo-rev",
|
||||
"atproto-content-labelers",
|
||||
"retry-after",
|
||||
"content-type",
|
||||
];
|
||||
|
||||
pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
|
||||
if !uri.starts_with("at://") {
|
||||
return Err("URI must start with at://");
|
||||
}
|
||||
|
||||
let path = uri.trim_start_matches("at://");
|
||||
let parts: Vec<&str> = path.split('/').collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
return Err("URI missing DID");
|
||||
}
|
||||
|
||||
let did = parts[0];
|
||||
if !did.starts_with("did:") {
|
||||
return Err("Invalid DID in URI");
|
||||
}
|
||||
|
||||
if parts.len() > 1 {
|
||||
let collection = parts[1];
|
||||
if collection.is_empty() || !collection.contains('.') {
|
||||
return Err("Invalid collection NSID");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AtUriParts {
|
||||
did: did.to_string(),
|
||||
collection: parts.get(1).map(|s| s.to_string()),
|
||||
rkey: parts.get(2).map(|s| s.to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AtUriParts {
|
||||
pub did: String,
|
||||
pub collection: Option<String>,
|
||||
pub rkey: Option<String>,
|
||||
}
|
||||
|
||||
pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
|
||||
match limit {
|
||||
Some(l) if l == 0 => default,
|
||||
Some(l) if l > max => max,
|
||||
Some(l) => l,
|
||||
None => default,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_did(did: &str) -> Result<(), &'static str> {
|
||||
if !did.starts_with("did:") {
|
||||
return Err("Invalid DID format");
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = did.split(':').collect();
|
||||
if parts.len() < 3 {
|
||||
return Err("DID must have at least method and identifier");
|
||||
}
|
||||
|
||||
let method = parts[1];
|
||||
if method != "plc" && method != "web" {
|
||||
return Err("Unsupported DID method");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_safe_https() {
|
||||
assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_http_by_default() {
|
||||
let result = is_ssrf_safe("http://external.example.com/xrpc/test");
|
||||
assert!(matches!(result, Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_allows_localhost_http() {
|
||||
assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok());
|
||||
assert!(is_ssrf_safe("http://localhost:8080/test").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_at_uri() {
|
||||
let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123");
|
||||
assert!(result.is_ok());
|
||||
let parts = result.unwrap();
|
||||
assert_eq!(parts.did, "did:plc:test");
|
||||
assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string()));
|
||||
assert_eq!(parts.rkey, Some("abc123".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_at_uri_invalid() {
|
||||
assert!(validate_at_uri("https://example.com").is_err());
|
||||
assert!(validate_at_uri("at://notadid/collection/rkey").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_limit() {
|
||||
assert_eq!(validate_limit(None, 50, 100), 50);
|
||||
assert_eq!(validate_limit(Some(0), 50, 100), 50);
|
||||
assert_eq!(validate_limit(Some(200), 50, 100), 100);
|
||||
assert_eq!(validate_limit(Some(75), 50, 100), 75);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_did() {
|
||||
assert!(validate_did("did:plc:abc123").is_ok());
|
||||
assert!(validate_did("did:web:example.com").is_ok());
|
||||
assert!(validate_did("notadid").is_err());
|
||||
assert!(validate_did("did:unknown:test").is_err());
|
||||
}
|
||||
}
|
||||
433
src/api/read_after_write.rs
Normal file
433
src/api/read_after_write.rs
Normal file
@@ -0,0 +1,433 @@
|
||||
use crate::api::proxy_client::{
|
||||
is_ssrf_safe, proxy_client, MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD,
|
||||
};
|
||||
use crate::api::ApiError;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
http::{HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
|
||||
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostRecord {
|
||||
#[serde(rename = "$type")]
|
||||
pub record_type: Option<String>,
|
||||
pub text: String,
|
||||
pub created_at: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reply: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embed: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub langs: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub labels: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tags: Option<Vec<String>>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProfileRecord {
|
||||
#[serde(rename = "$type")]
|
||||
pub record_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avatar: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub banner: Option<Value>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RecordDescript<T> {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
pub indexed_at: DateTime<Utc>,
|
||||
pub record: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LikeRecord {
|
||||
#[serde(rename = "$type")]
|
||||
pub record_type: Option<String>,
|
||||
pub subject: LikeSubject,
|
||||
pub created_at: String,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LikeSubject {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LocalRecords {
|
||||
pub count: usize,
|
||||
pub profile: Option<RecordDescript<ProfileRecord>>,
|
||||
pub posts: Vec<RecordDescript<PostRecord>>,
|
||||
pub likes: Vec<RecordDescript<LikeRecord>>,
|
||||
}
|
||||
|
||||
pub async fn get_records_since_rev(
|
||||
state: &AppState,
|
||||
did: &str,
|
||||
rev: &str,
|
||||
) -> Result<LocalRecords, String> {
|
||||
let mut result = LocalRecords::default();
|
||||
|
||||
let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| format!("DB error: {}", e))?
|
||||
.ok_or_else(|| "User not found".to_string())?;
|
||||
|
||||
let rows = sqlx::query!(
|
||||
r#"
|
||||
SELECT record_cid, collection, rkey, created_at, repo_rev
|
||||
FROM records
|
||||
WHERE repo_id = $1 AND repo_rev > $2
|
||||
ORDER BY repo_rev ASC
|
||||
LIMIT 10
|
||||
"#,
|
||||
user_id,
|
||||
rev
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| format!("DB error fetching records: {}", e))?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let sanity_check = sqlx::query_scalar!(
|
||||
"SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
|
||||
user_id,
|
||||
rev
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| format!("DB error sanity check: {}", e))?;
|
||||
|
||||
if sanity_check.is_none() {
|
||||
warn!("Sanity check failed: no records found before rev {}", rev);
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
for row in rows {
|
||||
result.count += 1;
|
||||
|
||||
let cid: cid::Cid = match row.record_cid.parse() {
|
||||
Ok(c) => c,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let block_bytes = match state.block_store.get(&cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let uri = format!("at://{}/{}/{}", did, row.collection, row.rkey);
|
||||
let indexed_at = row.created_at;
|
||||
|
||||
if row.collection == "app.bsky.actor.profile" && row.rkey == "self" {
|
||||
if let Ok(record) = serde_ipld_dagcbor::from_slice::<ProfileRecord>(&block_bytes) {
|
||||
result.profile = Some(RecordDescript {
|
||||
uri,
|
||||
cid: row.record_cid,
|
||||
indexed_at,
|
||||
record,
|
||||
});
|
||||
}
|
||||
} else if row.collection == "app.bsky.feed.post" {
|
||||
if let Ok(record) = serde_ipld_dagcbor::from_slice::<PostRecord>(&block_bytes) {
|
||||
result.posts.push(RecordDescript {
|
||||
uri,
|
||||
cid: row.record_cid,
|
||||
indexed_at,
|
||||
record,
|
||||
});
|
||||
}
|
||||
} else if row.collection == "app.bsky.feed.like" {
|
||||
if let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) {
|
||||
result.likes.push(RecordDescript {
|
||||
uri,
|
||||
cid: row.record_cid,
|
||||
indexed_at,
|
||||
record,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
|
||||
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
|
||||
|
||||
for post in &local.posts {
|
||||
match oldest {
|
||||
None => oldest = Some(post.indexed_at),
|
||||
Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
for like in &local.likes {
|
||||
match oldest {
|
||||
None => oldest = Some(like.indexed_at),
|
||||
Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
oldest.map(|o| (Utc::now() - o).num_milliseconds())
|
||||
}
|
||||
|
||||
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get(REPO_REV_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyResponse {
|
||||
pub status: StatusCode,
|
||||
pub headers: HeaderMap,
|
||||
pub body: bytes::Bytes,
|
||||
}
|
||||
|
||||
pub async fn proxy_to_appview(
|
||||
method: &str,
|
||||
params: &HashMap<String, String>,
|
||||
auth_header: Option<&str>,
|
||||
) -> Result<ProxyResponse, Response> {
|
||||
let appview_url = std::env::var("APPVIEW_URL").map_err(|_| {
|
||||
ApiError::UpstreamUnavailable("No upstream AppView configured".to_string()).into_response()
|
||||
})?;
|
||||
|
||||
if let Err(e) = is_ssrf_safe(&appview_url) {
|
||||
error!("SSRF check failed for appview URL: {}", e);
|
||||
return Err(ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
|
||||
.into_response());
|
||||
}
|
||||
|
||||
let target_url = format!("{}/xrpc/{}", appview_url, method);
|
||||
info!(target = %target_url, "Proxying request to appview");
|
||||
|
||||
let client = proxy_client();
|
||||
let mut request_builder = client.get(&target_url).query(params);
|
||||
|
||||
if let Some(auth) = auth_header {
|
||||
request_builder = request_builder.header("Authorization", auth);
|
||||
}
|
||||
|
||||
match request_builder.send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
|
||||
let headers: HeaderMap = resp
|
||||
.headers()
|
||||
.iter()
|
||||
.filter(|(k, _)| {
|
||||
RESPONSE_HEADERS_TO_FORWARD
|
||||
.iter()
|
||||
.any(|h| k.as_str().eq_ignore_ascii_case(h))
|
||||
})
|
||||
.filter_map(|(k, v)| {
|
||||
let name = axum::http::HeaderName::try_from(k.as_str()).ok()?;
|
||||
let value = HeaderValue::from_bytes(v.as_bytes()).ok()?;
|
||||
Some((name, value))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let content_length = resp
|
||||
.content_length()
|
||||
.unwrap_or(0);
|
||||
if content_length > MAX_RESPONSE_SIZE {
|
||||
error!(
|
||||
content_length,
|
||||
max = MAX_RESPONSE_SIZE,
|
||||
"Upstream response too large"
|
||||
);
|
||||
return Err(ApiError::UpstreamFailure.into_response());
|
||||
}
|
||||
|
||||
let body = resp.bytes().await.map_err(|e| {
|
||||
error!(error = ?e, "Error reading proxy response body");
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
})?;
|
||||
|
||||
if body.len() as u64 > MAX_RESPONSE_SIZE {
|
||||
error!(
|
||||
len = body.len(),
|
||||
max = MAX_RESPONSE_SIZE,
|
||||
"Upstream response body exceeded size limit"
|
||||
);
|
||||
return Err(ApiError::UpstreamFailure.into_response());
|
||||
}
|
||||
|
||||
Ok(ProxyResponse {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error sending proxy request");
|
||||
if e.is_timeout() {
|
||||
Err(ApiError::UpstreamTimeout.into_response())
|
||||
} else if e.is_connect() {
|
||||
Err(ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
|
||||
.into_response())
|
||||
} else {
|
||||
Err(ApiError::UpstreamFailure.into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
|
||||
let mut response = (StatusCode::OK, Json(data)).into_response();
|
||||
|
||||
if let Some(lag_ms) = lag {
|
||||
if let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(UPSTREAM_LAG_HEADER, header_val);
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AuthorView {
|
||||
pub did: String,
|
||||
pub handle: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avatar: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostView {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
pub author: AuthorView,
|
||||
pub record: Value,
|
||||
pub indexed_at: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embed: Option<Value>,
|
||||
#[serde(default)]
|
||||
pub reply_count: i64,
|
||||
#[serde(default)]
|
||||
pub repost_count: i64,
|
||||
#[serde(default)]
|
||||
pub like_count: i64,
|
||||
#[serde(default)]
|
||||
pub quote_count: i64,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FeedViewPost {
|
||||
pub post: PostView,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reply: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub feed_context: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FeedOutput {
|
||||
pub feed: Vec<FeedViewPost>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
pub fn format_local_post(
|
||||
descript: &RecordDescript<PostRecord>,
|
||||
author_did: &str,
|
||||
author_handle: &str,
|
||||
profile: Option<&RecordDescript<ProfileRecord>>,
|
||||
) -> PostView {
|
||||
let display_name = profile.and_then(|p| p.record.display_name.clone());
|
||||
|
||||
PostView {
|
||||
uri: descript.uri.clone(),
|
||||
cid: descript.cid.clone(),
|
||||
author: AuthorView {
|
||||
did: author_did.to_string(),
|
||||
handle: author_handle.to_string(),
|
||||
display_name,
|
||||
avatar: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
record: serde_json::to_value(&descript.record).unwrap_or(Value::Null),
|
||||
indexed_at: descript.indexed_at.to_rfc3339(),
|
||||
embed: descript.record.embed.clone(),
|
||||
reply_count: 0,
|
||||
repost_count: 0,
|
||||
like_count: 0,
|
||||
quote_count: 0,
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
|
||||
if posts.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let new_items: Vec<FeedViewPost> = posts
|
||||
.into_iter()
|
||||
.map(|post| FeedViewPost {
|
||||
post,
|
||||
reply: None,
|
||||
reason: None,
|
||||
feed_context: None,
|
||||
extra: HashMap::new(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
feed.extend(new_items);
|
||||
feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at));
|
||||
}
|
||||
@@ -25,7 +25,7 @@ pub async fn commit_and_log(
|
||||
current_root_cid: Option<Cid>,
|
||||
new_mst_root: Cid,
|
||||
ops: Vec<RecordOp>,
|
||||
blocks_cids: &Vec<String>,
|
||||
blocks_cids: &[String],
|
||||
) -> Result<CommitResult, String> {
|
||||
let key_row = sqlx::query!(
|
||||
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
|
||||
@@ -63,16 +63,18 @@ pub async fn commit_and_log(
|
||||
.await
|
||||
.map_err(|e| format!("DB Error (repos): {}", e))?;
|
||||
|
||||
let rev_str = rev.to_string();
|
||||
for op in &ops {
|
||||
match op {
|
||||
RecordOp::Create { collection, rkey, cid } | RecordOp::Update { collection, rkey, cid } => {
|
||||
sqlx::query!(
|
||||
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
|
||||
"INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()",
|
||||
user_id,
|
||||
collection,
|
||||
rkey,
|
||||
cid.to_string()
|
||||
cid.to_string(),
|
||||
rev_str
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
|
||||
21
src/lib.rs
21
src/lib.rs
@@ -296,11 +296,30 @@ pub fn app(state: AppState) -> Router {
|
||||
"/xrpc/app.bsky.actor.getProfiles",
|
||||
get(api::actor::get_profiles),
|
||||
)
|
||||
// I know I know, I'm not supposed to implement appview endpoints. Leave me be
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getTimeline",
|
||||
get(api::feed::get_timeline),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getAuthorFeed",
|
||||
get(api::feed::get_author_feed),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getActorLikes",
|
||||
get(api::feed::get_actor_likes),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getPostThread",
|
||||
get(api::feed::get_post_thread),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getFeed",
|
||||
get(api::feed::get_feed),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.notification.registerPush",
|
||||
post(api::notification::register_push),
|
||||
)
|
||||
.route("/.well-known/did.json", get(api::identity::well_known_did))
|
||||
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
|
||||
// OAuth 2.1 endpoints
|
||||
|
||||
149
tests/appview_integration.rs
Normal file
149
tests/appview_integration.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
mod common;
|
||||
|
||||
use common::{base_url, client, create_account_and_login};
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_author_feed_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}",
|
||||
base, did
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(body["feed"].is_array(), "Response should have feed array");
|
||||
let feed = body["feed"].as_array().unwrap();
|
||||
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
|
||||
assert_eq!(
|
||||
feed[0]["post"]["record"]["text"].as_str(),
|
||||
Some("Author feed post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_actor_likes_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getActorLikes?actor={}",
|
||||
base, did
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(body["feed"].is_array(), "Response should have feed array");
|
||||
let feed = body["feed"].as_array().unwrap();
|
||||
assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview");
|
||||
assert_eq!(
|
||||
feed[0]["post"]["record"]["text"].as_str(),
|
||||
Some("Liked post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_post_thread_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123",
|
||||
base, did
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(body["thread"].is_object(), "Response should have thread object");
|
||||
assert_eq!(
|
||||
body["thread"]["$type"].as_str(),
|
||||
Some("app.bsky.feed.defs#threadViewPost"),
|
||||
"Thread should be a threadViewPost"
|
||||
);
|
||||
assert_eq!(
|
||||
body["thread"]["post"]["record"]["text"].as_str(),
|
||||
Some("Thread post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_feed_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
|
||||
base
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(body["feed"].is_array(), "Response should have feed array");
|
||||
let feed = body["feed"].as_array().unwrap();
|
||||
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
|
||||
assert_eq!(
|
||||
feed[0]["post"]["record"]["text"].as_str(),
|
||||
Some("Custom feed post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_push_proxies_to_appview() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/app.bsky.notification.registerPush",
|
||||
base
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.json(&json!({
|
||||
"serviceDid": "did:web:example.com",
|
||||
"token": "test-push-token",
|
||||
"platform": "ios",
|
||||
"appId": "xyz.bsky.app"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
@@ -233,6 +233,122 @@ async fn setup_mock_appview(mock_server: &MockServer) {
|
||||
})))
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getTimeline"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"feed": [],
|
||||
"cursor": null
|
||||
}))
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getAuthorFeed"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"feed": [{
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author",
|
||||
"cid": "bafyappview123",
|
||||
"author": {"did": "did:plc:mock-author", "handle": "mock.author"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Author feed post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
}],
|
||||
"cursor": "author-cursor"
|
||||
})),
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getActorLikes"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"feed": [{
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post",
|
||||
"cid": "bafyliked123",
|
||||
"author": {"did": "did:plc:mock-likes", "handle": "mock.likes"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Liked post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
}],
|
||||
"cursor": null
|
||||
})),
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getPostThread"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"thread": {
|
||||
"$type": "app.bsky.feed.defs#threadViewPost",
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock/app.bsky.feed.post/thread-post",
|
||||
"cid": "bafythread123",
|
||||
"author": {"did": "did:plc:mock", "handle": "mock.handle"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Thread post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"replies": []
|
||||
}
|
||||
})),
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getFeed"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"feed": [{
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post",
|
||||
"cid": "bafyfeed123",
|
||||
"author": {"did": "did:plc:mock-feed", "handle": "mock.feed"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Custom feed post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
}],
|
||||
"cursor": null
|
||||
})))
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/xrpc/app.bsky.notification.registerPush"))
|
||||
.respond_with(ResponseTemplate::new(200))
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn spawn_app(database_url: String) -> String {
|
||||
|
||||
122
tests/feed.rs
Normal file
122
tests/feed.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
mod common;
|
||||
|
||||
use common::{base_url, client, create_account_and_login};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_timeline_requires_auth() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getTimeline", base))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), 401);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_author_feed_requires_actor() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_actor_likes_requires_actor() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_post_thread_requires_uri() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getPostThread", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_feed_requires_auth() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
|
||||
base
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), 401);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_feed_requires_feed_param() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getFeed", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_push_requires_auth() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/app.bsky.notification.registerPush",
|
||||
base
|
||||
))
|
||||
.json(&json!({
|
||||
"serviceDid": "did:web:example.com",
|
||||
"token": "test-token",
|
||||
"platform": "ios",
|
||||
"appId": "xyz.bsky.app"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), 401);
|
||||
}
|
||||
@@ -23,7 +23,7 @@ fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
|
||||
async fn test_import_rejects_car_for_different_user() {
|
||||
let client = client();
|
||||
|
||||
let (token_a, did_a) = create_account_and_login(&client).await;
|
||||
let (token_a, _did_a) = create_account_and_login(&client).await;
|
||||
let (_token_b, did_b) = create_account_and_login(&client).await;
|
||||
|
||||
let export_res = client
|
||||
|
||||
Reference in New Issue
Block a user