diff --git a/.env.example b/.env.example index 598aa45..1b1c622 100644 --- a/.env.example +++ b/.env.example @@ -48,25 +48,13 @@ AWS_SECRET_ACCESS_KEY=minioadmin # Optional: rotation key for PLC operations (defaults to user's key) # PLC_ROTATION_KEY=did:key:... # ============================================================================= -# AppView Federation +# DID Resolution +# ============================================================================= +# Cache TTL for resolved DID documents (default: 300 seconds) +# DID_CACHE_TTL_SECS=300 +# ============================================================================= +# Relays # ============================================================================= -# AppViews are resolved via DID-based discovery. Configure by mapping lexicon -# namespaces to AppView DIDs. The DID document is fetched and the service -# endpoint is extracted automatically. -# -# Format: APPVIEW_DID_= -# Where uses underscores instead of dots (e.g., APP_BSKY for app.bsky) -# -# Default: app.bsky and com.atproto -> did:web:api.bsky.app -# -# Examples: -# APPVIEW_DID_APP_BSKY=did:web:api.bsky.app -# APPVIEW_DID_COM_WHTWND=did:web:whtwnd.com -# APPVIEW_DID_BLUE_ZIO=did:plc:some-custom-appview -# -# Cache TTL for resolved AppView endpoints (default: 300 seconds) -# APPVIEW_CACHE_TTL_SECS=300 -# # Comma-separated list of relay URLs to notify via requestCrawl # CRAWLERS=https://bsky.network,https://relay.upcloud.world # ============================================================================= diff --git a/.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json b/.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json deleted file mode 100644 index 1f4d972..0000000 --- a/.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "key_bytes", - "type_info": "Bytea" - }, - { - "ordinal": 1, - "name": "encryption_version", - "type_info": "Int4" - } - ], - "parameters": { - "Left": [ - "Text" - ] - }, - "nullable": [ - false, - true - ] - }, - "hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b" -} diff --git a/.sqlx/query-4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456.json b/.sqlx/query-4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456.json deleted file mode 100644 index e71e6d2..0000000 --- a/.sqlx/query-4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle\n FROM records r\n JOIN repos rp ON r.repo_id = rp.user_id\n JOIN users u ON rp.user_id = u.id\n WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'\n ORDER BY r.created_at DESC\n LIMIT 50", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "record_cid", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "rkey", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "created_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 3, - "name": "did", - "type_info": "Text" - }, - { - "ordinal": 4, - "name": "handle", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "TextArray" - ] - }, - "nullable": [ - false, - false, - false, - false, - false - ] - }, - "hash": "4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456" -} diff --git a/.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json b/.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json deleted file mode 100644 index 4096224..0000000 --- a/.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "val", - "type_info": "Int4" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - }, - "nullable": [ - null - ] - }, - "hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288" -} diff --git a/.sqlx/query-94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f.json b/.sqlx/query-94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f.json deleted file mode 100644 index 5ac320b..0000000 --- a/.sqlx/query-94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "record_cid", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [ - false - ] - }, - "hash": "94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f" -} diff --git a/.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json b/.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json deleted file mode 100644 index ae031f0..0000000 --- a/.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "record_cid", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [ - false - ] - }, - "hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc" -} diff --git a/.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json b/.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json deleted file mode 100644 index e63f03c..0000000 --- a/.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "record_cid", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "collection", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "rkey", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "created_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 4, - "name": "repo_rev", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - }, - "nullable": [ - false, - false, - false, - false, - true - ] - }, - "hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e" -} diff --git a/TODO.md b/TODO.md index 3b969ab..f1d7d3f 100644 --- a/TODO.md +++ b/TODO.md @@ -38,19 +38,20 @@ Accounts controlled by other accounts rather than having their own password. Whe - [ ] Log all actions with both actor DID and controller DID - [ ] Audit log view for delegated account owners -### Passkey support -Modern passwordless authentication using WebAuthn/FIDO2, alongside or instead of passwords. +### Passkeys and 2FA +Modern passwordless authentication using WebAuthn/FIDO2, plus TOTP for defense in depth. - [ ] passkeys table (id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name) -- [ ] Generate WebAuthn registration challenge -- [ ] Verify attestation response and store credential -- [ ] UI for registering new passkey from settings -- [ ] Detect if account has passkeys during OAuth authorize -- [ ] Offer passkey option alongside password -- [ ] Generate authentication challenge and verify assertion -- [ ] Update sign count (replay protection) -- [ ] Allow creating account with passkey instead of password -- [ ] List/rename/remove passkeys in settings +- [ ] user_totp table (did, secret_encrypted, verified, created_at, last_used) +- [ ] WebAuthn registration challenge generation and attestation verification +- [ ] TOTP secret generation with QR code setup flow +- [ ] Backup codes (hashed, one-time use) with recovery flow +- [ ] OAuth authorize flow: password → 2FA (if enabled) → passkey (as alternative) +- [ ] Passkey-only account creation (no password) +- [ ] Settings UI for managing passkeys, TOTP, backup codes +- [ ] Trusted devices option (remember this browser) +- [ ] Rate limit 2FA attempts +- [ ] Re-auth for sensitive actions (email change, adding new auth methods) ### Private/encrypted data Records that only authorized parties can see and decrypt. Requires key federation between PDSes. @@ -65,6 +66,22 @@ Records that only authorized parties can see and decrypt. Requires key federatio - [ ] Protocol for sharing decryption keys between PDSes - [ ] Handle key rotation and revocation +### Plugin system +Extensible architecture allowing third-party plugins to add functionality, like minecraft mods or browser extensions. + +- [ ] Research: survey Fabric/Forge, VS Code, Grafana, Caddy plugin architectures +- [ ] Evaluate rust approaches: WASM, dynamic linking, subprocess IPC, embedded scripting (Lua/Rhai) +- [ ] Define security model (sandboxing, permissions, resource limits) +- [ ] Plugin manifest format (name, version, deps, permissions, hooks) +- [ ] Plugin discovery, loading, lifecycle (enable/disable/hot reload) +- [ ] Error isolation (bad plugin shouldn't crash PDS) +- [ ] Extension points: request middleware, record lifecycle hooks, custom XRPC endpoints +- [ ] Extension points: custom lexicons, storage backends, auth providers, notification channels +- [ ] Extension points: firehose consumers (react to repo events) +- [ ] Plugin SDK crate with traits and helpers +- [ ] Example plugins: custom feed algorithm, content filter, S3 backup +- [ ] Plugin registry with signature verification and version compatibility + --- ## Completed diff --git a/src/api/actor/mod.rs b/src/api/actor/mod.rs index 1002757..4854235 100644 --- a/src/api/actor/mod.rs +++ b/src/api/actor/mod.rs @@ -1,5 +1,3 @@ mod preferences; -mod profile; pub use preferences::{get_preferences, put_preferences}; -pub use profile::{get_profile, get_profiles}; diff --git a/src/api/actor/profile.rs b/src/api/actor/profile.rs deleted file mode 100644 index 60ba0c0..0000000 --- a/src/api/actor/profile.rs +++ /dev/null @@ -1,290 +0,0 @@ -use crate::api::proxy_client::proxy_client; -use crate::state::AppState; -use axum::{ - Json, - extract::{Query, RawQuery, State}, - http::StatusCode, - response::{IntoResponse, Response}, -}; -use jacquard_repo::storage::BlockStore; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; -use std::collections::HashMap; -use tracing::{error, info}; - -#[derive(Deserialize)] -pub struct GetProfileParams { - pub actor: String, -} - -#[derive(Serialize, Deserialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct ProfileViewDetailed { - pub did: String, - pub handle: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub display_name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub avatar: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub banner: Option, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Serialize, Deserialize)] -pub struct GetProfilesOutput { - pub profiles: Vec, -} - -async fn get_local_profile_record(state: &AppState, did: &str) -> Option { - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) - .fetch_optional(&state.db) - .await - .ok()??; - let record_row = sqlx::query!( - "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'", - user_id - ) - .fetch_optional(&state.db) - .await - .ok()??; - let cid: cid::Cid = record_row.record_cid.parse().ok()?; - let block_bytes = state.block_store.get(&cid).await.ok()??; - serde_ipld_dagcbor::from_slice(&block_bytes).ok() -} - -fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) { - if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) { - profile.display_name = Some(display_name.to_string()); - } - if let Some(description) = local_record.get("description").and_then(|v| v.as_str()) { - profile.description = Some(description.to_string()); - } -} - -async fn proxy_to_appview( - state: &AppState, - method: &str, - params: &HashMap, - auth_did: &str, - auth_key_bytes: Option<&[u8]>, -) -> Result<(StatusCode, Value), Response> { - let resolved = match state.appview_registry.get_appview_for_method(method).await { - Some(r) => r, - None => { - return Err(( - StatusCode::BAD_GATEWAY, - Json( - json!({"error": "UpstreamError", "message": "No upstream AppView configured"}), - ), - ) - .into_response()); - } - }; - let target_url = format!("{}/xrpc/{}", resolved.url, method); - info!("Proxying GET request to {}", target_url); - let client = proxy_client(); - let request_builder = client.get(&target_url).query(params); - proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await -} - -async fn proxy_to_appview_raw( - state: &AppState, - method: &str, - raw_query: Option<&str>, - auth_did: &str, - auth_key_bytes: Option<&[u8]>, -) -> Result<(StatusCode, Value), Response> { - let resolved = match state.appview_registry.get_appview_for_method(method).await { - Some(r) => r, - None => { - return Err(( - StatusCode::BAD_GATEWAY, - Json( - json!({"error": "UpstreamError", "message": "No upstream AppView configured"}), - ), - ) - .into_response()); - } - }; - let target_url = match raw_query { - Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), - None => format!("{}/xrpc/{}", resolved.url, method), - }; - info!("Proxying GET request to {}", target_url); - let client = proxy_client(); - let request_builder = client.get(&target_url); - proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await -} - -async fn proxy_request( - mut request_builder: reqwest::RequestBuilder, - auth_did: &str, - auth_key_bytes: Option<&[u8]>, - method: &str, - appview_did: &str, -) -> Result<(StatusCode, Value), Response> { - if let Some(key_bytes) = auth_key_bytes { - match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) { - Ok(service_token) => { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", service_token)); - } - Err(e) => { - error!("Failed to create service token: {:?}", e); - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError"})), - ) - .into_response()); - } - } - } - match request_builder.send().await { - Ok(resp) => { - let status = - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - match resp.json::().await { - Ok(body) => Ok((status, body)), - Err(e) => { - error!("Error parsing proxy response: {:?}", e); - Err(( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamError"})), - ) - .into_response()) - } - } - } - Err(e) => { - error!("Error sending proxy request: {:?}", e); - if e.is_timeout() { - Err(( - StatusCode::GATEWAY_TIMEOUT, - Json(json!({"error": "UpstreamTimeout"})), - ) - .into_response()) - } else { - Err(( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamError"})), - ) - .into_response()) - } - } - } -} - -pub async fn get_profile( - State(state): State, - headers: axum::http::HeaderMap, - Query(params): Query, -) -> Response { - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); - let auth_user = if let Some(h) = auth_header { - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { - crate::auth::validate_bearer_token(&state.db, &token) - .await - .ok() - } else { - None - } - } else { - None - }; - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); - let mut query_params = HashMap::new(); - query_params.insert("actor".to_string(), params.actor.clone()); - let (status, body) = match proxy_to_appview( - &state, - "app.bsky.actor.getProfile", - &query_params, - auth_did.as_deref().unwrap_or(""), - auth_key_bytes.as_deref(), - ) - .await - { - Ok(r) => r, - Err(e) => return e, - }; - if !status.is_success() { - return (status, Json(body)).into_response(); - } - let mut profile: ProfileViewDetailed = match serde_json::from_value(body) { - Ok(p) => p, - Err(_) => { - return ( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamError", "message": "Invalid profile response"})), - ) - .into_response(); - } - }; - if let Some(ref did) = auth_did - && profile.did == *did - && let Some(local_record) = get_local_profile_record(&state, did).await { - munge_profile_with_local(&mut profile, &local_record); - } - (StatusCode::OK, Json(profile)).into_response() -} - -pub async fn get_profiles( - State(state): State, - headers: axum::http::HeaderMap, - RawQuery(raw_query): RawQuery, -) -> Response { - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); - let auth_user = if let Some(h) = auth_header { - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { - crate::auth::validate_bearer_token(&state.db, &token) - .await - .ok() - } else { - None - } - } else { - None - }; - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); - let (status, body) = match proxy_to_appview_raw( - &state, - "app.bsky.actor.getProfiles", - raw_query.as_deref(), - auth_did.as_deref().unwrap_or(""), - auth_key_bytes.as_deref(), - ) - .await - { - Ok(r) => r, - Err(e) => return e, - }; - if !status.is_success() { - return (status, Json(body)).into_response(); - } - let mut output: GetProfilesOutput = match serde_json::from_value(body) { - Ok(p) => p, - Err(_) => { - return ( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"})), - ) - .into_response(); - } - }; - if let Some(ref did) = auth_did { - for profile in &mut output.profiles { - if profile.did == *did { - if let Some(local_record) = get_local_profile_record(&state, did).await { - munge_profile_with_local(profile, &local_record); - } - break; - } - } - } - (StatusCode::OK, Json(output)).into_response() -} diff --git a/src/api/feed/actor_likes.rs b/src/api/feed/actor_likes.rs deleted file mode 100644 index e234384..0000000 --- a/src/api/feed/actor_likes.rs +++ /dev/null @@ -1,158 +0,0 @@ -use crate::api::read_after_write::{ - FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, extract_repo_rev, - format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry, -}; -use crate::state::AppState; -use axum::{ - Json, - extract::{Query, State}, - http::StatusCode, - response::{IntoResponse, Response}, -}; -use serde::Deserialize; -use serde_json::Value; -use std::collections::HashMap; -use tracing::warn; - -#[derive(Deserialize)] -pub struct GetActorLikesParams { - pub actor: String, - pub limit: Option, - pub cursor: Option, -} - -fn insert_likes_into_feed(feed: &mut Vec, likes: &[RecordDescript]) { - for like in likes { - let like_time = &like.indexed_at.to_rfc3339(); - let idx = feed - .iter() - .position(|fi| &fi.post.indexed_at < like_time) - .unwrap_or(feed.len()); - let placeholder_post = PostView { - uri: like.record.subject.uri.clone(), - cid: like.record.subject.cid.clone(), - author: crate::api::read_after_write::AuthorView { - did: String::new(), - handle: String::new(), - display_name: None, - avatar: None, - extra: HashMap::new(), - }, - record: Value::Null, - indexed_at: like.indexed_at.to_rfc3339(), - embed: None, - reply_count: 0, - repost_count: 0, - like_count: 0, - quote_count: 0, - extra: HashMap::new(), - }; - feed.insert( - idx, - FeedViewPost { - post: placeholder_post, - reply: None, - reason: None, - feed_context: None, - extra: HashMap::new(), - }, - ); - } -} - -pub async fn get_actor_likes( - State(state): State, - headers: axum::http::HeaderMap, - Query(params): Query, -) -> Response { - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); - let auth_user = if let Some(h) = auth_header { - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { - crate::auth::validate_bearer_token(&state.db, &token) - .await - .ok() - } else { - None - } - } else { - None - }; - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); - let mut query_params = HashMap::new(); - query_params.insert("actor".to_string(), params.actor.clone()); - if let Some(limit) = params.limit { - query_params.insert("limit".to_string(), limit.to_string()); - } - if let Some(cursor) = ¶ms.cursor { - query_params.insert("cursor".to_string(), cursor.clone()); - } - let proxy_result = match proxy_to_appview_via_registry( - &state, - "app.bsky.feed.getActorLikes", - &query_params, - auth_did.as_deref().unwrap_or(""), - auth_key_bytes.as_deref(), - ) - .await - { - Ok(r) => r, - Err(e) => return e, - }; - if !proxy_result.status.is_success() { - return proxy_result.into_response(); - } - let rev = match extract_repo_rev(&proxy_result.headers) { - Some(r) => r, - None => return proxy_result.into_response(), - }; - let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { - Ok(f) => f, - Err(e) => { - warn!("Failed to parse actor likes response: {:?}", e); - return proxy_result.into_response(); - } - }; - let requester_did = match &auth_did { - Some(d) => d.clone(), - None => return (StatusCode::OK, Json(feed_output)).into_response(), - }; - let actor_did = if params.actor.starts_with("did:") { - params.actor.clone() - } else { - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); - let suffix = format!(".{}", hostname); - let short_handle = if params.actor.ends_with(&suffix) { - params.actor.strip_suffix(&suffix).unwrap_or(¶ms.actor) - } else { - ¶ms.actor - }; - match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle) - .fetch_optional(&state.db) - .await - { - Ok(Some(did)) => did, - Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(), - Err(e) => { - warn!("Database error resolving actor handle: {:?}", e); - return proxy_result.into_response(); - } - } - }; - if actor_did != requester_did { - return (StatusCode::OK, Json(feed_output)).into_response(); - } - let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { - Ok(r) => r, - Err(e) => { - warn!("Failed to get local records: {}", e); - return proxy_result.into_response(); - } - }; - if local_records.likes.is_empty() { - return (StatusCode::OK, Json(feed_output)).into_response(); - } - insert_likes_into_feed(&mut feed_output.feed, &local_records.likes); - let lag = get_local_lag(&local_records); - format_munged_response(feed_output, lag) -} diff --git a/src/api/feed/author_feed.rs b/src/api/feed/author_feed.rs deleted file mode 100644 index 2b99a05..0000000 --- a/src/api/feed/author_feed.rs +++ /dev/null @@ -1,160 +0,0 @@ -use crate::api::read_after_write::{ - FeedOutput, FeedViewPost, ProfileRecord, RecordDescript, extract_repo_rev, format_local_post, - format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed, - proxy_to_appview_via_registry, -}; -use crate::state::AppState; -use axum::{ - Json, - extract::{Query, State}, - http::StatusCode, - response::{IntoResponse, Response}, -}; -use serde::Deserialize; -use std::collections::HashMap; -use tracing::warn; - -#[derive(Deserialize)] -pub struct GetAuthorFeedParams { - pub actor: String, - pub limit: Option, - pub cursor: Option, - pub filter: Option, - #[serde(rename = "includePins")] - pub include_pins: Option, -} - -fn update_author_profile_in_feed( - feed: &mut [FeedViewPost], - author_did: &str, - local_profile: &RecordDescript, -) { - for item in feed.iter_mut() { - if item.post.author.did == author_did - && let Some(ref display_name) = local_profile.record.display_name { - item.post.author.display_name = Some(display_name.clone()); - } - } -} - -pub async fn get_author_feed( - State(state): State, - headers: axum::http::HeaderMap, - Query(params): Query, -) -> Response { - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); - let auth_user = if let Some(h) = auth_header { - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { - crate::auth::validate_bearer_token(&state.db, &token) - .await - .ok() - } else { - None - } - } else { - None - }; - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); - let mut query_params = HashMap::new(); - query_params.insert("actor".to_string(), params.actor.clone()); - if let Some(limit) = params.limit { - query_params.insert("limit".to_string(), limit.to_string()); - } - if let Some(cursor) = ¶ms.cursor { - query_params.insert("cursor".to_string(), cursor.clone()); - } - if let Some(filter) = ¶ms.filter { - query_params.insert("filter".to_string(), filter.clone()); - } - if let Some(include_pins) = params.include_pins { - query_params.insert("includePins".to_string(), include_pins.to_string()); - } - let proxy_result = match proxy_to_appview_via_registry( - &state, - "app.bsky.feed.getAuthorFeed", - &query_params, - auth_did.as_deref().unwrap_or(""), - auth_key_bytes.as_deref(), - ) - .await - { - Ok(r) => r, - Err(e) => return e, - }; - if !proxy_result.status.is_success() { - return proxy_result.into_response(); - } - let rev = match extract_repo_rev(&proxy_result.headers) { - Some(r) => r, - None => return proxy_result.into_response(), - }; - let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { - Ok(f) => f, - Err(e) => { - warn!("Failed to parse author feed response: {:?}", e); - return proxy_result.into_response(); - } - }; - let requester_did = match &auth_did { - Some(d) => d.clone(), - None => return (StatusCode::OK, Json(feed_output)).into_response(), - }; - let actor_did = if params.actor.starts_with("did:") { - params.actor.clone() - } else { - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); - let suffix = format!(".{}", hostname); - let short_handle = if params.actor.ends_with(&suffix) { - params.actor.strip_suffix(&suffix).unwrap_or(¶ms.actor) - } else { - ¶ms.actor - }; - match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle) - .fetch_optional(&state.db) - .await - { - Ok(Some(did)) => did, - Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(), - Err(e) => { - warn!("Database error resolving actor handle: {:?}", e); - return proxy_result.into_response(); - } - } - }; - if actor_did != requester_did { - return (StatusCode::OK, Json(feed_output)).into_response(); - } - let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { - Ok(r) => r, - Err(e) => { - warn!("Failed to get local records: {}", e); - return proxy_result.into_response(); - } - }; - if local_records.count == 0 { - return (StatusCode::OK, Json(feed_output)).into_response(); - } - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) - .fetch_optional(&state.db) - .await - { - Ok(Some(h)) => h, - Ok(None) => requester_did.clone(), - Err(e) => { - warn!("Database error fetching handle: {:?}", e); - requester_did.clone() - } - }; - if let Some(ref local_profile) = local_records.profile { - update_author_profile_in_feed(&mut feed_output.feed, &requester_did, local_profile); - } - let local_posts: Vec<_> = local_records - .posts - .iter() - .map(|p| format_local_post(p, &requester_did, &handle, local_records.profile.as_ref())) - .collect(); - insert_posts_into_feed(&mut feed_output.feed, local_posts); - let lag = get_local_lag(&local_records); - format_munged_response(feed_output, lag) -} diff --git a/src/api/feed/custom_feed.rs b/src/api/feed/custom_feed.rs deleted file mode 100644 index 6a02137..0000000 --- a/src/api/feed/custom_feed.rs +++ /dev/null @@ -1,131 +0,0 @@ -use crate::api::ApiError; -use crate::api::proxy_client::{ - MAX_RESPONSE_SIZE, is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, -}; -use crate::state::AppState; -use axum::{ - extract::{Query, State}, - http::StatusCode, - response::{IntoResponse, Response}, -}; -use serde::Deserialize; -use std::collections::HashMap; -use tracing::{error, info}; - -#[derive(Deserialize)] -pub struct GetFeedParams { - pub feed: String, - pub limit: Option, - pub cursor: Option, -} - -pub async fn get_feed( - State(state): State, - headers: axum::http::HeaderMap, - Query(params): Query, -) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; - if let Err(e) = validate_at_uri(¶ms.feed) { - return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response(); - } - let resolved = match state.appview_registry.get_appview_for_method("app.bsky.feed.getFeed").await { - Some(r) => r, - None => { - return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.feed.getFeed".to_string()) - .into_response(); - } - }; - if let Err(e) = is_ssrf_safe(&resolved.url) { - error!("SSRF check failed for appview URL: {}", e); - return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) - .into_response(); - } - let limit = validate_limit(params.limit, 50, 100); - let mut query_params = HashMap::new(); - query_params.insert("feed".to_string(), params.feed.clone()); - query_params.insert("limit".to_string(), limit.to_string()); - if let Some(cursor) = ¶ms.cursor { - query_params.insert("cursor".to_string(), cursor.clone()); - } - let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", resolved.url); - info!(target = %target_url, feed = %params.feed, "Proxying getFeed request"); - let client = proxy_client(); - let mut request_builder = client.get(&target_url).query(&query_params); - if let Some(key_bytes) = auth_user.key_bytes.as_ref() { - match crate::auth::create_service_token( - &auth_user.did, - &resolved.did, - "app.bsky.feed.getFeed", - key_bytes, - ) { - Ok(service_token) => { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", service_token)); - } - Err(e) => { - error!(error = ?e, "Failed to create service token for getFeed"); - return ApiError::InternalError.into_response(); - } - } - } - match request_builder.send().await { - Ok(resp) => { - let status = - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let content_length = resp.content_length().unwrap_or(0); - if content_length > MAX_RESPONSE_SIZE { - error!( - content_length, - max = MAX_RESPONSE_SIZE, - "getFeed response too large" - ); - return ApiError::UpstreamFailure.into_response(); - } - let resp_headers = resp.headers().clone(); - let body = match resp.bytes().await { - Ok(b) => { - if b.len() as u64 > MAX_RESPONSE_SIZE { - error!(len = b.len(), "getFeed response body exceeded limit"); - return ApiError::UpstreamFailure.into_response(); - } - b - } - Err(e) => { - error!(error = ?e, "Error reading getFeed response"); - return ApiError::UpstreamFailure.into_response(); - } - }; - let mut response_builder = axum::response::Response::builder().status(status); - if let Some(ct) = resp_headers.get("content-type") { - response_builder = response_builder.header("content-type", ct); - } - match response_builder.body(axum::body::Body::from(body)) { - Ok(r) => r, - Err(e) => { - error!(error = ?e, "Error building getFeed response"); - ApiError::UpstreamFailure.into_response() - } - } - } - Err(e) => { - error!(error = ?e, "Error proxying getFeed"); - if e.is_timeout() { - ApiError::UpstreamTimeout.into_response() - } else if e.is_connect() { - ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) - .into_response() - } else { - ApiError::UpstreamFailure.into_response() - } - } - } -} diff --git a/src/api/feed/mod.rs b/src/api/feed/mod.rs deleted file mode 100644 index ea4c0a2..0000000 --- a/src/api/feed/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod actor_likes; -mod author_feed; -mod custom_feed; -mod post_thread; -mod timeline; - -pub use actor_likes::get_actor_likes; -pub use author_feed::get_author_feed; -pub use custom_feed::get_feed; -pub use post_thread::get_post_thread; -pub use timeline::get_timeline; diff --git a/src/api/feed/post_thread.rs b/src/api/feed/post_thread.rs deleted file mode 100644 index a680d6a..0000000 --- a/src/api/feed/post_thread.rs +++ /dev/null @@ -1,315 +0,0 @@ -use crate::api::read_after_write::{ - PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post, - format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry, -}; -use crate::state::AppState; -use axum::{ - Json, - extract::{Query, State}, - http::StatusCode, - response::{IntoResponse, Response}, -}; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; -use std::collections::HashMap; -use tracing::warn; - -#[derive(Deserialize)] -pub struct GetPostThreadParams { - pub uri: String, - pub depth: Option, - #[serde(rename = "parentHeight")] - pub parent_height: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ThreadViewPost { - #[serde(rename = "$type")] - pub thread_type: Option, - pub post: PostView, - #[serde(skip_serializing_if = "Option::is_none")] - pub parent: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub replies: Option>, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ThreadNode { - Post(Box), - NotFound(ThreadNotFound), - Blocked(ThreadBlocked), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ThreadNotFound { - #[serde(rename = "$type")] - pub thread_type: String, - pub uri: String, - pub not_found: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ThreadBlocked { - #[serde(rename = "$type")] - pub thread_type: String, - pub uri: String, - pub blocked: bool, - pub author: Value, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PostThreadOutput { - pub thread: ThreadNode, - #[serde(skip_serializing_if = "Option::is_none")] - pub threadgate: Option, -} - -const MAX_THREAD_DEPTH: usize = 10; - -fn add_replies_to_thread( - thread: &mut ThreadViewPost, - local_posts: &[RecordDescript], - author_did: &str, - author_handle: &str, - depth: usize, -) { - if depth >= MAX_THREAD_DEPTH { - return; - } - let thread_uri = &thread.post.uri; - let replies: Vec<_> = local_posts - .iter() - .filter(|p| { - p.record - .reply - .as_ref() - .and_then(|r| r.get("parent")) - .and_then(|parent| parent.get("uri")) - .and_then(|u| u.as_str()) - == Some(thread_uri) - }) - .map(|p| { - let post_view = format_local_post(p, author_did, author_handle, None); - ThreadNode::Post(Box::new(ThreadViewPost { - thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), - post: post_view, - parent: None, - replies: None, - extra: HashMap::new(), - })) - }) - .collect(); - if !replies.is_empty() { - match &mut thread.replies { - Some(existing) => existing.extend(replies), - None => thread.replies = Some(replies), - } - } - if let Some(ref mut existing_replies) = thread.replies { - for reply in existing_replies.iter_mut() { - if let ThreadNode::Post(reply_thread) = reply { - add_replies_to_thread( - reply_thread, - local_posts, - author_did, - author_handle, - depth + 1, - ); - } - } - } -} - -pub async fn get_post_thread( - State(state): State, - headers: axum::http::HeaderMap, - Query(params): Query, -) -> Response { - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); - let auth_user = if let Some(h) = auth_header { - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { - crate::auth::validate_bearer_token(&state.db, &token) - .await - .ok() - } else { - None - } - } else { - None - }; - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); - let mut query_params = HashMap::new(); - query_params.insert("uri".to_string(), params.uri.clone()); - if let Some(depth) = params.depth { - query_params.insert("depth".to_string(), depth.to_string()); - } - if let Some(parent_height) = params.parent_height { - query_params.insert("parentHeight".to_string(), parent_height.to_string()); - } - let proxy_result = match proxy_to_appview_via_registry( - &state, - "app.bsky.feed.getPostThread", - &query_params, - auth_did.as_deref().unwrap_or(""), - auth_key_bytes.as_deref(), - ) - .await - { - Ok(r) => r, - Err(e) => return e, - }; - if proxy_result.status == StatusCode::NOT_FOUND { - return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await; - } - if !proxy_result.status.is_success() { - return proxy_result.into_response(); - } - let rev = match extract_repo_rev(&proxy_result.headers) { - Some(r) => r, - None => return proxy_result.into_response(), - }; - let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) { - Ok(t) => t, - Err(e) => { - warn!("Failed to parse post thread response: {:?}", e); - return proxy_result.into_response(); - } - }; - let requester_did = match auth_did { - Some(d) => d, - None => return (StatusCode::OK, Json(thread_output)).into_response(), - }; - let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { - Ok(r) => r, - Err(e) => { - warn!("Failed to get local records: {}", e); - return proxy_result.into_response(); - } - }; - if local_records.posts.is_empty() { - return (StatusCode::OK, Json(thread_output)).into_response(); - } - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) - .fetch_optional(&state.db) - .await - { - Ok(Some(h)) => h, - Ok(None) => requester_did.clone(), - Err(e) => { - warn!("Database error fetching handle: {:?}", e); - requester_did.clone() - } - }; - if let ThreadNode::Post(ref mut thread_post) = thread_output.thread { - add_replies_to_thread( - thread_post, - &local_records.posts, - &requester_did, - &handle, - 0, - ); - } - let lag = get_local_lag(&local_records); - format_munged_response(thread_output, lag) -} - -async fn handle_not_found( - state: &AppState, - uri: &str, - auth_did: Option, - headers: &axum::http::HeaderMap, -) -> Response { - let rev = match extract_repo_rev(headers) { - Some(r) => r, - None => { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Post not found"})), - ) - .into_response(); - } - }; - let requester_did = match auth_did { - Some(d) => d, - None => { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Post not found"})), - ) - .into_response(); - } - }; - let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect(); - if uri_parts.len() != 3 { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Post not found"})), - ) - .into_response(); - } - let post_did = uri_parts[0]; - if post_did != requester_did { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Post not found"})), - ) - .into_response(); - } - let local_records = match get_records_since_rev(state, &requester_did, &rev).await { - Ok(r) => r, - Err(_) => { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Post not found"})), - ) - .into_response(); - } - }; - let local_post = local_records.posts.iter().find(|p| p.uri == uri); - let local_post = match local_post { - Some(p) => p, - None => { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Post not found"})), - ) - .into_response(); - } - }; - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) - .fetch_optional(&state.db) - .await - { - Ok(Some(h)) => h, - Ok(None) => requester_did.clone(), - Err(e) => { - warn!("Database error fetching handle: {:?}", e); - requester_did.clone() - } - }; - let post_view = format_local_post( - local_post, - &requester_did, - &handle, - local_records.profile.as_ref(), - ); - let thread = PostThreadOutput { - thread: ThreadNode::Post(Box::new(ThreadViewPost { - thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), - post: post_view, - parent: None, - replies: None, - extra: HashMap::new(), - })), - threadgate: None, - }; - let lag = get_local_lag(&local_records); - format_munged_response(thread, lag) -} diff --git a/src/api/feed/timeline.rs b/src/api/feed/timeline.rs deleted file mode 100644 index 0acc1b0..0000000 --- a/src/api/feed/timeline.rs +++ /dev/null @@ -1,275 +0,0 @@ -use crate::api::read_after_write::{ - FeedOutput, FeedViewPost, PostView, extract_repo_rev, format_local_post, - format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed, - proxy_to_appview_via_registry, -}; -use crate::state::AppState; -use axum::{ - Json, - extract::{Query, State}, - http::StatusCode, - response::{IntoResponse, Response}, -}; -use jacquard_repo::storage::BlockStore; -use serde::Deserialize; -use serde_json::{Value, json}; -use std::collections::HashMap; -use tracing::warn; - -#[derive(Deserialize)] -pub struct GetTimelineParams { - pub algorithm: Option, - pub limit: Option, - pub cursor: Option, -} - -pub async fn get_timeline( - State(state): State, - headers: axum::http::HeaderMap, - Query(params): Query, -) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } - }; - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { - Ok(user) => user, - Err(_) => { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationFailed"})), - ) - .into_response(); - } - }; - if state.appview_registry.get_appview_for_method("app.bsky.feed.getTimeline").await.is_some() { - return get_timeline_with_appview( - &state, - ¶ms, - &auth_user.did, - auth_user.key_bytes.as_deref(), - ) - .await; - } - get_timeline_local_only(&state, &auth_user.did).await -} - -async fn get_timeline_with_appview( - state: &AppState, - params: &GetTimelineParams, - auth_did: &str, - auth_key_bytes: Option<&[u8]>, -) -> Response { - let mut query_params = HashMap::new(); - if let Some(algo) = ¶ms.algorithm { - query_params.insert("algorithm".to_string(), algo.clone()); - } - if let Some(limit) = params.limit { - query_params.insert("limit".to_string(), limit.to_string()); - } - if let Some(cursor) = ¶ms.cursor { - query_params.insert("cursor".to_string(), cursor.clone()); - } - let proxy_result = match proxy_to_appview_via_registry( - state, - "app.bsky.feed.getTimeline", - &query_params, - auth_did, - auth_key_bytes, - ) - .await - { - Ok(r) => r, - Err(e) => return e, - }; - if !proxy_result.status.is_success() { - return proxy_result.into_response(); - } - let rev = extract_repo_rev(&proxy_result.headers); - if rev.is_none() { - return proxy_result.into_response(); - } - let rev = rev.unwrap(); - let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { - Ok(f) => f, - Err(e) => { - warn!("Failed to parse timeline response: {:?}", e); - return proxy_result.into_response(); - } - }; - let local_records = match get_records_since_rev(state, auth_did, &rev).await { - Ok(r) => r, - Err(e) => { - warn!("Failed to get local records: {}", e); - return proxy_result.into_response(); - } - }; - if local_records.count == 0 { - return proxy_result.into_response(); - } - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did) - .fetch_optional(&state.db) - .await - { - Ok(Some(h)) => h, - Ok(None) => auth_did.to_string(), - Err(e) => { - warn!("Database error fetching handle: {:?}", e); - auth_did.to_string() - } - }; - let local_posts: Vec<_> = local_records - .posts - .iter() - .map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref())) - .collect(); - insert_posts_into_feed(&mut feed_output.feed, local_posts); - let lag = get_local_lag(&local_records); - format_munged_response(feed_output, lag) -} - -async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response { - let user_id: uuid::Uuid = - match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_did) - .fetch_optional(&state.db) - .await - { - Ok(Some(id)) => id, - Ok(None) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError", "message": "User not found"})), - ) - .into_response(); - } - Err(e) => { - warn!("Database error fetching user: {:?}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError", "message": "Database error"})), - ) - .into_response(); - } - }; - let follows_query = sqlx::query!( - "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000", - user_id - ) - .fetch_all(&state.db) - .await; - let follow_cids: Vec = match follows_query { - Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(), - Err(_) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError"})), - ) - .into_response(); - } - }; - let mut followed_dids: Vec = Vec::new(); - for cid_str in follow_cids { - let cid = match cid_str.parse::() { - Ok(c) => c, - Err(_) => continue, - }; - let block_bytes = match state.block_store.get(&cid).await { - Ok(Some(b)) => b, - _ => continue, - }; - let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) { - Ok(v) => v, - Err(_) => continue, - }; - if let Some(subject) = record.get("subject").and_then(|s| s.as_str()) { - followed_dids.push(subject.to_string()); - } - } - if followed_dids.is_empty() { - return ( - StatusCode::OK, - Json(FeedOutput { - feed: vec![], - cursor: None, - }), - ) - .into_response(); - } - let posts_result = sqlx::query!( - "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle - FROM records r - JOIN repos rp ON r.repo_id = rp.user_id - JOIN users u ON rp.user_id = u.id - WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post' - ORDER BY r.created_at DESC - LIMIT 50", - &followed_dids - ) - .fetch_all(&state.db) - .await; - let posts = match posts_result { - Ok(rows) => rows, - Err(_) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": "InternalError"})), - ) - .into_response(); - } - }; - let mut feed: Vec = Vec::new(); - for row in posts { - let record_cid: String = row.record_cid; - let rkey: String = row.rkey; - let created_at: chrono::DateTime = row.created_at; - let author_did: String = row.did; - let author_handle: String = row.handle; - let cid = match record_cid.parse::() { - Ok(c) => c, - Err(_) => continue, - }; - let block_bytes = match state.block_store.get(&cid).await { - Ok(Some(b)) => b, - _ => continue, - }; - let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) { - Ok(v) => v, - Err(_) => continue, - }; - let uri = format!("at://{}/app.bsky.feed.post/{}", author_did, rkey); - feed.push(FeedViewPost { - post: PostView { - uri, - cid: record_cid, - author: crate::api::read_after_write::AuthorView { - did: author_did, - handle: author_handle, - display_name: None, - avatar: None, - extra: HashMap::new(), - }, - record, - indexed_at: created_at.to_rfc3339(), - embed: None, - reply_count: 0, - repost_count: 0, - like_count: 0, - quote_count: 0, - extra: HashMap::new(), - }, - reply: None, - reason: None, - feed_context: None, - extra: HashMap::new(), - }); - } - (StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response() -} diff --git a/src/api/mod.rs b/src/api/mod.rs index 190d7a6..9d8e756 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,14 +1,11 @@ pub mod actor; pub mod admin; pub mod error; -pub mod feed; pub mod identity; pub mod moderation; -pub mod notification; pub mod notification_prefs; pub mod proxy; pub mod proxy_client; -pub mod read_after_write; pub mod repo; pub mod server; pub mod temp; diff --git a/src/api/notification/mod.rs b/src/api/notification/mod.rs deleted file mode 100644 index 0d60b0a..0000000 --- a/src/api/notification/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod register_push; - -pub use register_push::register_push; diff --git a/src/api/notification/register_push.rs b/src/api/notification/register_push.rs deleted file mode 100644 index 1c8b3d4..0000000 --- a/src/api/notification/register_push.rs +++ /dev/null @@ -1,153 +0,0 @@ -use crate::api::ApiError; -use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did}; -use crate::state::AppState; -use axum::{ - Json, - extract::State, - http::{HeaderMap, StatusCode}, - response::{IntoResponse, Response}, -}; -use serde::Deserialize; -use serde_json::json; -use tracing::{error, info}; - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RegisterPushInput { - pub service_did: String, - pub token: String, - pub platform: String, - pub app_id: String, -} - -const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"]; - -pub async fn register_push( - State(state): State, - headers: HeaderMap, - Json(input): Json, -) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; - if let Err(e) = validate_did(&input.service_did) { - return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response(); - } - if input.token.is_empty() || input.token.len() > 4096 { - return ApiError::InvalidRequest("Invalid push token".to_string()).into_response(); - } - if !VALID_PLATFORMS.contains(&input.platform.as_str()) { - return ApiError::InvalidRequest(format!( - "Invalid platform. Must be one of: {}", - VALID_PLATFORMS.join(", ") - )) - .into_response(); - } - if input.app_id.is_empty() || input.app_id.len() > 256 { - return ApiError::InvalidRequest("Invalid appId".to_string()).into_response(); - } - let resolved = match state.appview_registry.get_appview_for_method("app.bsky.notification.registerPush").await { - Some(r) => r, - None => { - return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.notification.registerPush".to_string()) - .into_response(); - } - }; - if let Err(e) = is_ssrf_safe(&resolved.url) { - error!("SSRF check failed for appview URL: {}", e); - return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) - .into_response(); - } - let key_row = match sqlx::query!( - "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", - auth_user.did - ) - .fetch_optional(&state.db) - .await - { - Ok(Some(row)) => row, - Ok(None) => { - error!(did = %auth_user.did, "No signing key found for user"); - return ApiError::InternalError.into_response(); - } - Err(e) => { - error!(error = ?e, "Database error fetching signing key"); - return ApiError::DatabaseError.into_response(); - } - }; - let decrypted_key = - match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) { - Ok(k) => k, - Err(e) => { - error!(error = ?e, "Failed to decrypt signing key"); - return ApiError::InternalError.into_response(); - } - }; - let service_token = match crate::auth::create_service_token( - &auth_user.did, - &input.service_did, - "app.bsky.notification.registerPush", - &decrypted_key, - ) { - Ok(t) => t, - Err(e) => { - error!(error = ?e, "Failed to create service token"); - return ApiError::InternalError.into_response(); - } - }; - let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", resolved.url); - info!( - target = %target_url, - service_did = %input.service_did, - platform = %input.platform, - "Proxying registerPush request" - ); - let client = proxy_client(); - let request_body = json!({ - "serviceDid": input.service_did, - "token": input.token, - "platform": input.platform, - "appId": input.app_id - }); - match client - .post(&target_url) - .header("Authorization", format!("Bearer {}", service_token)) - .header("Content-Type", "application/json") - .json(&request_body) - .send() - .await - { - Ok(resp) => { - let status = - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - if status.is_success() { - StatusCode::OK.into_response() - } else { - let body = resp.bytes().await.unwrap_or_default(); - error!( - status = %status, - "registerPush upstream error" - ); - ApiError::from_upstream_response(status.as_u16(), &body).into_response() - } - } - Err(e) => { - error!(error = ?e, "Error proxying registerPush"); - if e.is_timeout() { - ApiError::UpstreamTimeout.into_response() - } else if e.is_connect() { - ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) - .into_response() - } else { - ApiError::UpstreamFailure.into_response() - } - } - } -} diff --git a/src/api/proxy.rs b/src/api/proxy.rs index 8b5a798..31725d6 100644 --- a/src/api/proxy.rs +++ b/src/api/proxy.rs @@ -1,11 +1,13 @@ use crate::api::proxy_client::proxy_client; use crate::state::AppState; use axum::{ + Json, body::Bytes, extract::{Path, RawQuery, State}, http::{HeaderMap, Method, StatusCode}, response::{IntoResponse, Response}, }; +use serde_json::json; use tracing::{error, info, warn}; pub async fn proxy_handler( @@ -16,65 +18,80 @@ pub async fn proxy_handler( RawQuery(query): RawQuery, body: Bytes, ) -> Response { - let proxy_header = headers + let proxy_header = match headers .get("atproto-proxy") .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()); - let (appview_url, service_aud) = match &proxy_header { - Some(did_str) => { - let did_without_fragment = did_str.split('#').next().unwrap_or(did_str).to_string(); - match state.appview_registry.resolve_appview_did(&did_without_fragment).await { - Some(resolved) => (resolved.url, Some(resolved.did)), - None => { - error!(did = %did_str, "Could not resolve service DID"); - return (StatusCode::BAD_GATEWAY, "Could not resolve service DID") - .into_response(); - } - } - } + { + Some(h) => h.to_string(), None => { - match state.appview_registry.get_appview_for_method(&method).await { - Some(resolved) => (resolved.url, Some(resolved.did)), - None => { - return (StatusCode::BAD_GATEWAY, "No upstream AppView configured for this method") - .into_response(); - } - } + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Missing required atproto-proxy header" + })), + ) + .into_response(); } }; + + let did = proxy_header.split('#').next().unwrap_or(&proxy_header); + let resolved = match state.did_resolver.resolve_did(did).await { + Some(r) => r, + None => { + error!(did = %did, "Could not resolve service DID"); + return ( + StatusCode::BAD_GATEWAY, + Json(json!({ + "error": "UpstreamFailure", + "message": "Could not resolve service DID" + })), + ) + .into_response(); + } + }; + let target_url = match &query { - Some(q) => format!("{}/xrpc/{}?{}", appview_url, method, q), - None => format!("{}/xrpc/{}", appview_url, method), + Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), + None => format!("{}/xrpc/{}", resolved.url, method), }; info!("Proxying {} request to {}", method_verb, target_url); + let client = proxy_client(); let mut request_builder = client.request(method_verb, &target_url); + let mut auth_header_val = headers.get("Authorization").cloned(); - if let Some(aud) = &service_aud { - if let Some(token) = crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - match crate::auth::validate_bearer_token(&state.db, &token).await { - Ok(auth_user) => { - if let Some(key_bytes) = auth_user.key_bytes { - match crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes) { - Ok(new_token) => { - if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) { - auth_header_val = Some(val); - } - } - Err(e) => { - warn!("Failed to create service token: {:?}", e); + if let Some(token) = crate::auth::extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(auth_user) => { + if let Some(key_bytes) = auth_user.key_bytes { + match crate::auth::create_service_token( + &auth_user.did, + &resolved.did, + &method, + &key_bytes, + ) { + Ok(new_token) => { + if let Ok(val) = + axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) + { + auth_header_val = Some(val); } } + Err(e) => { + warn!("Failed to create service token: {:?}", e); + } } } - Err(e) => { - warn!("Token validation failed: {:?}", e); - } + } + Err(e) => { + warn!("Token validation failed: {:?}", e); } } } + if let Some(val) = auth_header_val { request_builder = request_builder.header("Authorization", val); } @@ -86,6 +103,7 @@ pub async fn proxy_handler( if !body.is_empty() { request_builder = request_builder.body(body); } + match request_builder.send().await { Ok(resp) => { let status = resp.status(); diff --git a/src/api/read_after_write.rs b/src/api/read_after_write.rs deleted file mode 100644 index 4970a7d..0000000 --- a/src/api/read_after_write.rs +++ /dev/null @@ -1,456 +0,0 @@ -use crate::api::ApiError; -use crate::api::proxy_client::{ - MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, is_ssrf_safe, proxy_client, -}; -use crate::state::AppState; -use axum::{ - Json, - http::{HeaderMap, HeaderValue, StatusCode}, - response::{IntoResponse, Response}, -}; -use bytes::Bytes; -use chrono::{DateTime, Utc}; -use cid::Cid; -use jacquard_repo::storage::BlockStore; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::collections::HashMap; -use tracing::{error, info, warn}; -use uuid::Uuid; - -pub const REPO_REV_HEADER: &str = "atproto-repo-rev"; -pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct PostRecord { - #[serde(rename = "$type")] - pub record_type: Option, - pub text: String, - pub created_at: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub reply: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub embed: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub langs: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub labels: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tags: Option>, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ProfileRecord { - #[serde(rename = "$type")] - pub record_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub display_name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub avatar: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub banner: Option, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Debug, Clone)] -pub struct RecordDescript { - pub uri: String, - pub cid: String, - pub indexed_at: DateTime, - pub record: T, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LikeRecord { - #[serde(rename = "$type")] - pub record_type: Option, - pub subject: LikeSubject, - pub created_at: String, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LikeSubject { - pub uri: String, - pub cid: String, -} - -#[derive(Debug, Default)] -pub struct LocalRecords { - pub count: usize, - pub profile: Option>, - pub posts: Vec>, - pub likes: Vec>, -} - -pub async fn get_records_since_rev( - state: &AppState, - did: &str, - rev: &str, -) -> Result { - let mut result = LocalRecords::default(); - let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) - .fetch_optional(&state.db) - .await - .map_err(|e| format!("DB error: {}", e))? - .ok_or_else(|| "User not found".to_string())?; - let rows = sqlx::query!( - r#" - SELECT record_cid, collection, rkey, created_at, repo_rev - FROM records - WHERE repo_id = $1 AND repo_rev > $2 - ORDER BY repo_rev ASC - LIMIT 10 - "#, - user_id, - rev - ) - .fetch_all(&state.db) - .await - .map_err(|e| format!("DB error fetching records: {}", e))?; - if rows.is_empty() { - return Ok(result); - } - let sanity_check = sqlx::query_scalar!( - "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1", - user_id, - rev - ) - .fetch_optional(&state.db) - .await - .map_err(|e| format!("DB error sanity check: {}", e))?; - if sanity_check.is_none() { - warn!("Sanity check failed: no records found before rev {}", rev); - return Ok(result); - } - struct RowData { - cid_str: String, - collection: String, - rkey: String, - created_at: DateTime, - } - let mut row_data: Vec = Vec::with_capacity(rows.len()); - let mut cids: Vec = Vec::with_capacity(rows.len()); - for row in &rows { - if let Ok(cid) = row.record_cid.parse::() { - cids.push(cid); - row_data.push(RowData { - cid_str: row.record_cid.clone(), - collection: row.collection.clone(), - rkey: row.rkey.clone(), - created_at: row.created_at, - }); - } - } - let blocks: Vec> = state - .block_store - .get_many(&cids) - .await - .map_err(|e| format!("Error fetching blocks: {}", e))?; - for (data, block_opt) in row_data.into_iter().zip(blocks.into_iter()) { - let block_bytes = match block_opt { - Some(b) => b, - None => continue, - }; - result.count += 1; - let uri = format!("at://{}/{}/{}", did, data.collection, data.rkey); - if data.collection == "app.bsky.actor.profile" && data.rkey == "self" { - if let Ok(record) = serde_ipld_dagcbor::from_slice::(&block_bytes) { - result.profile = Some(RecordDescript { - uri, - cid: data.cid_str, - indexed_at: data.created_at, - record, - }); - } - } else if data.collection == "app.bsky.feed.post" { - if let Ok(record) = serde_ipld_dagcbor::from_slice::(&block_bytes) { - result.posts.push(RecordDescript { - uri, - cid: data.cid_str, - indexed_at: data.created_at, - record, - }); - } - } else if data.collection == "app.bsky.feed.like" - && let Ok(record) = serde_ipld_dagcbor::from_slice::(&block_bytes) { - result.likes.push(RecordDescript { - uri, - cid: data.cid_str, - indexed_at: data.created_at, - record, - }); - } - } - Ok(result) -} - -pub fn get_local_lag(local: &LocalRecords) -> Option { - let mut oldest: Option> = local.profile.as_ref().map(|p| p.indexed_at); - for post in &local.posts { - match oldest { - None => oldest = Some(post.indexed_at), - Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at), - _ => {} - } - } - for like in &local.likes { - match oldest { - None => oldest = Some(like.indexed_at), - Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at), - _ => {} - } - } - oldest.map(|o| (Utc::now() - o).num_milliseconds()) -} - -pub fn extract_repo_rev(headers: &HeaderMap) -> Option { - headers - .get(REPO_REV_HEADER) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()) -} - -#[derive(Debug)] -pub struct ProxyResponse { - pub status: StatusCode, - pub headers: HeaderMap, - pub body: bytes::Bytes, -} - -impl ProxyResponse { - pub fn into_response(self) -> Response { - let mut response = Response::builder().status(self.status); - for (key, value) in self.headers.iter() { - response = response.header(key, value); - } - response.body(axum::body::Body::from(self.body)).unwrap() - } -} - -pub async fn proxy_to_appview_via_registry( - state: &AppState, - method: &str, - params: &HashMap, - auth_did: &str, - auth_key_bytes: Option<&[u8]>, -) -> Result { - let resolved = state.appview_registry.get_appview_for_method(method).await.ok_or_else(|| { - ApiError::UpstreamUnavailable(format!("No AppView configured for method: {}", method)).into_response() - })?; - proxy_to_appview_with_url(method, params, auth_did, auth_key_bytes, &resolved.url, &resolved.did).await -} - -pub async fn proxy_to_appview_with_url( - method: &str, - params: &HashMap, - auth_did: &str, - auth_key_bytes: Option<&[u8]>, - appview_url: &str, - appview_did: &str, -) -> Result { - if let Err(e) = is_ssrf_safe(appview_url) { - error!("SSRF check failed for appview URL: {}", e); - return Err( - ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)).into_response(), - ); - } - let target_url = format!("{}/xrpc/{}", appview_url, method); - info!(target = %target_url, "Proxying request to appview"); - let client = proxy_client(); - let mut request_builder = client.get(&target_url).query(params); - if let Some(key_bytes) = auth_key_bytes { - match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) { - Ok(service_token) => { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", service_token)); - } - Err(e) => { - error!(error = ?e, "Failed to create service token"); - return Err(ApiError::InternalError.into_response()); - } - } - } - match request_builder.send().await { - Ok(resp) => { - let status = - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let headers: HeaderMap = resp - .headers() - .iter() - .filter(|(k, _)| { - RESPONSE_HEADERS_TO_FORWARD - .iter() - .any(|h| k.as_str().eq_ignore_ascii_case(h)) - }) - .filter_map(|(k, v)| { - let name = axum::http::HeaderName::try_from(k.as_str()).ok()?; - let value = HeaderValue::from_bytes(v.as_bytes()).ok()?; - Some((name, value)) - }) - .collect(); - let content_length = resp.content_length().unwrap_or(0); - if content_length > MAX_RESPONSE_SIZE { - error!( - content_length, - max = MAX_RESPONSE_SIZE, - "Upstream response too large" - ); - return Err(ApiError::UpstreamFailure.into_response()); - } - let body = resp.bytes().await.map_err(|e| { - error!(error = ?e, "Error reading proxy response body"); - ApiError::UpstreamFailure.into_response() - })?; - if body.len() as u64 > MAX_RESPONSE_SIZE { - error!( - len = body.len(), - max = MAX_RESPONSE_SIZE, - "Upstream response body exceeded size limit" - ); - return Err(ApiError::UpstreamFailure.into_response()); - } - Ok(ProxyResponse { - status, - headers, - body, - }) - } - Err(e) => { - error!(error = ?e, "Error sending proxy request"); - if e.is_timeout() { - Err(ApiError::UpstreamTimeout.into_response()) - } else if e.is_connect() { - Err( - ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) - .into_response(), - ) - } else { - Err(ApiError::UpstreamFailure.into_response()) - } - } - } -} - -pub fn format_munged_response(data: T, lag: Option) -> Response { - let mut response = (StatusCode::OK, Json(data)).into_response(); - if let Some(lag_ms) = lag - && let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) { - response - .headers_mut() - .insert(UPSTREAM_LAG_HEADER, header_val); - } - response -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct AuthorView { - pub did: String, - pub handle: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub display_name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub avatar: Option, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct PostView { - pub uri: String, - pub cid: String, - pub author: AuthorView, - pub record: Value, - pub indexed_at: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub embed: Option, - #[serde(default)] - pub reply_count: i64, - #[serde(default)] - pub repost_count: i64, - #[serde(default)] - pub like_count: i64, - #[serde(default)] - pub quote_count: i64, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FeedViewPost { - pub post: PostView, - #[serde(skip_serializing_if = "Option::is_none")] - pub reply: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub reason: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub feed_context: Option, - #[serde(flatten)] - pub extra: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FeedOutput { - pub feed: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub cursor: Option, -} - -pub fn format_local_post( - descript: &RecordDescript, - author_did: &str, - author_handle: &str, - profile: Option<&RecordDescript>, -) -> PostView { - let display_name = profile.and_then(|p| p.record.display_name.clone()); - PostView { - uri: descript.uri.clone(), - cid: descript.cid.clone(), - author: AuthorView { - did: author_did.to_string(), - handle: author_handle.to_string(), - display_name, - avatar: None, - extra: HashMap::new(), - }, - record: serde_json::to_value(&descript.record).unwrap_or(Value::Null), - indexed_at: descript.indexed_at.to_rfc3339(), - embed: descript.record.embed.clone(), - reply_count: 0, - repost_count: 0, - like_count: 0, - quote_count: 0, - extra: HashMap::new(), - } -} - -pub fn insert_posts_into_feed(feed: &mut Vec, posts: Vec) { - if posts.is_empty() { - return; - } - let new_items: Vec = posts - .into_iter() - .map(|post| FeedViewPost { - post, - reply: None, - reason: None, - feed_context: None, - extra: HashMap::new(), - }) - .collect(); - feed.extend(new_items); - feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at)); -} diff --git a/src/api/repo/meta.rs b/src/api/repo/meta.rs index 56b11b7..0f5c253 100644 --- a/src/api/repo/meta.rs +++ b/src/api/repo/meta.rs @@ -1,75 +1,21 @@ -use crate::api::proxy_client::proxy_client; use crate::state::AppState; use axum::{ Json, - extract::{Query, RawQuery, State}, + extract::{Query, State}, http::StatusCode, response::{IntoResponse, Response}, }; use serde::Deserialize; use serde_json::json; -use tracing::{error, info}; #[derive(Deserialize)] pub struct DescribeRepoInput { pub repo: String, } -async fn proxy_describe_repo_to_appview(state: &AppState, raw_query: Option<&str>) -> Response { - let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.describeRepo").await { - Some(r) => r, - None => { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Repo not found"})), - ) - .into_response(); - } - }; - let target_url = match raw_query { - Some(q) => format!("{}/xrpc/com.atproto.repo.describeRepo?{}", resolved.url, q), - None => format!("{}/xrpc/com.atproto.repo.describeRepo", resolved.url), - }; - info!("Proxying describeRepo to AppView: {}", target_url); - let client = proxy_client(); - match client.get(&target_url).send().await { - Ok(resp) => { - let status = - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let content_type = resp - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - match resp.bytes().await { - Ok(body) => { - let mut builder = Response::builder().status(status); - if let Some(ct) = content_type { - builder = builder.header("content-type", ct); - } - builder - .body(axum::body::Body::from(body)) - .unwrap_or_else(|_| { - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() - }) - } - Err(e) => { - error!("Error reading AppView response: {:?}", e); - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() - } - } - } - Err(e) => { - error!("Error proxying to AppView: {:?}", e); - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() - } - } -} - pub async fn describe_repo( State(state): State, Query(input): Query, - RawQuery(raw_query): RawQuery, ) -> Response { let user_row = if input.repo.starts_with("did:") { sqlx::query!( @@ -90,8 +36,19 @@ pub async fn describe_repo( }; let (user_id, handle, did) = match user_row { Ok(Some((id, handle, did))) => (id, handle, did), - _ => { - return proxy_describe_repo_to_appview(&state, raw_query.as_deref()).await; + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), + ) + .into_response(); + } + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } }; let collections_query = sqlx::query!( diff --git a/src/api/repo/record/delete.rs b/src/api/repo/record/delete.rs index 0ad9ee2..63cf961 100644 --- a/src/api/repo/record/delete.rs +++ b/src/api/repo/record/delete.rs @@ -31,10 +31,11 @@ pub struct DeleteRecordInput { pub async fn delete_record( State(state): State, headers: HeaderMap, + axum::extract::OriginalUri(uri): axum::extract::OriginalUri, Json(input): Json, ) -> Response { let (did, user_id, current_root_cid) = - match prepare_repo_write(&state, &headers, &input.repo).await { + match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { Ok(res) => res, Err(err_res) => return err_res, }; diff --git a/src/api/repo/record/read.rs b/src/api/repo/record/read.rs index 635927f..a8ab69c 100644 --- a/src/api/repo/record/read.rs +++ b/src/api/repo/record/read.rs @@ -1,8 +1,7 @@ -use crate::api::proxy_client::proxy_client; use crate::state::AppState; use axum::{ Json, - extract::{Query, RawQuery, State}, + extract::{Query, State}, http::StatusCode, response::{IntoResponse, Response}, }; @@ -12,7 +11,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; use std::str::FromStr; -use tracing::{error, info}; +use tracing::error; #[derive(Deserialize)] pub struct GetRecordInput { @@ -22,69 +21,9 @@ pub struct GetRecordInput { pub cid: Option, } -async fn proxy_get_record_to_appview(state: &AppState, raw_query: Option<&str>) -> Response { - let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.getRecord").await { - Some(r) => r, - None => { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Repo not found"})), - ) - .into_response(); - } - }; - let target_url = match raw_query { - Some(q) => format!("{}/xrpc/com.atproto.repo.getRecord?{}", resolved.url, q), - None => format!("{}/xrpc/com.atproto.repo.getRecord", resolved.url), - }; - info!("Proxying getRecord to AppView: {}", target_url); - let client = proxy_client(); - match client.get(&target_url).send().await { - Ok(resp) => { - let status = - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let content_type = resp - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - match resp.bytes().await { - Ok(body) => { - let mut builder = Response::builder().status(status); - if let Some(ct) = content_type { - builder = builder.header("content-type", ct); - } - builder - .body(axum::body::Body::from(body)) - .unwrap_or_else(|_| { - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() - }) - } - Err(e) => { - error!("Error reading AppView response: {:?}", e); - ( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamError"})), - ) - .into_response() - } - } - } - Err(e) => { - error!("Error proxying to AppView: {:?}", e); - ( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamError"})), - ) - .into_response() - } - } -} - pub async fn get_record( State(state): State, Query(input): Query, - RawQuery(raw_query): RawQuery, ) -> Response { let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); let user_id_opt = if input.repo.starts_with("did:") { @@ -106,8 +45,19 @@ pub async fn get_record( }; let user_id: uuid::Uuid = match user_id_opt { Ok(Some(id)) => id, - _ => { - return proxy_get_record_to_appview(&state, raw_query.as_deref()).await; + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), + ) + .into_response(); + } + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } }; let record_row = sqlx::query!( @@ -192,61 +142,9 @@ pub struct ListRecordsOutput { pub records: Vec, } -async fn proxy_list_records_to_appview(state: &AppState, raw_query: Option<&str>) -> Response { - let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.listRecords").await { - Some(r) => r, - None => { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "NotFound", "message": "Repo not found"})), - ) - .into_response(); - } - }; - let target_url = match raw_query { - Some(q) => format!("{}/xrpc/com.atproto.repo.listRecords?{}", resolved.url, q), - None => format!("{}/xrpc/com.atproto.repo.listRecords", resolved.url), - }; - info!("Proxying listRecords to AppView: {}", target_url); - let client = proxy_client(); - match client.get(&target_url).send().await { - Ok(resp) => { - let status = - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let content_type = resp - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - match resp.bytes().await { - Ok(body) => { - let mut builder = Response::builder().status(status); - if let Some(ct) = content_type { - builder = builder.header("content-type", ct); - } - builder - .body(axum::body::Body::from(body)) - .unwrap_or_else(|_| { - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() - }) - } - Err(e) => { - error!("Error reading AppView response: {:?}", e); - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() - } - } - } - Err(e) => { - error!("Error proxying to AppView: {:?}", e); - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() - } - } -} - pub async fn list_records( State(state): State, Query(input): Query, - RawQuery(raw_query): RawQuery, ) -> Response { let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); let user_id_opt = if input.repo.starts_with("did:") { @@ -268,8 +166,19 @@ pub async fn list_records( }; let user_id: uuid::Uuid = match user_id_opt { Ok(Some(id)) => id, - _ => { - return proxy_list_records_to_appview(&state, raw_query.as_deref()).await; + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), + ) + .into_response(); + } + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } }; let limit = input.limit.unwrap_or(50).clamp(1, 100); diff --git a/src/api/repo/record/write.rs b/src/api/repo/record/write.rs index 9d91cbb..a698f2d 100644 --- a/src/api/repo/record/write.rs +++ b/src/api/repo/record/write.rs @@ -56,8 +56,10 @@ pub async fn prepare_repo_write( state: &AppState, headers: &HeaderMap, repo_did: &str, + http_method: &str, + http_uri: &str, ) -> Result<(String, Uuid, Cid), Response> { - let token = crate::auth::extract_bearer_token_from_header( + let extracted = crate::auth::extract_auth_token_from_header( headers.get("Authorization").and_then(|h| h.to_str().ok()), ) .ok_or_else(|| { @@ -67,15 +69,26 @@ pub async fn prepare_repo_write( ) .into_response() })?; - let auth_user = crate::auth::validate_bearer_token(&state.db, &token) - .await - .map_err(|_| { - ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationFailed"})), - ) - .into_response() - })?; + let dpop_proof = headers + .get("DPoP") + .and_then(|h| h.to_str().ok()); + let auth_user = crate::auth::validate_token_with_dpop( + &state.db, + &extracted.token, + extracted.is_dpop, + dpop_proof, + http_method, + http_uri, + false, + ) + .await + .map_err(|e| { + ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": e.to_string()})), + ) + .into_response() + })?; if repo_did != auth_user.did { return Err(( StatusCode::FORBIDDEN, @@ -172,10 +185,11 @@ pub struct CreateRecordOutput { pub async fn create_record( State(state): State, headers: HeaderMap, + axum::extract::OriginalUri(uri): axum::extract::OriginalUri, Json(input): Json, ) -> Response { let (did, user_id, current_root_cid) = - match prepare_repo_write(&state, &headers, &input.repo).await { + match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { Ok(res) => res, Err(err_res) => return err_res, }; @@ -339,10 +353,11 @@ pub struct PutRecordOutput { pub async fn put_record( State(state): State, headers: HeaderMap, + axum::extract::OriginalUri(uri): axum::extract::OriginalUri, Json(input): Json, ) -> Response { let (did, user_id, current_root_cid) = - match prepare_repo_write(&state, &headers, &input.repo).await { + match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { Ok(res) => res, Err(err_res) => return err_res, }; diff --git a/src/appview/mod.rs b/src/appview/mod.rs index 9cb8d0c..93fb4f8 100644 --- a/src/appview/mod.rs +++ b/src/appview/mod.rs @@ -1,6 +1,7 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; use tracing::{debug, error, info, warn}; @@ -22,24 +23,28 @@ pub struct DidService { } #[derive(Clone)] -struct CachedAppView { +struct CachedDid { url: String, did: String, resolved_at: Instant, } -pub struct AppViewRegistry { - namespace_to_did: HashMap, - did_cache: RwLock>, +#[derive(Debug, Clone)] +pub struct ResolvedService { + pub url: String, + pub did: String, +} + +pub struct DidResolver { + did_cache: RwLock>, client: Client, cache_ttl: Duration, plc_directory_url: String, } -impl Clone for AppViewRegistry { +impl Clone for DidResolver { fn clone(&self) -> Self { Self { - namespace_to_did: self.namespace_to_did.clone(), did_cache: RwLock::new(HashMap::new()), client: self.client.clone(), cache_ttl: self.cache_ttl, @@ -48,31 +53,9 @@ impl Clone for AppViewRegistry { } } -#[derive(Debug, Clone)] -pub struct ResolvedAppView { - pub url: String, - pub did: String, -} - -impl AppViewRegistry { +impl DidResolver { pub fn new() -> Self { - let mut namespace_to_did = HashMap::new(); - - let bsky_did = std::env::var("APPVIEW_DID_BSKY") - .unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); - namespace_to_did.insert("app.bsky".to_string(), bsky_did.clone()); - namespace_to_did.insert("com.atproto".to_string(), bsky_did); - - for (key, value) in std::env::vars() { - if let Some(namespace) = key.strip_prefix("APPVIEW_DID_") { - let namespace = namespace.to_lowercase().replace('_', "."); - if namespace != "bsky" { - namespace_to_did.insert(namespace, value); - } - } - } - - let cache_ttl_secs: u64 = std::env::var("APPVIEW_CACHE_TTL_SECS") + let cache_ttl_secs: u64 = std::env::var("DID_CACHE_TTL_SECS") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(300); @@ -87,16 +70,9 @@ impl AppViewRegistry { .build() .unwrap_or_else(|_| Client::new()); - info!( - "AppView registry initialized with {} namespace mappings", - namespace_to_did.len() - ); - for (ns, did) in &namespace_to_did { - debug!(" {} -> {}", ns, did); - } + info!("DID resolver initialized"); Self { - namespace_to_did, did_cache: RwLock::new(HashMap::new()), client, cache_ttl: Duration::from_secs(cache_ttl_secs), @@ -104,45 +80,12 @@ impl AppViewRegistry { } } - pub fn register_namespace(&mut self, namespace: &str, did: &str) { - info!("Registering AppView: {} -> {}", namespace, did); - self.namespace_to_did - .insert(namespace.to_string(), did.to_string()); - } - - pub async fn get_appview_for_method(&self, method: &str) -> Option { - let namespace = self.extract_namespace(method)?; - self.get_appview_for_namespace(&namespace).await - } - - pub async fn get_appview_for_namespace(&self, namespace: &str) -> Option { - let did = self.get_did_for_namespace(namespace)?; - self.resolve_appview_did(&did).await - } - - pub fn get_did_for_namespace(&self, namespace: &str) -> Option { - if let Some(did) = self.namespace_to_did.get(namespace) { - return Some(did.clone()); - } - - let mut parts: Vec<&str> = namespace.split('.').collect(); - while !parts.is_empty() { - let prefix = parts.join("."); - if let Some(did) = self.namespace_to_did.get(&prefix) { - return Some(did.clone()); - } - parts.pop(); - } - - None - } - - pub async fn resolve_appview_did(&self, did: &str) -> Option { + pub async fn resolve_did(&self, did: &str) -> Option { { let cache = self.did_cache.read().await; if let Some(cached) = cache.get(did) { if cached.resolved_at.elapsed() < self.cache_ttl { - return Some(ResolvedAppView { + return Some(ResolvedService { url: cached.url.clone(), did: cached.did.clone(), }); @@ -156,7 +99,7 @@ impl AppViewRegistry { let mut cache = self.did_cache.write().await; cache.insert( did.to_string(), - CachedAppView { + CachedDid { url: resolved.url.clone(), did: resolved.did.clone(), resolved_at: Instant::now(), @@ -167,7 +110,7 @@ impl AppViewRegistry { Some(resolved) } - async fn resolve_did_internal(&self, did: &str) -> Option { + async fn resolve_did_internal(&self, did: &str) -> Option { let did_doc = if did.starts_with("did:web:") { self.resolve_did_web(did).await } else if did.starts_with("did:plc:") { @@ -185,7 +128,7 @@ impl AppViewRegistry { } }; - self.extract_appview_endpoint(&doc) + self.extract_service_endpoint(&doc) } async fn resolve_did_web(&self, did: &str) -> Result { @@ -275,13 +218,13 @@ impl AppViewRegistry { .map_err(|e| format!("Failed to parse DID document: {}", e)) } - fn extract_appview_endpoint(&self, doc: &DidDocument) -> Option { + fn extract_service_endpoint(&self, doc: &DidDocument) -> Option { for service in &doc.service { if service.service_type == "AtprotoAppView" || service.id.contains("atproto_appview") || service.id.ends_with("#bsky_appview") { - return Some(ResolvedAppView { + return Some(ResolvedService { url: service.service_endpoint.clone(), did: doc.id.clone(), }); @@ -290,7 +233,7 @@ impl AppViewRegistry { for service in &doc.service { if service.service_type.contains("AppView") || service.id.contains("appview") { - return Some(ResolvedAppView { + return Some(ResolvedService { url: service.service_endpoint.clone(), did: doc.id.clone(), }); @@ -303,7 +246,7 @@ impl AppViewRegistry { "No explicit AppView service found for {}, using first service: {}", doc.id, service.service_endpoint ); - return Some(ResolvedAppView { + return Some(ResolvedService { url: service.service_endpoint.clone(), did: doc.id.clone(), }); @@ -326,7 +269,7 @@ impl AppViewRegistry { "No service found for {}, deriving URL from DID: {}://{}", doc.id, scheme, base_host ); - return Some(ResolvedAppView { + return Some(ResolvedService { url: format!("{}://{}", scheme, base_host), did: doc.id.clone(), }); @@ -335,79 +278,18 @@ impl AppViewRegistry { None } - fn extract_namespace(&self, method: &str) -> Option { - let parts: Vec<&str> = method.split('.').collect(); - if parts.len() >= 2 { - Some(format!("{}.{}", parts[0], parts[1])) - } else { - None - } - } - - pub fn list_namespaces(&self) -> Vec<(String, String)> { - self.namespace_to_did - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect() - } - pub async fn invalidate_cache(&self, did: &str) { let mut cache = self.did_cache.write().await; cache.remove(did); } - - pub async fn invalidate_all_cache(&self) { - let mut cache = self.did_cache.write().await; - cache.clear(); - } } -impl Default for AppViewRegistry { +impl Default for DidResolver { fn default() -> Self { Self::new() } } -pub async fn get_appview_url_for_method(registry: &AppViewRegistry, method: &str) -> Option { - registry.get_appview_for_method(method).await.map(|r| r.url) -} - -pub async fn get_appview_did_for_method(registry: &AppViewRegistry, method: &str) -> Option { - registry.get_appview_for_method(method).await.map(|r| r.did) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_extract_namespace() { - let registry = AppViewRegistry::new(); - assert_eq!( - registry.extract_namespace("app.bsky.actor.getProfile"), - Some("app.bsky".to_string()) - ); - assert_eq!( - registry.extract_namespace("com.atproto.repo.createRecord"), - Some("com.atproto".to_string()) - ); - assert_eq!( - registry.extract_namespace("com.whtwnd.blog.getPost"), - Some("com.whtwnd".to_string()) - ); - assert_eq!(registry.extract_namespace("invalid"), None); - } - - #[test] - fn test_get_did_for_namespace() { - let mut registry = AppViewRegistry::new(); - registry.register_namespace("com.whtwnd", "did:web:whtwnd.com"); - - assert!(registry.get_did_for_namespace("app.bsky").is_some()); - assert_eq!( - registry.get_did_for_namespace("com.whtwnd"), - Some("did:web:whtwnd.com".to_string()) - ); - assert!(registry.get_did_for_namespace("unknown.namespace").is_none()); - } +pub fn create_did_resolver() -> Arc { + Arc::new(DidResolver::new()) } diff --git a/src/lib.rs b/src/lib.rs index f3b988b..3f2c2f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -317,35 +317,6 @@ pub fn app(state: AppState) -> Router { "/xrpc/app.bsky.actor.putPreferences", post(api::actor::put_preferences), ) - .route( - "/xrpc/app.bsky.actor.getProfile", - get(api::actor::get_profile), - ) - .route( - "/xrpc/app.bsky.actor.getProfiles", - get(api::actor::get_profiles), - ) - .route( - "/xrpc/app.bsky.feed.getTimeline", - get(api::feed::get_timeline), - ) - .route( - "/xrpc/app.bsky.feed.getAuthorFeed", - get(api::feed::get_author_feed), - ) - .route( - "/xrpc/app.bsky.feed.getActorLikes", - get(api::feed::get_actor_likes), - ) - .route( - "/xrpc/app.bsky.feed.getPostThread", - get(api::feed::get_post_thread), - ) - .route("/xrpc/app.bsky.feed.getFeed", get(api::feed::get_feed)) - .route( - "/xrpc/app.bsky.notification.registerPush", - post(api::notification::register_push), - ) .route("/.well-known/did.json", get(api::identity::well_known_did)) .route( "/.well-known/atproto-did", diff --git a/src/state.rs b/src/state.rs index 43a3624..0fd51d6 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,4 +1,4 @@ -use crate::appview::AppViewRegistry; +use crate::appview::DidResolver; use crate::cache::{Cache, DistributedRateLimiter, create_cache}; use crate::circuit_breaker::CircuitBreakers; use crate::config::AuthConfig; @@ -20,7 +20,7 @@ pub struct AppState { pub circuit_breakers: Arc, pub cache: Arc, pub distributed_rate_limiter: Arc, - pub appview_registry: Arc, + pub did_resolver: Arc, } pub enum RateLimitKind { @@ -87,7 +87,7 @@ impl AppState { let rate_limiters = Arc::new(RateLimiters::new()); let circuit_breakers = Arc::new(CircuitBreakers::new()); let (cache, distributed_rate_limiter) = create_cache().await; - let appview_registry = Arc::new(AppViewRegistry::new()); + let did_resolver = Arc::new(DidResolver::new()); Self { db, @@ -98,7 +98,7 @@ impl AppState { circuit_breakers, cache, distributed_rate_limiter, - appview_registry, + did_resolver, } } diff --git a/tests/account_notifications.rs b/tests/account_notifications.rs index ad96219..279760e 100644 --- a/tests/account_notifications.rs +++ b/tests/account_notifications.rs @@ -170,8 +170,9 @@ async fn test_update_email_via_notification_prefs() { let pool = get_pool().await; let (token, did) = create_account_and_login(&client).await; + let unique_email = format!("newemail_{}@example.com", uuid::Uuid::new_v4()); let prefs = json!({ - "email": "newemail@example.com" + "email": unique_email }); let resp = client .post(format!("{}/xrpc/com.bspds.account.updateNotificationPrefs", base)) @@ -217,5 +218,5 @@ async fn test_update_email_via_notification_prefs() { .await .unwrap(); let body: Value = resp.json().await.unwrap(); - assert_eq!(body["email"], "newemail@example.com"); + assert_eq!(body["email"], unique_email); } diff --git a/tests/admin_search.rs b/tests/admin_search.rs index b555adf..92da44e 100644 --- a/tests/admin_search.rs +++ b/tests/admin_search.rs @@ -12,7 +12,7 @@ async fn test_search_accounts_as_admin() { let (user_did, _) = setup_new_user("search-target").await; let res = client .get(format!( - "{}/xrpc/com.atproto.admin.searchAccounts", + "{}/xrpc/com.atproto.admin.searchAccounts?limit=1000", base_url().await )) .bearer_auth(&admin_jwt) @@ -24,7 +24,7 @@ async fn test_search_accounts_as_admin() { let accounts = body["accounts"].as_array().expect("accounts should be array"); assert!(!accounts.is_empty(), "Should return some accounts"); let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did)); - assert!(found, "Should find the created user in results"); + assert!(found, "Should find the created user in results (DID: {})", user_did); } #[tokio::test] @@ -111,6 +111,7 @@ async fn test_search_accounts_pagination() { #[tokio::test] async fn test_search_accounts_requires_admin() { let client = client(); + let _ = create_account_and_login(&client).await; let (_, user_jwt) = setup_new_user("search-nonadmin").await; let res = client .get(format!( diff --git a/tests/appview_integration.rs b/tests/appview_integration.rs deleted file mode 100644 index d8dab0c..0000000 --- a/tests/appview_integration.rs +++ /dev/null @@ -1,135 +0,0 @@ -mod common; - -use common::{base_url, client, create_account_and_login}; -use reqwest::StatusCode; -use serde_json::{Value, json}; - -#[tokio::test] -async fn test_get_author_feed_returns_appview_data() { - let client = client(); - let base = base_url().await; - let (jwt, did) = create_account_and_login(&client).await; - let res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}", - base, did - )) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - assert!(body["feed"].is_array(), "Response should have feed array"); - let feed = body["feed"].as_array().unwrap(); - assert_eq!(feed.len(), 1, "Feed should have 1 post from appview"); - assert_eq!( - feed[0]["post"]["record"]["text"].as_str(), - Some("Author feed post from appview"), - "Post text should match appview response" - ); -} - -#[tokio::test] -async fn test_get_actor_likes_returns_appview_data() { - let client = client(); - let base = base_url().await; - let (jwt, did) = create_account_and_login(&client).await; - let res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getActorLikes?actor={}", - base, did - )) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - assert!(body["feed"].is_array(), "Response should have feed array"); - let feed = body["feed"].as_array().unwrap(); - assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview"); - assert_eq!( - feed[0]["post"]["record"]["text"].as_str(), - Some("Liked post from appview"), - "Post text should match appview response" - ); -} - -#[tokio::test] -async fn test_get_post_thread_returns_appview_data() { - let client = client(); - let base = base_url().await; - let (jwt, did) = create_account_and_login(&client).await; - let res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123", - base, did - )) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - assert!( - body["thread"].is_object(), - "Response should have thread object" - ); - assert_eq!( - body["thread"]["$type"].as_str(), - Some("app.bsky.feed.defs#threadViewPost"), - "Thread should be a threadViewPost" - ); - assert_eq!( - body["thread"]["post"]["record"]["text"].as_str(), - Some("Thread post from appview"), - "Post text should match appview response" - ); -} - -#[tokio::test] -async fn test_get_feed_returns_appview_data() { - let client = client(); - let base = base_url().await; - let (jwt, _did) = create_account_and_login(&client).await; - let res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test", - base - )) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - assert!(body["feed"].is_array(), "Response should have feed array"); - let feed = body["feed"].as_array().unwrap(); - assert_eq!(feed.len(), 1, "Feed should have 1 post from appview"); - assert_eq!( - feed[0]["post"]["record"]["text"].as_str(), - Some("Custom feed post from appview"), - "Post text should match appview response" - ); -} - -#[tokio::test] -async fn test_register_push_proxies_to_appview() { - let client = client(); - let base = base_url().await; - let (jwt, _did) = create_account_and_login(&client).await; - let res = client - .post(format!("{}/xrpc/app.bsky.notification.registerPush", base)) - .header("Authorization", format!("Bearer {}", jwt)) - .json(&json!({ - "serviceDid": "did:web:example.com", - "token": "test-push-token", - "platform": "ios", - "appId": "xyz.bsky.app" - })) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); -} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 84dcc63..36d311b 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -141,9 +141,6 @@ async fn setup_with_external_infra() -> String { let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; - unsafe { - std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did); - } MOCK_APPVIEW.set(mock_server).ok(); spawn_app(database_url).await } @@ -194,9 +191,6 @@ async fn setup_with_testcontainers() -> String { let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; - unsafe { - std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did); - } MOCK_APPVIEW.set(mock_server).ok(); S3_CONTAINER.set(s3_container).ok(); let container = Postgres::default() @@ -238,134 +232,7 @@ async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_en .await; } -async fn setup_mock_appview(mock_server: &MockServer) { - Mock::given(method("GET")) - .and(path("/xrpc/app.bsky.actor.getProfile")) - .respond_with(ResponseTemplate::new(200).set_body_json(json!({ - "handle": "mock.handle", - "did": "did:plc:mock", - "displayName": "Mock User" - }))) - .mount(mock_server) - .await; - Mock::given(method("GET")) - .and(path("/xrpc/app.bsky.actor.searchActors")) - .respond_with(ResponseTemplate::new(200).set_body_json(json!({ - "actors": [], - "cursor": null - }))) - .mount(mock_server) - .await; - Mock::given(method("GET")) - .and(path("/xrpc/app.bsky.feed.getTimeline")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("atproto-repo-rev", "0") - .set_body_json(json!({ - "feed": [], - "cursor": null - })), - ) - .mount(mock_server) - .await; - Mock::given(method("GET")) - .and(path("/xrpc/app.bsky.feed.getAuthorFeed")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("atproto-repo-rev", "0") - .set_body_json(json!({ - "feed": [{ - "post": { - "uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author", - "cid": "bafyappview123", - "author": {"did": "did:plc:mock-author", "handle": "mock.author"}, - "record": { - "$type": "app.bsky.feed.post", - "text": "Author feed post from appview", - "createdAt": "2025-01-01T00:00:00Z" - }, - "indexedAt": "2025-01-01T00:00:00Z" - } - }], - "cursor": "author-cursor" - })), - ) - .mount(mock_server) - .await; - Mock::given(method("GET")) - .and(path("/xrpc/app.bsky.feed.getActorLikes")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("atproto-repo-rev", "0") - .set_body_json(json!({ - "feed": [{ - "post": { - "uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post", - "cid": "bafyliked123", - "author": {"did": "did:plc:mock-likes", "handle": "mock.likes"}, - "record": { - "$type": "app.bsky.feed.post", - "text": "Liked post from appview", - "createdAt": "2025-01-01T00:00:00Z" - }, - "indexedAt": "2025-01-01T00:00:00Z" - } - }], - "cursor": null - })), - ) - .mount(mock_server) - .await; - Mock::given(method("GET")) - .and(path("/xrpc/app.bsky.feed.getPostThread")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("atproto-repo-rev", "0") - .set_body_json(json!({ - "thread": { - "$type": "app.bsky.feed.defs#threadViewPost", - "post": { - "uri": "at://did:plc:mock/app.bsky.feed.post/thread-post", - "cid": "bafythread123", - "author": {"did": "did:plc:mock", "handle": "mock.handle"}, - "record": { - "$type": "app.bsky.feed.post", - "text": "Thread post from appview", - "createdAt": "2025-01-01T00:00:00Z" - }, - "indexedAt": "2025-01-01T00:00:00Z" - }, - "replies": [] - } - })), - ) - .mount(mock_server) - .await; - Mock::given(method("GET")) - .and(path("/xrpc/app.bsky.feed.getFeed")) - .respond_with(ResponseTemplate::new(200).set_body_json(json!({ - "feed": [{ - "post": { - "uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post", - "cid": "bafyfeed123", - "author": {"did": "did:plc:mock-feed", "handle": "mock.feed"}, - "record": { - "$type": "app.bsky.feed.post", - "text": "Custom feed post from appview", - "createdAt": "2025-01-01T00:00:00Z" - }, - "indexedAt": "2025-01-01T00:00:00Z" - } - }], - "cursor": null - }))) - .mount(mock_server) - .await; - Mock::given(method("POST")) - .and(path("/xrpc/app.bsky.notification.registerPush")) - .respond_with(ResponseTemplate::new(200)) - .mount(mock_server) - .await; +async fn setup_mock_appview(_mock_server: &MockServer) { } async fn spawn_app(database_url: String) -> String { diff --git a/tests/feed.rs b/tests/feed.rs deleted file mode 100644 index 376d0c5..0000000 --- a/tests/feed.rs +++ /dev/null @@ -1,104 +0,0 @@ -mod common; -use common::{base_url, client, create_account_and_login}; -use serde_json::json; - -#[tokio::test] -async fn test_get_timeline_requires_auth() { - let client = client(); - let base = base_url().await; - let res = client - .get(format!("{}/xrpc/app.bsky.feed.getTimeline", base)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), 401); -} - -#[tokio::test] -async fn test_get_author_feed_requires_actor() { - let client = client(); - let base = base_url().await; - let (jwt, _did) = create_account_and_login(&client).await; - let res = client - .get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base)) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), 400); -} - -#[tokio::test] -async fn test_get_actor_likes_requires_actor() { - let client = client(); - let base = base_url().await; - let (jwt, _did) = create_account_and_login(&client).await; - let res = client - .get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base)) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), 400); -} - -#[tokio::test] -async fn test_get_post_thread_requires_uri() { - let client = client(); - let base = base_url().await; - let (jwt, _did) = create_account_and_login(&client).await; - let res = client - .get(format!("{}/xrpc/app.bsky.feed.getPostThread", base)) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), 400); -} - -#[tokio::test] -async fn test_get_feed_requires_auth() { - let client = client(); - let base = base_url().await; - let res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test", - base - )) - .send() - .await - .unwrap(); - assert_eq!(res.status(), 401); -} - -#[tokio::test] -async fn test_get_feed_requires_feed_param() { - let client = client(); - let base = base_url().await; - let (jwt, _did) = create_account_and_login(&client).await; - let res = client - .get(format!("{}/xrpc/app.bsky.feed.getFeed", base)) - .header("Authorization", format!("Bearer {}", jwt)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), 400); -} - -#[tokio::test] -async fn test_register_push_requires_auth() { - let client = client(); - let base = base_url().await; - let res = client - .post(format!("{}/xrpc/app.bsky.notification.registerPush", base)) - .json(&json!({ - "serviceDid": "did:web:example.com", - "token": "test-token", - "platform": "ios", - "appId": "xyz.bsky.app" - })) - .send() - .await - .unwrap(); - assert_eq!(res.status(), 401); -} diff --git a/tests/image_processing.rs b/tests/image_processing.rs index 5f7011f..486ee58 100644 --- a/tests/image_processing.rs +++ b/tests/image_processing.rs @@ -8,223 +8,154 @@ use std::io::Cursor; fn create_test_png(width: u32, height: u32) -> Vec { let img = DynamicImage::new_rgb8(width, height); let mut buf = Vec::new(); - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) - .unwrap(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); buf } fn create_test_jpeg(width: u32, height: u32) -> Vec { let img = DynamicImage::new_rgb8(width, height); let mut buf = Vec::new(); - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg) - .unwrap(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); buf } fn create_test_gif(width: u32, height: u32) -> Vec { let img = DynamicImage::new_rgb8(width, height); let mut buf = Vec::new(); - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif) - .unwrap(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); buf } fn create_test_webp(width: u32, height: u32) -> Vec { let img = DynamicImage::new_rgb8(width, height); let mut buf = Vec::new(); - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP) - .unwrap(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); buf } #[test] -fn test_process_png() { +fn test_format_support() { let processor = ImageProcessor::new(); - let data = create_test_png(500, 500); - let result = processor.process(&data, "image/png").unwrap(); + + let png = create_test_png(500, 500); + let result = processor.process(&png, "image/png").unwrap(); assert_eq!(result.original.width, 500); assert_eq!(result.original.height, 500); -} -#[test] -fn test_process_jpeg() { - let processor = ImageProcessor::new(); - let data = create_test_jpeg(400, 300); - let result = processor.process(&data, "image/jpeg").unwrap(); + let jpeg = create_test_jpeg(400, 300); + let result = processor.process(&jpeg, "image/jpeg").unwrap(); assert_eq!(result.original.width, 400); assert_eq!(result.original.height, 300); -} -#[test] -fn test_process_gif() { - let processor = ImageProcessor::new(); - let data = create_test_gif(200, 200); - let result = processor.process(&data, "image/gif").unwrap(); + let gif = create_test_gif(200, 200); + let result = processor.process(&gif, "image/gif").unwrap(); assert_eq!(result.original.width, 200); - assert_eq!(result.original.height, 200); -} -#[test] -fn test_process_webp() { - let processor = ImageProcessor::new(); - let data = create_test_webp(300, 200); - let result = processor.process(&data, "image/webp").unwrap(); + let webp = create_test_webp(300, 200); + let result = processor.process(&webp, "image/webp").unwrap(); assert_eq!(result.original.width, 300); - assert_eq!(result.original.height, 200); } #[test] -fn test_thumbnail_feed_size() { +fn test_thumbnail_generation() { let processor = ImageProcessor::new(); - let data = create_test_png(800, 600); - let result = processor.process(&data, "image/png").unwrap(); - let thumb = result - .thumbnail_feed - .expect("Should generate feed thumbnail for large image"); - assert!(thumb.width <= THUMB_SIZE_FEED); - assert!(thumb.height <= THUMB_SIZE_FEED); + + let small = create_test_png(100, 100); + let result = processor.process(&small, "image/png").unwrap(); + assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); + assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); + + let medium = create_test_png(500, 500); + let result = processor.process(&medium, "image/png").unwrap(); + assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail"); + assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail"); + + let large = create_test_png(2000, 2000); + let result = processor.process(&large, "image/png").unwrap(); + assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail"); + assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail"); + let thumb = result.thumbnail_feed.unwrap(); + assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED); + let full = result.thumbnail_full.unwrap(); + assert!(full.width <= THUMB_SIZE_FULL && full.height <= THUMB_SIZE_FULL); + + let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); + let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); + assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none()); + assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some()); + + let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); + let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); + assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none()); + assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some()); + + let disabled = ImageProcessor::new().with_thumbnails(false); + let result = disabled.process(&large, "image/png").unwrap(); + assert!(result.thumbnail_feed.is_none() && result.thumbnail_full.is_none()); } #[test] -fn test_thumbnail_full_size() { - let processor = ImageProcessor::new(); - let data = create_test_png(2000, 1500); - let result = processor.process(&data, "image/png").unwrap(); - let thumb = result - .thumbnail_full - .expect("Should generate full thumbnail for large image"); - assert!(thumb.width <= THUMB_SIZE_FULL); - assert!(thumb.height <= THUMB_SIZE_FULL); +fn test_output_format_conversion() { + let png = create_test_png(300, 300); + let jpeg = create_test_jpeg(300, 300); + + let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP); + assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp"); + + let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); + assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg"); + + let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png); + assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png"); } #[test] -fn test_no_thumbnail_small_image() { - let processor = ImageProcessor::new(); - let data = create_test_png(100, 100); - let result = processor.process(&data, "image/png").unwrap(); - assert!( - result.thumbnail_feed.is_none(), - "Small image should not get feed thumbnail" - ); - assert!( - result.thumbnail_full.is_none(), - "Small image should not get full thumbnail" - ); -} - -#[test] -fn test_webp_conversion() { - let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); - let data = create_test_png(300, 300); - let result = processor.process(&data, "image/png").unwrap(); - assert_eq!(result.original.mime_type, "image/webp"); -} - -#[test] -fn test_jpeg_output_format() { - let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); - let data = create_test_png(300, 300); - let result = processor.process(&data, "image/png").unwrap(); - assert_eq!(result.original.mime_type, "image/jpeg"); -} - -#[test] -fn test_png_output_format() { - let processor = ImageProcessor::new().with_output_format(OutputFormat::Png); - let data = create_test_jpeg(300, 300); - let result = processor.process(&data, "image/jpeg").unwrap(); - assert_eq!(result.original.mime_type, "image/png"); -} - -#[test] -fn test_max_dimension_enforced() { - let processor = ImageProcessor::new().with_max_dimension(1000); - let data = create_test_png(2000, 2000); - let result = processor.process(&data, "image/png"); - assert!(matches!(result, Err(ImageError::TooLarge { .. }))); - if let Err(ImageError::TooLarge { - width, - height, - max_dimension, - }) = result - { - assert_eq!(width, 2000); - assert_eq!(height, 2000); - assert_eq!(max_dimension, 1000); - } -} - -#[test] -fn test_file_size_limit() { - let processor = ImageProcessor::new().with_max_file_size(100); - let data = create_test_png(500, 500); - let result = processor.process(&data, "image/png"); - assert!(matches!(result, Err(ImageError::FileTooLarge { .. }))); - if let Err(ImageError::FileTooLarge { size, max_size }) = result { - assert!(size > 100); - assert_eq!(max_size, 100); - } -} - -#[test] -fn test_default_max_file_size() { +fn test_size_and_dimension_limits() { assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); + + let max_dim = ImageProcessor::new().with_max_dimension(1000); + let large = create_test_png(2000, 2000); + let result = max_dim.process(&large, "image/png"); + assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 }))); + + let max_file = ImageProcessor::new().with_max_file_size(100); + let data = create_test_png(500, 500); + let result = max_file.process(&data, "image/png"); + assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. }))); } #[test] -fn test_unsupported_format_rejected() { +fn test_error_handling() { let processor = ImageProcessor::new(); - let data = b"this is not an image"; - let result = processor.process(data, "application/octet-stream"); + + let result = processor.process(b"this is not an image", "application/octet-stream"); assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); -} -#[test] -fn test_corrupted_image_handling() { - let processor = ImageProcessor::new(); - let data = b"\x89PNG\r\n\x1a\ncorrupted data here"; - let result = processor.process(data, "image/png"); + let result = processor.process(b"\x89PNG\r\n\x1a\ncorrupted data here", "image/png"); assert!(matches!(result, Err(ImageError::DecodeError(_)))); } #[test] -fn test_aspect_ratio_preserved_landscape() { +fn test_aspect_ratio_preservation() { let processor = ImageProcessor::new(); - let data = create_test_png(1600, 800); - let result = processor.process(&data, "image/png").unwrap(); - let thumb = result.thumbnail_full.expect("Should have thumbnail"); + + let landscape = create_test_png(1600, 800); + let result = processor.process(&landscape, "image/png").unwrap(); + let thumb = result.thumbnail_full.unwrap(); let original_ratio = 1600.0 / 800.0; let thumb_ratio = thumb.width as f64 / thumb.height as f64; - assert!( - (original_ratio - thumb_ratio).abs() < 0.1, - "Aspect ratio should be preserved" - ); -} + assert!((original_ratio - thumb_ratio).abs() < 0.1); -#[test] -fn test_aspect_ratio_preserved_portrait() { - let processor = ImageProcessor::new(); - let data = create_test_png(800, 1600); - let result = processor.process(&data, "image/png").unwrap(); - let thumb = result.thumbnail_full.expect("Should have thumbnail"); + let portrait = create_test_png(800, 1600); + let result = processor.process(&portrait, "image/png").unwrap(); + let thumb = result.thumbnail_full.unwrap(); let original_ratio = 800.0 / 1600.0; let thumb_ratio = thumb.width as f64 / thumb.height as f64; - assert!( - (original_ratio - thumb_ratio).abs() < 0.1, - "Aspect ratio should be preserved" - ); + assert!((original_ratio - thumb_ratio).abs() < 0.1); } #[test] -fn test_mime_type_detection_auto() { - let processor = ImageProcessor::new(); - let data = create_test_png(100, 100); - let result = processor.process(&data, "application/octet-stream"); - assert!(result.is_ok(), "Should detect PNG format from data"); -} - -#[test] -fn test_is_supported_mime_type() { +fn test_utilities_and_builder() { assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); assert!(ImageProcessor::is_supported_mime_type("image/jpg")); assert!(ImageProcessor::is_supported_mime_type("image/png")); @@ -235,35 +166,16 @@ fn test_is_supported_mime_type() { assert!(!ImageProcessor::is_supported_mime_type("image/bmp")); assert!(!ImageProcessor::is_supported_mime_type("image/tiff")); assert!(!ImageProcessor::is_supported_mime_type("text/plain")); - assert!(!ImageProcessor::is_supported_mime_type("application/json")); -} -#[test] -fn test_strip_exif() { - let data = create_test_jpeg(100, 100); - let result = ImageProcessor::strip_exif(&data); - assert!(result.is_ok()); - let stripped = result.unwrap(); + let data = create_test_png(100, 100); + let processor = ImageProcessor::new(); + let result = processor.process(&data, "application/octet-stream"); + assert!(result.is_ok(), "Should detect PNG format from data"); + + let jpeg = create_test_jpeg(100, 100); + let stripped = ImageProcessor::strip_exif(&jpeg).unwrap(); assert!(!stripped.is_empty()); -} -#[test] -fn test_with_thumbnails_disabled() { - let processor = ImageProcessor::new().with_thumbnails(false); - let data = create_test_png(2000, 2000); - let result = processor.process(&data, "image/png").unwrap(); - assert!( - result.thumbnail_feed.is_none(), - "Thumbnails should be disabled" - ); - assert!( - result.thumbnail_full.is_none(), - "Thumbnails should be disabled" - ); -} - -#[test] -fn test_builder_chaining() { let processor = ImageProcessor::new() .with_max_dimension(2048) .with_max_file_size(5 * 1024 * 1024) @@ -272,79 +184,6 @@ fn test_builder_chaining() { let data = create_test_png(500, 500); let result = processor.process(&data, "image/png").unwrap(); assert_eq!(result.original.mime_type, "image/jpeg"); -} - -#[test] -fn test_processed_image_fields() { - let processor = ImageProcessor::new(); - let data = create_test_png(500, 500); - let result = processor.process(&data, "image/png").unwrap(); assert!(!result.original.data.is_empty()); - assert!(!result.original.mime_type.is_empty()); - assert!(result.original.width > 0); - assert!(result.original.height > 0); -} - -#[test] -fn test_only_feed_thumbnail_for_medium_images() { - let processor = ImageProcessor::new(); - let data = create_test_png(500, 500); - let result = processor.process(&data, "image/png").unwrap(); - assert!( - result.thumbnail_feed.is_some(), - "Should have feed thumbnail" - ); - assert!( - result.thumbnail_full.is_none(), - "Should NOT have full thumbnail for 500px image" - ); -} - -#[test] -fn test_both_thumbnails_for_large_images() { - let processor = ImageProcessor::new(); - let data = create_test_png(2000, 2000); - let result = processor.process(&data, "image/png").unwrap(); - assert!( - result.thumbnail_feed.is_some(), - "Should have feed thumbnail" - ); - assert!( - result.thumbnail_full.is_some(), - "Should have full thumbnail for 2000px image" - ); -} - -#[test] -fn test_exact_threshold_boundary_feed() { - let processor = ImageProcessor::new(); - let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); - let result = processor.process(&at_threshold, "image/png").unwrap(); - assert!( - result.thumbnail_feed.is_none(), - "Exact threshold should not generate thumbnail" - ); - let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); - let result = processor.process(&above_threshold, "image/png").unwrap(); - assert!( - result.thumbnail_feed.is_some(), - "Above threshold should generate thumbnail" - ); -} - -#[test] -fn test_exact_threshold_boundary_full() { - let processor = ImageProcessor::new(); - let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); - let result = processor.process(&at_threshold, "image/png").unwrap(); - assert!( - result.thumbnail_full.is_none(), - "Exact threshold should not generate thumbnail" - ); - let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); - let result = processor.process(&above_threshold, "image/png").unwrap(); - assert!( - result.thumbnail_full.is_some(), - "Above threshold should generate thumbnail" - ); + assert!(result.original.width > 0 && result.original.height > 0); } diff --git a/tests/jwt_security.rs b/tests/jwt_security.rs index 80f1f71..136ec81 100644 --- a/tests/jwt_security.rs +++ b/tests/jwt_security.rs @@ -38,78 +38,55 @@ fn create_unsigned_jwt(header: &Value, claims: &Value) -> String { } #[test] -fn test_jwt_security_forged_signature_rejected() { +fn test_signature_attacks() { let key_bytes = generate_user_key(); let did = "did:plc:test"; let token = create_access_token(did, &key_bytes).expect("create token"); let parts: Vec<&str> = token.split('.').collect(); + let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 64]); let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); let result = verify_access_token(&forged_token, &key_bytes); assert!(result.is_err(), "Forged signature must be rejected"); - let err_msg = result.err().unwrap().to_string(); - assert!( - err_msg.contains("signature") || err_msg.contains("Signature"), - "Error should mention signature: {}", - err_msg - ); -} + assert!(result.err().unwrap().to_string().to_lowercase().contains("signature")); -#[test] -fn test_jwt_security_modified_payload_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:legitimate"; - let token = create_access_token(did, &key_bytes).expect("create token"); - let parts: Vec<&str> = token.split('.').collect(); let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); payload["sub"] = json!("did:plc:attacker"); let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); - let result = verify_access_token(&modified_token, &key_bytes); - assert!(result.is_err(), "Modified payload must be rejected"); + assert!(verify_access_token(&modified_token, &key_bytes).is_err(), "Modified payload must be rejected"); + + let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); + let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]); + let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig); + assert!(verify_access_token(&truncated_token, &key_bytes).is_err(), "Truncated signature must be rejected"); + + let mut extended_sig = sig_bytes.clone(); + extended_sig.extend_from_slice(&[0u8; 32]); + let extended_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&extended_sig)); + assert!(verify_access_token(&extended_token, &key_bytes).is_err(), "Extended signature must be rejected"); + + let key_bytes_user2 = generate_user_key(); + assert!(verify_access_token(&token, &key_bytes_user2).is_err(), "Token signed with different key must be rejected"); } #[test] -fn test_jwt_security_algorithm_none_attack_rejected() { +fn test_algorithm_substitution_attacks() { let key_bytes = generate_user_key(); let did = "did:plc:test"; - let header = json!({ - "alg": "none", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "attacker-token-1", - "scope": SCOPE_ACCESS - }); - let malicious_token = create_unsigned_jwt(&header, &claims); - let result = verify_access_token(&malicious_token, &key_bytes); - assert!(result.is_err(), "Algorithm 'none' attack must be rejected"); -} -#[test] -fn test_jwt_security_algorithm_substitution_hs256_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "HS256", - "typ": TOKEN_TYPE_ACCESS - }); + let none_header = json!({ "alg": "none", "typ": TOKEN_TYPE_ACCESS }); let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "attacker-token-2", - "scope": SCOPE_ACCESS + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "attack-token", "scope": SCOPE_ACCESS }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); + let none_token = create_unsigned_jwt(&none_header, &claims); + assert!(verify_access_token(&none_token, &key_bytes).is_err(), "Algorithm 'none' must be rejected"); + + let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS }); + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap()); let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); use hmac::{Hmac, Mac}; type HmacSha256 = Hmac; @@ -117,925 +94,378 @@ fn test_jwt_security_algorithm_substitution_hs256_rejected() { let mut mac = HmacSha256::new_from_slice(&key_bytes).unwrap(); mac.update(message.as_bytes()); let hmac_sig = mac.finalize().into_bytes(); - let signature_b64 = URL_SAFE_NO_PAD.encode(&hmac_sig); - let malicious_token = format!("{}.{}", message, signature_b64); - let result = verify_access_token(&malicious_token, &key_bytes); - assert!( - result.is_err(), - "HS256 algorithm substitution must be rejected" - ); + let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig)); + assert!(verify_access_token(&hs256_token, &key_bytes).is_err(), "HS256 substitution must be rejected"); + + for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] { + let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS }); + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); + let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]); + let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); + assert!(verify_access_token(&token, &key_bytes).is_err(), "{} substitution must be rejected", alg); + } } #[test] -fn test_jwt_security_algorithm_substitution_rs256_rejected() { +fn test_token_type_confusion() { let key_bytes = generate_user_key(); let did = "did:plc:test"; - let header = json!({ - "alg": "RS256", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "attacker-token-3", - "scope": SCOPE_ACCESS - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 256]); - let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); - let result = verify_access_token(&malicious_token, &key_bytes); - assert!( - result.is_err(), - "RS256 algorithm substitution must be rejected" - ); -} -#[test] -fn test_jwt_security_algorithm_substitution_es256_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "attacker-token-4", - "scope": SCOPE_ACCESS - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); - let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); - let result = verify_access_token(&malicious_token, &key_bytes); - assert!( - result.is_err(), - "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)" - ); -} - -#[test] -fn test_jwt_security_token_type_confusion_refresh_as_access() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token"); let result = verify_access_token(&refresh_token, &key_bytes); - assert!( - result.is_err(), - "Refresh token must not be accepted as access token" - ); - let err_msg = result.err().unwrap().to_string(); - assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); -} + assert!(result.is_err(), "Refresh token as access must be rejected"); + assert!(result.err().unwrap().to_string().contains("Invalid token type")); -#[test] -fn test_jwt_security_token_type_confusion_access_as_refresh() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; let access_token = create_access_token(did, &key_bytes).expect("create access token"); let result = verify_refresh_token(&access_token, &key_bytes); - assert!( - result.is_err(), - "Access token must not be accepted as refresh token" - ); - let err_msg = result.err().unwrap().to_string(); - assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); + assert!(result.is_err(), "Access token as refresh must be rejected"); + assert!(result.err().unwrap().to_string().contains("Invalid token type")); + + let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap(); + assert!(verify_access_token(&service_token, &key_bytes).is_err(), "Service token as access must be rejected"); } #[test] -fn test_jwt_security_token_type_confusion_service_as_access() { +fn test_scope_validation() { let key_bytes = generate_user_key(); let did = "did:plc:test"; - let service_token = - create_service_token(did, "did:web:target", "com.example.method", &key_bytes) - .expect("create service token"); - let result = verify_access_token(&service_token, &key_bytes); - assert!( - result.is_err(), - "Service token must not be accepted as access token" - ); -} + let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); -#[test] -fn test_jwt_security_scope_manipulation_attack() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS + let invalid_scope = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "test", "scope": "admin.all" }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "scope-attack-token", - "scope": "admin.all" - }); - let malicious_token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&malicious_token, &key_bytes); - assert!(result.is_err(), "Invalid scope must be rejected"); - let err_msg = result.err().unwrap().to_string(); - assert!( - err_msg.contains("Invalid token scope"), - "Error: {}", - err_msg - ); -} + let result = verify_access_token(&create_custom_jwt(&header, &invalid_scope, &key_bytes), &key_bytes); + assert!(result.is_err() && result.err().unwrap().to_string().contains("Invalid token scope")); -#[test] -fn test_jwt_security_empty_scope_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS + let empty_scope = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "test", "scope": "" }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "empty-scope-token", - "scope": "" + assert!(verify_access_token(&create_custom_jwt(&header, &empty_scope, &key_bytes), &key_bytes).is_err()); + + let missing_scope = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "test" }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!( - result.is_err(), - "Empty scope must be rejected for access tokens" - ); -} + assert!(verify_access_token(&create_custom_jwt(&header, &missing_scope, &key_bytes), &key_bytes).is_err()); -#[test] -fn test_jwt_security_missing_scope_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "no-scope-token" - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!( - result.is_err(), - "Missing scope must be rejected for access tokens" - ); -} - -#[test] -fn test_jwt_security_expired_token_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp() - 7200, - "exp": Utc::now().timestamp() - 3600, - "jti": "expired-token", - "scope": SCOPE_ACCESS - }); - let expired_token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&expired_token, &key_bytes); - assert!(result.is_err(), "Expired token must be rejected"); - let err_msg = result.err().unwrap().to_string(); - assert!(err_msg.contains("expired"), "Error: {}", err_msg); -} - -#[test] -fn test_jwt_security_future_iat_accepted() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp() + 60, - "exp": Utc::now().timestamp() + 7200, - "jti": "future-iat-token", - "scope": SCOPE_ACCESS - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!( - result.is_ok(), - "Slight future iat should be accepted for clock skew tolerance" - ); -} - -#[test] -fn test_jwt_security_cross_user_key_attack() { - let key_bytes_user1 = generate_user_key(); - let key_bytes_user2 = generate_user_key(); - let did = "did:plc:user1"; - let token = create_access_token(did, &key_bytes_user1).expect("create token"); - let result = verify_access_token(&token, &key_bytes_user2); - assert!( - result.is_err(), - "Token signed by user1's key must not verify with user2's key" - ); -} - -#[test] -fn test_jwt_security_signature_truncation_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let token = create_access_token(did, &key_bytes).expect("create token"); - let parts: Vec<&str> = token.split('.').collect(); - let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); - let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]); - let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig); - let result = verify_access_token(&truncated_token, &key_bytes); - assert!(result.is_err(), "Truncated signature must be rejected"); -} - -#[test] -fn test_jwt_security_signature_extension_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let token = create_access_token(did, &key_bytes).expect("create token"); - let parts: Vec<&str> = token.split('.').collect(); - let mut sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); - sig_bytes.extend_from_slice(&[0u8; 32]); - let extended_sig = URL_SAFE_NO_PAD.encode(&sig_bytes); - let extended_token = format!("{}.{}.{}", parts[0], parts[1], extended_sig); - let result = verify_access_token(&extended_token, &key_bytes); - assert!(result.is_err(), "Extended signature must be rejected"); -} - -#[test] -fn test_jwt_security_malformed_tokens_rejected() { - let key_bytes = generate_user_key(); - let malformed_tokens = vec![ - "", - "not-a-token", - "one.two", - "one.two.three.four", - "....", - "eyJhbGciOiJFUzI1NksifQ", - "eyJhbGciOiJFUzI1NksifQ.", - "eyJhbGciOiJFUzI1NksifQ..", - ".eyJzdWIiOiJ0ZXN0In0.", - "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig", - "eyJhbGciOiJFUzI1NksifQ.!!invalid!!.sig", - ]; - for token in malformed_tokens { - let result = verify_access_token(token, &key_bytes); - assert!( - result.is_err(), - "Malformed token '{}' must be rejected", - if token.len() > 40 { - &token[..40] - } else { - token - } - ); - } -} - -#[test] -fn test_jwt_security_missing_required_claims_rejected() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let test_cases = vec![ - ( - json!({ - "iss": did, - "sub": did, - "aud": "did:web:test", - "iat": Utc::now().timestamp(), - "scope": SCOPE_ACCESS - }), - "exp", - ), - ( - json!({ - "iss": did, - "sub": did, - "aud": "did:web:test", - "exp": Utc::now().timestamp() + 3600, - "scope": SCOPE_ACCESS - }), - "iat", - ), - ( - json!({ - "iss": did, - "aud": "did:web:test", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "scope": SCOPE_ACCESS - }), - "sub", - ), - ]; - for (claims, missing_claim) in test_cases { - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS + for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] { + let claims = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "test", "scope": scope }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!( - result.is_err(), - "Token missing '{}' claim must be rejected", - missing_claim - ); + assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); } + + let refresh_scope = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "test", "scope": SCOPE_REFRESH + }); + assert!(verify_access_token(&create_custom_jwt(&header, &refresh_scope, &key_bytes), &key_bytes).is_err()); } #[test] -fn test_jwt_security_invalid_header_json_rejected() { +fn test_expiration_and_timing() { let key_bytes = generate_user_key(); + let did = "did:plc:test"; + let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); + let now = Utc::now().timestamp(); + + let expired = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS + }); + let result = verify_access_token(&create_custom_jwt(&header, &expired, &key_bytes), &key_bytes); + assert!(result.is_err() && result.err().unwrap().to_string().contains("expired")); + + let future_iat = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS + }); + assert!(verify_access_token(&create_custom_jwt(&header, &future_iat, &key_bytes), &key_bytes).is_ok()); + + let just_expired = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS + }); + assert!(verify_access_token(&create_custom_jwt(&header, &just_expired, &key_bytes), &key_bytes).is_err()); + + let far_future = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS + }); + let _ = verify_access_token(&create_custom_jwt(&header, &far_future, &key_bytes), &key_bytes); + + let negative_iat = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS + }); + let _ = verify_access_token(&create_custom_jwt(&header, &negative_iat, &key_bytes), &key_bytes); +} + +#[test] +fn test_malformed_tokens() { + let key_bytes = generate_user_key(); + + for token in ["", "not-a-token", "one.two", "one.two.three.four", "....", + "eyJhbGciOiJFUzI1NksifQ", "eyJhbGciOiJFUzI1NksifQ.", "eyJhbGciOiJFUzI1NksifQ..", + ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig"] { + assert!(verify_access_token(token, &key_bytes).is_err(), "Malformed token must be rejected"); + } + let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}"); let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#); let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); - let malicious_token = format!("{}.{}.{}", invalid_header, claims_b64, fake_sig); - let result = verify_access_token(&malicious_token, &key_bytes); - assert!(result.is_err(), "Invalid header JSON must be rejected"); -} + assert!(verify_access_token(&format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), &key_bytes).is_err()); -#[test] -fn test_jwt_security_invalid_claims_json_rejected() { - let key_bytes = generate_user_key(); let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#); let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}"); - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); - let malicious_token = format!("{}.{}.{}", header_b64, invalid_claims, fake_sig); - let result = verify_access_token(&malicious_token, &key_bytes); - assert!(result.is_err(), "Invalid claims JSON must be rejected"); + assert!(verify_access_token(&format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), &key_bytes).is_err()); } #[test] -fn test_jwt_security_header_injection_attack() { +fn test_claim_validation() { let key_bytes = generate_user_key(); let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS, - "kid": "../../../../../../etc/passwd", - "jku": "https://attacker.com/keys" - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "header-injection-token", - "scope": SCOPE_ACCESS - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!( - result.is_ok(), - "Extra header fields should not cause issues (we ignore them)" - ); -} + let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); -#[test] -fn test_jwt_security_claims_type_confusion() { - let key_bytes = generate_user_key(); - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS + let missing_exp = json!({ + "iss": did, "sub": did, "aud": "did:web:test", + "iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS }); - let claims = json!({ - "iss": 12345, - "sub": ["did:plc:test"], - "aud": {"url": "did:web:test"}, - "iat": "not a number", - "exp": "also not a number", - "jti": null, - "scope": SCOPE_ACCESS - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!(result.is_err(), "Claims with wrong types must be rejected"); -} + assert!(verify_access_token(&create_custom_jwt(&header, &missing_exp, &key_bytes), &key_bytes).is_err()); -#[test] -fn test_jwt_security_unicode_injection_in_claims() { - let key_bytes = generate_user_key(); - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS + let missing_iat = json!({ + "iss": did, "sub": did, "aud": "did:web:test", + "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS }); - let claims = json!({ - "iss": "did:plc:test\u{0000}attacker", - "sub": "did:plc:test\u{202E}rekatta", - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "unicode-injection", - "scope": SCOPE_ACCESS + assert!(verify_access_token(&create_custom_jwt(&header, &missing_iat, &key_bytes), &key_bytes).is_err()); + + let missing_sub = json!({ + "iss": did, "aud": "did:web:test", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - if result.is_ok() { - let data = result.unwrap(); - assert!( - !data.claims.sub.contains('\0'), - "Null bytes in claims should be sanitized or rejected" - ); + assert!(verify_access_token(&create_custom_jwt(&header, &missing_sub, &key_bytes), &key_bytes).is_err()); + + let wrong_types = json!({ + "iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"}, + "iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS + }); + assert!(verify_access_token(&create_custom_jwt(&header, &wrong_types, &key_bytes), &key_bytes).is_err()); + + let unicode_injection = json!({ + "iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta", + "aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "test", "scope": SCOPE_ACCESS + }); + if let Ok(data) = verify_access_token(&create_custom_jwt(&header, &unicode_injection, &key_bytes), &key_bytes) { + assert!(!data.claims.sub.contains('\0')); } } #[test] -fn test_jwt_security_signature_verification_is_constant_time() { +fn test_did_and_jti_extraction() { + let key_bytes = generate_user_key(); + let did = "did:plc:legitimate"; + let token = create_access_token(did, &key_bytes).expect("create token"); + + assert_eq!(get_did_from_token(&token).unwrap(), did); + assert!(get_did_from_token("invalid").is_err()); + assert!(get_did_from_token("a.b").is_err()); + assert!(get_did_from_token("").is_err()); + + let jti = get_jti_from_token(&token).unwrap(); + assert!(!jti.is_empty()); + assert!(get_jti_from_token("invalid").is_err()); + + let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#); + let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:iss","sub":"did:plc:sub"}"#); + let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); + let unverified = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); + assert_eq!(get_did_from_token(&unverified).unwrap(), "did:plc:sub"); + + let no_jti_claims = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#); + assert!(get_jti_from_token(&format!("{}.{}.{}", header_b64, no_jti_claims, fake_sig)).is_err()); +} + +#[test] +fn test_header_injection_and_constant_time() { let key_bytes = generate_user_key(); let did = "did:plc:test"; + + let header = json!({ + "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS, + "kid": "../../../../../../etc/passwd", "jku": "https://attacker.com/keys" + }); + let claims = json!({ + "iss": did, "sub": did, "aud": "did:web:test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, + "jti": "test", "scope": SCOPE_ACCESS + }); + assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); + let valid_token = create_access_token(did, &key_bytes).expect("create token"); let parts: Vec<&str> = valid_token.split('.').collect(); let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); almost_valid[0] ^= 1; - let almost_valid_sig = URL_SAFE_NO_PAD.encode(&almost_valid); - let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], almost_valid_sig); - let completely_invalid_sig = URL_SAFE_NO_PAD.encode(&[0xFFu8; 64]); - let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], completely_invalid_sig); - let _result1 = verify_access_token(&almost_valid_token, &key_bytes); - let _result2 = verify_access_token(&completely_invalid_token, &key_bytes); - assert!( - true, - "Signature verification should use constant-time comparison (timing attack prevention)" - ); -} - -#[test] -fn test_jwt_security_valid_scopes_accepted() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let valid_scopes = vec![SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]; - for scope in valid_scopes { - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": format!("scope-test-{}", scope), - "scope": scope - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope); - } -} - -#[test] -fn test_jwt_security_refresh_token_scope_rejected_as_access() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "refresh-scope-access-typ", - "scope": SCOPE_REFRESH - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let result = verify_access_token(&token, &key_bytes); - assert!( - result.is_err(), - "Refresh scope with access token type must be rejected" - ); -} - -#[test] -fn test_jwt_security_get_did_extraction_safe() { - let key_bytes = generate_user_key(); - let did = "did:plc:legitimate"; - let token = create_access_token(did, &key_bytes).expect("create token"); - let extracted = get_did_from_token(&token).expect("extract did"); - assert_eq!(extracted, did); - assert!(get_did_from_token("invalid").is_err()); - assert!(get_did_from_token("a.b").is_err()); - assert!(get_did_from_token("").is_err()); - let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#); - let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:iss","sub":"did:plc:sub"}"#); - let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); - let unverified_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); - let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe"); - assert_eq!( - extracted_unsafe, "did:plc:sub", - "get_did_from_token extracts sub without verification (by design for lookup)" - ); -} - -#[test] -fn test_jwt_security_get_jti_extraction_safe() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let token = create_access_token(did, &key_bytes).expect("create token"); - let jti = get_jti_from_token(&token).expect("extract jti"); - assert!(!jti.is_empty()); - assert!(get_jti_from_token("invalid").is_err()); - assert!(get_jti_from_token("a.b").is_err()); - let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#); - let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#); - let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); - let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); - assert!( - get_jti_from_token(&no_jti_token).is_err(), - "Missing jti should error" - ); -} - -#[test] -fn test_jwt_security_key_from_invalid_bytes_rejected() { - let invalid_keys: Vec<&[u8]> = vec![&[], &[0u8; 31], &[0u8; 33], &[0xFFu8; 32]]; - for key in invalid_keys { - let result = create_access_token("did:plc:test", key); - if result.is_ok() { - let token = result.unwrap(); - let verify_result = verify_access_token(&token, key); - if verify_result.is_err() { - continue; - } - } - } -} - -#[test] -fn test_jwt_security_boundary_exp_values() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let now = Utc::now().timestamp(); - let just_expired = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": now - 10, - "exp": now - 1, - "jti": "just-expired", - "scope": SCOPE_ACCESS - }); - let token1 = create_custom_jwt(&header, &just_expired, &key_bytes); - assert!( - verify_access_token(&token1, &key_bytes).is_err(), - "Just expired token must be rejected" - ); - let expires_exactly_now = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": now - 10, - "exp": now, - "jti": "expires-now", - "scope": SCOPE_ACCESS - }); - let token2 = create_custom_jwt(&header, &expires_exactly_now, &key_bytes); - let result2 = verify_access_token(&token2, &key_bytes); - assert!( - result2.is_err() || result2.is_ok(), - "Token expiring exactly now is a boundary case - either behavior is acceptable" - ); -} - -#[test] -fn test_jwt_security_very_long_exp_handled() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": Utc::now().timestamp(), - "exp": i64::MAX, - "jti": "far-future", - "scope": SCOPE_ACCESS - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let _result = verify_access_token(&token, &key_bytes); -} - -#[test] -fn test_jwt_security_negative_timestamps_handled() { - let key_bytes = generate_user_key(); - let did = "did:plc:test"; - let header = json!({ - "alg": "ES256K", - "typ": TOKEN_TYPE_ACCESS - }); - let claims = json!({ - "iss": did, - "sub": did, - "aud": "did:web:test.pds", - "iat": -1000000000i64, - "exp": Utc::now().timestamp() + 3600, - "jti": "negative-iat", - "scope": SCOPE_ACCESS - }); - let token = create_custom_jwt(&header, &claims, &key_bytes); - let _result = verify_access_token(&token, &key_bytes); + let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&almost_valid)); + let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&[0xFFu8; 64])); + let _ = verify_access_token(&almost_valid_token, &key_bytes); + let _ = verify_access_token(&completely_invalid_token, &key_bytes); } #[tokio::test] -async fn test_jwt_security_server_rejects_forged_session_token() { +async fn test_server_rejects_invalid_tokens() { let url = base_url().await; let http_client = client(); + let key_bytes = generate_user_key(); - let did = "did:plc:fake-user"; - let forged_token = create_access_token(did, &key_bytes).expect("create forged token"); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) + let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap(); + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) .header("Authorization", format!("Bearer {}", forged_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Forged session token must be rejected" - ); -} + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged token must be rejected"); -#[tokio::test] -async fn test_jwt_security_server_rejects_expired_token() { - let url = base_url().await; - let http_client = client(); let (access_jwt, _did) = create_account_and_login(&http_client).await; let parts: Vec<&str> = access_jwt.split('.').collect(); let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); + payload["exp"] = json!(Utc::now().timestamp() - 3600); - let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let tampered_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) + let expired_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), parts[2]); + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", format!("Bearer {}", expired_token)) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); + tampered_payload["sub"] = json!("did:plc:attacker"); + tampered_payload["iss"] = json!("did:plc:attacker"); + let tampered_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), parts[2]); + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) .header("Authorization", format!("Bearer {}", tampered_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Tampered/expired token must be rejected" - ); + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] -async fn test_jwt_security_server_rejects_tampered_did() { +async fn test_authorization_header_formats() { let url = base_url().await; let http_client = client(); let (access_jwt, _did) = create_account_and_login(&http_client).await; - let parts: Vec<&str> = access_jwt.split('.').collect(); - let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); - let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); - payload["sub"] = json!("did:plc:attacker"); - payload["iss"] = json!("did:plc:attacker"); - let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let tampered_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", tampered_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "DID-tampered token must be rejected" - ); + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", format!("Bearer {}", access_jwt)) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", format!("bearer {}", access_jwt)) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", format!("Basic {}", access_jwt)) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", &access_jwt) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", "Bearer ") + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] -async fn test_jwt_security_refresh_token_replay_protection() { +async fn test_session_lifecycle_security() { + let url = base_url().await; + let http_client = client(); + let (access_jwt, _did) = create_account_and_login(&http_client).await; + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", format!("Bearer {}", access_jwt)) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let logout = http_client.post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) + .header("Authorization", format!("Bearer {}", access_jwt)) + .send().await.unwrap(); + assert_eq!(logout.status(), StatusCode::OK); + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", format!("Bearer {}", access_jwt)) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_deactivated_account_rejected() { + let url = base_url().await; + let http_client = client(); + let (access_jwt, _did) = create_account_and_login(&http_client).await; + + let deact = http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) + .header("Authorization", format!("Bearer {}", access_jwt)) + .json(&json!({})) + .send().await.unwrap(); + assert_eq!(deact.status(), StatusCode::OK); + + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", format!("Bearer {}", access_jwt)) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + let body: Value = res.json().await.unwrap(); + assert_eq!(body["error"], "AccountDeactivated"); +} + +#[tokio::test] +async fn test_refresh_token_replay_protection() { let url = base_url().await; let http_client = client(); let ts = Utc::now().timestamp_millis(); let handle = format!("rt-replay-jwt-{}", ts); let email = format!("rt-replay-jwt-{}@example.com", ts); - let password = "test-password-123"; - let create_res = http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); + + let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle, "email": email, "password": "test-password-123" })) + .send().await.unwrap(); assert_eq!(create_res.status(), StatusCode::OK); let account: Value = create_res.json().await.unwrap(); let did = account["did"].as_str().unwrap(); - let conn_str = get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new() .max_connections(2) - .connect(&conn_str) - .await - .expect("Failed to connect to test database"); - let verification_code: String = sqlx::query_scalar!( + .connect(&get_db_connection_string().await) + .await.unwrap(); + let code: String = sqlx::query_scalar!( "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'", did - ) - .fetch_one(&pool) - .await - .expect("Failed to get verification code"); - let confirm_res = http_client - .post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) - .json(&json!({ - "did": did, - "verificationCode": verification_code - })) - .send() - .await - .unwrap(); - assert_eq!(confirm_res.status(), StatusCode::OK); - let confirmed: Value = confirm_res.json().await.unwrap(); + ).fetch_one(&pool).await.unwrap(); + + let confirm = http_client.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) + .json(&json!({ "did": did, "verificationCode": code })) + .send().await.unwrap(); + assert_eq!(confirm.status(), StatusCode::OK); + let confirmed: Value = confirm.json().await.unwrap(); let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string(); - let first_refresh = http_client - .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) + + let first = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) .header("Authorization", format!("Bearer {}", refresh_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - first_refresh.status(), - StatusCode::OK, - "First refresh should succeed" - ); - let replay_res = http_client - .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) + .send().await.unwrap(); + assert_eq!(first.status(), StatusCode::OK); + + let replay = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) .header("Authorization", format!("Bearer {}", refresh_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - replay_res.status(), - StatusCode::UNAUTHORIZED, - "Refresh token replay must be rejected" - ); -} - -#[tokio::test] -async fn test_jwt_security_authorization_header_formats() { - let url = base_url().await; - let http_client = client(); - let (access_jwt, _did) = create_account_and_login(&http_client).await; - let valid_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", access_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - valid_res.status(), - StatusCode::OK, - "Valid Bearer format should work" - ); - let lowercase_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("bearer {}", access_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - lowercase_res.status(), - StatusCode::OK, - "Lowercase 'bearer' should be accepted (RFC 7235 case-insensitivity)" - ); - let basic_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Basic {}", access_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - basic_res.status(), - StatusCode::UNAUTHORIZED, - "Basic scheme must be rejected" - ); - let no_scheme_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", &access_jwt) - .send() - .await - .unwrap(); - assert_eq!( - no_scheme_res.status(), - StatusCode::UNAUTHORIZED, - "Missing scheme must be rejected" - ); - let empty_token_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", "Bearer ") - .send() - .await - .unwrap(); - assert_eq!( - empty_token_res.status(), - StatusCode::UNAUTHORIZED, - "Empty token must be rejected" - ); -} - -#[tokio::test] -async fn test_jwt_security_deleted_session_rejected() { - let url = base_url().await; - let http_client = client(); - let (access_jwt, _did) = create_account_and_login(&http_client).await; - let get_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", access_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - get_res.status(), - StatusCode::OK, - "Token should work before logout" - ); - let logout_res = http_client - .post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) - .header("Authorization", format!("Bearer {}", access_jwt)) - .send() - .await - .unwrap(); - assert_eq!(logout_res.status(), StatusCode::OK); - let after_logout_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", access_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - after_logout_res.status(), - StatusCode::UNAUTHORIZED, - "Token must be rejected after logout" - ); -} - -#[tokio::test] -async fn test_jwt_security_deactivated_account_rejected() { - let url = base_url().await; - let http_client = client(); - let (access_jwt, _did) = create_account_and_login(&http_client).await; - let deact_res = http_client - .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) - .header("Authorization", format!("Bearer {}", access_jwt)) - .json(&json!({})) - .send() - .await - .unwrap(); - assert_eq!(deact_res.status(), StatusCode::OK); - let get_res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", access_jwt)) - .send() - .await - .unwrap(); - assert_eq!( - get_res.status(), - StatusCode::UNAUTHORIZED, - "Deactivated account token must be rejected" - ); - let body: Value = get_res.json().await.unwrap(); - assert_eq!(body["error"], "AccountDeactivated"); + .send().await.unwrap(); + assert_eq!(replay.status(), StatusCode::UNAUTHORIZED); } diff --git a/tests/lifecycle_record.rs b/tests/lifecycle_record.rs index 5e29f94..46cdde5 100644 --- a/tests/lifecycle_record.rs +++ b/tests/lifecycle_record.rs @@ -8,7 +8,7 @@ use serde_json::{Value, json}; use std::time::Duration; #[tokio::test] -async fn test_post_crud_lifecycle() { +async fn test_record_crud_lifecycle() { let client = client(); let (did, jwt) = setup_new_user("lifecycle-crud").await; let collection = "app.bsky.feed.post"; @@ -26,50 +26,24 @@ async fn test_post_crud_lifecycle() { } }); let create_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) .bearer_auth(&jwt) .json(&create_payload) .send() .await .expect("Failed to send create request"); - if create_res.status() != reqwest::StatusCode::OK { - let status = create_res.status(); - let body = create_res - .text() - .await - .unwrap_or_else(|_| "Could not get body".to_string()); - panic!( - "Failed to create record. Status: {}, Body: {}", - status, body - ); - } - let create_body: Value = create_res - .json() - .await - .expect("create response was not JSON"); + assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record"); + let create_body: Value = create_res.json().await.expect("create response was not JSON"); let uri = create_body["uri"].as_str().unwrap(); - let params = [ - ("repo", did.as_str()), - ("collection", collection), - ("rkey", &rkey), - ]; + let initial_cid = create_body["cid"].as_str().unwrap().to_string(); + let params = [("repo", did.as_str()), ("collection", collection), ("rkey", &rkey)]; let get_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) .query(¶ms) .send() .await .expect("Failed to send get request"); - assert_eq!( - get_res.status(), - reqwest::StatusCode::OK, - "Failed to get record after create" - ); + assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create"); let get_body: Value = get_res.json().await.expect("get response was not JSON"); assert_eq!(get_body["uri"], uri); assert_eq!(get_body["value"]["text"], original_text); @@ -78,416 +52,82 @@ async fn test_post_crud_lifecycle() { "repo": did, "collection": collection, "rkey": rkey, - "record": { - "$type": collection, - "text": updated_text, - "createdAt": now - } + "record": { "$type": collection, "text": updated_text, "createdAt": now }, + "swapRecord": initial_cid }); let update_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) .bearer_auth(&jwt) .json(&update_payload) .send() .await .expect("Failed to send update request"); - assert_eq!( - update_res.status(), - reqwest::StatusCode::OK, - "Failed to update record" - ); + assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record"); + let update_body: Value = update_res.json().await.expect("update response was not JSON"); + let updated_cid = update_body["cid"].as_str().unwrap().to_string(); let get_updated_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) .query(¶ms) .send() .await .expect("Failed to send get-after-update request"); - assert_eq!( - get_updated_res.status(), - reqwest::StatusCode::OK, - "Failed to get record after update" - ); - let get_updated_body: Value = get_updated_res - .json() - .await - .expect("get-updated response was not JSON"); - assert_eq!( - get_updated_body["value"]["text"], updated_text, - "Text was not updated" - ); - let delete_payload = json!({ + let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON"); + assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated"); + let stale_update_payload = json!({ "repo": did, "collection": collection, - "rkey": rkey + "rkey": rkey, + "record": { "$type": collection, "text": "Stale update", "createdAt": now }, + "swapRecord": initial_cid }); + let stale_res = client + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + .bearer_auth(&jwt) + .json(&stale_update_payload) + .send() + .await + .expect("Failed to send stale update"); + assert_eq!(stale_res.status(), StatusCode::CONFLICT, "Stale update should cause 409"); + let good_update_payload = json!({ + "repo": did, + "collection": collection, + "rkey": rkey, + "record": { "$type": collection, "text": "Good update", "createdAt": now }, + "swapRecord": updated_cid + }); + let good_res = client + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + .bearer_auth(&jwt) + .json(&good_update_payload) + .send() + .await + .expect("Failed to send good update"); + assert_eq!(good_res.status(), StatusCode::OK, "Good update should succeed"); + let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey }); let delete_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.deleteRecord", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) .bearer_auth(&jwt) .json(&delete_payload) .send() .await .expect("Failed to send delete request"); - assert_eq!( - delete_res.status(), - reqwest::StatusCode::OK, - "Failed to delete record" - ); + assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record"); let get_deleted_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) .query(¶ms) .send() .await .expect("Failed to send get-after-delete request"); - assert_eq!( - get_deleted_res.status(), - reqwest::StatusCode::NOT_FOUND, - "Record was found, but it should be deleted" - ); + assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record should be deleted"); } #[tokio::test] -async fn test_record_update_conflict_lifecycle() { +async fn test_profile_with_blob_lifecycle() { let client = client(); - let (user_did, user_jwt) = setup_new_user("user-conflict").await; - let profile_payload = json!({ - "repo": user_did, - "collection": "app.bsky.actor.profile", - "rkey": "self", - "record": { - "$type": "app.bsky.actor.profile", - "displayName": "Original Name" - } - }); - let create_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&user_jwt) - .json(&profile_payload) - .send() - .await - .expect("create profile failed"); - if create_res.status() != reqwest::StatusCode::OK { - return; - } - let get_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", &user_did), - ("collection", &"app.bsky.actor.profile".to_string()), - ("rkey", &"self".to_string()), - ]) - .send() - .await - .expect("getRecord failed"); - let get_body: Value = get_res.json().await.expect("getRecord not json"); - let cid_v1 = get_body["cid"] - .as_str() - .expect("Profile v1 had no CID") - .to_string(); - let update_payload_v2 = json!({ - "repo": user_did, - "collection": "app.bsky.actor.profile", - "rkey": "self", - "record": { - "$type": "app.bsky.actor.profile", - "displayName": "Updated Name (v2)" - }, - "swapRecord": cid_v1 - }); - let update_res_v2 = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&user_jwt) - .json(&update_payload_v2) - .send() - .await - .expect("putRecord v2 failed"); - assert_eq!( - update_res_v2.status(), - reqwest::StatusCode::OK, - "v2 update failed" - ); - let update_body_v2: Value = update_res_v2.json().await.expect("v2 body not json"); - let cid_v2 = update_body_v2["cid"] - .as_str() - .expect("v2 response had no CID") - .to_string(); - let update_payload_v3_stale = json!({ - "repo": user_did, - "collection": "app.bsky.actor.profile", - "rkey": "self", - "record": { - "$type": "app.bsky.actor.profile", - "displayName": "Stale Update (v3)" - }, - "swapRecord": cid_v1 - }); - let update_res_v3_stale = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&user_jwt) - .json(&update_payload_v3_stale) - .send() - .await - .expect("putRecord v3 (stale) failed"); - assert_eq!( - update_res_v3_stale.status(), - reqwest::StatusCode::CONFLICT, - "Stale update did not cause a 409 Conflict" - ); - let update_payload_v3_good = json!({ - "repo": user_did, - "collection": "app.bsky.actor.profile", - "rkey": "self", - "record": { - "$type": "app.bsky.actor.profile", - "displayName": "Good Update (v3)" - }, - "swapRecord": cid_v2 - }); - let update_res_v3_good = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&user_jwt) - .json(&update_payload_v3_good) - .send() - .await - .expect("putRecord v3 (good) failed"); - assert_eq!( - update_res_v3_good.status(), - reqwest::StatusCode::OK, - "v3 (good) update failed" - ); -} - -#[tokio::test] -async fn test_profile_lifecycle() { - let client = client(); - let (did, jwt) = setup_new_user("profile-lifecycle").await; - let profile_payload = json!({ - "repo": did, - "collection": "app.bsky.actor.profile", - "rkey": "self", - "record": { - "$type": "app.bsky.actor.profile", - "displayName": "Test User", - "description": "A test profile for lifecycle testing" - } - }); - let create_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&jwt) - .json(&profile_payload) - .send() - .await - .expect("Failed to create profile"); - assert_eq!( - create_res.status(), - StatusCode::OK, - "Failed to create profile" - ); - let create_body: Value = create_res.json().await.unwrap(); - let initial_cid = create_body["cid"].as_str().unwrap().to_string(); - let get_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.actor.profile"), - ("rkey", "self"), - ]) - .send() - .await - .expect("Failed to get profile"); - assert_eq!(get_res.status(), StatusCode::OK); - let get_body: Value = get_res.json().await.unwrap(); - assert_eq!(get_body["value"]["displayName"], "Test User"); - assert_eq!( - get_body["value"]["description"], - "A test profile for lifecycle testing" - ); - let update_payload = json!({ - "repo": did, - "collection": "app.bsky.actor.profile", - "rkey": "self", - "record": { - "$type": "app.bsky.actor.profile", - "displayName": "Updated User", - "description": "Profile has been updated" - }, - "swapRecord": initial_cid - }); - let update_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&jwt) - .json(&update_payload) - .send() - .await - .expect("Failed to update profile"); - assert_eq!( - update_res.status(), - StatusCode::OK, - "Failed to update profile" - ); - let get_updated_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.actor.profile"), - ("rkey", "self"), - ]) - .send() - .await - .expect("Failed to get updated profile"); - let updated_body: Value = get_updated_res.json().await.unwrap(); - assert_eq!(updated_body["value"]["displayName"], "Updated User"); -} - -#[tokio::test] -async fn test_reply_thread_lifecycle() { - let client = client(); - let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; - let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; - let (root_uri, root_cid) = - create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; - tokio::time::sleep(Duration::from_millis(100)).await; - let reply_collection = "app.bsky.feed.post"; - let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); - let now = Utc::now().to_rfc3339(); - let reply_payload = json!({ - "repo": bob_did, - "collection": reply_collection, - "rkey": reply_rkey, - "record": { - "$type": reply_collection, - "text": "This is Bob's reply to Alice", - "createdAt": now, - "reply": { - "root": { - "uri": root_uri, - "cid": root_cid - }, - "parent": { - "uri": root_uri, - "cid": root_cid - } - } - } - }); - let reply_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&bob_jwt) - .json(&reply_payload) - .send() - .await - .expect("Failed to create reply"); - assert_eq!(reply_res.status(), StatusCode::OK, "Failed to create reply"); - let reply_body: Value = reply_res.json().await.unwrap(); - let reply_uri = reply_body["uri"].as_str().unwrap(); - let reply_cid = reply_body["cid"].as_str().unwrap(); - let get_reply_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", bob_did.as_str()), - ("collection", reply_collection), - ("rkey", reply_rkey.as_str()), - ]) - .send() - .await - .expect("Failed to get reply"); - assert_eq!(get_reply_res.status(), StatusCode::OK); - let reply_record: Value = get_reply_res.json().await.unwrap(); - assert_eq!(reply_record["value"]["reply"]["root"]["uri"], root_uri); - assert_eq!(reply_record["value"]["reply"]["parent"]["uri"], root_uri); - tokio::time::sleep(Duration::from_millis(100)).await; - let nested_reply_rkey = format!("e2e_nested_reply_{}", Utc::now().timestamp_millis()); - let nested_payload = json!({ - "repo": alice_did, - "collection": reply_collection, - "rkey": nested_reply_rkey, - "record": { - "$type": reply_collection, - "text": "Alice replies to Bob's reply", - "createdAt": Utc::now().to_rfc3339(), - "reply": { - "root": { - "uri": root_uri, - "cid": root_cid - }, - "parent": { - "uri": reply_uri, - "cid": reply_cid - } - } - } - }); - let nested_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) - .bearer_auth(&alice_jwt) - .json(&nested_payload) - .send() - .await - .expect("Failed to create nested reply"); - assert_eq!( - nested_res.status(), - StatusCode::OK, - "Failed to create nested reply" - ); -} - -#[tokio::test] -async fn test_blob_in_record_lifecycle() { - let client = client(); - let (did, jwt) = setup_new_user("blob-record").await; + let (did, jwt) = setup_new_user("profile-blob").await; let blob_data = b"This is test blob data for a profile avatar"; let upload_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.uploadBlob", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) .header(header::CONTENT_TYPE, "text/plain") .bearer_auth(&jwt) .body(blob_data.to_vec()) @@ -503,166 +143,181 @@ async fn test_blob_in_record_lifecycle() { "rkey": "self", "record": { "$type": "app.bsky.actor.profile", - "displayName": "User With Avatar", + "displayName": "Test User", + "description": "A test profile for lifecycle testing", "avatar": blob_ref } }); let create_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) .bearer_auth(&jwt) .json(&profile_payload) .send() .await - .expect("Failed to create profile with blob"); - assert_eq!( - create_res.status(), - StatusCode::OK, - "Failed to create profile with blob" - ); + .expect("Failed to create profile"); + assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile"); + let create_body: Value = create_res.json().await.unwrap(); + let initial_cid = create_body["cid"].as_str().unwrap().to_string(); let get_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.actor.profile"), - ("rkey", "self"), - ]) + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) .send() .await .expect("Failed to get profile"); assert_eq!(get_res.status(), StatusCode::OK); - let profile: Value = get_res.json().await.unwrap(); - assert!(profile["value"]["avatar"]["ref"]["$link"].is_string()); + let get_body: Value = get_res.json().await.unwrap(); + assert_eq!(get_body["value"]["displayName"], "Test User"); + assert!(get_body["value"]["avatar"]["ref"]["$link"].is_string()); + let update_payload = json!({ + "repo": did, + "collection": "app.bsky.actor.profile", + "rkey": "self", + "record": { "$type": "app.bsky.actor.profile", "displayName": "Updated User", "description": "Profile updated" }, + "swapRecord": initial_cid + }); + let update_res = client + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + .bearer_auth(&jwt) + .json(&update_payload) + .send() + .await + .expect("Failed to update profile"); + assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile"); + let get_updated_res = client + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) + .send() + .await + .expect("Failed to get updated profile"); + let updated_body: Value = get_updated_res.json().await.unwrap(); + assert_eq!(updated_body["value"]["displayName"], "Updated User"); } #[tokio::test] -async fn test_authorization_cannot_modify_other_repo() { +async fn test_reply_thread_lifecycle() { let client = client(); - let (alice_did, _alice_jwt) = setup_new_user("alice-auth").await; + let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; + let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; + let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; + tokio::time::sleep(Duration::from_millis(100)).await; + let reply_collection = "app.bsky.feed.post"; + let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); + let reply_payload = json!({ + "repo": bob_did, + "collection": reply_collection, + "rkey": reply_rkey, + "record": { + "$type": reply_collection, + "text": "This is Bob's reply to Alice", + "createdAt": Utc::now().to_rfc3339(), + "reply": { + "root": { "uri": root_uri, "cid": root_cid }, + "parent": { "uri": root_uri, "cid": root_cid } + } + } + }); + let reply_res = client + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + .bearer_auth(&bob_jwt) + .json(&reply_payload) + .send() + .await + .expect("Failed to create reply"); + assert_eq!(reply_res.status(), StatusCode::OK, "Failed to create reply"); + let reply_body: Value = reply_res.json().await.unwrap(); + let reply_uri = reply_body["uri"].as_str().unwrap(); + let reply_cid = reply_body["cid"].as_str().unwrap(); + let get_reply_res = client + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", bob_did.as_str()), ("collection", reply_collection), ("rkey", reply_rkey.as_str())]) + .send() + .await + .expect("Failed to get reply"); + assert_eq!(get_reply_res.status(), StatusCode::OK); + let reply_record: Value = get_reply_res.json().await.unwrap(); + assert_eq!(reply_record["value"]["reply"]["root"]["uri"], root_uri); + tokio::time::sleep(Duration::from_millis(100)).await; + let nested_reply_rkey = format!("e2e_nested_reply_{}", Utc::now().timestamp_millis()); + let nested_payload = json!({ + "repo": alice_did, + "collection": reply_collection, + "rkey": nested_reply_rkey, + "record": { + "$type": reply_collection, + "text": "Alice replies to Bob's reply", + "createdAt": Utc::now().to_rfc3339(), + "reply": { + "root": { "uri": root_uri, "cid": root_cid }, + "parent": { "uri": reply_uri, "cid": reply_cid } + } + } + }); + let nested_res = client + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + .bearer_auth(&alice_jwt) + .json(&nested_payload) + .send() + .await + .expect("Failed to create nested reply"); + assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); +} + +#[tokio::test] +async fn test_authorization_protects_repos() { + let client = client(); + let (alice_did, alice_jwt) = setup_new_user("alice-auth").await; let (_bob_did, bob_jwt) = setup_new_user("bob-auth").await; + let (post_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's post").await; + let post_rkey = post_uri.split('/').last().unwrap(); let post_payload = json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": "unauthorized-post", - "record": { - "$type": "app.bsky.feed.post", - "text": "Bob trying to post as Alice", - "createdAt": Utc::now().to_rfc3339() - } + "record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() } }); - let res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) + let write_res = client + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) .bearer_auth(&bob_jwt) .json(&post_payload) .send() .await .expect("Failed to send request"); - assert!( - res.status() == StatusCode::FORBIDDEN || res.status() == StatusCode::UNAUTHORIZED, - "Expected 403 or 401 when writing to another user's repo, got {}", - res.status() - ); -} - -#[tokio::test] -async fn test_authorization_cannot_delete_other_record() { - let client = client(); - let (alice_did, alice_jwt) = setup_new_user("alice-del-auth").await; - let (_bob_did, bob_jwt) = setup_new_user("bob-del-auth").await; - let (post_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's post").await; - let post_rkey = post_uri.split('/').last().unwrap(); - let delete_payload = json!({ - "repo": alice_did, - "collection": "app.bsky.feed.post", - "rkey": post_rkey - }); - let res = client - .post(format!( - "{}/xrpc/com.atproto.repo.deleteRecord", - base_url().await - )) + assert!(write_res.status() == StatusCode::FORBIDDEN || write_res.status() == StatusCode::UNAUTHORIZED, + "Expected 403/401 for writing to another user's repo, got {}", write_res.status()); + let delete_payload = json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey }); + let delete_res = client + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) .bearer_auth(&bob_jwt) .json(&delete_payload) .send() .await .expect("Failed to send request"); - assert!( - res.status() == StatusCode::FORBIDDEN || res.status() == StatusCode::UNAUTHORIZED, - "Expected 403 or 401 when deleting another user's record, got {}", - res.status() - ); + assert!(delete_res.status() == StatusCode::FORBIDDEN || delete_res.status() == StatusCode::UNAUTHORIZED, + "Expected 403/401 for deleting another user's record, got {}", delete_res.status()); let get_res = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", alice_did.as_str()), - ("collection", "app.bsky.feed.post"), - ("rkey", post_rkey), - ]) + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", alice_did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", post_rkey)]) .send() .await .expect("Failed to verify record exists"); - assert_eq!( - get_res.status(), - StatusCode::OK, - "Record should still exist" - ); + assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); } #[tokio::test] -async fn test_apply_writes_batch_lifecycle() { +async fn test_apply_writes_batch() { let client = client(); let (did, jwt) = setup_new_user("apply-writes-batch").await; let now = Utc::now().to_rfc3339(); let writes_payload = json!({ "repo": did, "writes": [ - { - "$type": "com.atproto.repo.applyWrites#create", - "collection": "app.bsky.feed.post", - "rkey": "batch-post-1", - "value": { - "$type": "app.bsky.feed.post", - "text": "First batch post", - "createdAt": now - } - }, - { - "$type": "com.atproto.repo.applyWrites#create", - "collection": "app.bsky.feed.post", - "rkey": "batch-post-2", - "value": { - "$type": "app.bsky.feed.post", - "text": "Second batch post", - "createdAt": now - } - }, - { - "$type": "com.atproto.repo.applyWrites#create", - "collection": "app.bsky.actor.profile", - "rkey": "self", - "value": { - "$type": "app.bsky.actor.profile", - "displayName": "Batch User" - } - } + { "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.feed.post", "rkey": "batch-post-1", "value": { "$type": "app.bsky.feed.post", "text": "First batch post", "createdAt": now } }, + { "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.feed.post", "rkey": "batch-post-2", "value": { "$type": "app.bsky.feed.post", "text": "Second batch post", "createdAt": now } }, + { "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.actor.profile", "rkey": "self", "value": { "$type": "app.bsky.actor.profile", "displayName": "Batch User" } } ] }); let apply_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.applyWrites", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) .bearer_auth(&jwt) .json(&writes_payload) .send() @@ -670,75 +325,32 @@ async fn test_apply_writes_batch_lifecycle() { .expect("Failed to apply writes"); assert_eq!(apply_res.status(), StatusCode::OK); let get_post1 = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("rkey", "batch-post-1"), - ]) - .send() - .await - .expect("Failed to get post 1"); + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) + .send().await.expect("Failed to get post 1"); assert_eq!(get_post1.status(), StatusCode::OK); let post1_body: Value = get_post1.json().await.unwrap(); assert_eq!(post1_body["value"]["text"], "First batch post"); let get_post2 = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("rkey", "batch-post-2"), - ]) - .send() - .await - .expect("Failed to get post 2"); + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-2")]) + .send().await.expect("Failed to get post 2"); assert_eq!(get_post2.status(), StatusCode::OK); let get_profile = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.actor.profile"), - ("rkey", "self"), - ]) - .send() - .await - .expect("Failed to get profile"); - assert_eq!(get_profile.status(), StatusCode::OK); + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) + .send().await.expect("Failed to get profile"); let profile_body: Value = get_profile.json().await.unwrap(); assert_eq!(profile_body["value"]["displayName"], "Batch User"); let update_writes = json!({ "repo": did, "writes": [ - { - "$type": "com.atproto.repo.applyWrites#update", - "collection": "app.bsky.actor.profile", - "rkey": "self", - "value": { - "$type": "app.bsky.actor.profile", - "displayName": "Updated Batch User" - } - }, - { - "$type": "com.atproto.repo.applyWrites#delete", - "collection": "app.bsky.feed.post", - "rkey": "batch-post-1" - } + { "$type": "com.atproto.repo.applyWrites#update", "collection": "app.bsky.actor.profile", "rkey": "self", "value": { "$type": "app.bsky.actor.profile", "displayName": "Updated Batch User" } }, + { "$type": "com.atproto.repo.applyWrites#delete", "collection": "app.bsky.feed.post", "rkey": "batch-post-1" } ] }); let update_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.applyWrites", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) .bearer_auth(&jwt) .json(&update_writes) .send() @@ -746,65 +358,25 @@ async fn test_apply_writes_batch_lifecycle() { .expect("Failed to apply update writes"); assert_eq!(update_res.status(), StatusCode::OK); let get_updated_profile = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.actor.profile"), - ("rkey", "self"), - ]) - .send() - .await - .expect("Failed to get updated profile"); + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) + .send().await.expect("Failed to get updated profile"); let updated_profile: Value = get_updated_profile.json().await.unwrap(); - assert_eq!( - updated_profile["value"]["displayName"], - "Updated Batch User" - ); + assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User"); let get_deleted_post = client - .get(format!( - "{}/xrpc/com.atproto.repo.getRecord", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("rkey", "batch-post-1"), - ]) - .send() - .await - .expect("Failed to check deleted post"); - assert_eq!( - get_deleted_post.status(), - StatusCode::NOT_FOUND, - "Batch-deleted post should be gone" - ); + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) + .send().await.expect("Failed to check deleted post"); + assert_eq!(get_deleted_post.status(), StatusCode::NOT_FOUND, "Batch-deleted post should be gone"); } -async fn create_post_with_rkey( - client: &reqwest::Client, - did: &str, - jwt: &str, - rkey: &str, - text: &str, -) -> (String, String) { +async fn create_post_with_rkey(client: &reqwest::Client, did: &str, jwt: &str, rkey: &str, text: &str) -> (String, String) { let payload = json!({ - "repo": did, - "collection": "app.bsky.feed.post", - "rkey": rkey, - "record": { - "$type": "app.bsky.feed.post", - "text": text, - "createdAt": Utc::now().to_rfc3339() - } + "repo": did, "collection": "app.bsky.feed.post", "rkey": rkey, + "record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() } }); let res = client - .post(format!( - "{}/xrpc/com.atproto.repo.putRecord", - base_url().await - )) + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) .bearer_auth(jwt) .json(&payload) .send() @@ -812,522 +384,80 @@ async fn create_post_with_rkey( .expect("Failed to create record"); assert_eq!(res.status(), StatusCode::OK); let body: Value = res.json().await.unwrap(); - ( - body["uri"].as_str().unwrap().to_string(), - body["cid"].as_str().unwrap().to_string(), - ) + (body["uri"].as_str().unwrap().to_string(), body["cid"].as_str().unwrap().to_string()) } #[tokio::test] -async fn test_list_records_default_order() { +async fn test_list_records_comprehensive() { let client = client(); - let (did, jwt) = setup_new_user("list-default-order").await; - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; - tokio::time::sleep(Duration::from_millis(50)).await; - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; - tokio::time::sleep(Duration::from_millis(50)).await; - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - assert_eq!(records.len(), 3); - let rkeys: Vec<&str> = records - .iter() - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) - .collect(); - assert_eq!( - rkeys, - vec!["cccc", "bbbb", "aaaa"], - "Default order should be DESC (newest first)" - ); -} - -#[tokio::test] -async fn test_list_records_reverse_true() { - let client = client(); - let (did, jwt) = setup_new_user("list-reverse").await; - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; - tokio::time::sleep(Duration::from_millis(50)).await; - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; - tokio::time::sleep(Duration::from_millis(50)).await; - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("reverse", "true"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - let rkeys: Vec<&str> = records - .iter() - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) - .collect(); - assert_eq!( - rkeys, - vec!["aaaa", "bbbb", "cccc"], - "reverse=true should give ASC order (oldest first)" - ); -} - -#[tokio::test] -async fn test_list_records_cursor_pagination() { - let client = client(); - let (did, jwt) = setup_new_user("list-cursor").await; + let (did, jwt) = setup_new_user("list-records-test").await; for i in 0..5 { - create_post_with_rkey( - &client, - &did, - &jwt, - &format!("post{:02}", i), - &format!("Post {}", i), - ) - .await; + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; tokio::time::sleep(Duration::from_millis(50)).await; } let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "2"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - assert_eq!(records.len(), 2); - let cursor = body["cursor"] - .as_str() - .expect("Should have cursor with more records"); - let res2 = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "2"), - ("cursor", cursor), - ]) - .send() - .await - .expect("Failed to list records with cursor"); - assert_eq!(res2.status(), StatusCode::OK); - let body2: Value = res2.json().await.unwrap(); - let records2 = body2["records"].as_array().unwrap(); - assert_eq!(records2.len(), 2); - let all_uris: Vec<&str> = records - .iter() - .chain(records2.iter()) - .map(|r| r["uri"].as_str().unwrap()) - .collect(); - let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); - assert_eq!( - all_uris.len(), - unique_uris.len(), - "Cursor pagination should not repeat records" - ); -} - -#[tokio::test] -async fn test_list_records_rkey_start() { - let client = client(); - let (did, jwt) = setup_new_user("list-rkey-start").await; - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("rkeyStart", "bbbb"), - ("reverse", "true"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - let rkeys: Vec<&str> = records - .iter() - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) - .collect(); - for rkey in &rkeys { - assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); - } -} - -#[tokio::test] -async fn test_list_records_rkey_end() { - let client = client(); - let (did, jwt) = setup_new_user("list-rkey-end").await; - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("rkeyEnd", "cccc"), - ("reverse", "true"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - let rkeys: Vec<&str> = records - .iter() - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) - .collect(); - for rkey in &rkeys { - assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); - } -} - -#[tokio::test] -async fn test_list_records_rkey_range() { - let client = client(); - let (did, jwt) = setup_new_user("list-rkey-range").await; - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; - create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("rkeyStart", "bbbb"), - ("rkeyEnd", "dddd"), - ("reverse", "true"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - let rkeys: Vec<&str> = records - .iter() - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) - .collect(); - for rkey in &rkeys { - assert!( - *rkey >= "bbbb" && *rkey <= "dddd", - "Range should be inclusive, got {}", - rkey - ); - } - assert!( - !rkeys.is_empty(), - "Should have at least some records in range" - ); -} - -#[tokio::test] -async fn test_list_records_limit_clamping_max() { - let client = client(); - let (did, jwt) = setup_new_user("list-limit-max").await; - for i in 0..5 { - create_post_with_rkey( - &client, - &did, - &jwt, - &format!("post{:02}", i), - &format!("Post {}", i), - ) - .await; - } - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "1000"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - assert!(records.len() <= 100, "Limit should be clamped to max 100"); -} - -#[tokio::test] -async fn test_list_records_limit_clamping_min() { - let client = client(); - let (did, jwt) = setup_new_user("list-limit-min").await; - create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "0"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - assert!(records.len() >= 1, "Limit should be clamped to min 1"); -} - -#[tokio::test] -async fn test_list_records_empty_collection() { - let client = client(); - let (did, _jwt) = setup_new_user("list-empty").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - assert!( - records.is_empty(), - "Empty collection should return empty array" - ); - assert!( - body["cursor"].is_null(), - "Empty collection should have no cursor" - ); -} - -#[tokio::test] -async fn test_list_records_exact_limit() { - let client = client(); - let (did, jwt) = setup_new_user("list-exact-limit").await; - for i in 0..10 { - create_post_with_rkey( - &client, - &did, - &jwt, - &format!("post{:02}", i), - &format!("Post {}", i), - ) - .await; - } - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "5"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - assert_eq!( - records.len(), - 5, - "Should return exactly 5 records when limit=5" - ); -} - -#[tokio::test] -async fn test_list_records_cursor_exhaustion() { - let client = client(); - let (did, jwt) = setup_new_user("list-cursor-exhaust").await; - for i in 0..3 { - create_post_with_rkey( - &client, - &did, - &jwt, - &format!("post{:02}", i), - &format!("Post {}", i), - ) - .await; - } - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "10"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - assert_eq!(records.len(), 3); -} - -#[tokio::test] -async fn test_list_records_repo_not_found() { - let client = client(); - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", "did:plc:nonexistent12345"), - ("collection", "app.bsky.feed.post"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::NOT_FOUND); -} - -#[tokio::test] -async fn test_list_records_includes_cid() { - let client = client(); - let (did, jwt) = setup_new_user("list-includes-cid").await; - create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await; - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) - .send() - .await - .expect("Failed to list records"); + .send().await.expect("Failed to list records"); assert_eq!(res.status(), StatusCode::OK); let body: Value = res.json().await.unwrap(); let records = body["records"].as_array().unwrap(); + assert_eq!(records.len(), 5); + let rkeys: Vec<&str> = records.iter().map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); + assert_eq!(rkeys, vec!["post04", "post03", "post02", "post01", "post00"], "Default order should be DESC"); for record in records { - assert!(record["uri"].is_string(), "Record should have uri"); - assert!(record["cid"].is_string(), "Record should have cid"); - assert!(record["value"].is_object(), "Record should have value"); - let cid = record["cid"].as_str().unwrap(); - assert!(cid.starts_with("bafy"), "CID should be valid"); - } -} - -#[tokio::test] -async fn test_list_records_cursor_with_reverse() { - let client = client(); - let (did, jwt) = setup_new_user("list-cursor-reverse").await; - for i in 0..5 { - create_post_with_rkey( - &client, - &did, - &jwt, - &format!("post{:02}", i), - &format!("Post {}", i), - ) - .await; - } - let res = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "2"), - ("reverse", "true"), - ]) - .send() - .await - .expect("Failed to list records"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); - let records = body["records"].as_array().unwrap(); - let first_rkeys: Vec<&str> = records - .iter() - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) - .collect(); - assert_eq!( - first_rkeys, - vec!["post00", "post01"], - "First page with reverse should start from oldest" - ); - if let Some(cursor) = body["cursor"].as_str() { - let res2 = client - .get(format!( - "{}/xrpc/com.atproto.repo.listRecords", - base_url().await - )) - .query(&[ - ("repo", did.as_str()), - ("collection", "app.bsky.feed.post"), - ("limit", "2"), - ("reverse", "true"), - ("cursor", cursor), - ]) - .send() - .await - .expect("Failed to list records with cursor"); - let body2: Value = res2.json().await.unwrap(); - let records2 = body2["records"].as_array().unwrap(); - let second_rkeys: Vec<&str> = records2 - .iter() - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) - .collect(); - assert_eq!( - second_rkeys, - vec!["post02", "post03"], - "Second page should continue in ASC order" - ); + assert!(record["uri"].is_string()); + assert!(record["cid"].is_string()); + assert!(record["cid"].as_str().unwrap().starts_with("bafy")); + assert!(record["value"].is_object()); } + let rev_res = client + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("reverse", "true")]) + .send().await.expect("Failed to list records reverse"); + let rev_body: Value = rev_res.json().await.unwrap(); + let rev_rkeys: Vec<&str> = rev_body["records"].as_array().unwrap().iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); + assert_eq!(rev_rkeys, vec!["post00", "post01", "post02", "post03", "post04"], "reverse=true should give ASC"); + let page1 = client + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2")]) + .send().await.expect("Failed to list page 1"); + let page1_body: Value = page1.json().await.unwrap(); + let page1_records = page1_body["records"].as_array().unwrap(); + assert_eq!(page1_records.len(), 2); + let cursor = page1_body["cursor"].as_str().expect("Should have cursor"); + let page2 = client + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2"), ("cursor", cursor)]) + .send().await.expect("Failed to list page 2"); + let page2_body: Value = page2.json().await.unwrap(); + let page2_records = page2_body["records"].as_array().unwrap(); + assert_eq!(page2_records.len(), 2); + let all_uris: Vec<&str> = page1_records.iter().chain(page2_records.iter()) + .map(|r| r["uri"].as_str().unwrap()).collect(); + let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); + assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); + let range_res = client + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), + ("rkeyStart", "post01"), ("rkeyEnd", "post03"), ("reverse", "true")]) + .send().await.expect("Failed to list range"); + let range_body: Value = range_res.json().await.unwrap(); + let range_rkeys: Vec<&str> = range_body["records"].as_array().unwrap().iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); + for rkey in &range_rkeys { + assert!(*rkey >= "post01" && *rkey <= "post03", "Range should be inclusive"); + } + let limit_res = client + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "1000")]) + .send().await.expect("Failed with high limit"); + let limit_body: Value = limit_res.json().await.unwrap(); + assert!(limit_body["records"].as_array().unwrap().len() <= 100, "Limit should be clamped to max 100"); + let not_found_res = client + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) + .query(&[("repo", "did:plc:nonexistent12345"), ("collection", "app.bsky.feed.post")]) + .send().await.expect("Failed with nonexistent repo"); + assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); } diff --git a/tests/lifecycle_social.rs b/tests/lifecycle_social.rs index 0a5e198..43bc05a 100644 --- a/tests/lifecycle_social.rs +++ b/tests/lifecycle_social.rs @@ -4,114 +4,7 @@ use chrono::Utc; use common::*; use helpers::*; use reqwest::StatusCode; -use serde_json::{Value, json}; -use std::time::Duration; - -#[tokio::test] -async fn test_social_flow_lifecycle() { - let client = client(); - let (alice_did, alice_jwt) = setup_new_user("alice-social").await; - let (bob_did, bob_jwt) = setup_new_user("bob-social").await; - let (post1_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's first post!").await; - create_follow(&client, &bob_did, &bob_jwt, &alice_did).await; - tokio::time::sleep(Duration::from_secs(1)).await; - let timeline_res_1 = client - .get(format!( - "{}/xrpc/app.bsky.feed.getTimeline", - base_url().await - )) - .bearer_auth(&bob_jwt) - .send() - .await - .expect("Failed to get timeline (1)"); - assert_eq!( - timeline_res_1.status(), - reqwest::StatusCode::OK, - "Failed to get timeline (1)" - ); - let timeline_body_1: Value = timeline_res_1.json().await.expect("Timeline (1) not JSON"); - let feed_1 = timeline_body_1["feed"].as_array().unwrap(); - assert_eq!(feed_1.len(), 1, "Timeline should have 1 post"); - assert_eq!( - feed_1[0]["post"]["uri"], post1_uri, - "Post URI mismatch in timeline (1)" - ); - let (post2_uri, _) = create_post( - &client, - &alice_did, - &alice_jwt, - "Alice's second post, so exciting!", - ) - .await; - tokio::time::sleep(Duration::from_secs(1)).await; - let timeline_res_2 = client - .get(format!( - "{}/xrpc/app.bsky.feed.getTimeline", - base_url().await - )) - .bearer_auth(&bob_jwt) - .send() - .await - .expect("Failed to get timeline (2)"); - assert_eq!( - timeline_res_2.status(), - reqwest::StatusCode::OK, - "Failed to get timeline (2)" - ); - let timeline_body_2: Value = timeline_res_2.json().await.expect("Timeline (2) not JSON"); - let feed_2 = timeline_body_2["feed"].as_array().unwrap(); - assert_eq!(feed_2.len(), 2, "Timeline should have 2 posts"); - assert_eq!( - feed_2[0]["post"]["uri"], post2_uri, - "Post 2 should be first" - ); - assert_eq!( - feed_2[1]["post"]["uri"], post1_uri, - "Post 1 should be second" - ); - let delete_payload = json!({ - "repo": alice_did, - "collection": "app.bsky.feed.post", - "rkey": post1_uri.split('/').last().unwrap() - }); - let delete_res = client - .post(format!( - "{}/xrpc/com.atproto.repo.deleteRecord", - base_url().await - )) - .bearer_auth(&alice_jwt) - .json(&delete_payload) - .send() - .await - .expect("Failed to send delete request"); - assert_eq!( - delete_res.status(), - reqwest::StatusCode::OK, - "Failed to delete record" - ); - tokio::time::sleep(Duration::from_secs(1)).await; - let timeline_res_3 = client - .get(format!( - "{}/xrpc/app.bsky.feed.getTimeline", - base_url().await - )) - .bearer_auth(&bob_jwt) - .send() - .await - .expect("Failed to get timeline (3)"); - assert_eq!( - timeline_res_3.status(), - reqwest::StatusCode::OK, - "Failed to get timeline (3)" - ); - let timeline_body_3: Value = timeline_res_3.json().await.expect("Timeline (3) not JSON"); - let feed_3 = timeline_body_3["feed"].as_array().unwrap(); - assert_eq!(feed_3.len(), 1, "Timeline should have 1 post after delete"); - assert_eq!( - feed_3[0]["post"]["uri"], post2_uri, - "Only post 2 should remain" - ); -} +use serde_json::{json, Value}; #[tokio::test] async fn test_like_lifecycle() { @@ -277,97 +170,6 @@ async fn test_unfollow_lifecycle() { ); } -#[tokio::test] -async fn test_timeline_after_unfollow() { - let client = client(); - let (alice_did, alice_jwt) = setup_new_user("alice-tl-unfollow").await; - let (bob_did, bob_jwt) = setup_new_user("bob-tl-unfollow").await; - let (follow_uri, _) = create_follow(&client, &bob_did, &bob_jwt, &alice_did).await; - create_post(&client, &alice_did, &alice_jwt, "Post while following").await; - tokio::time::sleep(Duration::from_secs(1)).await; - let timeline_res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getTimeline", - base_url().await - )) - .bearer_auth(&bob_jwt) - .send() - .await - .expect("Failed to get timeline"); - assert_eq!(timeline_res.status(), StatusCode::OK); - let timeline_body: Value = timeline_res.json().await.unwrap(); - let feed = timeline_body["feed"].as_array().unwrap(); - assert_eq!(feed.len(), 1, "Should see 1 post from Alice"); - let follow_rkey = follow_uri.split('/').last().unwrap(); - let unfollow_payload = json!({ - "repo": bob_did, - "collection": "app.bsky.graph.follow", - "rkey": follow_rkey - }); - client - .post(format!( - "{}/xrpc/com.atproto.repo.deleteRecord", - base_url().await - )) - .bearer_auth(&bob_jwt) - .json(&unfollow_payload) - .send() - .await - .expect("Failed to unfollow"); - tokio::time::sleep(Duration::from_secs(1)).await; - let timeline_after_res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getTimeline", - base_url().await - )) - .bearer_auth(&bob_jwt) - .send() - .await - .expect("Failed to get timeline after unfollow"); - assert_eq!(timeline_after_res.status(), StatusCode::OK); - let timeline_after: Value = timeline_after_res.json().await.unwrap(); - let feed_after = timeline_after["feed"].as_array().unwrap(); - assert_eq!(feed_after.len(), 0, "Should see 0 posts after unfollowing"); -} - -#[tokio::test] -async fn test_mutual_follow_lifecycle() { - let client = client(); - let (alice_did, alice_jwt) = setup_new_user("alice-mutual").await; - let (bob_did, bob_jwt) = setup_new_user("bob-mutual").await; - create_follow(&client, &alice_did, &alice_jwt, &bob_did).await; - create_follow(&client, &bob_did, &bob_jwt, &alice_did).await; - create_post(&client, &alice_did, &alice_jwt, "Alice's post for mutual").await; - create_post(&client, &bob_did, &bob_jwt, "Bob's post for mutual").await; - tokio::time::sleep(Duration::from_secs(1)).await; - let alice_timeline_res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getTimeline", - base_url().await - )) - .bearer_auth(&alice_jwt) - .send() - .await - .expect("Failed to get Alice's timeline"); - assert_eq!(alice_timeline_res.status(), StatusCode::OK); - let alice_tl: Value = alice_timeline_res.json().await.unwrap(); - let alice_feed = alice_tl["feed"].as_array().unwrap(); - assert_eq!(alice_feed.len(), 1, "Alice should see Bob's 1 post"); - let bob_timeline_res = client - .get(format!( - "{}/xrpc/app.bsky.feed.getTimeline", - base_url().await - )) - .bearer_auth(&bob_jwt) - .send() - .await - .expect("Failed to get Bob's timeline"); - assert_eq!(bob_timeline_res.status(), StatusCode::OK); - let bob_tl: Value = bob_timeline_res.json().await.unwrap(); - let bob_feed = bob_tl["feed"].as_array().unwrap(); - assert_eq!(bob_feed.len(), 1, "Bob should see Alice's 1 post"); -} - #[tokio::test] async fn test_account_to_post_full_lifecycle() { let client = client(); diff --git a/tests/oauth.rs b/tests/oauth.rs index af7d697..774c3b4 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -2,7 +2,7 @@ mod common; mod helpers; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; use chrono::Utc; -use common::{base_url, client, create_account_and_login}; +use common::{base_url, client, create_account_and_login, get_db_connection_string}; use reqwest::{StatusCode, redirect}; use serde_json::{Value, json}; use sha2::{Digest, Sha256}; @@ -10,10 +10,7 @@ use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; fn no_redirect_client() -> reqwest::Client { - reqwest::Client::builder() - .redirect(redirect::Policy::none()) - .build() - .unwrap() + reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() } fn generate_pkce() -> (String, String) { @@ -21,8 +18,7 @@ fn generate_pkce() -> (String, String) { let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); let mut hasher = Sha256::new(); hasher.update(code_verifier.as_bytes()); - let hash = hasher.finalize(); - let code_challenge = URL_SAFE_NO_PAD.encode(&hash); + let code_challenge = URL_SAFE_NO_PAD.encode(&hasher.finalize()); (code_verifier, code_challenge) } @@ -45,136 +41,37 @@ async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { .await; mock_server } -#[allow(dead_code)] -async fn setup_mock_dpop_client(redirect_uri: &str) -> MockServer { - let mock_server = MockServer::start().await; - let client_id = mock_server.uri(); - let metadata = json!({ - "client_id": client_id, - "client_name": "DPoP Test Client", - "redirect_uris": [redirect_uri], - "grant_types": ["authorization_code", "refresh_token"], - "response_types": ["code"], - "token_endpoint_auth_method": "none", - "dpop_bound_access_tokens": true - }); - Mock::given(method("GET")) - .and(path("/")) - .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) - .mount(&mock_server) - .await; - mock_server -} + #[tokio::test] -async fn test_oauth_protected_resource_metadata() { +async fn test_oauth_metadata_endpoints() { let url = base_url().await; let client = client(); - let res = client - .get(format!("{}/.well-known/oauth-protected-resource", url)) - .send() - .await - .expect("Failed to fetch protected resource metadata"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.expect("Invalid JSON"); - assert!(body["resource"].is_string()); - assert!(body["authorization_servers"].is_array()); - assert!(body["bearer_methods_supported"].is_array()); - let bearer_methods = body["bearer_methods_supported"].as_array().unwrap(); - assert!(bearer_methods.contains(&json!("header"))); + let pr_res = client.get(format!("{}/.well-known/oauth-protected-resource", url)).send().await.unwrap(); + assert_eq!(pr_res.status(), StatusCode::OK); + let pr_body: Value = pr_res.json().await.unwrap(); + assert!(pr_body["resource"].is_string()); + assert!(pr_body["authorization_servers"].is_array()); + assert!(pr_body["bearer_methods_supported"].as_array().unwrap().contains(&json!("header"))); + let as_res = client.get(format!("{}/.well-known/oauth-authorization-server", url)).send().await.unwrap(); + assert_eq!(as_res.status(), StatusCode::OK); + let as_body: Value = as_res.json().await.unwrap(); + assert!(as_body["issuer"].is_string()); + assert!(as_body["authorization_endpoint"].is_string()); + assert!(as_body["token_endpoint"].is_string()); + assert!(as_body["jwks_uri"].is_string()); + assert!(as_body["response_types_supported"].as_array().unwrap().contains(&json!("code"))); + assert!(as_body["grant_types_supported"].as_array().unwrap().contains(&json!("authorization_code"))); + assert!(as_body["code_challenge_methods_supported"].as_array().unwrap().contains(&json!("S256"))); + assert_eq!(as_body["require_pushed_authorization_requests"], json!(true)); + assert!(as_body["dpop_signing_alg_values_supported"].as_array().unwrap().contains(&json!("ES256"))); + let jwks_res = client.get(format!("{}/oauth/jwks", url)).send().await.unwrap(); + assert_eq!(jwks_res.status(), StatusCode::OK); + let jwks_body: Value = jwks_res.json().await.unwrap(); + assert!(jwks_body["keys"].is_array()); } + #[tokio::test] -async fn test_oauth_authorization_server_metadata() { - let url = base_url().await; - let client = client(); - let res = client - .get(format!("{}/.well-known/oauth-authorization-server", url)) - .send() - .await - .expect("Failed to fetch authorization server metadata"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.expect("Invalid JSON"); - assert!(body["issuer"].is_string()); - assert!(body["authorization_endpoint"].is_string()); - assert!(body["token_endpoint"].is_string()); - assert!(body["jwks_uri"].is_string()); - let response_types = body["response_types_supported"].as_array().unwrap(); - assert!(response_types.contains(&json!("code"))); - let grant_types = body["grant_types_supported"].as_array().unwrap(); - assert!(grant_types.contains(&json!("authorization_code"))); - assert!(grant_types.contains(&json!("refresh_token"))); - let code_challenge_methods = body["code_challenge_methods_supported"].as_array().unwrap(); - assert!(code_challenge_methods.contains(&json!("S256"))); - assert_eq!(body["require_pushed_authorization_requests"], json!(true)); - let dpop_algs = body["dpop_signing_alg_values_supported"] - .as_array() - .unwrap(); - assert!(dpop_algs.contains(&json!("ES256"))); -} -#[tokio::test] -async fn test_oauth_jwks_endpoint() { - let url = base_url().await; - let client = client(); - let res = client - .get(format!("{}/oauth/jwks", url)) - .send() - .await - .expect("Failed to fetch JWKS"); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.expect("Invalid JSON"); - assert!(body["keys"].is_array()); -} -#[tokio::test] -async fn test_par_success() { - let url = base_url().await; - let client = client(); - let redirect_uri = "https://example.com/callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (_code_verifier, code_challenge) = generate_pkce(); - let res = client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ("scope", "atproto"), - ("state", "test-state-123"), - ]) - .send() - .await - .expect("Failed to send PAR request"); - assert_eq!( - res.status(), - StatusCode::CREATED, - "PAR should succeed: {:?}", - res.text().await - ); - let body: Value = client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ("scope", "atproto"), - ("state", "test-state-123"), - ]) - .send() - .await - .unwrap() - .json() - .await - .expect("Invalid JSON"); - assert!(body["request_uri"].is_string()); - assert!(body["expires_in"].is_number()); - let request_uri = body["request_uri"].as_str().unwrap(); - assert!(request_uri.starts_with("urn:ietf:params:oauth:request_uri:")); -} -#[tokio::test] -async fn test_authorize_get_with_valid_request_uri() { +async fn test_par_and_authorize() { let url = base_url().await; let client = client(); let redirect_uri = "https://example.com/callback"; @@ -183,82 +80,47 @@ async fn test_authorize_get_with_valid_request_uri() { let (_, code_challenge) = generate_pkce(); let par_res = client .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ("scope", "atproto"), - ("state", "test-state"), - ]) - .send() - .await - .expect("PAR failed"); - let par_body: Value = par_res.json().await.expect("Invalid PAR JSON"); + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", "test-state")]) + .send().await.unwrap(); + assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed"); + let par_body: Value = par_res.json().await.unwrap(); + assert!(par_body["request_uri"].is_string()); + assert!(par_body["expires_in"].is_number()); let request_uri = par_body["request_uri"].as_str().unwrap(); + assert!(request_uri.starts_with("urn:ietf:params:oauth:request_uri:")); let auth_res = client .get(format!("{}/oauth/authorize", url)) .header("Accept", "application/json") .query(&[("request_uri", request_uri)]) - .send() - .await - .expect("Authorize GET failed"); + .send().await.unwrap(); assert_eq!(auth_res.status(), StatusCode::OK); - let auth_body: Value = auth_res.json().await.expect("Invalid auth JSON"); + let auth_body: Value = auth_res.json().await.unwrap(); assert_eq!(auth_body["client_id"], client_id); assert_eq!(auth_body["redirect_uri"], redirect_uri); assert_eq!(auth_body["scope"], "atproto"); - assert_eq!(auth_body["state"], "test-state"); -} -#[tokio::test] -async fn test_authorize_rejects_invalid_request_uri() { - let url = base_url().await; - let client = client(); - let res = client + let invalid_res = client .get(format!("{}/oauth/authorize", url)) .header("Accept", "application/json") - .query(&[( - "request_uri", - "urn:ietf:params:oauth:request_uri:nonexistent", - )]) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: Value = res.json().await.expect("Invalid JSON"); - assert_eq!(body["error"], "invalid_request"); + .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) + .send().await.unwrap(); + assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST); + let missing_res = client.get(format!("{}/oauth/authorize", url)).send().await.unwrap(); + assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); } + #[tokio::test] -async fn test_authorize_requires_request_uri() { - let url = base_url().await; - let client = client(); - let res = client - .get(format!("{}/oauth/authorize", url)) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); -} -#[tokio::test] -async fn test_full_oauth_flow_without_dpop() { +async fn test_full_oauth_flow() { let url = base_url().await; let http_client = client(); - let (_, _user_did) = create_account_and_login(&http_client).await; let ts = Utc::now().timestamp_millis(); let handle = format!("oauth-test-{}", ts); let email = format!("oauth-test-{}@example.com", ts); let password = "oauth-test-password"; let create_res = http_client .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .expect("Account creation failed"); + .json(&json!({ "handle": handle, "email": email, "password": password })) + .send().await.unwrap(); assert_eq!(create_res.status(), StatusCode::OK); let account: Value = create_res.json().await.unwrap(); let user_did = account["did"].as_str().unwrap(); @@ -269,980 +131,187 @@ async fn test_full_oauth_flow_without_dpop() { let state = format!("state-{}", ts); let par_res = http_client .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ("scope", "atproto"), - ("state", &state), - ]) - .send() - .await - .expect("PAR failed"); - let par_status = par_res.status(); - let par_text = par_res.text().await.unwrap_or_default(); - if par_status != StatusCode::OK && par_status != StatusCode::CREATED { - panic!("PAR failed with status {}: {}", par_status, par_text); - } - let par_body: Value = serde_json::from_str(&par_text).unwrap(); + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", &state)]) + .send().await.unwrap(); + let par_body: Value = par_res.json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_client = no_redirect_client(); let auth_res = auth_client .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .expect("Authorize POST failed"); - let auth_status = auth_res.status(); - if auth_status != StatusCode::TEMPORARY_REDIRECT - && auth_status != StatusCode::SEE_OTHER - && auth_status != StatusCode::FOUND - { - let auth_text = auth_res.text().await.unwrap_or_default(); - panic!("Expected redirect, got {}: {}", auth_status, auth_text); - } - let location = auth_res - .headers() - .get("location") - .expect("No Location header") - .to_str() - .unwrap(); - assert!( - location.starts_with(redirect_uri), - "Redirect to wrong URI: {}", - location - ); - assert!( - location.contains("code="), - "No code in redirect: {}", - location - ); - assert!( - location.contains(&format!("state={}", state)), - "Wrong state in redirect" - ); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) + .send().await.unwrap(); + assert!(auth_res.status().is_redirection(), "Expected redirect, got {}", auth_res.status()); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(location.starts_with(redirect_uri), "Redirect to wrong URI"); + assert!(location.contains("code="), "No code in redirect"); + assert!(location.contains(&format!("state={}", state)), "Wrong state"); + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); let token_res = http_client .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .expect("Token request failed"); - let token_status = token_res.status(); - let token_text = token_res.text().await.unwrap_or_default(); - if token_status != StatusCode::OK { - panic!( - "Token request failed with status {}: {}", - token_status, token_text - ); - } - let token_body: Value = serde_json::from_str(&token_text).unwrap(); + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), ("client_id", &client_id)]) + .send().await.unwrap(); + assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed"); + let token_body: Value = token_res.json().await.unwrap(); assert!(token_body["access_token"].is_string()); assert!(token_body["refresh_token"].is_string()); assert_eq!(token_body["token_type"], "Bearer"); assert!(token_body["expires_in"].is_number()); assert_eq!(token_body["sub"], user_did); -} -#[tokio::test] -async fn test_token_refresh_flow() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("refresh-test-{}", ts); - let email = format!("refresh-test-{}@example.com", ts); - let password = "refresh-test-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .expect("Account creation failed"); - let redirect_uri = "https://example.com/refresh-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_body: Value = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + let access_token = token_body["access_token"].as_str().unwrap(); let refresh_token = token_body["refresh_token"].as_str().unwrap(); - let original_access_token = token_body["access_token"].as_str().unwrap(); let refresh_res = http_client .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "refresh_token"), - ("refresh_token", refresh_token), - ("client_id", &client_id), - ]) - .send() - .await - .expect("Refresh request failed"); + .form(&[("grant_type", "refresh_token"), ("refresh_token", refresh_token), ("client_id", &client_id)]) + .send().await.unwrap(); assert_eq!(refresh_res.status(), StatusCode::OK); let refresh_body: Value = refresh_res.json().await.unwrap(); - assert!(refresh_body["access_token"].is_string()); - assert!(refresh_body["refresh_token"].is_string()); - let new_access_token = refresh_body["access_token"].as_str().unwrap(); - let new_refresh_token = refresh_body["refresh_token"].as_str().unwrap(); - assert_ne!( - new_access_token, original_access_token, - "Access token should rotate" - ); - assert_ne!( - new_refresh_token, refresh_token, - "Refresh token should rotate" - ); + assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token); + assert_ne!(refresh_body["refresh_token"].as_str().unwrap(), refresh_token); + let introspect_res = http_client + .post(format!("{}/oauth/introspect", url)) + .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) + .send().await.unwrap(); + assert_eq!(introspect_res.status(), StatusCode::OK); + let introspect_body: Value = introspect_res.json().await.unwrap(); + assert_eq!(introspect_body["active"], true); + let revoke_res = http_client + .post(format!("{}/oauth/revoke", url)) + .form(&[("token", refresh_body["refresh_token"].as_str().unwrap())]) + .send().await.unwrap(); + assert_eq!(revoke_res.status(), StatusCode::OK); + let introspect_after = http_client + .post(format!("{}/oauth/introspect", url)) + .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) + .send().await.unwrap(); + let after_body: Value = introspect_after.json().await.unwrap(); + assert_eq!(after_body["active"], false, "Revoked token should be inactive"); } + #[tokio::test] -async fn test_wrong_credentials_denied() { +async fn test_oauth_error_cases() { let url = base_url().await; let http_client = client(); let ts = Utc::now().timestamp_millis(); let handle = format!("wrong-creds-{}", ts); let email = format!("wrong-creds-{}@example.com", ts); - let password = "correct-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/wrong-creds-callback"; + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle, "email": email, "password": "correct-password" })) + .send().await.unwrap(); + let redirect_uri = "https://example.com/callback"; let mock_client = setup_mock_client_metadata(redirect_uri).await; let client_id = mock_client.uri(); let (_, code_challenge) = generate_pkce(); let par_body: Value = http_client .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_res = http_client .post(format!("{}/oauth/authorize", url)) .header("Accept", "application/json") - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", "wrong-password"), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "wrong-password"), ("remember_device", "false")]) + .send().await.unwrap(); assert_eq!(auth_res.status(), StatusCode::FORBIDDEN); let error_body: Value = auth_res.json().await.unwrap(); assert_eq!(error_body["error"], "access_denied"); -} -#[tokio::test] -async fn test_token_revocation() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("revoke-test-{}", ts); - let email = format!("revoke-test-{}@example.com", ts); - let password = "revoke-test-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/revoke-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_body: Value = http_client + let unsupported = http_client .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let refresh_token = token_body["refresh_token"].as_str().unwrap(); - let revoke_res = http_client - .post(format!("{}/oauth/revoke", url)) - .form(&[("token", refresh_token)]) - .send() - .await - .unwrap(); - assert_eq!(revoke_res.status(), StatusCode::OK); - let refresh_after_revoke = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "refresh_token"), - ("refresh_token", refresh_token), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); - assert_eq!(refresh_after_revoke.status(), StatusCode::BAD_REQUEST); -} -#[tokio::test] -async fn test_unsupported_grant_type() { - let url = base_url().await; - let http_client = client(); - let res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "client_credentials"), - ("client_id", "https://example.com"), - ]) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: Value = res.json().await.unwrap(); + .form(&[("grant_type", "client_credentials"), ("client_id", "https://example.com")]) + .send().await.unwrap(); + assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST); + let body: Value = unsupported.json().await.unwrap(); assert_eq!(body["error"], "unsupported_grant_type"); -} -#[tokio::test] -async fn test_invalid_refresh_token() { - let url = base_url().await; - let http_client = client(); - let res = http_client + let invalid_refresh = http_client .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "refresh_token"), - ("refresh_token", "invalid-refresh-token"), - ("client_id", "https://example.com"), - ]) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: Value = res.json().await.unwrap(); + .form(&[("grant_type", "refresh_token"), ("refresh_token", "invalid-token"), ("client_id", "https://example.com")]) + .send().await.unwrap(); + assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST); + let body: Value = invalid_refresh.json().await.unwrap(); assert_eq!(body["error"], "invalid_grant"); -} -#[tokio::test] -async fn test_expired_authorization_request() { - let url = base_url().await; - let http_client = client(); - let res = http_client - .get(format!("{}/oauth/authorize", url)) - .header("Accept", "application/json") - .query(&[( - "request_uri", - "urn:ietf:params:oauth:request_uri:expired-or-nonexistent", - )]) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: Value = res.json().await.unwrap(); - assert_eq!(body["error"], "invalid_request"); -} -#[tokio::test] -async fn test_token_introspection() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("introspect-{}", ts); - let email = format!("introspect-{}@example.com", ts); - let password = "introspect-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/introspect-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_body: Value = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let access_token = token_body["access_token"].as_str().unwrap(); - let introspect_res = http_client - .post(format!("{}/oauth/introspect", url)) - .form(&[("token", access_token)]) - .send() - .await - .unwrap(); - assert_eq!(introspect_res.status(), StatusCode::OK); - let introspect_body: Value = introspect_res.json().await.unwrap(); - assert_eq!(introspect_body["active"], true); - assert!(introspect_body["client_id"].is_string()); - assert!(introspect_body["exp"].is_number()); -} -#[tokio::test] -async fn test_introspect_invalid_token() { - let url = base_url().await; - let http_client = client(); - let res = http_client + let invalid_introspect = http_client .post(format!("{}/oauth/introspect", url)) .form(&[("token", "invalid.token.here")]) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - let body: Value = res.json().await.unwrap(); + .send().await.unwrap(); + assert_eq!(invalid_introspect.status(), StatusCode::OK); + let body: Value = invalid_introspect.json().await.unwrap(); assert_eq!(body["active"], false); + let expired_res = http_client + .get(format!("{}/oauth/authorize", url)) + .header("Accept", "application/json") + .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")]) + .send().await.unwrap(); + assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST); } + #[tokio::test] -async fn test_introspect_revoked_token() { +async fn test_oauth_2fa_flow() { let url = base_url().await; let http_client = client(); let ts = Utc::now().timestamp_millis(); - let handle = format!("introspect-revoked-{}", ts); - let email = format!("introspect-revoked-{}@example.com", ts); - let password = "introspect-revoked-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/introspect-revoked-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_body: Value = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let access_token = token_body["access_token"].as_str().unwrap(); - let refresh_token = token_body["refresh_token"].as_str().unwrap(); - http_client - .post(format!("{}/oauth/revoke", url)) - .form(&[("token", refresh_token)]) - .send() - .await - .unwrap(); - let introspect_res = http_client - .post(format!("{}/oauth/introspect", url)) - .form(&[("token", access_token)]) - .send() - .await - .unwrap(); - assert_eq!(introspect_res.status(), StatusCode::OK); - let body: Value = introspect_res.json().await.unwrap(); - assert_eq!(body["active"], false, "Revoked token should be inactive"); -} -#[tokio::test] -async fn test_state_with_special_chars() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("state-special-{}", ts); - let email = format!("state-special-{}@example.com", ts); - let password = "state-special-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/state-special-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (_code_verifier, code_challenge) = generate_pkce(); - let special_state = "state=with&special=chars&plus+more"; - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ("state", special_state), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - assert!( - auth_res.status().is_redirection(), - "Should redirect even with special chars in state" - ); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!( - location.contains("state="), - "State should be in redirect URL" - ); - let encoded_state = urlencoding::encode(special_state); - assert!( - location.contains(&format!("state={}", encoded_state)), - "State should be URL-encoded. Got: {}", - location - ); -} -#[tokio::test] -async fn test_2fa_required_when_enabled() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("2fa-required-{}", ts); - let email = format!("2fa-required-{}@example.com", ts); + let handle = format!("2fa-test-{}", ts); + let email = format!("2fa-test-{}@example.com", ts); let password = "2fa-test-password"; let create_res = http_client .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); + .json(&json!({ "handle": handle, "email": email, "password": password })) + .send().await.unwrap(); assert_eq!(create_res.status(), StatusCode::OK); let account: Value = create_res.json().await.unwrap(); let user_did = account["did"].as_str().unwrap(); - let db_url = common::get_db_connection_string().await; - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(1) - .connect(&db_url) - .await - .expect("Failed to connect to database"); + let db_url = get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") - .bind(user_did) - .execute(&pool) - .await - .expect("Failed to enable 2FA"); + .bind(user_did).execute(&pool).await.unwrap(); let redirect_uri = "https://example.com/2fa-callback"; let mock_client = setup_mock_client_metadata(redirect_uri).await; let client_id = mock_client.uri(); - let (_, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - assert!( - auth_res.status().is_redirection(), - "Should redirect to 2FA page, got status: {}", - auth_res.status() - ); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!( - location.contains("/oauth/authorize/2fa"), - "Should redirect to 2FA page, got: {}", - location - ); - assert!( - location.contains("request_uri="), - "2FA redirect should include request_uri" - ); -} -#[tokio::test] -async fn test_2fa_invalid_code_rejected() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("2fa-invalid-{}", ts); - let email = format!("2fa-invalid-{}@example.com", ts); - let password = "2fa-test-password"; - let create_res = http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - assert_eq!(create_res.status(), StatusCode::OK); - let account: Value = create_res.json().await.unwrap(); - let user_did = account["did"].as_str().unwrap(); - let db_url = common::get_db_connection_string().await; - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(1) - .connect(&db_url) - .await - .expect("Failed to connect to database"); - sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") - .bind(user_did) - .execute(&pool) - .await - .expect("Failed to enable 2FA"); - let redirect_uri = "https://example.com/2fa-invalid-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (_, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - assert!(auth_res.status().is_redirection()); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!(location.contains("/oauth/authorize/2fa")); - let twofa_res = http_client - .post(format!("{}/oauth/authorize/2fa", url)) - .form(&[("request_uri", request_uri), ("code", "000000")]) - .send() - .await - .unwrap(); - assert_eq!(twofa_res.status(), StatusCode::OK); - let body = twofa_res.text().await.unwrap(); - assert!( - body.contains("Invalid verification code") || body.contains("invalid"), - "Should show error for invalid code" - ); -} -#[tokio::test] -async fn test_2fa_valid_code_completes_auth() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("2fa-valid-{}", ts); - let email = format!("2fa-valid-{}@example.com", ts); - let password = "2fa-test-password"; - let create_res = http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - assert_eq!(create_res.status(), StatusCode::OK); - let account: Value = create_res.json().await.unwrap(); - let user_did = account["did"].as_str().unwrap(); - let db_url = common::get_db_connection_string().await; - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(1) - .connect(&db_url) - .await - .expect("Failed to connect to database"); - sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") - .bind(user_did) - .execute(&pool) - .await - .expect("Failed to enable 2FA"); - let redirect_uri = "https://example.com/2fa-valid-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); let (code_verifier, code_challenge) = generate_pkce(); let par_body: Value = http_client .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_client = no_redirect_client(); let auth_res = auth_client .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - assert!(auth_res.status().is_redirection()); - let twofa_code: String = - sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") - .bind(request_uri) - .fetch_one(&pool) - .await - .expect("Failed to get 2FA code from database"); + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) + .send().await.unwrap(); + assert!(auth_res.status().is_redirection(), "Should redirect to 2FA page"); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page, got: {}", location); + let twofa_invalid = http_client + .post(format!("{}/oauth/authorize/2fa", url)) + .form(&[("request_uri", request_uri), ("code", "000000")]) + .send().await.unwrap(); + assert_eq!(twofa_invalid.status(), StatusCode::OK); + let body = twofa_invalid.text().await.unwrap(); + assert!(body.contains("Invalid verification code") || body.contains("invalid")); + let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") + .bind(request_uri).fetch_one(&pool).await.unwrap(); let twofa_res = auth_client .post(format!("{}/oauth/authorize/2fa", url)) .form(&[("request_uri", request_uri), ("code", &twofa_code)]) - .send() - .await - .unwrap(); - assert!( - twofa_res.status().is_redirection(), - "Valid 2FA code should redirect to success, got status: {}", - twofa_res.status() - ); - let location = twofa_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!( - location.starts_with(redirect_uri), - "Should redirect to client callback, got: {}", - location - ); - assert!( - location.contains("code="), - "Redirect should include authorization code" - ); - let auth_code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); + .send().await.unwrap(); + assert!(twofa_res.status().is_redirection(), "Valid 2FA code should redirect"); + let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); + let auth_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); let token_res = http_client .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", auth_code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); - assert_eq!( - token_res.status(), - StatusCode::OK, - "Token exchange should succeed" - ); + .form(&[("grant_type", "authorization_code"), ("code", auth_code), ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), ("client_id", &client_id)]) + .send().await.unwrap(); + assert_eq!(token_res.status(), StatusCode::OK); let token_body: Value = token_res.json().await.unwrap(); - assert!(token_body["access_token"].is_string()); assert_eq!(token_body["sub"], user_did); } + #[tokio::test] -async fn test_2fa_lockout_after_max_attempts() { +async fn test_oauth_2fa_lockout() { let url = base_url().await; let http_client = client(); let ts = Utc::now().timestamp_millis(); @@ -1251,99 +320,49 @@ async fn test_2fa_lockout_after_max_attempts() { let password = "2fa-test-password"; let create_res = http_client .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - assert_eq!(create_res.status(), StatusCode::OK); + .json(&json!({ "handle": handle, "email": email, "password": password })) + .send().await.unwrap(); let account: Value = create_res.json().await.unwrap(); let user_did = account["did"].as_str().unwrap(); - let db_url = common::get_db_connection_string().await; - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(1) - .connect(&db_url) - .await - .expect("Failed to connect to database"); + let db_url = get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") - .bind(user_did) - .execute(&pool) - .await - .expect("Failed to enable 2FA"); + .bind(user_did).execute(&pool).await.unwrap(); let redirect_uri = "https://example.com/2fa-lockout-callback"; let mock_client = setup_mock_client_metadata(redirect_uri).await; let client_id = mock_client.uri(); let (_, code_challenge) = generate_pkce(); let par_body: Value = http_client .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_client = no_redirect_client(); let auth_res = auth_client .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) + .send().await.unwrap(); assert!(auth_res.status().is_redirection()); for i in 0..5 { let res = http_client .post(format!("{}/oauth/authorize/2fa", url)) .form(&[("request_uri", request_uri), ("code", "999999")]) - .send() - .await - .unwrap(); + .send().await.unwrap(); if i < 4 { - assert_eq!( - res.status(), - StatusCode::OK, - "Attempt {} should show error page", - i + 1 - ); - let body = res.text().await.unwrap(); - assert!( - body.contains("Invalid verification code"), - "Should show invalid code error on attempt {}", - i + 1 - ); + assert_eq!(res.status(), StatusCode::OK); } } let lockout_res = http_client .post(format!("{}/oauth/authorize/2fa", url)) .form(&[("request_uri", request_uri), ("code", "999999")]) - .send() - .await - .unwrap(); - assert_eq!(lockout_res.status(), StatusCode::OK); + .send().await.unwrap(); let body = lockout_res.text().await.unwrap(); - assert!( - body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"), - "Should be locked out after max attempts. Body: {}", - &body[..body.len().min(500)] - ); + assert!(body.contains("Too many failed attempts") || body.contains("No 2FA challenge found")); } + #[tokio::test] -async fn test_account_selector_with_2fa_requires_verification() { +async fn test_account_selector_with_2fa() { let url = base_url().await; let http_client = client(); let ts = Utc::now().timestamp_millis(); @@ -1352,15 +371,8 @@ async fn test_account_selector_with_2fa_requires_verification() { let password = "selector-2fa-password"; let create_res = http_client .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - assert_eq!(create_res.status(), StatusCode::OK); + .json(&json!({ "handle": handle, "email": email, "password": password })) + .send().await.unwrap(); let account: Value = create_res.json().await.unwrap(); let user_did = account["did"].as_str().unwrap().to_string(); let redirect_uri = "https://example.com/selector-2fa-callback"; @@ -1369,167 +381,99 @@ async fn test_account_selector_with_2fa_requires_verification() { let (code_verifier, code_challenge) = generate_pkce(); let par_body: Value = http_client .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_client = no_redirect_client(); let auth_res = auth_client .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "true"), - ]) - .send() - .await - .unwrap(); + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "true")]) + .send().await.unwrap(); assert!(auth_res.status().is_redirection()); - let device_cookie = auth_res - .headers() - .get("set-cookie") + let device_cookie = auth_res.headers().get("set-cookie") .and_then(|v| v.to_str().ok()) .map(|s| s.split(';').next().unwrap_or("").to_string()) - .expect("Should have received device cookie"); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!(location.contains("code="), "First auth should succeed"); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let _token_body: Value = http_client + .expect("Should have device cookie"); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(location.contains("code=")); + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); + let _ = http_client .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let db_url = common::get_db_connection_string().await; - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(1) - .connect(&db_url) - .await - .expect("Failed to connect to database"); + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), ("client_id", &client_id)]) + .send().await.unwrap().json::().await.unwrap(); + let db_url = get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") - .bind(&user_did) - .execute(&pool) - .await - .expect("Failed to enable 2FA"); + .bind(&user_did).execute(&pool).await.unwrap(); let (code_verifier2, code_challenge2) = generate_pkce(); let par_body2: Value = http_client .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge2), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri2 = par_body2["request_uri"].as_str().unwrap(); let select_res = auth_client .post(format!("{}/oauth/authorize/select", url)) .header("cookie", &device_cookie) .form(&[("request_uri", request_uri2), ("did", &user_did)]) - .send() - .await - .unwrap(); - assert!( - select_res.status().is_redirection(), - "Account selector should redirect, got status: {}", - select_res.status() - ); - let select_location = select_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!( - select_location.contains("/oauth/authorize/2fa"), - "Account selector with 2FA enabled should redirect to 2FA page, got: {}", - select_location - ); - let twofa_code: String = - sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") - .bind(request_uri2) - .fetch_one(&pool) - .await - .expect("Failed to get 2FA code"); + .send().await.unwrap(); + assert!(select_res.status().is_redirection()); + let select_location = select_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(select_location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page"); + let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") + .bind(request_uri2).fetch_one(&pool).await.unwrap(); let twofa_res = auth_client .post(format!("{}/oauth/authorize/2fa", url)) .header("cookie", &device_cookie) .form(&[("request_uri", request_uri2), ("code", &twofa_code)]) - .send() - .await - .unwrap(); + .send().await.unwrap(); assert!(twofa_res.status().is_redirection()); - let final_location = twofa_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!( - final_location.starts_with(redirect_uri) && final_location.contains("code="), - "After 2FA, should redirect to client with code, got: {}", - final_location - ); - let final_code = final_location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); + let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); + let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); let token_res = http_client .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", final_code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier2), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); + .form(&[("grant_type", "authorization_code"), ("code", final_code), ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier2), ("client_id", &client_id)]) + .send().await.unwrap(); assert_eq!(token_res.status(), StatusCode::OK); let final_token: Value = token_res.json().await.unwrap(); - assert_eq!( - final_token["sub"], user_did, - "Token should be for the correct user" - ); + assert_eq!(final_token["sub"], user_did); +} + +#[tokio::test] +async fn test_oauth_state_encoding() { + let url = base_url().await; + let http_client = client(); + let ts = Utc::now().timestamp_millis(); + let handle = format!("state-special-{}", ts); + let email = format!("state-special-{}@example.com", ts); + let password = "state-special-password"; + http_client + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle, "email": email, "password": password })) + .send().await.unwrap(); + let redirect_uri = "https://example.com/state-special-callback"; + let mock_client = setup_mock_client_metadata(redirect_uri).await; + let client_id = mock_client.uri(); + let (_, code_challenge) = generate_pkce(); + let special_state = "state=with&special=chars&plus+more"; + let par_body: Value = http_client + .post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("state", special_state)]) + .send().await.unwrap().json().await.unwrap(); + let request_uri = par_body["request_uri"].as_str().unwrap(); + let auth_client = no_redirect_client(); + let auth_res = auth_client + .post(format!("{}/oauth/authorize", url)) + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) + .send().await.unwrap(); + assert!(auth_res.status().is_redirection()); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(location.contains("state=")); + let encoded_state = urlencoding::encode(special_state); + assert!(location.contains(&format!("state={}", encoded_state)), "State should be URL-encoded. Got: {}", location); } diff --git a/tests/oauth_security.rs b/tests/oauth_security.rs index 060e38f..a83f872 100644 --- a/tests/oauth_security.rs +++ b/tests/oauth_security.rs @@ -1,5 +1,4 @@ #![allow(unused_imports)] -#![allow(unused_variables)] mod common; mod helpers; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; @@ -14,10 +13,7 @@ use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; fn no_redirect_client() -> reqwest::Client { - reqwest::Client::builder() - .redirect(redirect::Policy::none()) - .build() - .unwrap() + reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() } fn generate_pkce() -> (String, String) { @@ -25,16 +21,14 @@ fn generate_pkce() -> (String, String) { let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); let mut hasher = Sha256::new(); hasher.update(code_verifier.as_bytes()); - let hash = hasher.finalize(); - let code_challenge = URL_SAFE_NO_PAD.encode(&hash); + let code_challenge = URL_SAFE_NO_PAD.encode(&hasher.finalize()); (code_verifier, code_challenge) } async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { let mock_server = MockServer::start().await; - let client_id = mock_server.uri(); let metadata = json!({ - "client_id": client_id, + "client_id": mock_server.uri(), "client_name": "Security Test Client", "redirect_uris": [redirect_uri], "grant_types": ["authorization_code", "refresh_token"], @@ -42,1754 +36,435 @@ async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { "token_endpoint_auth_method": "none", "dpop_bound_access_tokens": false }); - Mock::given(method("GET")) - .and(path("/")) + Mock::given(method("GET")).and(path("/")) .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) - .mount(&mock_server) - .await; + .mount(&mock_server).await; mock_server } async fn get_oauth_tokens(http_client: &reqwest::Client, url: &str) -> (String, String, String) { let ts = Utc::now().timestamp_millis(); let handle = format!("sec-test-{}", ts); - let email = format!("sec-test-{}@example.com", ts); - let password = "security-test-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "security-test-password" })) + .send().await.unwrap(); let redirect_uri = "https://example.com/sec-callback"; let mock_client = setup_mock_client_metadata(redirect_uri).await; let client_id = mock_client.uri(); let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + let par_body: Value = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_body: Value = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let access_token = token_body["access_token"].as_str().unwrap().to_string(); - let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); - (access_token, refresh_token, client_id) + let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "security-test-password"), ("remember_device", "false")]) + .send().await.unwrap(); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); + let token_body: Value = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), ("client_id", &client_id)]) + .send().await.unwrap().json().await.unwrap(); + (token_body["access_token"].as_str().unwrap().to_string(), + token_body["refresh_token"].as_str().unwrap().to_string(), client_id) } #[tokio::test] -async fn test_security_forged_token_signature_rejected() { - let url = base_url().await; - let http_client = client(); - let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; - let parts: Vec<&str> = access_token.split('.').collect(); - assert_eq!(parts.len(), 3, "Token should have 3 parts"); - let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 32]); - let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", forged_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Forged signature should be rejected" - ); -} - -#[tokio::test] -async fn test_security_modified_payload_rejected() { +async fn test_token_tampering_attacks() { let url = base_url().await; let http_client = client(); let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; let parts: Vec<&str> = access_token.split('.').collect(); + assert_eq!(parts.len(), 3); + let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]); + let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig); + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .bearer_auth(&forged_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); payload["sub"] = json!("did:plc:attacker"); let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", modified_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Modified payload should be rejected" - ); + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .bearer_auth(&modified_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); + let none_header = json!({ "alg": "none", "typ": "at+jwt" }); + let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds", + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" }); + let none_token = format!("{}.{}.", URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()), + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap())); + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .bearer_auth(&none_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "alg=none should be rejected"); + let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" }); + let rs256_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()), + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 64])); + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .bearer_auth(&rs256_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Algorithm substitution should be rejected"); + let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds", + "iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" }); + let expired_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()), + URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .bearer_auth(&expired_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); } #[tokio::test] -async fn test_security_algorithm_none_attack_rejected() { +async fn test_pkce_security() { let url = base_url().await; let http_client = client(); - let header = json!({ - "alg": "none", - "typ": "at+jwt" - }); - let payload = json!({ - "iss": "https://test.pds", - "sub": "did:plc:attacker", - "aud": "https://test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "fake-token-id", - "scope": "atproto" - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let malicious_token = format!("{}.{}.", header_b64, payload_b64); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", malicious_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Algorithm 'none' attack should be rejected" - ); -} - -#[tokio::test] -async fn test_security_algorithm_substitution_attack_rejected() { - let url = base_url().await; - let http_client = client(); - let header = json!({ - "alg": "RS256", - "typ": "at+jwt" - }); - let payload = json!({ - "iss": "https://test.pds", - "sub": "did:plc:attacker", - "aud": "https://test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "fake-token-id" - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); - let malicious_token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", malicious_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Algorithm substitution attack should be rejected" - ); -} - -#[tokio::test] -async fn test_security_expired_token_rejected() { - let url = base_url().await; - let http_client = client(); - let header = json!({ - "alg": "HS256", - "typ": "at+jwt" - }); - let payload = json!({ - "iss": "https://test.pds", - "sub": "did:plc:test", - "aud": "https://test.pds", - "iat": Utc::now().timestamp() - 7200, - "exp": Utc::now().timestamp() - 3600, - "jti": "expired-token-id" - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]); - let expired_token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", expired_token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Expired token should be rejected" - ); -} - -#[tokio::test] -async fn test_security_pkce_plain_method_rejected() { - let url = base_url().await; - let http_client = client(); - let redirect_uri = "https://example.com/pkce-plain-callback"; + let redirect_uri = "https://example.com/pkce-callback"; let mock_client = setup_mock_client_metadata(redirect_uri).await; let client_id = mock_client.uri(); - let res = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", "plain-text-challenge"), - ("code_challenge_method", "plain"), - ]) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::BAD_REQUEST, - "PKCE plain method should be rejected" - ); + let res = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", "plain-text-challenge"), ("code_challenge_method", "plain")]) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected"); let body: Value = res.json().await.unwrap(); - assert_eq!(body["error"], "invalid_request"); - assert!( - body["error_description"] - .as_str() - .unwrap() - .to_lowercase() - .contains("s256"), - "Error should mention S256 requirement" - ); -} - -#[tokio::test] -async fn test_security_pkce_missing_challenge_rejected() { - let url = base_url().await; - let http_client = client(); - let redirect_uri = "https://example.com/no-pkce-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let res = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ]) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::BAD_REQUEST, - "Missing PKCE challenge should be rejected" - ); -} - -#[tokio::test] -async fn test_security_pkce_wrong_verifier_rejected() { - let url = base_url().await; - let http_client = client(); + assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("s256")); + let res = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri)]) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); let ts = Utc::now().timestamp_millis(); let handle = format!("pkce-attack-{}", ts); - let email = format!("pkce-attack-{}@example.com", ts); - let password = "pkce-attack-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/pkce-attack-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "pkce-password" })) + .send().await.unwrap(); let (_, code_challenge) = generate_pkce(); let (attacker_verifier, _) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + let par_body: Value = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &attacker_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); - assert_eq!( - token_res.status(), - StatusCode::BAD_REQUEST, - "Wrong PKCE verifier should be rejected" - ); - let body: Value = token_res.json().await.unwrap(); - assert_eq!(body["error"], "invalid_grant"); + let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "pkce-password"), ("remember_device", "false")]) + .send().await.unwrap(); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); + let token_res = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), + ("code_verifier", &attacker_verifier), ("client_id", &client_id)]) + .send().await.unwrap(); + assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected"); } #[tokio::test] -async fn test_security_authorization_code_replay_attack() { +async fn test_replay_attacks() { let url = base_url().await; let http_client = client(); let ts = Utc::now().timestamp_millis(); - let handle = format!("code-replay-{}", ts); - let email = format!("code-replay-{}@example.com", ts); - let password = "code-replay-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/code-replay-callback"; + let handle = format!("replay-{}", ts); + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "replay-password" })) + .send().await.unwrap(); + let redirect_uri = "https://example.com/replay-callback"; let mock_client = setup_mock_client_metadata(redirect_uri).await; let client_id = mock_client.uri(); let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + let par_body: Value = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let request_uri = par_body["request_uri"].as_str().unwrap(); let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let stolen_code = code.to_string(); - let first_res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); - assert_eq!( - first_res.status(), - StatusCode::OK, - "First use should succeed" - ); - let replay_res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", &stolen_code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); - assert_eq!( - replay_res.status(), - StatusCode::BAD_REQUEST, - "Replay attack should fail" - ); - let body: Value = replay_res.json().await.unwrap(); - assert_eq!(body["error"], "invalid_grant"); + let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "replay-password"), ("remember_device", "false")]) + .send().await.unwrap(); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap().to_string(); + let first = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), ("client_id", &client_id)]) + .send().await.unwrap(); + assert_eq!(first.status(), StatusCode::OK, "First use should succeed"); + let first_body: Value = first.json().await.unwrap(); + let replay = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), ("client_id", &client_id)]) + .send().await.unwrap(); + assert_eq!(replay.status(), StatusCode::BAD_REQUEST, "Auth code replay should fail"); + let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string(); + let first_refresh: Value = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) + .send().await.unwrap().json().await.unwrap(); + assert!(first_refresh["access_token"].is_string(), "First refresh should succeed"); + let new_rt = first_refresh["refresh_token"].as_str().unwrap(); + let rt_replay = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) + .send().await.unwrap(); + assert_eq!(rt_replay.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail"); + let body: Value = rt_replay.json().await.unwrap(); + assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("reuse")); + let family_revoked = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "refresh_token"), ("refresh_token", new_rt), ("client_id", &client_id)]) + .send().await.unwrap(); + assert_eq!(family_revoked.status(), StatusCode::BAD_REQUEST, "Token family should be revoked"); } #[tokio::test] -async fn test_security_refresh_token_replay_attack() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("rt-replay-{}", ts); - let email = format!("rt-replay-{}@example.com", ts); - let password = "rt-replay-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/rt-replay-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_body: Value = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri), - ("code_verifier", &code_verifier), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let stolen_refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); - let first_refresh: Value = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "refresh_token"), - ("refresh_token", &stolen_refresh_token), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - assert!( - first_refresh["access_token"].is_string(), - "First refresh should succeed" - ); - let new_refresh_token = first_refresh["refresh_token"].as_str().unwrap(); - let replay_res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "refresh_token"), - ("refresh_token", &stolen_refresh_token), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); - assert_eq!( - replay_res.status(), - StatusCode::BAD_REQUEST, - "Refresh token replay should fail" - ); - let body: Value = replay_res.json().await.unwrap(); - assert_eq!(body["error"], "invalid_grant"); - assert!( - body["error_description"] - .as_str() - .unwrap() - .to_lowercase() - .contains("reuse"), - "Error should mention token reuse" - ); - let family_revoked_res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "refresh_token"), - ("refresh_token", new_refresh_token), - ("client_id", &client_id), - ]) - .send() - .await - .unwrap(); - assert_eq!( - family_revoked_res.status(), - StatusCode::BAD_REQUEST, - "Token family should be revoked after replay detection" - ); -} - -#[tokio::test] -async fn test_security_redirect_uri_manipulation() { +async fn test_oauth_security_boundaries() { let url = base_url().await; let http_client = client(); let registered_redirect = "https://legitimate-app.com/callback"; - let attacker_redirect = "https://attacker.com/steal"; let mock_client = setup_mock_client_metadata(registered_redirect).await; let client_id = mock_client.uri(); let (_, code_challenge) = generate_pkce(); - let res = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", attacker_redirect), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::BAD_REQUEST, - "Unregistered redirect_uri should be rejected" - ); -} - -#[tokio::test] -async fn test_security_deactivated_account_blocked() { - let url = base_url().await; - let http_client = client(); + let res = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", "https://attacker.com/steal"), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); let ts = Utc::now().timestamp_millis(); - let handle = format!("deact-sec-{}", ts); - let email = format!("deact-sec-{}@example.com", ts); - let password = "deact-sec-password"; - let create_res = http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - assert_eq!(create_res.status(), StatusCode::OK); + let handle = format!("deact-{}", ts); + let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "deact-password" })) + .send().await.unwrap(); let account: Value = create_res.json().await.unwrap(); - let did = account["did"].as_str().unwrap(); - let access_jwt = verify_new_account(&http_client, did).await; - let deact_res = http_client - .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) - .header("Authorization", format!("Bearer {}", access_jwt)) - .json(&json!({})) - .send() - .await - .unwrap(); - assert_eq!(deact_res.status(), StatusCode::OK); - let redirect_uri = "https://example.com/deact-sec-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (_, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_res = http_client - .post(format!("{}/oauth/authorize", url)) + let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await; + http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) + .bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap(); + let deact_par: Value = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", registered_redirect), + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) .header("Accept", "application/json") - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - assert_eq!( - auth_res.status(), - StatusCode::FORBIDDEN, - "Deactivated account should be blocked from OAuth" - ); - let body: Value = auth_res.json().await.unwrap(); - assert_eq!(body["error"], "access_denied"); -} - -#[tokio::test] -async fn test_security_url_injection_in_state_parameter() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("inject-state-{}", ts); - let email = format!("inject-state-{}@example.com", ts); - let password = "inject-state-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); - let redirect_uri = "https://example.com/inject-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (code_verifier, code_challenge) = generate_pkce(); - let malicious_state = "state&redirect_uri=https://attacker.com&extra="; - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ("state", malicious_state), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - assert!( - auth_res.status().is_redirection(), - "Should redirect successfully" - ); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - assert!( - location.starts_with(redirect_uri), - "Redirect should go to registered URI, not attacker URI. Got: {}", - location - ); - let redirect_uri_count = location.matches("redirect_uri=").count(); - assert!( - redirect_uri_count <= 1, - "State injection should not add extra redirect_uri parameters" - ); - assert!( - location.contains(&urlencoding::encode(malicious_state).to_string()) - || location.contains("state=state%26redirect_uri"), - "State parameter should be properly URL-encoded. Got: {}", - location - ); -} - -#[tokio::test] -async fn test_security_cross_client_token_theft() { - let url = base_url().await; - let http_client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("cross-client-{}", ts); - let email = format!("cross-client-{}@example.com", ts); - let password = "cross-client-password"; - http_client - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) - .json(&json!({ - "handle": handle, - "email": email, - "password": password - })) - .send() - .await - .unwrap(); + .form(&[("request_uri", deact_par["request_uri"].as_str().unwrap()), ("username", &handle), ("password", "deact-password"), ("remember_device", "false")]) + .send().await.unwrap(); + assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked"); let redirect_uri_a = "https://app-a.com/callback"; - let mock_client_a = setup_mock_client_metadata(redirect_uri_a).await; - let client_id_a = mock_client_a.uri(); - let redirect_uri_b = "https://app-b.com/callback"; - let mock_client_b = setup_mock_client_metadata(redirect_uri_b).await; - let client_id_b = mock_client_b.uri(); - let (code_verifier, code_challenge) = generate_pkce(); - let par_body: Value = http_client - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id_a), - ("redirect_uri", redirect_uri_a), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); + let mock_a = setup_mock_client_metadata(redirect_uri_a).await; + let client_id_a = mock_a.uri(); + let mock_b = setup_mock_client_metadata("https://app-b.com/callback").await; + let client_id_b = mock_b.uri(); + let ts2 = Utc::now().timestamp_millis(); + let handle2 = format!("cross-{}", ts2); + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ "handle": handle2, "email": format!("{}@example.com", handle2), "password": "cross-password" })) + .send().await.unwrap(); + let (code_verifier2, code_challenge2) = generate_pkce(); + let par_a: Value = http_client.post(format!("{}/oauth/par", url)) + .form(&[("response_type", "code"), ("client_id", &client_id_a), ("redirect_uri", redirect_uri_a), + ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) + .send().await.unwrap().json().await.unwrap(); let auth_client = no_redirect_client(); - let auth_res = auth_client - .post(format!("{}/oauth/authorize", url)) - .form(&[ - ("request_uri", request_uri), - ("username", &handle), - ("password", password), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - let location = auth_res - .headers() - .get("location") - .unwrap() - .to_str() - .unwrap(); - let code = location - .split("code=") - .nth(1) - .unwrap() - .split('&') - .next() - .unwrap(); - let token_res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("redirect_uri", redirect_uri_a), - ("code_verifier", &code_verifier), - ("client_id", &client_id_b), - ]) - .send() - .await - .unwrap(); - assert_eq!( - token_res.status(), - StatusCode::BAD_REQUEST, - "Cross-client code exchange must be explicitly rejected (defense-in-depth)" - ); - let body: Value = token_res.json().await.unwrap(); - assert_eq!(body["error"], "invalid_grant"); - assert!( - body["error_description"] - .as_str() - .unwrap() - .contains("client_id"), - "Error should mention client_id mismatch" - ); + let auth_a = auth_client.post(format!("{}/oauth/authorize", url)) + .form(&[("request_uri", par_a["request_uri"].as_str().unwrap()), ("username", &handle2), ("password", "cross-password"), ("remember_device", "false")]) + .send().await.unwrap(); + let loc_a = auth_a.headers().get("location").unwrap().to_str().unwrap(); + let code_a = loc_a.split("code=").nth(1).unwrap().split('&').next().unwrap(); + let cross_client = http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", "authorization_code"), ("code", code_a), ("redirect_uri", redirect_uri_a), + ("code_verifier", &code_verifier2), ("client_id", &client_id_b)]) + .send().await.unwrap(); + assert_eq!(cross_client.status(), StatusCode::BAD_REQUEST, "Cross-client code exchange must be rejected"); } -#[test] -fn test_security_dpop_nonce_tamper_detection() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let nonce = verifier.generate_nonce(); - let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap(); - let mut tampered = nonce_bytes.clone(); - if !tampered.is_empty() { - tampered[0] ^= 0xFF; +#[tokio::test] +async fn test_malformed_tokens_and_headers() { + let url = base_url().await; + let http_client = client(); + let malformed = vec!["", "not-a-token", "one.two", "one.two.three.four", "....", "eyJhbGciOiJIUzI1NiJ9", + "eyJhbGciOiJIUzI1NiJ9.", "eyJhbGciOiJIUzI1NiJ9..", ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid!!.eyJ9.sig"]; + for token in &malformed { + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .bearer_auth(token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); + } + let wrong_types = vec!["JWT", "jwt", "at+JWT", ""]; + for typ in wrong_types { + let header = json!({ "alg": "HS256", "typ": typ }); + let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" }); + let token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .bearer_auth(&token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "typ='{}' should be rejected", typ); + } + let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; + let invalid_formats = vec![format!("Basic {}", access_token), format!("Digest {}", access_token), + access_token.clone(), format!("Bearer{}", access_token)]; + for auth in &invalid_formats { + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", auth).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); + } + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .send().await.unwrap().status(), StatusCode::UNAUTHORIZED); + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) + .header("Authorization", "").send().await.unwrap().status(), StatusCode::UNAUTHORIZED); + let grants = vec!["client_credentials", "password", "implicit", "", "AUTHORIZATION_CODE"]; + for grant in grants { + assert_eq!(http_client.post(format!("{}/oauth/token", url)) + .form(&[("grant_type", grant), ("client_id", "https://example.com")]) + .send().await.unwrap().status(), StatusCode::BAD_REQUEST, "Grant '{}' should be rejected", grant); } - let tampered_nonce = URL_SAFE_NO_PAD.encode(&tampered); - let result = verifier.validate_nonce(&tampered_nonce); - assert!(result.is_err(), "Tampered nonce should be rejected"); } -#[test] -fn test_security_dpop_nonce_cross_server_rejected() { - let secret1 = b"server-1-secret-32-bytes-long!!!"; - let secret2 = b"server-2-secret-32-bytes-long!!!"; - let verifier1 = DPoPVerifier::new(secret1); - let verifier2 = DPoPVerifier::new(secret2); - let nonce_from_server1 = verifier1.generate_nonce(); - let result = verifier2.validate_nonce(&nonce_from_server1); - assert!( - result.is_err(), - "Nonce from different server should be rejected" - ); +#[tokio::test] +async fn test_token_revocation() { + let url = base_url().await; + let http_client = client(); + let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await; + assert_eq!(http_client.post(format!("{}/oauth/revoke", url)) + .form(&[("token", &refresh_token)]).send().await.unwrap().status(), StatusCode::OK); + let introspect: Value = http_client.post(format!("{}/oauth/introspect", url)) + .form(&[("token", &access_token)]).send().await.unwrap().json().await.unwrap(); + assert_eq!(introspect["active"], false, "Revoked token should be inactive"); } -#[test] -fn test_security_dpop_proof_signature_tampering() { +fn create_dpop_proof(method: &str, uri: &str, _nonce: Option<&str>, ath: Option<&str>, iat_offset: i64) -> String { use p256::ecdsa::{Signature, SigningKey, signature::Signer}; use p256::elliptic_curve::sec1::ToEncodedPoint; - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); let signing_key = SigningKey::random(&mut rand::thread_rng()); - let verifying_key = signing_key.verifying_key(); - let point = verifying_key.to_encoded_point(false); + let point = signing_key.verifying_key().to_encoded_point(false); let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); - let header = json!({ - "typ": "dpop+jwt", - "alg": "ES256", - "jwk": { - "kty": "EC", - "crv": "P-256", - "x": x, - "y": y - } - }); - let payload = json!({ - "jti": format!("tamper-test-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), - "htm": "POST", - "htu": "https://example.com/token", - "iat": Utc::now().timestamp() - }); + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); + let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), + "htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset }); + if let Some(a) = ath { payload["ath"] = json!(a); } let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); let signing_input = format!("{}.{}", header_b64, payload_b64); let signature: Signature = signing_key.sign(signing_input.as_bytes()); - let mut sig_bytes = signature.to_bytes().to_vec(); - sig_bytes[0] ^= 0xFF; - let tampered_sig = URL_SAFE_NO_PAD.encode(&sig_bytes); - let tampered_proof = format!("{}.{}.{}", header_b64, payload_b64, tampered_sig); - let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None); - assert!( - result.is_err(), - "Tampered DPoP signature should be rejected" - ); + format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())) } #[test] -fn test_security_dpop_proof_key_substitution() { +fn test_dpop_nonce_security() { + let secret1 = b"test-dpop-secret-32-bytes-long!!"; + let secret2 = b"different-secret-32-bytes-long!!"; + let v1 = DPoPVerifier::new(secret1); + let v2 = DPoPVerifier::new(secret2); + let nonce = v1.generate_nonce(); + assert!(!nonce.is_empty()); + assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass"); + assert!(v2.validate_nonce(&nonce).is_err(), "Nonce from different secret should fail"); + let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap(); + let mut tampered = nonce_bytes.clone(); + if !tampered.is_empty() { tampered[0] ^= 0xFF; } + assert!(v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)).is_err(), "Tampered nonce should fail"); + assert!(v1.validate_nonce("invalid").is_err()); + assert!(v1.validate_nonce("").is_err()); + assert!(v1.validate_nonce("!!!not-base64!!!").is_err()); +} + +#[test] +fn test_dpop_proof_validation() { + let secret = b"test-dpop-secret-32-bytes-long!!"; + let verifier = DPoPVerifier::new(secret); + assert!(verifier.verify_proof("not.enough", "POST", "https://example.com", None).is_err()); + assert!(verifier.verify_proof("invalid", "POST", "https://example.com", None).is_err()); + let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); + assert!(verifier.verify_proof(&proof, "GET", "https://example.com/token", None).is_err(), "Method mismatch"); + assert!(verifier.verify_proof(&proof, "POST", "https://other.com/token", None).is_err(), "URI mismatch"); + assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None).is_ok(), "Query params should be ignored"); + let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); + assert!(verifier.verify_proof(&old_proof, "POST", "https://example.com/token", None).is_err(), "iat too old"); + let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); + assert!(verifier.verify_proof(&future_proof, "POST", "https://example.com/token", None).is_err(), "iat in future"); + let ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, Some("wrong"), 0); + assert!(verifier.verify_proof(&ath_proof, "GET", "https://example.com/resource", Some("correct")).is_err(), "ath mismatch"); + let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); + assert!(verifier.verify_proof(&no_ath_proof, "GET", "https://example.com/resource", Some("expected")).is_err(), "Missing ath"); +} + +#[test] +fn test_dpop_proof_signature_attacks() { use p256::ecdsa::{Signature, SigningKey, signature::Signer}; use p256::elliptic_curve::sec1::ToEncodedPoint; let secret = b"test-dpop-secret-32-bytes-long!!"; let verifier = DPoPVerifier::new(secret); let signing_key = SigningKey::random(&mut rand::thread_rng()); let attacker_key = SigningKey::random(&mut rand::thread_rng()); - let attacker_verifying = attacker_key.verifying_key(); - let attacker_point = attacker_verifying.to_encoded_point(false); + let attacker_point = attacker_key.verifying_key().to_encoded_point(false); let x = URL_SAFE_NO_PAD.encode(attacker_point.x().unwrap()); let y = URL_SAFE_NO_PAD.encode(attacker_point.y().unwrap()); - let header = json!({ - "typ": "dpop+jwt", - "alg": "ES256", - "jwk": { - "kty": "EC", - "crv": "P-256", - "x": x, - "y": y - } - }); - let payload = json!({ - "jti": format!("key-sub-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), - "htm": "POST", - "htu": "https://example.com/token", - "iat": Utc::now().timestamp() - }); + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); + let payload = json!({ "jti": format!("key-sub-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), + "htm": "POST", "htu": "https://example.com/token", "iat": Utc::now().timestamp() }); let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); let signing_input = format!("{}.{}", header_b64, payload_b64); let signature: Signature = signing_key.sign(signing_input.as_bytes()); - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); - let mismatched_proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); - let result = - verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None); - assert!( - result.is_err(), - "DPoP proof with mismatched key should be rejected" - ); + let mismatched = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); + assert!(verifier.verify_proof(&mismatched, "POST", "https://example.com/token", None).is_err(), "Mismatched key should fail"); + let point = signing_key.verifying_key().to_encoded_point(false); + let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", + "x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } }); + let good_header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&good_header).unwrap()); + let good_input = format!("{}.{}", good_header_b64, payload_b64); + let good_sig: Signature = signing_key.sign(good_input.as_bytes()); + let mut sig_bytes = good_sig.to_bytes().to_vec(); + sig_bytes[0] ^= 0xFF; + let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes)); + assert!(verifier.verify_proof(&tampered, "POST", "https://example.com/token", None).is_err(), "Tampered sig should fail"); } #[test] -fn test_security_jwk_thumbprint_consistency() { - let jwk = DPoPJwk { - kty: "EC".to_string(), - crv: Some("P-256".to_string()), +fn test_jwk_thumbprint() { + let jwk = DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), - }; - let mut results = Vec::new(); - for _ in 0..100 { - results.push(compute_jwk_thumbprint(&jwk).unwrap()); - } - let first = &results[0]; - for (i, result) in results.iter().enumerate() { - assert_eq!( - first, result, - "Thumbprint should be deterministic, but iteration {} differs", - i - ); - } + y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()) }; + let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); + let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); + assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); + assert!(!tp1.is_empty()); + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("secp256k1".to_string()), + x: Some("x".to_string()), y: Some("y".to_string()) }).is_ok()); + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "OKP".to_string(), crv: Some("Ed25519".to_string()), + x: Some("x".to_string()), y: None }).is_ok()); + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: None, x: Some("x".to_string()), y: Some("y".to_string()) }).is_err()); + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: None, y: Some("y".to_string()) }).is_err()); + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: Some("x".to_string()), y: None }).is_err()); + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "RSA".to_string(), crv: None, x: None, y: None }).is_err()); } #[test] -fn test_security_dpop_iat_clock_skew_limits() { +fn test_dpop_clock_skew() { use p256::ecdsa::{Signature, SigningKey, signature::Signer}; use p256::elliptic_curve::sec1::ToEncodedPoint; let secret = b"test-dpop-secret-32-bytes-long!!"; let verifier = DPoPVerifier::new(secret); - let test_offsets = vec![ - (-600, true), - (-301, true), - (-299, false), - (0, false), - (299, false), - (301, true), - (600, true), - ]; - for (offset_secs, should_fail) in test_offsets { + let test_cases = vec![(-600, true), (-301, true), (-299, false), (0, false), (299, false), (301, true), (600, true)]; + for (offset, should_fail) in test_cases { let signing_key = SigningKey::random(&mut rand::thread_rng()); - let verifying_key = signing_key.verifying_key(); - let point = verifying_key.to_encoded_point(false); + let point = signing_key.verifying_key().to_encoded_point(false); let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); - let header = json!({ - "typ": "dpop+jwt", - "alg": "ES256", - "jwk": { - "kty": "EC", - "crv": "P-256", - "x": x, - "y": y - } - }); - let payload = json!({ - "jti": format!("clock-{}-{}", offset_secs, Utc::now().timestamp_nanos_opt().unwrap_or(0)), - "htm": "POST", - "htu": "https://example.com/token", - "iat": Utc::now().timestamp() + offset_secs - }); + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); + let payload = json!({ "jti": format!("clock-{}-{}", offset, Utc::now().timestamp_nanos_opt().unwrap_or(0)), + "htm": "POST", "htu": "https://example.com/token", "iat": Utc::now().timestamp() + offset }); let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); let signing_input = format!("{}.{}", header_b64, payload_b64); let signature: Signature = signing_key.sign(signing_input.as_bytes()); - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); - let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); + let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); - if should_fail { - assert!( - result.is_err(), - "iat offset {} should be rejected", - offset_secs - ); - } else { - assert!( - result.is_ok(), - "iat offset {} should be accepted", - offset_secs - ); - } + if should_fail { assert!(result.is_err(), "offset {} should fail", offset); } + else { assert!(result.is_ok(), "offset {} should pass", offset); } } } #[test] -fn test_security_dpop_method_case_insensitivity() { +fn test_dpop_http_method_case() { use p256::ecdsa::{Signature, SigningKey, signature::Signer}; use p256::elliptic_curve::sec1::ToEncodedPoint; let secret = b"test-dpop-secret-32-bytes-long!!"; let verifier = DPoPVerifier::new(secret); let signing_key = SigningKey::random(&mut rand::thread_rng()); - let verifying_key = signing_key.verifying_key(); - let point = verifying_key.to_encoded_point(false); + let point = signing_key.verifying_key().to_encoded_point(false); let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); - let header = json!({ - "typ": "dpop+jwt", - "alg": "ES256", - "jwk": { - "kty": "EC", - "crv": "P-256", - "x": x, - "y": y - } - }); - let payload = json!({ - "jti": format!("case-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), - "htm": "post", - "htu": "https://example.com/token", - "iat": Utc::now().timestamp() - }); + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); + let payload = json!({ "jti": format!("case-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), + "htm": "post", "htu": "https://example.com/token", "iat": Utc::now().timestamp() }); let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); let signing_input = format!("{}.{}", header_b64, payload_b64); let signature: Signature = signing_key.sign(signing_input.as_bytes()); - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); - let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); - assert!( - result.is_ok(), - "HTTP method comparison should be case-insensitive" - ); -} - -#[tokio::test] -async fn test_security_invalid_grant_type_rejected() { - let url = base_url().await; - let http_client = client(); - let grant_types = vec![ - "client_credentials", - "password", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "urn:ietf:params:oauth:grant-type:device_code", - "", - "AUTHORIZATION_CODE", - "Authorization_Code", - ]; - for grant_type in grant_types { - let res = http_client - .post(format!("{}/oauth/token", url)) - .form(&[ - ("grant_type", grant_type), - ("client_id", "https://example.com"), - ]) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::BAD_REQUEST, - "Grant type '{}' should be rejected", - grant_type - ); - } -} - -#[tokio::test] -async fn test_security_token_with_wrong_typ_rejected() { - let url = base_url().await; - let http_client = client(); - let wrong_types = vec!["JWT", "jwt", "at+JWT", "access_token", ""]; - for typ in wrong_types { - let header = json!({ - "alg": "HS256", - "typ": typ - }); - let payload = json!({ - "iss": "https://test.pds", - "sub": "did:plc:test", - "aud": "https://test.pds", - "iat": Utc::now().timestamp(), - "exp": Utc::now().timestamp() + 3600, - "jti": "wrong-typ-token" - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]); - let token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Token with typ='{}' should be rejected", - typ - ); - } -} - -#[tokio::test] -async fn test_security_missing_required_claims_rejected() { - let url = base_url().await; - let http_client = client(); - let tokens_missing_claims = vec![ - (json!({"iss": "x", "sub": "x", "aud": "x", "iat": 0}), "exp"), - ( - json!({"iss": "x", "sub": "x", "aud": "x", "exp": 9999999999i64}), - "iat", - ), - ( - json!({"iss": "x", "aud": "x", "iat": 0, "exp": 9999999999i64}), - "sub", - ), - ]; - for (payload, missing_claim) in tokens_missing_claims { - let header = json!({ - "alg": "HS256", - "typ": "at+jwt" - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]); - let token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Token missing '{}' claim should be rejected", - missing_claim - ); - } -} - -#[tokio::test] -async fn test_security_malformed_tokens_rejected() { - let url = base_url().await; - let http_client = client(); - let malformed_tokens = vec![ - "", - "not-a-token", - "one.two", - "one.two.three.four", - "....", - "eyJhbGciOiJIUzI1NiJ9", - "eyJhbGciOiJIUzI1NiJ9.", - "eyJhbGciOiJIUzI1NiJ9..", - ".eyJzdWIiOiJ0ZXN0In0.", - "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig", - "eyJhbGciOiJIUzI1NiJ9.!!invalid!!.sig", - ]; - for token in malformed_tokens { - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", format!("Bearer {}", token)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Malformed token '{}' should be rejected", - if token.len() > 50 { - &token[..50] - } else { - token - } - ); - } -} - -#[tokio::test] -async fn test_security_authorization_header_formats() { - let url = base_url().await; - let http_client = client(); - let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; - let valid_case_variants = vec![ - format!("bearer {}", access_token), - format!("BEARER {}", access_token), - format!("Bearer {}", access_token), - ]; - for auth_header in valid_case_variants { - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", &auth_header) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::OK, - "Auth header '{}...' should be accepted (RFC 7235 case-insensitivity)", - if auth_header.len() > 30 { - &auth_header[..30] - } else { - &auth_header - } - ); - } - let invalid_formats = vec![ - format!("Basic {}", access_token), - format!("Digest {}", access_token), - access_token.clone(), - format!("Bearer{}", access_token), - ]; - for auth_header in invalid_formats { - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", &auth_header) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Auth header '{}...' should be rejected", - if auth_header.len() > 30 { - &auth_header[..30] - } else { - &auth_header - } - ); - } -} - -#[tokio::test] -async fn test_security_no_authorization_header() { - let url = base_url().await; - let http_client = client(); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Missing auth header should return 401" - ); -} - -#[tokio::test] -async fn test_security_empty_authorization_header() { - let url = base_url().await; - let http_client = client(); - let res = http_client - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) - .header("Authorization", "") - .send() - .await - .unwrap(); - assert_eq!( - res.status(), - StatusCode::UNAUTHORIZED, - "Empty auth header should return 401" - ); -} - -#[tokio::test] -async fn test_security_revoked_token_rejected() { - let url = base_url().await; - let http_client = client(); - let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await; - let revoke_res = http_client - .post(format!("{}/oauth/revoke", url)) - .form(&[("token", &refresh_token)]) - .send() - .await - .unwrap(); - assert_eq!(revoke_res.status(), StatusCode::OK); - let introspect_res = http_client - .post(format!("{}/oauth/introspect", url)) - .form(&[("token", &access_token)]) - .send() - .await - .unwrap(); - let introspect_body: Value = introspect_res.json().await.unwrap(); - assert_eq!( - introspect_body["active"], false, - "Revoked token should be inactive" - ); -} - -#[tokio::test] -#[ignore = "rate limiting is disabled in test environment"] -async fn test_security_oauth_authorize_rate_limiting() { - let url = base_url().await; - let http_client = no_redirect_client(); - let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0); - let unique_ip = format!( - "10.{}.{}.{}", - (ts >> 16) & 0xFF, - (ts >> 8) & 0xFF, - ts & 0xFF - ); - let redirect_uri = "https://example.com/rate-limit-callback"; - let mock_client = setup_mock_client_metadata(redirect_uri).await; - let client_id = mock_client.uri(); - let (_, code_challenge) = generate_pkce(); - let client_for_par = client(); - let par_body: Value = client_for_par - .post(format!("{}/oauth/par", url)) - .form(&[ - ("response_type", "code"), - ("client_id", &client_id), - ("redirect_uri", redirect_uri), - ("code_challenge", &code_challenge), - ("code_challenge_method", "S256"), - ]) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); - let request_uri = par_body["request_uri"].as_str().unwrap(); - let mut rate_limited_count = 0; - let mut other_count = 0; - for _ in 0..15 { - let res = http_client - .post(format!("{}/oauth/authorize", url)) - .header("X-Forwarded-For", &unique_ip) - .form(&[ - ("request_uri", request_uri), - ("username", "nonexistent_user"), - ("password", "wrong_password"), - ("remember_device", "false"), - ]) - .send() - .await - .unwrap(); - match res.status() { - StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1, - _ => other_count += 1, - } - } - assert!( - rate_limited_count > 0, - "Expected at least one rate-limited response after 15 OAuth authorize attempts. Got {} other and {} rate limited.", - other_count, - rate_limited_count - ); -} - -fn create_dpop_proof( - method: &str, - uri: &str, - nonce: Option<&str>, - ath: Option<&str>, - iat_offset_secs: i64, -) -> String { - use p256::ecdsa::{Signature, SigningKey, signature::Signer}; - let signing_key = SigningKey::random(&mut rand::thread_rng()); - let verifying_key = signing_key.verifying_key(); - let point = verifying_key.to_encoded_point(false); - let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); - let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); - let jwk = json!({ - "kty": "EC", - "crv": "P-256", - "x": x, - "y": y - }); - let header = json!({ - "typ": "dpop+jwt", - "alg": "ES256", - "jwk": jwk - }); - let mut payload = json!({ - "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), - "htm": method, - "htu": uri, - "iat": Utc::now().timestamp() + iat_offset_secs - }); - if let Some(n) = nonce { - payload["nonce"] = json!(n); - } - if let Some(a) = ath { - payload["ath"] = json!(a); - } - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let signing_input = format!("{}.{}", header_b64, payload_b64); - let signature: Signature = signing_key.sign(signing_input.as_bytes()); - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); - format!("{}.{}", signing_input, signature_b64) -} - -#[test] -fn test_dpop_nonce_generation() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let nonce1 = verifier.generate_nonce(); - let nonce2 = verifier.generate_nonce(); - assert!(!nonce1.is_empty()); - assert!(!nonce2.is_empty()); -} - -#[test] -fn test_dpop_nonce_validation_success() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let nonce = verifier.generate_nonce(); - let result = verifier.validate_nonce(&nonce); - assert!(result.is_ok(), "Valid nonce should pass: {:?}", result); -} - -#[test] -fn test_dpop_nonce_wrong_secret() { - let secret1 = b"test-dpop-secret-32-bytes-long!!"; - let secret2 = b"different-secret-32-bytes-long!!"; - let verifier1 = DPoPVerifier::new(secret1); - let verifier2 = DPoPVerifier::new(secret2); - let nonce = verifier1.generate_nonce(); - let result = verifier2.validate_nonce(&nonce); - assert!(result.is_err(), "Nonce from different secret should fail"); -} - -#[test] -fn test_dpop_nonce_invalid_format() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - assert!(verifier.validate_nonce("invalid").is_err()); - assert!(verifier.validate_nonce("").is_err()); - assert!(verifier.validate_nonce("!!!not-base64!!!").is_err()); -} - -#[test] -fn test_jwk_thumbprint_ec_p256() { - let jwk = DPoPJwk { - kty: "EC".to_string(), - crv: Some("P-256".to_string()), - x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), - }; - let thumbprint = compute_jwk_thumbprint(&jwk); - assert!(thumbprint.is_ok()); - let tp = thumbprint.unwrap(); - assert!(!tp.is_empty()); - assert!( - tp.chars() - .all(|c| c.is_alphanumeric() || c == '-' || c == '_') - ); -} - -#[test] -fn test_jwk_thumbprint_ec_secp256k1() { - let jwk = DPoPJwk { - kty: "EC".to_string(), - crv: Some("secp256k1".to_string()), - x: Some("some_x_value".to_string()), - y: Some("some_y_value".to_string()), - }; - let thumbprint = compute_jwk_thumbprint(&jwk); - assert!(thumbprint.is_ok()); -} - -#[test] -fn test_jwk_thumbprint_okp_ed25519() { - let jwk = DPoPJwk { - kty: "OKP".to_string(), - crv: Some("Ed25519".to_string()), - x: Some("some_x_value".to_string()), - y: None, - }; - let thumbprint = compute_jwk_thumbprint(&jwk); - assert!(thumbprint.is_ok()); -} - -#[test] -fn test_jwk_thumbprint_missing_crv() { - let jwk = DPoPJwk { - kty: "EC".to_string(), - crv: None, - x: Some("x".to_string()), - y: Some("y".to_string()), - }; - let thumbprint = compute_jwk_thumbprint(&jwk); - assert!(thumbprint.is_err()); -} - -#[test] -fn test_jwk_thumbprint_missing_x() { - let jwk = DPoPJwk { - kty: "EC".to_string(), - crv: Some("P-256".to_string()), - x: None, - y: Some("y".to_string()), - }; - let thumbprint = compute_jwk_thumbprint(&jwk); - assert!(thumbprint.is_err()); -} - -#[test] -fn test_jwk_thumbprint_missing_y_for_ec() { - let jwk = DPoPJwk { - kty: "EC".to_string(), - crv: Some("P-256".to_string()), - x: Some("x".to_string()), - y: None, - }; - let thumbprint = compute_jwk_thumbprint(&jwk); - assert!(thumbprint.is_err()); -} - -#[test] -fn test_jwk_thumbprint_unsupported_key_type() { - let jwk = DPoPJwk { - kty: "RSA".to_string(), - crv: None, - x: None, - y: None, - }; - let thumbprint = compute_jwk_thumbprint(&jwk); - assert!(thumbprint.is_err()); -} - -#[test] -fn test_jwk_thumbprint_deterministic() { - let jwk = DPoPJwk { - kty: "EC".to_string(), - crv: Some("P-256".to_string()), - x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), - }; - let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); - let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); - assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); -} - -#[test] -fn test_dpop_proof_invalid_format() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let result = verifier.verify_proof("not.enough.parts", "POST", "https://example.com", None); - assert!(result.is_err()); - let result = verifier.verify_proof("invalid", "POST", "https://example.com", None); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_invalid_typ() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let header = json!({ - "typ": "JWT", - "alg": "ES256", - "jwk": { - "kty": "EC", - "crv": "P-256", - "x": "x", - "y": "y" - } - }); - let payload = json!({ - "jti": "unique", - "htm": "POST", - "htu": "https://example.com", - "iat": Utc::now().timestamp() - }); - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); - let proof = format!("{}.{}.sig", header_b64, payload_b64); - let result = verifier.verify_proof(&proof, "POST", "https://example.com", None); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_method_mismatch() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); - let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_uri_mismatch() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); - let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_iat_too_old() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_iat_future() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_ath_mismatch() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let proof = create_dpop_proof( - "GET", - "https://example.com/resource", - None, - Some("wrong_hash"), - 0, - ); - let result = verifier.verify_proof( - &proof, - "GET", - "https://example.com/resource", - Some("correct_hash"), - ); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_missing_ath_when_required() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); - let result = verifier.verify_proof( - &proof, - "GET", - "https://example.com/resource", - Some("expected_hash"), - ); - assert!(result.is_err()); -} - -#[test] -fn test_dpop_proof_uri_ignores_query_params() { - let secret = b"test-dpop-secret-32-bytes-long!!"; - let verifier = DPoPVerifier::new(secret); - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None); - assert!( - result.is_ok(), - "Query params should be ignored: {:?}", - result - ); + let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); + assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token", None).is_ok(), "HTTP method should be case-insensitive"); } diff --git a/tests/plc_operations.rs b/tests/plc_operations.rs index b526f62..2b5abbd 100644 --- a/tests/plc_operations.rs +++ b/tests/plc_operations.rs @@ -5,454 +5,127 @@ use serde_json::json; use sqlx::PgPool; #[tokio::test] -async fn test_request_plc_operation_signature_requires_auth() { +async fn test_plc_operation_auth() { let client = client(); - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", - base_url().await - )) - .send() - .await - .expect("Request failed"); + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) + .send().await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); -} - -#[tokio::test] -async fn test_request_plc_operation_signature_success() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", - base_url().await - )) - .bearer_auth(&token) - .send() - .await - .expect("Request failed"); + let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) + .json(&json!({})).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) + .json(&json!({ "operation": {} })).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + let (token, _) = create_account_and_login(&client).await; + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) + .bearer_auth(&token).send().await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] -async fn test_sign_plc_operation_requires_auth() { +async fn test_sign_plc_operation_validation() { let client = client(); - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.signPlcOperation", - base_url().await - )) - .json(&json!({})) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); -} - -#[tokio::test] -async fn test_sign_plc_operation_requires_token() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.signPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({})) - .send() - .await - .expect("Request failed"); + let (token, _) = create_account_and_login(&client).await; + let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({})).send().await.unwrap(); assert_eq!(res.status(), StatusCode::BAD_REQUEST); let body: serde_json::Value = res.json().await.unwrap(); assert_eq!(body["error"], "InvalidRequest"); -} - -#[tokio::test] -async fn test_sign_plc_operation_invalid_token() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.signPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({ - "token": "invalid-token-12345" - })) - .send() - .await - .expect("Request failed"); + let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap(); assert_eq!(res.status(), StatusCode::BAD_REQUEST); let body: serde_json::Value = res.json().await.unwrap(); assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); } #[tokio::test] -async fn test_submit_plc_operation_requires_auth() { - let client = client(); - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.submitPlcOperation", - base_url().await - )) - .json(&json!({ - "operation": {} - })) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); -} - -#[tokio::test] -async fn test_submit_plc_operation_invalid_operation() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.submitPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({ - "operation": { - "type": "invalid_type" - } - })) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: serde_json::Value = res.json().await.unwrap(); - assert_eq!(body["error"], "InvalidRequest"); -} - -#[tokio::test] -async fn test_submit_plc_operation_missing_sig() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.submitPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({ - "operation": { - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "prev": null - } - })) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: serde_json::Value = res.json().await.unwrap(); - assert_eq!(body["error"], "InvalidRequest"); -} - -#[tokio::test] -async fn test_submit_plc_operation_wrong_service_endpoint() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.submitPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({ - "operation": { - "type": "plc_operation", - "rotationKeys": ["did:key:z123"], - "verificationMethods": {"atproto": "did:key:z456"}, - "alsoKnownAs": ["at://wrong.handle"], - "services": { - "atproto_pds": { - "type": "AtprotoPersonalDataServer", - "endpoint": "https://wrong.example.com" - } - }, - "prev": null, - "sig": "fake_signature" - } - })) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); -} - -#[tokio::test] -async fn test_request_plc_operation_creates_token_in_db() { +async fn test_submit_plc_operation_validation() { let client = client(); let (token, did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", - base_url().await - )) - .bearer_auth(&token) - .send() - .await - .expect("Request failed"); + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({ + "operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, + "alsoKnownAs": [], "services": {}, "prev": null } + })).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let handle = did.split(':').last().unwrap_or("user"); + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({ + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], + "verificationMethods": { "atproto": "did:key:z456" }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": "https://wrong.example.com" } }, + "prev": null, "sig": "fake_signature" } + })).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({ + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:zWrongRotationKey123"], + "verificationMethods": { "atproto": "did:key:zWrongVerificationKey456" }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } }, + "prev": null, "sig": "fake_signature" } + })).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); + assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation")); + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({ + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], + "verificationMethods": { "atproto": "did:key:z456" }, + "alsoKnownAs": ["at://totally.wrong.handle"], + "services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } }, + "prev": null, "sig": "fake_signature" } + })).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) + .bearer_auth(&token).json(&json!({ + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], + "verificationMethods": { "atproto": "did:key:z456" }, + "alsoKnownAs": ["at://user"], + "services": { "atproto_pds": { "type": "WrongServiceType", "endpoint": format!("https://{}", hostname) } }, + "prev": null, "sig": "fake_signature" } + })).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_plc_token_lifecycle() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) + .bearer_auth(&token).send().await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let db_url = get_db_connection_string().await; - let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); + let pool = PgPool::connect(&db_url).await.unwrap(); let row = sqlx::query!( - r#" - SELECT t.token, t.expires_at - FROM plc_operation_tokens t - JOIN users u ON t.user_id = u.id - WHERE u.did = $1 - "#, + "SELECT t.token, t.expires_at FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did - ) - .fetch_optional(&pool) - .await - .expect("Query failed"); + ).fetch_optional(&pool).await.unwrap(); assert!(row.is_some(), "PLC token should be created in database"); let row = row.unwrap(); - assert!( - row.token.len() == 11, - "Token should be in format xxxxx-xxxxx" - ); + assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx"); assert!(row.token.contains('-'), "Token should contain hyphen"); - assert!( - row.expires_at > chrono::Utc::now(), - "Token should not be expired" - ); -} - -#[tokio::test] -async fn test_request_plc_operation_replaces_existing_token() { - let client = client(); - let (token, did) = create_account_and_login(&client).await; - let res1 = client - .post(format!( - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", - base_url().await - )) - .bearer_auth(&token) - .send() - .await - .expect("Request 1 failed"); - assert_eq!(res1.status(), StatusCode::OK); - let db_url = get_db_connection_string().await; - let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); - let token1 = sqlx::query_scalar!( - r#" - SELECT t.token - FROM plc_operation_tokens t - JOIN users u ON t.user_id = u.id - WHERE u.did = $1 - "#, - did - ) - .fetch_one(&pool) - .await - .expect("Query failed"); - let res2 = client - .post(format!( - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", - base_url().await - )) - .bearer_auth(&token) - .send() - .await - .expect("Request 2 failed"); - assert_eq!(res2.status(), StatusCode::OK); + assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); + let diff = row.expires_at - chrono::Utc::now(); + assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes"); + let token1 = row.token.clone(); + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) + .bearer_auth(&token).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); let token2 = sqlx::query_scalar!( - r#" - SELECT t.token - FROM plc_operation_tokens t - JOIN users u ON t.user_id = u.id - WHERE u.did = $1 - "#, - did - ) - .fetch_one(&pool) - .await - .expect("Query failed"); + "SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did + ).fetch_one(&pool).await.unwrap(); assert_ne!(token1, token2, "Second request should generate a new token"); let count: i64 = sqlx::query_scalar!( - r#" - SELECT COUNT(*) as "count!" - FROM plc_operation_tokens t - JOIN users u ON t.user_id = u.id - WHERE u.did = $1 - "#, - did - ) - .fetch_one(&pool) - .await - .expect("Count query failed"); + "SELECT COUNT(*) as \"count!\" FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did + ).fetch_one(&pool).await.unwrap(); assert_eq!(count, 1, "Should only have one token per user"); } - -#[tokio::test] -async fn test_submit_plc_operation_wrong_verification_method() { - let client = client(); - let (token, did) = create_account_and_login(&client).await; - let hostname = - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); - let handle = did.split(':').last().unwrap_or("user"); - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.submitPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({ - "operation": { - "type": "plc_operation", - "rotationKeys": ["did:key:zWrongRotationKey123"], - "verificationMethods": {"atproto": "did:key:zWrongVerificationKey456"}, - "alsoKnownAs": [format!("at://{}", handle)], - "services": { - "atproto_pds": { - "type": "AtprotoPersonalDataServer", - "endpoint": format!("https://{}", hostname) - } - }, - "prev": null, - "sig": "fake_signature" - } - })) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: serde_json::Value = res.json().await.unwrap(); - assert_eq!(body["error"], "InvalidRequest"); - assert!( - body["message"] - .as_str() - .unwrap_or("") - .contains("signing key") - || body["message"].as_str().unwrap_or("").contains("rotation"), - "Error should mention key mismatch: {:?}", - body - ); -} - -#[tokio::test] -async fn test_submit_plc_operation_wrong_handle() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let hostname = - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.submitPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({ - "operation": { - "type": "plc_operation", - "rotationKeys": ["did:key:z123"], - "verificationMethods": {"atproto": "did:key:z456"}, - "alsoKnownAs": ["at://totally.wrong.handle"], - "services": { - "atproto_pds": { - "type": "AtprotoPersonalDataServer", - "endpoint": format!("https://{}", hostname) - } - }, - "prev": null, - "sig": "fake_signature" - } - })) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: serde_json::Value = res.json().await.unwrap(); - assert_eq!(body["error"], "InvalidRequest"); -} - -#[tokio::test] -async fn test_submit_plc_operation_wrong_service_type() { - let client = client(); - let (token, _did) = create_account_and_login(&client).await; - let hostname = - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.submitPlcOperation", - base_url().await - )) - .bearer_auth(&token) - .json(&json!({ - "operation": { - "type": "plc_operation", - "rotationKeys": ["did:key:z123"], - "verificationMethods": {"atproto": "did:key:z456"}, - "alsoKnownAs": ["at://user"], - "services": { - "atproto_pds": { - "type": "WrongServiceType", - "endpoint": format!("https://{}", hostname) - } - }, - "prev": null, - "sig": "fake_signature" - } - })) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - let body: serde_json::Value = res.json().await.unwrap(); - assert_eq!(body["error"], "InvalidRequest"); -} - -#[tokio::test] -async fn test_plc_token_expiry_format() { - let client = client(); - let (token, did) = create_account_and_login(&client).await; - let res = client - .post(format!( - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", - base_url().await - )) - .bearer_auth(&token) - .send() - .await - .expect("Request failed"); - assert_eq!(res.status(), StatusCode::OK); - let db_url = get_db_connection_string().await; - let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); - let row = sqlx::query!( - r#" - SELECT t.expires_at - FROM plc_operation_tokens t - JOIN users u ON t.user_id = u.id - WHERE u.did = $1 - "#, - did - ) - .fetch_one(&pool) - .await - .expect("Query failed"); - let now = chrono::Utc::now(); - let expires = row.expires_at; - let diff = expires - now; - assert!( - diff.num_minutes() >= 9, - "Token should expire in ~10 minutes, got {} minutes", - diff.num_minutes() - ); - assert!( - diff.num_minutes() <= 11, - "Token should expire in ~10 minutes, got {} minutes", - diff.num_minutes() - ); -} diff --git a/tests/plc_validation.rs b/tests/plc_validation.rs index ca3a93c..1335645 100644 --- a/tests/plc_validation.rs +++ b/tests/plc_validation.rs @@ -13,9 +13,7 @@ fn create_valid_operation() -> serde_json::Value { let op = json!({ "type": "plc_operation", "rotationKeys": [did_key.clone()], - "verificationMethods": { - "atproto": did_key.clone() - }, + "verificationMethods": { "atproto": did_key.clone() }, "alsoKnownAs": ["at://test.handle"], "services": { "atproto_pds": { @@ -29,444 +27,161 @@ fn create_valid_operation() -> serde_json::Value { } #[test] -fn test_validate_plc_operation_valid() { +fn test_plc_operation_basic_validation() { let op = create_valid_operation(); - let result = validate_plc_operation(&op); - assert!(result.is_ok()); + assert!(validate_plc_operation(&op).is_ok()); + + let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); + assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); + + let invalid_type = json!({ "type": "invalid_type", "sig": "test" }); + assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); + + let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); + assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); + + let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); + assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); + + let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" }); + assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); + + let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" }); + assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); + + let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" }); + assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); + + assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_)))); } #[test] -fn test_validate_plc_operation_missing_type() { - let op = json!({ - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "sig": "test" - }); - let result = validate_plc_operation(&op); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); -} - -#[test] -fn test_validate_plc_operation_invalid_type() { - let op = json!({ - "type": "invalid_type", - "sig": "test" - }); - let result = validate_plc_operation(&op); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); -} - -#[test] -fn test_validate_plc_operation_missing_sig() { - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {} - }); - let result = validate_plc_operation(&op); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); -} - -#[test] -fn test_validate_plc_operation_missing_rotation_keys() { - let op = json!({ - "type": "plc_operation", - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "sig": "test" - }); - let result = validate_plc_operation(&op); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); -} - -#[test] -fn test_validate_plc_operation_missing_verification_methods() { - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "alsoKnownAs": [], - "services": {}, - "sig": "test" - }); - let result = validate_plc_operation(&op); - assert!( - matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")) - ); -} - -#[test] -fn test_validate_plc_operation_missing_also_known_as() { - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "services": {}, - "sig": "test" - }); - let result = validate_plc_operation(&op); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); -} - -#[test] -fn test_validate_plc_operation_missing_services() { - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "sig": "test" - }); - let result = validate_plc_operation(&op); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); -} - -#[test] -fn test_validate_rotation_key_required() { +fn test_plc_submission_validation() { let key = SigningKey::random(&mut rand::thread_rng()); let did_key = signing_key_to_did_key(&key); let server_key = "did:key:zServer123"; - let op = json!({ + + let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({ "type": "plc_operation", - "rotationKeys": [did_key.clone()], - "verificationMethods": {"atproto": did_key.clone()}, - "alsoKnownAs": ["at://test.handle"], - "services": { - "atproto_pds": { - "type": "AtprotoPersonalDataServer", - "endpoint": "https://pds.example.com" - } - }, + "rotationKeys": [rotation_key], + "verificationMethods": {"atproto": signing_key}, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } }, "sig": "test" }); + let ctx = PlcValidationContext { server_rotation_key: server_key.to_string(), expected_signing_key: did_key.clone(), expected_handle: "test.handle".to_string(), expected_pds_endpoint: "https://pds.example.com".to_string(), }; - let result = validate_plc_operation_for_submission(&op, &ctx); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); -} -#[test] -fn test_validate_signing_key_match() { - let key = SigningKey::random(&mut rand::thread_rng()); - let did_key = signing_key_to_did_key(&key); - let wrong_key = "did:key:zWrongKey456"; - let op = json!({ - "type": "plc_operation", - "rotationKeys": [did_key.clone()], - "verificationMethods": {"atproto": wrong_key}, - "alsoKnownAs": ["at://test.handle"], - "services": { - "atproto_pds": { - "type": "AtprotoPersonalDataServer", - "endpoint": "https://pds.example.com" - } - }, - "sig": "test" - }); - let ctx = PlcValidationContext { + let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); + assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); + + let ctx_with_user_key = PlcValidationContext { server_rotation_key: did_key.clone(), expected_signing_key: did_key.clone(), expected_handle: "test.handle".to_string(), expected_pds_endpoint: "https://pds.example.com".to_string(), }; - let result = validate_plc_operation_for_submission(&op, &ctx); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); + + let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); + assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); + + let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); + assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); + + let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com"); + assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); + + let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com"); + assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); } #[test] -fn test_validate_handle_match() { +fn test_signature_verification() { let key = SigningKey::random(&mut rand::thread_rng()); let did_key = signing_key_to_did_key(&key); let op = json!({ - "type": "plc_operation", - "rotationKeys": [did_key.clone()], - "verificationMethods": {"atproto": did_key.clone()}, - "alsoKnownAs": ["at://wrong.handle"], - "services": { - "atproto_pds": { - "type": "AtprotoPersonalDataServer", - "endpoint": "https://pds.example.com" - } - }, - "sig": "test" - }); - let ctx = PlcValidationContext { - server_rotation_key: did_key.clone(), - expected_signing_key: did_key.clone(), - expected_handle: "test.handle".to_string(), - expected_pds_endpoint: "https://pds.example.com".to_string(), - }; - let result = validate_plc_operation_for_submission(&op, &ctx); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); -} - -#[test] -fn test_validate_pds_service_type() { - let key = SigningKey::random(&mut rand::thread_rng()); - let did_key = signing_key_to_did_key(&key); - let op = json!({ - "type": "plc_operation", - "rotationKeys": [did_key.clone()], - "verificationMethods": {"atproto": did_key.clone()}, - "alsoKnownAs": ["at://test.handle"], - "services": { - "atproto_pds": { - "type": "WrongServiceType", - "endpoint": "https://pds.example.com" - } - }, - "sig": "test" - }); - let ctx = PlcValidationContext { - server_rotation_key: did_key.clone(), - expected_signing_key: did_key.clone(), - expected_handle: "test.handle".to_string(), - expected_pds_endpoint: "https://pds.example.com".to_string(), - }; - let result = validate_plc_operation_for_submission(&op, &ctx); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); -} - -#[test] -fn test_validate_pds_endpoint_match() { - let key = SigningKey::random(&mut rand::thread_rng()); - let did_key = signing_key_to_did_key(&key); - let op = json!({ - "type": "plc_operation", - "rotationKeys": [did_key.clone()], - "verificationMethods": {"atproto": did_key.clone()}, - "alsoKnownAs": ["at://test.handle"], - "services": { - "atproto_pds": { - "type": "AtprotoPersonalDataServer", - "endpoint": "https://wrong.endpoint.com" - } - }, - "sig": "test" - }); - let ctx = PlcValidationContext { - server_rotation_key: did_key.clone(), - expected_signing_key: did_key.clone(), - expected_handle: "test.handle".to_string(), - expected_pds_endpoint: "https://pds.example.com".to_string(), - }; - let result = validate_plc_operation_for_submission(&op, &ctx); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); -} - -#[test] -fn test_verify_signature_secp256k1() { - let key = SigningKey::random(&mut rand::thread_rng()); - let did_key = signing_key_to_did_key(&key); - let op = json!({ - "type": "plc_operation", - "rotationKeys": [did_key.clone()], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "prev": null + "type": "plc_operation", "rotationKeys": [did_key.clone()], + "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "prev": null }); let signed = sign_operation(&op, &key).unwrap(); - let rotation_keys = vec![did_key]; - let result = verify_operation_signature(&signed, &rotation_keys); - assert!(result.is_ok()); - assert!(result.unwrap()); -} + let result = verify_operation_signature(&signed, &[did_key.clone()]); + assert!(result.is_ok() && result.unwrap()); -#[test] -fn test_verify_signature_wrong_key() { - let key = SigningKey::random(&mut rand::thread_rng()); let other_key = SigningKey::random(&mut rand::thread_rng()); - let other_did_key = signing_key_to_did_key(&other_key); - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "prev": null + let other_did = signing_key_to_did_key(&other_key); + let result = verify_operation_signature(&signed, &[other_did]); + assert!(result.is_ok() && !result.unwrap()); + + let result = verify_operation_signature(&signed, &["not-a-did-key".to_string()]); + assert!(result.is_ok() && !result.unwrap()); + + let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); + assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); + + let invalid_base64 = json!({ + "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, + "alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!" }); - let signed = sign_operation(&op, &key).unwrap(); - let wrong_rotation_keys = vec![other_did_key]; - let result = verify_operation_signature(&signed, &wrong_rotation_keys); - assert!(result.is_ok()); - assert!(!result.unwrap()); + assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_)))); } #[test] -fn test_verify_signature_invalid_did_key_format() { - let key = SigningKey::random(&mut rand::thread_rng()); - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "prev": null - }); - let signed = sign_operation(&op, &key).unwrap(); - let invalid_keys = vec!["not-a-did-key".to_string()]; - let result = verify_operation_signature(&signed, &invalid_keys); - assert!(result.is_ok()); - assert!(!result.unwrap()); -} - -#[test] -fn test_tombstone_validation() { - let op = json!({ - "type": "plc_tombstone", - "prev": "bafyreig6xxxxxyyyyyzzzzzz", - "sig": "test" - }); - let result = validate_plc_operation(&op); - assert!(result.is_ok()); -} - -#[test] -fn test_cid_for_cbor_deterministic() { - let value = json!({ - "alpha": 1, - "beta": 2 - }); +fn test_cid_and_key_utilities() { + let value = json!({ "alpha": 1, "beta": 2 }); let cid1 = cid_for_cbor(&value).unwrap(); let cid2 = cid_for_cbor(&value).unwrap(); - assert_eq!(cid1, cid2, "CID generation should be deterministic"); - assert!( - cid1.starts_with("bafyrei"), - "CID should start with bafyrei (dag-cbor + sha256)" - ); -} + assert_eq!(cid1, cid2, "CID should be deterministic"); + assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256"); -#[test] -fn test_cid_different_for_different_data() { - let value1 = json!({"data": 1}); - let value2 = json!({"data": 2}); - let cid1 = cid_for_cbor(&value1).unwrap(); - let cid2 = cid_for_cbor(&value2).unwrap(); - assert_ne!(cid1, cid2, "Different data should produce different CIDs"); -} + let value2 = json!({ "alpha": 999 }); + let cid3 = cid_for_cbor(&value2).unwrap(); + assert_ne!(cid1, cid3, "Different data should produce different CIDs"); -#[test] -fn test_signing_key_to_did_key_format() { let key = SigningKey::random(&mut rand::thread_rng()); - let did_key = signing_key_to_did_key(&key); - assert!( - did_key.starts_with("did:key:z"), - "Should start with did:key:z" - ); - assert!(did_key.len() > 50, "Did key should be reasonably long"); -} + let did = signing_key_to_did_key(&key); + assert!(did.starts_with("did:key:z") && did.len() > 50); + assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did"); -#[test] -fn test_signing_key_to_did_key_unique() { - let key1 = SigningKey::random(&mut rand::thread_rng()); let key2 = SigningKey::random(&mut rand::thread_rng()); - let did1 = signing_key_to_did_key(&key1); - let did2 = signing_key_to_did_key(&key2); - assert_ne!( - did1, did2, - "Different keys should produce different did:keys" - ); + assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids"); } #[test] -fn test_signing_key_to_did_key_consistent() { - let key = SigningKey::random(&mut rand::thread_rng()); - let did1 = signing_key_to_did_key(&key); - let did2 = signing_key_to_did_key(&key); - assert_eq!(did1, did2, "Same key should produce same did:key"); -} +fn test_tombstone_operations() { + let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" }); + assert!(validate_plc_operation(&tombstone).is_ok()); -#[test] -fn test_sign_operation_removes_existing_sig() { - let key = SigningKey::random(&mut rand::thread_rng()); - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "prev": null, - "sig": "old_signature" - }); - let signed = sign_operation(&op, &key).unwrap(); - let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap(); - assert_ne!(new_sig, "old_signature", "Should replace old signature"); -} - -#[test] -fn test_validate_plc_operation_not_object() { - let result = validate_plc_operation(&json!("not an object")); - assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); -} - -#[test] -fn test_validate_for_submission_tombstone_passes() { let key = SigningKey::random(&mut rand::thread_rng()); let did_key = signing_key_to_did_key(&key); - let op = json!({ - "type": "plc_tombstone", - "prev": "bafyreig6xxxxxyyyyyzzzzzz", - "sig": "test" - }); let ctx = PlcValidationContext { server_rotation_key: did_key.clone(), expected_signing_key: did_key, expected_handle: "test.handle".to_string(), expected_pds_endpoint: "https://pds.example.com".to_string(), }; - let result = validate_plc_operation_for_submission(&op, &ctx); - assert!( - result.is_ok(), - "Tombstone should pass submission validation" - ); + assert!(validate_plc_operation_for_submission(&tombstone, &ctx).is_ok()); } #[test] -fn test_verify_signature_missing_sig() { +fn test_sign_operation_and_struct() { + let key = SigningKey::random(&mut rand::thread_rng()); let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {} + "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, + "alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature" }); - let result = verify_operation_signature(&op, &[]); - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); -} + let signed = sign_operation(&op, &key).unwrap(); + assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature"); -#[test] -fn test_verify_signature_invalid_base64() { - let op = json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "sig": "not-valid-base64!!!" - }); - let result = verify_operation_signature(&op, &[]); - assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); -} - -#[test] -fn test_plc_operation_struct() { let mut services = HashMap::new(); - services.insert( - "atproto_pds".to_string(), - PlcService { - service_type: "AtprotoPersonalDataServer".to_string(), - endpoint: "https://pds.example.com".to_string(), - }, - ); + services.insert("atproto_pds".to_string(), PlcService { + service_type: "AtprotoPersonalDataServer".to_string(), + endpoint: "https://pds.example.com".to_string(), + }); let mut verification_methods = HashMap::new(); verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); let op = PlcOperation { diff --git a/tests/record_validation.rs b/tests/record_validation.rs index e3b0ace..d0519a8 100644 --- a/tests/record_validation.rs +++ b/tests/record_validation.rs @@ -9,187 +9,117 @@ fn now() -> String { } #[test] -fn test_validate_post_valid() { +fn test_post_record_validation() { let validator = RecordValidator::new(); - let post = json!({ + + let valid_post = json!({ "$type": "app.bsky.feed.post", "text": "Hello world!", "createdAt": now() }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_post_missing_text() { - let validator = RecordValidator::new(); - let post = json!({ + let missing_text = json!({ "$type": "app.bsky.feed.post", "createdAt": now() }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text")); -} + assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text")); -#[test] -fn test_validate_post_missing_created_at() { - let validator = RecordValidator::new(); - let post = json!({ + let missing_created_at = json!({ "$type": "app.bsky.feed.post", "text": "Hello" }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt")); -} + assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt")); -#[test] -fn test_validate_post_text_too_long() { - let validator = RecordValidator::new(); - let long_text = "a".repeat(3001); - let post = json!({ + let text_too_long = json!({ "$type": "app.bsky.feed.post", - "text": long_text, + "text": "a".repeat(3001), "createdAt": now() }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text")); -} + assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text")); -#[test] -fn test_validate_post_text_at_limit() { - let validator = RecordValidator::new(); - let limit_text = "a".repeat(3000); - let post = json!({ + let text_at_limit = json!({ "$type": "app.bsky.feed.post", - "text": limit_text, + "text": "a".repeat(3000), "createdAt": now() }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_post_too_many_langs() { - let validator = RecordValidator::new(); - let post = json!({ + let too_many_langs = json!({ "$type": "app.bsky.feed.post", "text": "Hello", "createdAt": now(), "langs": ["en", "fr", "de", "es"] }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs")); -} + assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs")); -#[test] -fn test_validate_post_three_langs_ok() { - let validator = RecordValidator::new(); - let post = json!({ + let three_langs_ok = json!({ "$type": "app.bsky.feed.post", "text": "Hello", "createdAt": now(), "langs": ["en", "fr", "de"] }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_post_too_many_tags() { - let validator = RecordValidator::new(); - let post = json!({ + let too_many_tags = json!({ "$type": "app.bsky.feed.post", "text": "Hello", "createdAt": now(), "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags")); -} + assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags")); -#[test] -fn test_validate_post_eight_tags_ok() { - let validator = RecordValidator::new(); - let post = json!({ + let eight_tags_ok = json!({ "$type": "app.bsky.feed.post", "text": "Hello", "createdAt": now(), "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_post_tag_too_long() { - let validator = RecordValidator::new(); - let long_tag = "t".repeat(641); - let post = json!({ + let tag_too_long = json!({ "$type": "app.bsky.feed.post", "text": "Hello", "createdAt": now(), - "tags": [long_tag] + "tags": ["t".repeat(641)] }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert!( - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")) - ); + assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); } #[test] -fn test_validate_profile_valid() { +fn test_profile_record_validation() { let validator = RecordValidator::new(); - let profile = json!({ + + let valid = json!({ "$type": "app.bsky.actor.profile", "displayName": "Test User", "description": "A test user profile" }); - let result = validator.validate(&profile, "app.bsky.actor.profile"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_profile_empty_ok() { - let validator = RecordValidator::new(); - let profile = json!({ + let empty_ok = json!({ "$type": "app.bsky.actor.profile" }); - let result = validator.validate(&profile, "app.bsky.actor.profile"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_profile_displayname_too_long() { - let validator = RecordValidator::new(); - let long_name = "n".repeat(641); - let profile = json!({ + let displayname_too_long = json!({ "$type": "app.bsky.actor.profile", - "displayName": long_name + "displayName": "n".repeat(641) }); - let result = validator.validate(&profile, "app.bsky.actor.profile"); - assert!( - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName") - ); -} + assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); -#[test] -fn test_validate_profile_description_too_long() { - let validator = RecordValidator::new(); - let long_desc = "d".repeat(2561); - let profile = json!({ + let description_too_long = json!({ "$type": "app.bsky.actor.profile", - "description": long_desc + "description": "d".repeat(2561) }); - let result = validator.validate(&profile, "app.bsky.actor.profile"); - assert!( - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description") - ); + assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description")); } #[test] -fn test_validate_like_valid() { +fn test_like_and_repost_validation() { let validator = RecordValidator::new(); - let like = json!({ + + let valid_like = json!({ "$type": "app.bsky.feed.like", "subject": { "uri": "at://did:plc:test/app.bsky.feed.post/123", @@ -197,39 +127,24 @@ fn test_validate_like_valid() { }, "createdAt": now() }); - let result = validator.validate(&like, "app.bsky.feed.like"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_like_missing_subject() { - let validator = RecordValidator::new(); - let like = json!({ + let missing_subject = json!({ "$type": "app.bsky.feed.like", "createdAt": now() }); - let result = validator.validate(&like, "app.bsky.feed.like"); - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); -} + assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject")); -#[test] -fn test_validate_like_missing_subject_uri() { - let validator = RecordValidator::new(); - let like = json!({ + let missing_subject_uri = json!({ "$type": "app.bsky.feed.like", "subject": { "cid": "bafyreig6xxxxxyyyyyzzzzzz" }, "createdAt": now() }); - let result = validator.validate(&like, "app.bsky.feed.like"); - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri"))); -} + assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri"))); -#[test] -fn test_validate_like_invalid_subject_uri() { - let validator = RecordValidator::new(); - let like = json!({ + let invalid_subject_uri = json!({ "$type": "app.bsky.feed.like", "subject": { "uri": "https://example.com/not-at-uri", @@ -237,16 +152,9 @@ fn test_validate_like_invalid_subject_uri() { }, "createdAt": now() }); - let result = validator.validate(&like, "app.bsky.feed.like"); - assert!( - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")) - ); -} + assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); -#[test] -fn test_validate_repost_valid() { - let validator = RecordValidator::new(); - let repost = json!({ + let valid_repost = json!({ "$type": "app.bsky.feed.repost", "subject": { "uri": "at://did:plc:test/app.bsky.feed.post/123", @@ -254,355 +162,220 @@ fn test_validate_repost_valid() { }, "createdAt": now() }); - let result = validator.validate(&repost, "app.bsky.feed.repost"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_repost_missing_subject() { - let validator = RecordValidator::new(); - let repost = json!({ + let repost_missing_subject = json!({ "$type": "app.bsky.feed.repost", "createdAt": now() }); - let result = validator.validate(&repost, "app.bsky.feed.repost"); - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); + assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject")); } #[test] -fn test_validate_follow_valid() { +fn test_follow_and_block_validation() { let validator = RecordValidator::new(); - let follow = json!({ + + let valid_follow = json!({ "$type": "app.bsky.graph.follow", "subject": "did:plc:test12345", "createdAt": now() }); - let result = validator.validate(&follow, "app.bsky.graph.follow"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_follow_missing_subject() { - let validator = RecordValidator::new(); - let follow = json!({ + let missing_follow_subject = json!({ "$type": "app.bsky.graph.follow", "createdAt": now() }); - let result = validator.validate(&follow, "app.bsky.graph.follow"); - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); -} + assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject")); -#[test] -fn test_validate_follow_invalid_subject() { - let validator = RecordValidator::new(); - let follow = json!({ + let invalid_follow_subject = json!({ "$type": "app.bsky.graph.follow", "subject": "not-a-did", "createdAt": now() }); - let result = validator.validate(&follow, "app.bsky.graph.follow"); - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); -} + assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); -#[test] -fn test_validate_block_valid() { - let validator = RecordValidator::new(); - let block = json!({ + let valid_block = json!({ "$type": "app.bsky.graph.block", "subject": "did:plc:blocked123", "createdAt": now() }); - let result = validator.validate(&block, "app.bsky.graph.block"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_block_invalid_subject() { - let validator = RecordValidator::new(); - let block = json!({ + let invalid_block_subject = json!({ "$type": "app.bsky.graph.block", "subject": "not-a-did", "createdAt": now() }); - let result = validator.validate(&block, "app.bsky.graph.block"); - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); + assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); } #[test] -fn test_validate_list_valid() { +fn test_list_and_graph_records_validation() { let validator = RecordValidator::new(); - let list = json!({ + + let valid_list = json!({ "$type": "app.bsky.graph.list", "name": "My List", "purpose": "app.bsky.graph.defs#modlist", "createdAt": now() }); - let result = validator.validate(&list, "app.bsky.graph.list"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_list_name_too_long() { - let validator = RecordValidator::new(); - let long_name = "n".repeat(65); - let list = json!({ + let list_name_too_long = json!({ "$type": "app.bsky.graph.list", - "name": long_name, + "name": "n".repeat(65), "purpose": "app.bsky.graph.defs#modlist", "createdAt": now() }); - let result = validator.validate(&list, "app.bsky.graph.list"); - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); -} + assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); -#[test] -fn test_validate_list_empty_name() { - let validator = RecordValidator::new(); - let list = json!({ + let list_empty_name = json!({ "$type": "app.bsky.graph.list", "name": "", "purpose": "app.bsky.graph.defs#modlist", "createdAt": now() }); - let result = validator.validate(&list, "app.bsky.graph.list"); - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); + assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); + + let valid_list_item = json!({ + "$type": "app.bsky.graph.listitem", + "subject": "did:plc:test123", + "list": "at://did:plc:owner/app.bsky.graph.list/mylist", + "createdAt": now() + }); + assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid); } #[test] -fn test_validate_feed_generator_valid() { +fn test_misc_record_types_validation() { let validator = RecordValidator::new(); - let generator = json!({ + + let valid_generator = json!({ "$type": "app.bsky.feed.generator", "did": "did:web:example.com", "displayName": "My Feed", "createdAt": now() }); - let result = validator.validate(&generator, "app.bsky.feed.generator"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_feed_generator_displayname_too_long() { - let validator = RecordValidator::new(); - let long_name = "f".repeat(241); - let generator = json!({ + let generator_displayname_too_long = json!({ "$type": "app.bsky.feed.generator", "did": "did:web:example.com", - "displayName": long_name, + "displayName": "f".repeat(241), "createdAt": now() }); - let result = validator.validate(&generator, "app.bsky.feed.generator"); - assert!( - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName") - ); -} + assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); -#[test] -fn test_validate_unknown_type_returns_unknown() { - let validator = RecordValidator::new(); - let custom = json!({ - "$type": "com.custom.record", - "data": "test" - }); - let result = validator.validate(&custom, "com.custom.record"); - assert_eq!(result.unwrap(), ValidationStatus::Unknown); -} - -#[test] -fn test_validate_unknown_type_strict_rejects() { - let validator = RecordValidator::new().require_lexicon(true); - let custom = json!({ - "$type": "com.custom.record", - "data": "test" - }); - let result = validator.validate(&custom, "com.custom.record"); - assert!(matches!(result, Err(ValidationError::UnknownType(_)))); -} - -#[test] -fn test_validate_type_mismatch() { - let validator = RecordValidator::new(); - let record = json!({ - "$type": "app.bsky.feed.like", - "subject": {"uri": "at://test", "cid": "bafytest"}, - "createdAt": now() - }); - let result = validator.validate(&record, "app.bsky.feed.post"); - assert!( - matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) - if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like") - ); -} - -#[test] -fn test_validate_missing_type() { - let validator = RecordValidator::new(); - let record = json!({ - "text": "Hello" - }); - let result = validator.validate(&record, "app.bsky.feed.post"); - assert!(matches!(result, Err(ValidationError::MissingType))); -} - -#[test] -fn test_validate_not_object() { - let validator = RecordValidator::new(); - let record = json!("just a string"); - let result = validator.validate(&record, "app.bsky.feed.post"); - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); -} - -#[test] -fn test_validate_datetime_format_valid() { - let validator = RecordValidator::new(); - let post = json!({ - "$type": "app.bsky.feed.post", - "text": "Test", - "createdAt": "2024-01-15T10:30:00.000Z" - }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} - -#[test] -fn test_validate_datetime_with_offset() { - let validator = RecordValidator::new(); - let post = json!({ - "$type": "app.bsky.feed.post", - "text": "Test", - "createdAt": "2024-01-15T10:30:00+05:30" - }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} - -#[test] -fn test_validate_datetime_invalid_format() { - let validator = RecordValidator::new(); - let post = json!({ - "$type": "app.bsky.feed.post", - "text": "Test", - "createdAt": "2024/01/15" - }); - let result = validator.validate(&post, "app.bsky.feed.post"); - assert!(matches!( - result, - Err(ValidationError::InvalidDatetime { .. }) - )); -} - -#[test] -fn test_validate_record_key_valid() { - assert!(validate_record_key("3k2n5j2").is_ok()); - assert!(validate_record_key("valid-key").is_ok()); - assert!(validate_record_key("valid_key").is_ok()); - assert!(validate_record_key("valid.key").is_ok()); - assert!(validate_record_key("valid~key").is_ok()); - assert!(validate_record_key("self").is_ok()); -} - -#[test] -fn test_validate_record_key_empty() { - let result = validate_record_key(""); - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); -} - -#[test] -fn test_validate_record_key_dot() { - assert!(validate_record_key(".").is_err()); - assert!(validate_record_key("..").is_err()); -} - -#[test] -fn test_validate_record_key_invalid_chars() { - assert!(validate_record_key("invalid/key").is_err()); - assert!(validate_record_key("invalid key").is_err()); - assert!(validate_record_key("invalid@key").is_err()); - assert!(validate_record_key("invalid#key").is_err()); -} - -#[test] -fn test_validate_record_key_too_long() { - let long_key = "k".repeat(513); - let result = validate_record_key(&long_key); - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); -} - -#[test] -fn test_validate_record_key_at_max_length() { - let max_key = "k".repeat(512); - assert!(validate_record_key(&max_key).is_ok()); -} - -#[test] -fn test_validate_collection_nsid_valid() { - assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); - assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); - assert!(validate_collection_nsid("a.b.c").is_ok()); - assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); -} - -#[test] -fn test_validate_collection_nsid_empty() { - let result = validate_collection_nsid(""); - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); -} - -#[test] -fn test_validate_collection_nsid_too_few_segments() { - assert!(validate_collection_nsid("a").is_err()); - assert!(validate_collection_nsid("a.b").is_err()); -} - -#[test] -fn test_validate_collection_nsid_empty_segment() { - assert!(validate_collection_nsid("a..b.c").is_err()); - assert!(validate_collection_nsid(".a.b.c").is_err()); - assert!(validate_collection_nsid("a.b.c.").is_err()); -} - -#[test] -fn test_validate_collection_nsid_invalid_chars() { - assert!(validate_collection_nsid("a.b.c/d").is_err()); - assert!(validate_collection_nsid("a.b.c_d").is_err()); - assert!(validate_collection_nsid("a.b.c@d").is_err()); -} - -#[test] -fn test_validate_threadgate() { - let validator = RecordValidator::new(); - let gate = json!({ + let valid_threadgate = json!({ "$type": "app.bsky.feed.threadgate", "post": "at://did:plc:test/app.bsky.feed.post/123", "createdAt": now() }); - let result = validator.validate(&gate, "app.bsky.feed.threadgate"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); -} + assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid); -#[test] -fn test_validate_labeler_service() { - let validator = RecordValidator::new(); - let labeler = json!({ + let valid_labeler = json!({ "$type": "app.bsky.labeler.service", "policies": { "labelValues": ["spam", "nsfw"] }, "createdAt": now() }); - let result = validator.validate(&labeler, "app.bsky.labeler.service"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); + assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid); } #[test] -fn test_validate_list_item() { +fn test_type_and_format_validation() { let validator = RecordValidator::new(); - let item = json!({ - "$type": "app.bsky.graph.listitem", - "subject": "did:plc:test123", - "list": "at://did:plc:owner/app.bsky.graph.list/mylist", + let strict_validator = RecordValidator::new().require_lexicon(true); + + let custom_record = json!({ + "$type": "com.custom.record", + "data": "test" + }); + assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown); + assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_)))); + + let type_mismatch = json!({ + "$type": "app.bsky.feed.like", + "subject": {"uri": "at://test", "cid": "bafytest"}, "createdAt": now() }); - let result = validator.validate(&item, "app.bsky.graph.listitem"); - assert_eq!(result.unwrap(), ValidationStatus::Valid); + assert!(matches!( + validator.validate(&type_mismatch, "app.bsky.feed.post"), + Err(ValidationError::TypeMismatch { expected, actual }) if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like" + )); + + let missing_type = json!({ + "text": "Hello" + }); + assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType))); + + let not_object = json!("just a string"); + assert!(matches!(validator.validate(¬_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_)))); + + let valid_datetime = json!({ + "$type": "app.bsky.feed.post", + "text": "Test", + "createdAt": "2024-01-15T10:30:00.000Z" + }); + assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); + + let datetime_with_offset = json!({ + "$type": "app.bsky.feed.post", + "text": "Test", + "createdAt": "2024-01-15T10:30:00+05:30" + }); + assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); + + let invalid_datetime = json!({ + "$type": "app.bsky.feed.post", + "text": "Test", + "createdAt": "2024/01/15" + }); + assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. }))); +} + +#[test] +fn test_record_key_validation() { + assert!(validate_record_key("3k2n5j2").is_ok()); + assert!(validate_record_key("valid-key").is_ok()); + assert!(validate_record_key("valid_key").is_ok()); + assert!(validate_record_key("valid.key").is_ok()); + assert!(validate_record_key("valid~key").is_ok()); + assert!(validate_record_key("self").is_ok()); + + assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_)))); + + assert!(validate_record_key(".").is_err()); + assert!(validate_record_key("..").is_err()); + + assert!(validate_record_key("invalid/key").is_err()); + assert!(validate_record_key("invalid key").is_err()); + assert!(validate_record_key("invalid@key").is_err()); + assert!(validate_record_key("invalid#key").is_err()); + + assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_)))); + assert!(validate_record_key(&"k".repeat(512)).is_ok()); +} + +#[test] +fn test_collection_nsid_validation() { + assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); + assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); + assert!(validate_collection_nsid("a.b.c").is_ok()); + assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); + + assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_)))); + + assert!(validate_collection_nsid("a").is_err()); + assert!(validate_collection_nsid("a.b").is_err()); + + assert!(validate_collection_nsid("a..b.c").is_err()); + assert!(validate_collection_nsid(".a.b.c").is_err()); + assert!(validate_collection_nsid("a.b.c.").is_err()); + + assert!(validate_collection_nsid("a.b.c/d").is_err()); + assert!(validate_collection_nsid("a.b.c_d").is_err()); + assert!(validate_collection_nsid("a.b.c@d").is_err()); } diff --git a/tests/security_fixes.rs b/tests/security_fixes.rs index 714c55d..7285081 100644 --- a/tests/security_fixes.rs +++ b/tests/security_fixes.rs @@ -4,145 +4,70 @@ use bspds::notifications::{SendError, is_valid_phone_number, sanitize_header_val use bspds::oauth::templates::{error_page, login_page, success_page}; #[test] -fn test_sanitize_header_value_removes_crlf() { +fn test_header_injection_sanitization() { let malicious = "Injected\r\nBcc: attacker@evil.com"; let sanitized = sanitize_header_value(malicious); - assert!(!sanitized.contains('\r'), "CR should be removed"); - assert!(!sanitized.contains('\n'), "LF should be removed"); - assert!( - sanitized.contains("Injected"), - "Original content should be preserved" - ); - assert!( - sanitized.contains("Bcc:"), - "Text after newline should be on same line (no header injection)" - ); -} + assert!(!sanitized.contains('\r') && !sanitized.contains('\n')); + assert!(sanitized.contains("Injected") && sanitized.contains("Bcc:")); -#[test] -fn test_sanitize_header_value_preserves_content() { let normal = "Normal Subject Line"; - let sanitized = sanitize_header_value(normal); - assert_eq!(sanitized, "Normal Subject Line"); -} + assert_eq!(sanitize_header_value(normal), "Normal Subject Line"); -#[test] -fn test_sanitize_header_value_trims_whitespace() { let padded = " Subject "; - let sanitized = sanitize_header_value(padded); - assert_eq!(sanitized, "Subject"); -} + assert_eq!(sanitize_header_value(padded), "Subject"); -#[test] -fn test_sanitize_header_value_handles_multiple_newlines() { - let input = "Line1\r\nLine2\nLine3\rLine4"; - let sanitized = sanitize_header_value(input); - assert!(!sanitized.contains('\r'), "CR should be removed"); - assert!(!sanitized.contains('\n'), "LF should be removed"); - assert!( - sanitized.contains("Line1"), - "Content before newlines preserved" - ); - assert!( - sanitized.contains("Line4"), - "Content after newlines preserved" - ); -} + let multi_newline = "Line1\r\nLine2\nLine3\rLine4"; + let sanitized = sanitize_header_value(multi_newline); + assert!(!sanitized.contains('\r') && !sanitized.contains('\n')); + assert!(sanitized.contains("Line1") && sanitized.contains("Line4")); -#[test] -fn test_email_header_injection_sanitization() { let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; let sanitized = sanitize_header_value(header_injection); - let lines: Vec<&str> = sanitized.split("\r\n").collect(); - assert_eq!(lines.len(), 1, "Should be a single line after sanitization"); - assert!( - sanitized.contains("Normal Subject"), - "Original content preserved" - ); - assert!( - sanitized.contains("Bcc:"), - "Content after CRLF preserved as same line text" - ); - assert!( - sanitized.contains("X-Injected:"), - "All content on same line" - ); + assert_eq!(sanitized.split("\r\n").count(), 1); + assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:")); + + let with_null = "client\0id"; + assert!(sanitize_header_value(with_null).contains("client")); + + let long_input = "x".repeat(10000); + assert!(!sanitize_header_value(&long_input).is_empty()); } #[test] -fn test_valid_phone_number_accepts_correct_format() { +fn test_phone_number_validation() { assert!(is_valid_phone_number("+1234567890")); assert!(is_valid_phone_number("+12025551234")); assert!(is_valid_phone_number("+442071234567")); assert!(is_valid_phone_number("+4915123456789")); assert!(is_valid_phone_number("+1")); -} -#[test] -fn test_valid_phone_number_rejects_missing_plus() { assert!(!is_valid_phone_number("1234567890")); assert!(!is_valid_phone_number("12025551234")); -} - -#[test] -fn test_valid_phone_number_rejects_empty() { assert!(!is_valid_phone_number("")); -} - -#[test] -fn test_valid_phone_number_rejects_just_plus() { assert!(!is_valid_phone_number("+")); -} - -#[test] -fn test_valid_phone_number_rejects_too_long() { assert!(!is_valid_phone_number("+12345678901234567890123")); -} -#[test] -fn test_valid_phone_number_rejects_letters() { assert!(!is_valid_phone_number("+abc123")); assert!(!is_valid_phone_number("+1234abc")); assert!(!is_valid_phone_number("+a")); -} -#[test] -fn test_valid_phone_number_rejects_spaces() { assert!(!is_valid_phone_number("+1234 5678")); assert!(!is_valid_phone_number("+ 1234567890")); assert!(!is_valid_phone_number("+1 ")); -} -#[test] -fn test_valid_phone_number_rejects_special_chars() { assert!(!is_valid_phone_number("+123-456-7890")); assert!(!is_valid_phone_number("+1(234)567890")); assert!(!is_valid_phone_number("+1.234.567.890")); -} -#[test] -fn test_signal_recipient_command_injection_blocked() { - let malicious_inputs = vec![ - "+123; rm -rf /", - "+123 && cat /etc/passwd", - "+123`id`", - "+123$(whoami)", - "+123|cat /etc/shadow", - "+123\n--help", - "+123\r\n--version", - "+123--help", - ]; - for input in malicious_inputs { - assert!( - !is_valid_phone_number(input), - "Malicious input '{}' should be rejected", - input - ); + for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`", + "+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help", + "+123\r\n--version", "+123--help"] { + assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious); } } #[test] -fn test_image_file_size_limit_enforced() { +fn test_image_file_size_limits() { let processor = ImageProcessor::new(); let oversized_data: Vec = vec![0u8; 11 * 1024 * 1024]; let result = processor.process(&oversized_data, "image/jpeg"); @@ -156,321 +81,109 @@ fn test_image_file_size_limit_enforced() { } Ok(_) => panic!("Should reject files over size limit"), } -} -#[test] -fn test_image_file_size_limit_configurable() { let processor = ImageProcessor::new().with_max_file_size(1024); let data: Vec = vec![0u8; 2048]; - let result = processor.process(&data, "image/jpeg"); - assert!(result.is_err(), "Should reject files over configured limit"); + assert!(processor.process(&data, "image/jpeg").is_err()); } #[test] -fn test_oauth_template_xss_escaping_client_id() { - let malicious_client_id = ""; - let html = login_page(malicious_client_id, None, None, "test-uri", None, None); - assert!(!html.contains("", None, None, "test-uri", None, None); + assert!(!html.contains(""), "test-uri", None, None); + assert!(!html.contains(""), None); + assert!(!html.contains("", Some("")); + assert!(!html.contains("")); + assert!(!html.contains(""; - let html = login_page( - "client123", - None, - Some(malicious_scope), - "test-uri", - None, - None, - ); - assert!( - !html.contains(""; - let html = login_page( - "client123", - None, - None, - "test-uri", - Some(malicious_error), - None, - ); - assert!( - !html.contains(""; - let malicious_desc = ""; - let html = error_page(malicious_error, Some(malicious_desc)); - assert!( - !html.contains(""; - let html = success_page(Some(malicious_name)); - assert!( - !html.contains("