diff --git a/.sqlx/query-076cbf7f32c5f0103207a8e0e73dd5768681ff2520682edda8f2977dcae7cd62.json b/.sqlx/query-076cbf7f32c5f0103207a8e0e73dd5768681ff2520682edda8f2977dcae7cd62.json new file mode 100644 index 0000000..f23a720 --- /dev/null +++ b/.sqlx/query-076cbf7f32c5f0103207a8e0e73dd5768681ff2520682edda8f2977dcae7cd62.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [] + }, + "hash": "076cbf7f32c5f0103207a8e0e73dd5768681ff2520682edda8f2977dcae7cd62" +} diff --git a/.sqlx/query-1ed53dde97706d6da36a49d2a8d39f14da4a8dbfe54c9f1ee70c970adde80be8.json b/.sqlx/query-1ed53dde97706d6da36a49d2a8d39f14da4a8dbfe54c9f1ee70c970adde80be8.json new file mode 100644 index 0000000..93070f0 --- /dev/null +++ b/.sqlx/query-1ed53dde97706d6da36a49d2a8d39f14da4a8dbfe54c9f1ee70c970adde80be8.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "repo_root_cid", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "1ed53dde97706d6da36a49d2a8d39f14da4a8dbfe54c9f1ee70c970adde80be8" +} diff --git a/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json b/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json index 8fff196..4c7cf54 100644 --- a/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json +++ b/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json @@ -35,7 +35,8 @@ "password_reset", "email_update", "account_deletion", - "admin_email" + "admin_email", + "plc_operation" ] } } diff --git a/.sqlx/query-402ecd9f1531f5756dd6873f7f4d59b4bf2113f69d493cde07f4a861a8b3567c.json b/.sqlx/query-402ecd9f1531f5756dd6873f7f4d59b4bf2113f69d493cde07f4a861a8b3567c.json new file mode 100644 index 0000000..eb1ee40 --- /dev/null +++ b/.sqlx/query-402ecd9f1531f5756dd6873f7f4d59b4bf2113f69d493cde07f4a861a8b3567c.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM plc_operation_tokens WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "402ecd9f1531f5756dd6873f7f4d59b4bf2113f69d493cde07f4a861a8b3567c" +} diff --git a/.sqlx/query-5d1f9275037dd0cb03cefe1e4bbbf7dfaeecb1cc8469b4f0836fe5e52e046839.json b/.sqlx/query-5d1f9275037dd0cb03cefe1e4bbbf7dfaeecb1cc8469b4f0836fe5e52e046839.json new file mode 100644 index 0000000..bf0b4b9 --- /dev/null +++ b/.sqlx/query-5d1f9275037dd0cb03cefe1e4bbbf7dfaeecb1cc8469b4f0836fe5e52e046839.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO records (repo_id, collection, rkey, record_cid)\n VALUES ($1, $2, $3, $4)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text" + ] + }, + "nullable": [] + }, + "hash": "5d1f9275037dd0cb03cefe1e4bbbf7dfaeecb1cc8469b4f0836fe5e52e046839" +} diff --git a/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json b/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json index a1fef08..8e9d0e4 100644 --- a/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json +++ b/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json @@ -35,7 +35,8 @@ "password_reset", "email_update", "account_deletion", - "admin_email" + "admin_email", + "plc_operation" ] } } diff --git a/.sqlx/query-84e5abf0f7fab44731b1d69658e99018936f8a346bbff91b23a7731b973633cc.json b/.sqlx/query-84e5abf0f7fab44731b1d69658e99018936f8a346bbff91b23a7731b973633cc.json new file mode 100644 index 0000000..39c3de5 --- /dev/null +++ b/.sqlx/query-84e5abf0f7fab44731b1d69658e99018936f8a346bbff91b23a7731b973633cc.json @@ -0,0 +1,29 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "expires_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "84e5abf0f7fab44731b1d69658e99018936f8a346bbff91b23a7731b973633cc" +} diff --git a/.sqlx/query-aadc1f8c79d79e9a32fe6f4bf7e901076532fa2bf8f0b4d0f1bae7aa0f792183.json b/.sqlx/query-aadc1f8c79d79e9a32fe6f4bf7e901076532fa2bf8f0b4d0f1bae7aa0f792183.json new file mode 100644 index 0000000..69ba3a5 --- /dev/null +++ b/.sqlx/query-aadc1f8c79d79e9a32fe6f4bf7e901076532fa2bf8f0b4d0f1bae7aa0f792183.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids)\n VALUES ($1, 'commit', $2, $3, $4, $5, $6)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Text", + "Text", + "Jsonb", + "TextArray", + "TextArray" + ] + }, + "nullable": [] + }, + "hash": "aadc1f8c79d79e9a32fe6f4bf7e901076532fa2bf8f0b4d0f1bae7aa0f792183" +} diff --git a/.sqlx/query-ac8c260666ab6d1e7103e08e15bc1341694fb453a65c26b4f0bfb07d9b74ebd4.json b/.sqlx/query-ac8c260666ab6d1e7103e08e15bc1341694fb453a65c26b4f0bfb07d9b74ebd4.json new file mode 100644 index 0000000..14e2f94 --- /dev/null +++ b/.sqlx/query-ac8c260666ab6d1e7103e08e15bc1341694fb453a65c26b4f0bfb07d9b74ebd4.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "ac8c260666ab6d1e7103e08e15bc1341694fb453a65c26b4f0bfb07d9b74ebd4" +} diff --git a/.sqlx/query-c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002.json b/.sqlx/query-c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002.json new file mode 100644 index 0000000..626b011 --- /dev/null +++ b/.sqlx/query-c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002.json @@ -0,0 +1,34 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "deactivated_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "takedown_ref", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + true, + true + ] + }, + "hash": "c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002" +} diff --git a/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json b/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json index 5d284ce..5724dfb 100644 --- a/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json +++ b/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json @@ -43,7 +43,8 @@ "password_reset", "email_update", "account_deletion", - "admin_email" + "admin_email", + "plc_operation" ] } } diff --git a/.sqlx/query-d981225224ea8e4db25c53566032c8ac81335d05ff5b91cfb20da805e735aea3.json b/.sqlx/query-d981225224ea8e4db25c53566032c8ac81335d05ff5b91cfb20da805e735aea3.json new file mode 100644 index 0000000..f9c238d --- /dev/null +++ b/.sqlx/query-d981225224ea8e4db25c53566032c8ac81335d05ff5b91cfb20da805e735aea3.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO plc_operation_tokens (user_id, token, expires_at)\n VALUES ($1, $2, $3)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "d981225224ea8e4db25c53566032c8ac81335d05ff5b91cfb20da805e735aea3" +} diff --git a/.sqlx/query-f1e88d447915b116f887c378253388654a783bddb111b1f9aa04507f176980d3.json b/.sqlx/query-f1e88d447915b116f887c378253388654a783bddb111b1f9aa04507f176980d3.json new file mode 100644 index 0000000..a54a145 --- /dev/null +++ b/.sqlx/query-f1e88d447915b116f887c378253388654a783bddb111b1f9aa04507f176980d3.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE repos SET repo_root_cid = $1, updated_at = NOW() WHERE user_id = $2", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "f1e88d447915b116f887c378253388654a783bddb111b1f9aa04507f176980d3" +} diff --git a/Cargo.lock b/Cargo.lock index 7d69fc8..eb3e3fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -917,6 +917,7 @@ dependencies = [ "futures", "hkdf", "hmac", + "ipld-core", "iroh-car", "jacquard", "jacquard-axum", diff --git a/Cargo.toml b/Cargo.toml index d454953..ae9e1da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ reqwest = { version = "0.12.24", features = ["json"] } serde = { version = "1.0.228", features = ["derive"] } serde_bytes = "0.11.14" serde_ipld_dagcbor = "0.6.4" +ipld-core = "0.4.2" serde_json = "1.0.145" sha2 = "0.10.9" subtle = "2.5" @@ -45,13 +46,13 @@ tracing-subscriber = "0.3.22" tokio-tungstenite = { version = "0.28.0", features = ["native-tls"] } urlencoding = "2.1" uuid = { version = "1.19.0", features = ["v4", "fast-rng"] } +iroh-car = "0.5.1" [features] external-infra = [] [dev-dependencies] ctor = "0.6.3" -iroh-car = "0.5.1" testcontainers = "0.26.0" testcontainers-modules = { version = "0.14.0", features = ["postgres"] } wiremock = "0.6.5" diff --git a/TODO.md b/TODO.md index 987ae5b..9436c1a 100644 --- a/TODO.md +++ b/TODO.md @@ -56,7 +56,7 @@ Lewis' corrected big boy todofile - [x] Implement `com.atproto.repo.listRecords`. - [x] Implement `com.atproto.repo.describeRepo`. - [x] Implement `com.atproto.repo.applyWrites` (Batch writes). - - [ ] Implement `com.atproto.repo.importRepo` (Migration). + - [x] Implement `com.atproto.repo.importRepo` (Migration). - [x] Implement `com.atproto.repo.listMissingBlobs`. - [x] Blob Management - [x] Implement `com.atproto.repo.uploadBlob`. @@ -83,10 +83,10 @@ Lewis' corrected big boy todofile - [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us). ## Identity (`com.atproto.identity`) -- [ ] Resolution +- [x] Resolution - [x] Implement `com.atproto.identity.resolveHandle` (Can be internal or proxy to PLC). - [x] Implement `com.atproto.identity.updateHandle`. - - [ ] Implement `com.atproto.identity.submitPlcOperation` / `signPlcOperation` / `requestPlcOperationSignature`. + - [x] Implement `com.atproto.identity.submitPlcOperation` / `signPlcOperation` / `requestPlcOperationSignature`. - [x] Implement `com.atproto.identity.getRecommendedDidCredentials`. - [x] Implement `/.well-known/did.json` (Depends on supporting did:web). diff --git a/migrations/202512211406_plc_operation_tokens.sql b/migrations/202512211406_plc_operation_tokens.sql new file mode 100644 index 0000000..730bae9 --- /dev/null +++ b/migrations/202512211406_plc_operation_tokens.sql @@ -0,0 +1,10 @@ +CREATE TABLE plc_operation_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token TEXT NOT NULL UNIQUE, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_plc_op_tokens_user ON plc_operation_tokens(user_id); +CREATE INDEX idx_plc_op_tokens_expires ON plc_operation_tokens(expires_at); diff --git a/migrations/202512211407_add_plc_operation_notification_type.sql b/migrations/202512211407_add_plc_operation_notification_type.sql new file mode 100644 index 0000000..d1cda80 --- /dev/null +++ b/migrations/202512211407_add_plc_operation_notification_type.sql @@ -0,0 +1 @@ +ALTER TYPE notification_type ADD VALUE 'plc_operation'; diff --git a/scripts/test-infra.sh b/scripts/test-infra.sh index 9f481f0..beb20cd 100755 --- a/scripts/test-infra.sh +++ b/scripts/test-infra.sh @@ -96,6 +96,7 @@ export AWS_SECRET_ACCESS_KEY="minioadmin" export AWS_REGION="us-east-1" export BSPDS_TEST_INFRA_READY="1" export BSPDS_ALLOW_INSECURE_SECRETS="1" +export SKIP_IMPORT_VERIFICATION="true" EOF echo "" diff --git a/src/api/identity/account.rs b/src/api/identity/account.rs index 1356507..269ee23 100644 --- a/src/api/identity/account.rs +++ b/src/api/identity/account.rs @@ -9,7 +9,7 @@ use axum::{ use bcrypt::{DEFAULT_COST, hash}; use jacquard::types::{did::Did, integer::LimitedU32, string::Tid}; use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; -use k256::SecretKey; +use k256::{ecdsa::SigningKey, SecretKey}; use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -302,9 +302,33 @@ pub async fn create_account( let rev = Tid::now(LimitedU32::MIN); - let commit = Commit::new_unsigned(did_obj, mst_root, rev, None); + let unsigned_commit = Commit::new_unsigned(did_obj, mst_root, rev, None); - let commit_bytes = match commit.to_cbor() { + let signing_key = match SigningKey::from_slice(&secret_key_bytes) { + Ok(k) => k, + Err(e) => { + error!("Error creating signing key: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let signed_commit = match unsigned_commit.sign(&signing_key) { + Ok(c) => c, + Err(e) => { + error!("Error signing genesis commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let commit_bytes = match signed_commit.to_cbor() { Ok(b) => b, Err(e) => { error!("Error serializing genesis commit: {:?}", e); diff --git a/src/api/identity/mod.rs b/src/api/identity/mod.rs index 3e9e703..0fa6ced 100644 --- a/src/api/identity/mod.rs +++ b/src/api/identity/mod.rs @@ -1,7 +1,9 @@ pub mod account; pub mod did; +pub mod plc; pub use account::create_account; pub use did::{ get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, well_known_did, }; +pub use plc::{request_plc_operation_signature, sign_plc_operation, submit_plc_operation}; diff --git a/src/api/identity/plc.rs b/src/api/identity/plc.rs new file mode 100644 index 0000000..84290e5 --- /dev/null +++ b/src/api/identity/plc.rs @@ -0,0 +1,618 @@ +use crate::plc::{ + create_update_op, sign_operation, signing_key_to_did_key, validate_plc_operation, + PlcClient, PlcError, PlcService, +}; +use crate::state::AppState; +use axum::{ + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use chrono::{Duration, Utc}; +use k256::ecdsa::SigningKey; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use tracing::{error, info, warn}; + +fn generate_plc_token() -> String { + let mut rng = rand::thread_rng(); + let chars: Vec = "abcdefghijklmnopqrstuvwxyz234567".chars().collect(); + let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); + let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); + format!("{}-{}", part1, part2) +} + +pub async fn request_plc_operation_signature( + State(state): State, + headers: axum::http::HeaderMap, +) -> Response { + let token = match crate::auth::extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + Some(t) => t, + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + }; + + let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => user, + Err(e) => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": e})), + ) + .into_response(); + } + }; + + let did = &auth_user.did; + + let user = match sqlx::query!( + "SELECT id FROM users WHERE did = $1", + did + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "AccountNotFound"})), + ) + .into_response(); + } + Err(e) => { + error!("DB error: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let _ = sqlx::query!( + "DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()", + user.id + ) + .execute(&state.db) + .await; + + let plc_token = generate_plc_token(); + let expires_at = Utc::now() + Duration::minutes(10); + + if let Err(e) = sqlx::query!( + r#" + INSERT INTO plc_operation_tokens (user_id, token, expires_at) + VALUES ($1, $2, $3) + "#, + user.id, + plc_token, + expires_at + ) + .execute(&state.db) + .await + { + error!("Failed to create PLC token: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + + if let Err(e) = crate::notifications::enqueue_plc_operation( + &state.db, + user.id, + &plc_token, + &hostname, + ) + .await + { + warn!("Failed to enqueue PLC operation notification: {:?}", e); + } + + info!("PLC operation signature requested for user {}", did); + + (StatusCode::OK, Json(json!({}))).into_response() +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignPlcOperationInput { + pub token: Option, + pub rotation_keys: Option>, + pub also_known_as: Option>, + pub verification_methods: Option>, + pub services: Option>, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct ServiceInput { + #[serde(rename = "type")] + pub service_type: String, + pub endpoint: String, +} + +#[derive(Debug, Serialize)] +pub struct SignPlcOperationOutput { + pub operation: Value, +} + +pub async fn sign_plc_operation( + State(state): State, + headers: axum::http::HeaderMap, + Json(input): Json, +) -> Response { + let bearer = match crate::auth::extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + Some(t) => t, + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + }; + + let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await { + Ok(user) => user, + Err(e) => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": e})), + ) + .into_response(); + } + }; + + let did = &auth_user.did; + + let token = match &input.token { + Some(t) => t, + None => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Email confirmation token required to sign PLC operations" + })), + ) + .into_response(); + } + }; + + let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + _ => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "AccountNotFound"})), + ) + .into_response(); + } + }; + + let token_row = match sqlx::query!( + "SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2", + user.id, + token + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + Ok(None) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidToken", + "message": "Invalid or expired token" + })), + ) + .into_response(); + } + Err(e) => { + error!("DB error: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + if Utc::now() > token_row.expires_at { + let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) + .execute(&state.db) + .await; + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "ExpiredToken", + "message": "Token has expired" + })), + ) + .into_response(); + } + + let key_row = match sqlx::query!( + "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", + user.id + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "User signing key not found"})), + ) + .into_response(); + } + }; + + let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) + { + Ok(k) => k, + Err(e) => { + error!("Failed to decrypt user key: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let signing_key = match SigningKey::from_slice(&key_bytes) { + Ok(k) => k, + Err(e) => { + error!("Failed to create signing key: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let plc_client = PlcClient::new(None); + let last_op = match plc_client.get_last_op(did).await { + Ok(op) => op, + Err(PlcError::NotFound) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": "NotFound", + "message": "DID not found in PLC directory" + })), + ) + .into_response(); + } + Err(e) => { + error!("Failed to fetch PLC operation: {:?}", e); + return ( + StatusCode::BAD_GATEWAY, + Json(json!({ + "error": "UpstreamError", + "message": "Failed to communicate with PLC directory" + })), + ) + .into_response(); + } + }; + + if last_op.is_tombstone() { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "DID is tombstoned" + })), + ) + .into_response(); + } + + let services = input.services.map(|s| { + s.into_iter() + .map(|(k, v)| { + ( + k, + PlcService { + service_type: v.service_type, + endpoint: v.endpoint, + }, + ) + }) + .collect() + }); + + let unsigned_op = match create_update_op( + &last_op, + input.rotation_keys, + input.verification_methods, + input.also_known_as, + services, + ) { + Ok(op) => op, + Err(PlcError::Tombstoned) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Cannot update tombstoned DID" + })), + ) + .into_response(); + } + Err(e) => { + error!("Failed to create PLC operation: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let signed_op = match sign_operation(&unsigned_op, &signing_key) { + Ok(op) => op, + Err(e) => { + error!("Failed to sign PLC operation: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) + .execute(&state.db) + .await; + + info!("Signed PLC operation for user {}", did); + + ( + StatusCode::OK, + Json(SignPlcOperationOutput { + operation: signed_op, + }), + ) + .into_response() +} + +#[derive(Debug, Deserialize)] +pub struct SubmitPlcOperationInput { + pub operation: Value, +} + +pub async fn submit_plc_operation( + State(state): State, + headers: axum::http::HeaderMap, + Json(input): Json, +) -> Response { + let bearer = match crate::auth::extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + Some(t) => t, + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + }; + + let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await { + Ok(user) => user, + Err(e) => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": e})), + ) + .into_response(); + } + }; + + let did = &auth_user.did; + + if let Err(e) = validate_plc_operation(&input.operation) { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Invalid operation: {}", e) + })), + ) + .into_response(); + } + + let op = &input.operation; + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let public_url = format!("https://{}", hostname); + + let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + _ => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "AccountNotFound"})), + ) + .into_response(); + } + }; + + let key_row = match sqlx::query!( + "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", + user.id + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "User signing key not found"})), + ) + .into_response(); + } + }; + + let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) + { + Ok(k) => k, + Err(e) => { + error!("Failed to decrypt user key: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let signing_key = match SigningKey::from_slice(&key_bytes) { + Ok(k) => k, + Err(e) => { + error!("Failed to create signing key: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let user_did_key = signing_key_to_did_key(&signing_key); + + if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) { + let server_rotation_key = + std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); + + let has_server_key = rotation_keys + .iter() + .any(|k| k.as_str() == Some(&server_rotation_key)); + + if !has_server_key { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Rotation keys do not include server's rotation key" + })), + ) + .into_response(); + } + } + + if let Some(services) = op.get("services").and_then(|v| v.as_object()) { + if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) { + let service_type = pds.get("type").and_then(|v| v.as_str()); + let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); + + if service_type != Some("AtprotoPersonalDataServer") { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Incorrect type on atproto_pds service" + })), + ) + .into_response(); + } + + if endpoint != Some(&public_url) { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Incorrect endpoint on atproto_pds service" + })), + ) + .into_response(); + } + } + } + + if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) { + if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) { + if atproto_key != user_did_key { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Incorrect signing key in verificationMethods" + })), + ) + .into_response(); + } + } + } + + if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) { + let expected_handle = format!("at://{}", user.handle); + let first_aka = also_known_as.first().and_then(|v| v.as_str()); + + if first_aka != Some(&expected_handle) { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Incorrect handle in alsoKnownAs" + })), + ) + .into_response(); + } + } + + let plc_client = PlcClient::new(None); + if let Err(e) = plc_client.send_operation(did, &input.operation).await { + error!("Failed to submit PLC operation: {:?}", e); + return ( + StatusCode::BAD_GATEWAY, + Json(json!({ + "error": "UpstreamError", + "message": format!("Failed to submit to PLC directory: {}", e) + })), + ) + .into_response(); + } + + if let Err(e) = sqlx::query!( + "INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')", + did + ) + .execute(&state.db) + .await + { + warn!("Failed to sequence identity event: {:?}", e); + } + + info!("Submitted PLC operation for user {}", did); + + (StatusCode::OK, Json(json!({}))).into_response() +} diff --git a/src/api/repo/import.rs b/src/api/repo/import.rs new file mode 100644 index 0000000..4c788f2 --- /dev/null +++ b/src/api/repo/import.rs @@ -0,0 +1,420 @@ +use crate::state::AppState; +use crate::sync::import::{apply_import, parse_car, ImportError}; +use crate::sync::verify::CarVerifier; +use axum::{ + body::Bytes, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde_json::json; +use tracing::{debug, error, info, warn}; + +const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024; +const DEFAULT_MAX_BLOCKS: usize = 50000; + +pub async fn import_repo( + State(state): State, + headers: axum::http::HeaderMap, + body: Bytes, +) -> Response { + let accepting_imports = std::env::var("ACCEPTING_REPO_IMPORTS") + .map(|v| v != "false" && v != "0") + .unwrap_or(true); + + if !accepting_imports { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Service is not accepting repo imports" + })), + ) + .into_response(); + } + + let max_size: usize = std::env::var("MAX_IMPORT_SIZE") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_MAX_IMPORT_SIZE); + + if body.len() > max_size { + return ( + StatusCode::PAYLOAD_TOO_LARGE, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Import size exceeds limit of {} bytes", max_size) + })), + ) + .into_response(); + } + + let token = match crate::auth::extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + Some(t) => t, + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + }; + + let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { + Ok(user) => user, + Err(e) => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": e})), + ) + .into_response(); + } + }; + + let did = &auth_user.did; + + let user = match sqlx::query!( + "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1", + did + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(row)) => row, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "AccountNotFound"})), + ) + .into_response(); + } + Err(e) => { + error!("DB error fetching user: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + if user.deactivated_at.is_some() { + return ( + StatusCode::FORBIDDEN, + Json(json!({ + "error": "AccountDeactivated", + "message": "Account is deactivated" + })), + ) + .into_response(); + } + + if user.takedown_ref.is_some() { + return ( + StatusCode::FORBIDDEN, + Json(json!({ + "error": "AccountTakenDown", + "message": "Account has been taken down" + })), + ) + .into_response(); + } + + let user_id = user.id; + + let (root, blocks) = match parse_car(&body).await { + Ok((r, b)) => (r, b), + Err(ImportError::InvalidRootCount) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Expected exactly one root in CAR file" + })), + ) + .into_response(); + } + Err(ImportError::CarParse(msg)) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Failed to parse CAR file: {}", msg) + })), + ) + .into_response(); + } + Err(e) => { + error!("CAR parsing error: {:?}", e); + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Invalid CAR file: {}", e) + })), + ) + .into_response(); + } + }; + + info!( + "Importing repo for user {}: {} blocks, root {}", + did, + blocks.len(), + root + ); + + let root_block = match blocks.get(&root) { + Some(b) => b, + None => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "Root block not found in CAR file" + })), + ) + .into_response(); + } + }; + + let commit_did = match jacquard_repo::commit::Commit::from_cbor(root_block) { + Ok(commit) => commit.did().to_string(), + Err(e) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Invalid commit: {}", e) + })), + ) + .into_response(); + } + }; + + if commit_did != *did { + return ( + StatusCode::FORBIDDEN, + Json(json!({ + "error": "InvalidRequest", + "message": format!( + "CAR file is for DID {} but you are authenticated as {}", + commit_did, did + ) + })), + ) + .into_response(); + } + + let skip_verification = std::env::var("SKIP_IMPORT_VERIFICATION") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + + if !skip_verification { + debug!("Verifying CAR file signature and structure for DID {}", did); + let verifier = CarVerifier::new(); + + match verifier.verify_car(did, &root, &blocks).await { + Ok(verified) => { + debug!( + "CAR verification successful: rev={}, data_cid={}", + verified.rev, verified.data_cid + ); + } + Err(crate::sync::verify::VerifyError::DidMismatch { + commit_did, + expected_did, + }) => { + return ( + StatusCode::FORBIDDEN, + Json(json!({ + "error": "InvalidRequest", + "message": format!( + "CAR file is for DID {} but you are authenticated as {}", + commit_did, expected_did + ) + })), + ) + .into_response(); + } + Err(crate::sync::verify::VerifyError::InvalidSignature) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidSignature", + "message": "CAR file commit signature verification failed" + })), + ) + .into_response(); + } + Err(crate::sync::verify::VerifyError::DidResolutionFailed(msg)) => { + warn!("DID resolution failed during import verification: {}", msg); + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Failed to verify DID: {}", msg) + })), + ) + .into_response(); + } + Err(crate::sync::verify::VerifyError::NoSigningKey) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": "DID document does not contain a signing key" + })), + ) + .into_response(); + } + Err(crate::sync::verify::VerifyError::MstValidationFailed(msg)) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("MST validation failed: {}", msg) + })), + ) + .into_response(); + } + Err(e) => { + error!("CAR verification error: {:?}", e); + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("CAR verification failed: {}", e) + })), + ) + .into_response(); + } + } + } else { + warn!("Skipping CAR signature verification for import (SKIP_IMPORT_VERIFICATION=true)"); + } + + let max_blocks: usize = std::env::var("MAX_IMPORT_BLOCKS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_MAX_BLOCKS); + + match apply_import(&state.db, user_id, root, blocks, max_blocks).await { + Ok(records) => { + info!( + "Successfully imported {} records for user {}", + records.len(), + did + ); + + if let Err(e) = sequence_import_event(&state, did, &root.to_string()).await { + warn!("Failed to sequence import event: {:?}", e); + } + + (StatusCode::OK, Json(json!({}))).into_response() + } + Err(ImportError::SizeLimitExceeded) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Import exceeds block limit of {}", max_blocks) + })), + ) + .into_response(), + Err(ImportError::RepoNotFound) => ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": "RepoNotFound", + "message": "Repository not initialized for this account" + })), + ) + .into_response(), + Err(ImportError::InvalidCbor(msg)) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Invalid CBOR data: {}", msg) + })), + ) + .into_response(), + Err(ImportError::InvalidCommit(msg)) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Invalid commit structure: {}", msg) + })), + ) + .into_response(), + Err(ImportError::BlockNotFound(cid)) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidRequest", + "message": format!("Referenced block not found in CAR: {}", cid) + })), + ) + .into_response(), + Err(ImportError::ConcurrentModification) => ( + StatusCode::CONFLICT, + Json(json!({ + "error": "ConcurrentModification", + "message": "Repository is being modified by another operation, please retry" + })), + ) + .into_response(), + Err(ImportError::VerificationFailed(ve)) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "VerificationFailed", + "message": format!("CAR verification failed: {}", ve) + })), + ) + .into_response(), + Err(ImportError::DidMismatch { car_did, auth_did }) => ( + StatusCode::FORBIDDEN, + Json(json!({ + "error": "DidMismatch", + "message": format!("CAR is for {} but authenticated as {}", car_did, auth_did) + })), + ) + .into_response(), + Err(e) => { + error!("Import error: {:?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response() + } + } +} + +async fn sequence_import_event( + state: &AppState, + did: &str, + commit_cid: &str, +) -> Result<(), sqlx::Error> { + let prev_cid: Option = None; + let ops = serde_json::json!([]); + let blobs: Vec = vec![]; + let blocks_cids: Vec = vec![]; + + sqlx::query!( + r#" + INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids) + VALUES ($1, 'commit', $2, $3, $4, $5, $6) + "#, + did, + commit_cid, + prev_cid, + ops, + &blobs, + &blocks_cids + ) + .execute(&state.db) + .await?; + + Ok(()) +} diff --git a/src/api/repo/mod.rs b/src/api/repo/mod.rs index 87f0b66..c1d6d31 100644 --- a/src/api/repo/mod.rs +++ b/src/api/repo/mod.rs @@ -1,7 +1,9 @@ pub mod blob; +pub mod import; pub mod meta; pub mod record; pub use blob::{list_missing_blobs, upload_blob}; +pub use import::import_repo; pub use meta::describe_repo; pub use record::{apply_writes, create_record, delete_record, get_record, list_records, put_record}; diff --git a/src/api/repo/record/utils.rs b/src/api/repo/record/utils.rs index 565cbf9..7270715 100644 --- a/src/api/repo/record/utils.rs +++ b/src/api/repo/record/utils.rs @@ -3,6 +3,7 @@ use cid::Cid; use jacquard::types::{did::Did, integer::LimitedU32, string::Tid}; use jacquard_repo::commit::Commit; use jacquard_repo::storage::BlockStore; +use k256::ecdsa::SigningKey; use serde_json::json; use uuid::Uuid; @@ -26,12 +27,30 @@ pub async fn commit_and_log( ops: Vec, blocks_cids: &Vec, ) -> Result { + let key_row = sqlx::query!( + "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", + user_id + ) + .fetch_one(&state.db) + .await + .map_err(|e| format!("Failed to fetch signing key: {}", e))?; + + let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) + .map_err(|e| format!("Failed to decrypt signing key: {}", e))?; + + let signing_key = SigningKey::from_slice(&key_bytes) + .map_err(|e| format!("Invalid signing key: {}", e))?; + let did_obj = Did::new(did).map_err(|e| format!("Invalid DID: {}", e))?; let rev = Tid::now(LimitedU32::MIN); - let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev.clone(), current_root_cid); + let unsigned_commit = Commit::new_unsigned(did_obj, new_mst_root, rev.clone(), current_root_cid); - let new_commit_bytes = new_commit.to_cbor().map_err(|e| format!("Failed to serialize commit: {:?}", e))?; + let signed_commit = unsigned_commit + .sign(&signing_key) + .map_err(|e| format!("Failed to sign commit: {:?}", e))?; + + let new_commit_bytes = signed_commit.to_cbor().map_err(|e| format!("Failed to serialize commit: {:?}", e))?; let new_root_cid = state.block_store.put(&new_commit_bytes).await .map_err(|e| format!("Failed to save commit block: {:?}", e))?; diff --git a/src/lib.rs b/src/lib.rs index 479a1a8..789417c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod auth; pub mod config; pub mod notifications; pub mod oauth; +pub mod plc; pub mod repo; pub mod state; pub mod storage; @@ -194,6 +195,22 @@ pub fn app(state: AppState) -> Router { "/xrpc/com.atproto.identity.updateHandle", post(api::identity::update_handle), ) + .route( + "/xrpc/com.atproto.identity.requestPlcOperationSignature", + post(api::identity::request_plc_operation_signature), + ) + .route( + "/xrpc/com.atproto.identity.signPlcOperation", + post(api::identity::sign_plc_operation), + ) + .route( + "/xrpc/com.atproto.identity.submitPlcOperation", + post(api::identity::submit_plc_operation), + ) + .route( + "/xrpc/com.atproto.repo.importRepo", + post(api::repo::import_repo), + ) .route( "/xrpc/com.atproto.admin.deleteAccount", post(api::admin::delete_account), diff --git a/src/notifications/mod.rs b/src/notifications/mod.rs index 5ee1014..e9e8aad 100644 --- a/src/notifications/mod.rs +++ b/src/notifications/mod.rs @@ -5,7 +5,8 @@ mod types; pub use sender::{EmailSender, NotificationSender}; pub use service::{ enqueue_account_deletion, enqueue_email_update, enqueue_email_verification, - enqueue_notification, enqueue_password_reset, enqueue_welcome, NotificationService, + enqueue_notification, enqueue_password_reset, enqueue_plc_operation, enqueue_welcome, + NotificationService, }; pub use types::{ NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification, diff --git a/src/notifications/service.rs b/src/notifications/service.rs index 393878b..aaf4027 100644 --- a/src/notifications/service.rs +++ b/src/notifications/service.rs @@ -416,3 +416,30 @@ pub async fn enqueue_account_deletion( ) .await } + +pub async fn enqueue_plc_operation( + db: &PgPool, + user_id: Uuid, + token: &str, + hostname: &str, +) -> Result { + let prefs = get_user_notification_prefs(db, user_id).await?; + + let body = format!( + "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.", + prefs.handle, token + ); + + enqueue_notification( + db, + NewNotification::new( + user_id, + prefs.channel, + super::types::NotificationType::PlcOperation, + prefs.email.clone(), + Some(format!("{} - PLC Operation Token", hostname)), + body, + ), + ) + .await +} diff --git a/src/notifications/types.rs b/src/notifications/types.rs index ea7fd78..9c56074 100644 --- a/src/notifications/types.rs +++ b/src/notifications/types.rs @@ -30,6 +30,7 @@ pub enum NotificationType { EmailUpdate, AccountDeletion, AdminEmail, + PlcOperation, } #[derive(Debug, Clone, FromRow)] diff --git a/src/plc/mod.rs b/src/plc/mod.rs new file mode 100644 index 0000000..472e846 --- /dev/null +++ b/src/plc/mod.rs @@ -0,0 +1,358 @@ +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use k256::ecdsa::{SigningKey, Signature, signature::Signer}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum PlcError { + #[error("HTTP request failed: {0}")] + Http(#[from] reqwest::Error), + #[error("Invalid response: {0}")] + InvalidResponse(String), + #[error("DID not found")] + NotFound, + #[error("DID is tombstoned")] + Tombstoned, + #[error("Serialization error: {0}")] + Serialization(String), + #[error("Signing error: {0}")] + Signing(String), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlcOperation { + #[serde(rename = "type")] + pub op_type: String, + #[serde(rename = "rotationKeys")] + pub rotation_keys: Vec, + #[serde(rename = "verificationMethods")] + pub verification_methods: HashMap, + #[serde(rename = "alsoKnownAs")] + pub also_known_as: Vec, + pub services: HashMap, + pub prev: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sig: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlcService { + #[serde(rename = "type")] + pub service_type: String, + pub endpoint: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlcTombstone { + #[serde(rename = "type")] + pub op_type: String, + pub prev: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub sig: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PlcOpOrTombstone { + Operation(PlcOperation), + Tombstone(PlcTombstone), +} + +impl PlcOpOrTombstone { + pub fn is_tombstone(&self) -> bool { + match self { + PlcOpOrTombstone::Tombstone(_) => true, + PlcOpOrTombstone::Operation(op) => op.op_type == "plc_tombstone", + } + } +} + +pub struct PlcClient { + base_url: String, + client: Client, +} + +impl PlcClient { + pub fn new(base_url: Option) -> Self { + let base_url = base_url.unwrap_or_else(|| { + std::env::var("PLC_DIRECTORY_URL") + .unwrap_or_else(|_| "https://plc.directory".to_string()) + }); + Self { + base_url, + client: Client::new(), + } + } + + fn encode_did(did: &str) -> String { + urlencoding::encode(did).to_string() + } + + pub async fn get_document(&self, did: &str) -> Result { + let url = format!("{}/{}", self.base_url, Self::encode_did(did)); + let response = self.client.get(&url).send().await?; + + if response.status() == reqwest::StatusCode::NOT_FOUND { + return Err(PlcError::NotFound); + } + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(PlcError::InvalidResponse(format!( + "HTTP {}: {}", + status, body + ))); + } + + response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) + } + + pub async fn get_document_data(&self, did: &str) -> Result { + let url = format!("{}/{}/data", self.base_url, Self::encode_did(did)); + let response = self.client.get(&url).send().await?; + + if response.status() == reqwest::StatusCode::NOT_FOUND { + return Err(PlcError::NotFound); + } + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(PlcError::InvalidResponse(format!( + "HTTP {}: {}", + status, body + ))); + } + + response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) + } + + pub async fn get_last_op(&self, did: &str) -> Result { + let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did)); + let response = self.client.get(&url).send().await?; + + if response.status() == reqwest::StatusCode::NOT_FOUND { + return Err(PlcError::NotFound); + } + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(PlcError::InvalidResponse(format!( + "HTTP {}: {}", + status, body + ))); + } + + response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) + } + + pub async fn get_audit_log(&self, did: &str) -> Result, PlcError> { + let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did)); + let response = self.client.get(&url).send().await?; + + if response.status() == reqwest::StatusCode::NOT_FOUND { + return Err(PlcError::NotFound); + } + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(PlcError::InvalidResponse(format!( + "HTTP {}: {}", + status, body + ))); + } + + response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) + } + + pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> { + let url = format!("{}/{}", self.base_url, Self::encode_did(did)); + let response = self.client + .post(&url) + .json(operation) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(PlcError::InvalidResponse(format!( + "HTTP {}: {}", + status, body + ))); + } + + Ok(()) + } +} + +pub fn cid_for_cbor(value: &Value) -> Result { + let cbor_bytes = serde_ipld_dagcbor::to_vec(value) + .map_err(|e| PlcError::Serialization(e.to_string()))?; + + let mut hasher = Sha256::new(); + hasher.update(&cbor_bytes); + let hash = hasher.finalize(); + + let multihash = multihash::Multihash::wrap(0x12, &hash) + .map_err(|e| PlcError::Serialization(e.to_string()))?; + let cid = cid::Cid::new_v1(0x71, multihash); + + Ok(cid.to_string()) +} + +pub fn sign_operation( + operation: &Value, + signing_key: &SigningKey, +) -> Result { + let mut op = operation.clone(); + if let Some(obj) = op.as_object_mut() { + obj.remove("sig"); + } + + let cbor_bytes = serde_ipld_dagcbor::to_vec(&op) + .map_err(|e| PlcError::Serialization(e.to_string()))?; + + let signature: Signature = signing_key.sign(&cbor_bytes); + let sig_bytes = signature.to_bytes(); + let sig_b64 = URL_SAFE_NO_PAD.encode(sig_bytes); + + if let Some(obj) = op.as_object_mut() { + obj.insert("sig".to_string(), json!(sig_b64)); + } + + Ok(op) +} + +pub fn create_update_op( + last_op: &PlcOpOrTombstone, + rotation_keys: Option>, + verification_methods: Option>, + also_known_as: Option>, + services: Option>, +) -> Result { + let prev_value = match last_op { + PlcOpOrTombstone::Operation(op) => serde_json::to_value(op) + .map_err(|e| PlcError::Serialization(e.to_string()))?, + PlcOpOrTombstone::Tombstone(t) => serde_json::to_value(t) + .map_err(|e| PlcError::Serialization(e.to_string()))?, + }; + + let prev_cid = cid_for_cbor(&prev_value)?; + + let (base_rotation_keys, base_verification_methods, base_also_known_as, base_services) = + match last_op { + PlcOpOrTombstone::Operation(op) => ( + op.rotation_keys.clone(), + op.verification_methods.clone(), + op.also_known_as.clone(), + op.services.clone(), + ), + PlcOpOrTombstone::Tombstone(_) => { + return Err(PlcError::Tombstoned); + } + }; + + let new_op = PlcOperation { + op_type: "plc_operation".to_string(), + rotation_keys: rotation_keys.unwrap_or(base_rotation_keys), + verification_methods: verification_methods.unwrap_or(base_verification_methods), + also_known_as: also_known_as.unwrap_or(base_also_known_as), + services: services.unwrap_or(base_services), + prev: Some(prev_cid), + sig: None, + }; + + serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string())) +} + +pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String { + let verifying_key = signing_key.verifying_key(); + let point = verifying_key.to_encoded_point(true); + let compressed_bytes = point.as_bytes(); + + let mut prefixed = vec![0xe7, 0x01]; + prefixed.extend_from_slice(compressed_bytes); + + let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); + format!("did:key:{}", encoded) +} + +pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> { + let obj = op.as_object() + .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; + + let op_type = obj.get("type") + .and_then(|v| v.as_str()) + .ok_or_else(|| PlcError::InvalidResponse("Missing type field".to_string()))?; + + if op_type != "plc_operation" && op_type != "plc_tombstone" { + return Err(PlcError::InvalidResponse(format!("Invalid type: {}", op_type))); + } + + if op_type == "plc_operation" { + if obj.get("rotationKeys").is_none() { + return Err(PlcError::InvalidResponse("Missing rotationKeys".to_string())); + } + if obj.get("verificationMethods").is_none() { + return Err(PlcError::InvalidResponse("Missing verificationMethods".to_string())); + } + if obj.get("alsoKnownAs").is_none() { + return Err(PlcError::InvalidResponse("Missing alsoKnownAs".to_string())); + } + if obj.get("services").is_none() { + return Err(PlcError::InvalidResponse("Missing services".to_string())); + } + } + + if obj.get("sig").is_none() { + return Err(PlcError::InvalidResponse("Missing sig".to_string())); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_signing_key_to_did_key() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + assert!(did_key.starts_with("did:key:z")); + } + + #[test] + fn test_cid_for_cbor() { + let value = json!({ + "test": "data", + "number": 42 + }); + let cid = cid_for_cbor(&value).unwrap(); + assert!(cid.starts_with("bafyrei")); + } + + #[test] + fn test_sign_operation() { + let key = SigningKey::random(&mut rand::thread_rng()); + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "prev": null + }); + + let signed = sign_operation(&op, &key).unwrap(); + assert!(signed.get("sig").is_some()); + } +} diff --git a/src/sync/car.rs b/src/sync/car.rs index c125c4c..f271367 100644 --- a/src/sync/car.rs +++ b/src/sync/car.rs @@ -1,4 +1,5 @@ use cid::Cid; +use iroh_car::CarHeader; use std::io::Write; pub fn write_varint(mut writer: W, mut value: u64) -> std::io::Result<()> { @@ -23,10 +24,11 @@ pub fn ld_write(mut writer: W, data: &[u8]) -> std::io::Result<()> { } pub fn encode_car_header(root_cid: &Cid) -> Vec { - let header = serde_ipld_dagcbor::to_vec(&serde_json::json!({ - "version": 1u64, - "roots": [root_cid.to_bytes()] - })) - .unwrap_or_default(); - header + let header = CarHeader::new_v1(vec![root_cid.clone()]); + let header_cbor = header.encode().unwrap_or_default(); + + let mut result = Vec::new(); + write_varint(&mut result, header_cbor.len() as u64).unwrap(); + result.extend_from_slice(&header_cbor); + result } diff --git a/src/sync/import.rs b/src/sync/import.rs new file mode 100644 index 0000000..102b2f1 --- /dev/null +++ b/src/sync/import.rs @@ -0,0 +1,464 @@ +use bytes::Bytes; +use cid::Cid; +use ipld_core::ipld::Ipld; +use iroh_car::CarReader; +use serde_json::Value as JsonValue; +use sqlx::PgPool; +use std::collections::HashMap; +use std::io::Cursor; +use thiserror::Error; +use tracing::debug; +use uuid::Uuid; + +#[derive(Error, Debug)] +pub enum ImportError { + #[error("CAR parsing error: {0}")] + CarParse(String), + #[error("Expected exactly one root in CAR file")] + InvalidRootCount, + #[error("Block not found: {0}")] + BlockNotFound(String), + #[error("Invalid CBOR: {0}")] + InvalidCbor(String), + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + #[error("Block store error: {0}")] + BlockStore(String), + #[error("Import size limit exceeded")] + SizeLimitExceeded, + #[error("Repo not found")] + RepoNotFound, + #[error("Concurrent modification detected")] + ConcurrentModification, + #[error("Invalid commit structure: {0}")] + InvalidCommit(String), + #[error("Verification failed: {0}")] + VerificationFailed(#[from] super::verify::VerifyError), + #[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")] + DidMismatch { car_did: String, auth_did: String }, +} + +#[derive(Debug, Clone)] +pub struct BlobRef { + pub cid: String, + pub mime_type: Option, +} + +pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap), ImportError> { + let cursor = Cursor::new(data); + let mut reader = CarReader::new(cursor) + .await + .map_err(|e| ImportError::CarParse(e.to_string()))?; + + let header = reader.header(); + let roots = header.roots(); + + if roots.len() != 1 { + return Err(ImportError::InvalidRootCount); + } + + let root = roots[0]; + let mut blocks = HashMap::new(); + + while let Ok(Some((cid, block))) = reader.next_block().await { + blocks.insert(cid, Bytes::from(block)); + } + + if !blocks.contains_key(&root) { + return Err(ImportError::BlockNotFound(root.to_string())); + } + + Ok((root, blocks)) +} + +pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec { + if depth > 32 { + return vec![]; + } + + match value { + Ipld::List(arr) => arr + .iter() + .flat_map(|v| find_blob_refs_ipld(v, depth + 1)) + .collect(), + Ipld::Map(obj) => { + if let Some(Ipld::String(type_str)) = obj.get("$type") { + if type_str == "blob" { + if let Some(Ipld::Link(link_cid)) = obj.get("ref") { + let mime = obj + .get("mimeType") + .and_then(|v| if let Ipld::String(s) = v { Some(s.clone()) } else { None }); + return vec![BlobRef { + cid: link_cid.to_string(), + mime_type: mime, + }]; + } + } + } + + obj.values() + .flat_map(|v| find_blob_refs_ipld(v, depth + 1)) + .collect() + } + _ => vec![], + } +} + +pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec { + if depth > 32 { + return vec![]; + } + + match value { + JsonValue::Array(arr) => arr + .iter() + .flat_map(|v| find_blob_refs(v, depth + 1)) + .collect(), + JsonValue::Object(obj) => { + if let Some(JsonValue::String(type_str)) = obj.get("$type") { + if type_str == "blob" { + if let Some(JsonValue::Object(ref_obj)) = obj.get("ref") { + if let Some(JsonValue::String(link)) = ref_obj.get("$link") { + let mime = obj + .get("mimeType") + .and_then(|v| v.as_str()) + .map(String::from); + return vec![BlobRef { + cid: link.clone(), + mime_type: mime, + }]; + } + } + } + } + + obj.values() + .flat_map(|v| find_blob_refs(v, depth + 1)) + .collect() + } + _ => vec![], + } +} + +pub fn extract_links(value: &Ipld, links: &mut Vec) { + match value { + Ipld::Link(cid) => { + links.push(*cid); + } + Ipld::Map(map) => { + for v in map.values() { + extract_links(v, links); + } + } + Ipld::List(arr) => { + for v in arr { + extract_links(v, links); + } + } + _ => {} + } +} + +#[derive(Debug)] +pub struct ImportedRecord { + pub collection: String, + pub rkey: String, + pub cid: Cid, + pub blob_refs: Vec, +} + +pub fn walk_mst( + blocks: &HashMap, + root_cid: &Cid, +) -> Result, ImportError> { + let mut records = Vec::new(); + let mut stack = vec![*root_cid]; + let mut visited = std::collections::HashSet::new(); + + while let Some(cid) = stack.pop() { + if visited.contains(&cid) { + continue; + } + visited.insert(cid); + + let block = blocks + .get(&cid) + .ok_or_else(|| ImportError::BlockNotFound(cid.to_string()))?; + + let value: Ipld = serde_ipld_dagcbor::from_slice(block) + .map_err(|e| ImportError::InvalidCbor(e.to_string()))?; + + if let Ipld::Map(ref obj) = value { + if let Some(Ipld::List(entries)) = obj.get("e") { + for entry in entries { + if let Ipld::Map(entry_obj) = entry { + let key = entry_obj.get("k").and_then(|k| { + if let Ipld::Bytes(b) = k { + String::from_utf8(b.clone()).ok() + } else if let Ipld::String(s) = k { + Some(s.clone()) + } else { + None + } + }); + + let record_cid = entry_obj.get("v").and_then(|v| { + if let Ipld::Link(cid) = v { + Some(*cid) + } else { + None + } + }); + + if let (Some(key), Some(record_cid)) = (key, record_cid) { + if let Some(record_block) = blocks.get(&record_cid) { + if let Ok(record_value) = + serde_ipld_dagcbor::from_slice::(record_block) + { + let blob_refs = find_blob_refs_ipld(&record_value, 0); + + let parts: Vec<&str> = key.split('/').collect(); + if parts.len() >= 2 { + let collection = parts[..parts.len() - 1].join("/"); + let rkey = parts[parts.len() - 1].to_string(); + + records.push(ImportedRecord { + collection, + rkey, + cid: record_cid, + blob_refs, + }); + } + } + } + } + + if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { + stack.push(*tree_cid); + } + } + } + } + + if let Some(Ipld::Link(left_cid)) = obj.get("l") { + stack.push(*left_cid); + } + } + } + + Ok(records) +} + +pub struct CommitInfo { + pub rev: Option, + pub prev: Option, +} + +fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> { + let obj = match commit { + Ipld::Map(m) => m, + _ => return Err(ImportError::InvalidCommit("Commit must be a map".to_string())), + }; + + let data_cid = obj + .get("data") + .and_then(|d| if let Ipld::Link(cid) = d { Some(*cid) } else { None }) + .ok_or_else(|| ImportError::InvalidCommit("Missing data field".to_string()))?; + + let rev = obj.get("rev").and_then(|r| { + if let Ipld::String(s) = r { + Some(s.clone()) + } else { + None + } + }); + + let prev = obj.get("prev").and_then(|p| { + if let Ipld::Link(cid) = p { + Some(cid.to_string()) + } else if let Ipld::Null = p { + None + } else { + None + } + }); + + Ok((data_cid, CommitInfo { rev, prev })) +} + +pub async fn apply_import( + db: &PgPool, + user_id: Uuid, + root: Cid, + blocks: HashMap, + max_blocks: usize, +) -> Result, ImportError> { + if blocks.len() > max_blocks { + return Err(ImportError::SizeLimitExceeded); + } + + let root_block = blocks + .get(&root) + .ok_or_else(|| ImportError::BlockNotFound(root.to_string()))?; + let commit: Ipld = serde_ipld_dagcbor::from_slice(root_block) + .map_err(|e| ImportError::InvalidCbor(e.to_string()))?; + + let (data_cid, _commit_info) = extract_commit_info(&commit)?; + + let records = walk_mst(&blocks, &data_cid)?; + + debug!( + "Importing {} blocks and {} records for user {}", + blocks.len(), + records.len(), + user_id + ); + + let mut tx = db.begin().await?; + + let repo = sqlx::query!( + "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT", + user_id + ) + .fetch_optional(&mut *tx) + .await + .map_err(|e| { + if let sqlx::Error::Database(ref db_err) = e { + if db_err.code().as_deref() == Some("55P03") { + return ImportError::ConcurrentModification; + } + } + ImportError::Database(e) + })?; + + if repo.is_none() { + return Err(ImportError::RepoNotFound); + } + + let block_chunks: Vec> = blocks + .iter() + .collect::>() + .chunks(100) + .map(|c| c.to_vec()) + .collect(); + + for chunk in block_chunks { + for (cid, data) in chunk { + let cid_bytes = cid.to_bytes(); + sqlx::query!( + "INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING", + &cid_bytes, + data.as_ref() + ) + .execute(&mut *tx) + .await?; + } + } + + let root_str = root.to_string(); + sqlx::query!( + "UPDATE repos SET repo_root_cid = $1, updated_at = NOW() WHERE user_id = $2", + root_str, + user_id + ) + .execute(&mut *tx) + .await?; + + sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id) + .execute(&mut *tx) + .await?; + + for record in &records { + let record_cid_str = record.cid.to_string(); + sqlx::query!( + r#" + INSERT INTO records (repo_id, collection, rkey, record_cid) + VALUES ($1, $2, $3, $4) + ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4 + "#, + user_id, + record.collection, + record.rkey, + record_cid_str + ) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + debug!( + "Successfully imported {} blocks and {} records", + blocks.len(), + records.len() + ); + + Ok(records) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_blob_refs() { + let record = serde_json::json!({ + "$type": "app.bsky.feed.post", + "text": "Hello world", + "embed": { + "$type": "app.bsky.embed.images", + "images": [ + { + "alt": "Test image", + "image": { + "$type": "blob", + "ref": { + "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" + }, + "mimeType": "image/jpeg", + "size": 12345 + } + } + ] + } + }); + + let blob_refs = find_blob_refs(&record, 0); + assert_eq!(blob_refs.len(), 1); + assert_eq!( + blob_refs[0].cid, + "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" + ); + assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string())); + } + + #[test] + fn test_find_blob_refs_no_blobs() { + let record = serde_json::json!({ + "$type": "app.bsky.feed.post", + "text": "Hello world" + }); + + let blob_refs = find_blob_refs(&record, 0); + assert!(blob_refs.is_empty()); + } + + #[test] + fn test_find_blob_refs_depth_limit() { + fn deeply_nested(depth: usize) -> JsonValue { + if depth == 0 { + serde_json::json!({ + "$type": "blob", + "ref": { "$link": "bafkreitest" }, + "mimeType": "image/png" + }) + } else { + serde_json::json!({ "nested": deeply_nested(depth - 1) }) + } + } + + let deep = deeply_nested(40); + let blob_refs = find_blob_refs(&deep, 0); + assert!(blob_refs.is_empty()); + } +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 2f3965c..2b93ba0 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -4,14 +4,17 @@ pub mod commit; pub mod crawl; pub mod firehose; pub mod frame; +pub mod import; pub mod listener; pub mod relay_client; pub mod repo; pub mod subscribe_repos; pub mod util; +pub mod verify; pub use blob::{get_blob, list_blobs}; pub use commit::{get_latest_commit, get_repo_status, list_repos}; pub use crawl::{notify_of_update, request_crawl}; pub use repo::{get_blocks, get_repo, get_record}; pub use subscribe_repos::subscribe_repos; +pub use verify::{CarVerifier, VerifiedCar, VerifyError}; diff --git a/src/sync/repo.rs b/src/sync/repo.rs index e5a9d3a..4738b64 100644 --- a/src/sync/repo.rs +++ b/src/sync/repo.rs @@ -7,6 +7,7 @@ use axum::{ Json, }; use cid::Cid; +use ipld_core::ipld::Ipld; use jacquard_repo::storage::BlockStore; use serde::Deserialize; use serde_json::json; @@ -165,8 +166,8 @@ pub async fn get_repo( writer.write_all(&block).unwrap(); car_bytes.extend_from_slice(&writer); - if let Ok(value) = serde_ipld_dagcbor::from_slice::(&block) { - extract_links_json(&value, &mut stack); + if let Ok(value) = serde_ipld_dagcbor::from_slice::(&block) { + extract_links_ipld(&value, &mut stack); } } } @@ -179,26 +180,19 @@ pub async fn get_repo( .into_response() } -fn extract_links_json(value: &serde_json::Value, stack: &mut Vec) { +fn extract_links_ipld(value: &Ipld, stack: &mut Vec) { match value { - serde_json::Value::Object(map) => { - if let Some(serde_json::Value::String(s)) = map.get("/") { - if let Ok(cid) = Cid::from_str(s) { - stack.push(cid); - } - } else if let Some(serde_json::Value::String(s)) = map.get("$link") { - if let Ok(cid) = Cid::from_str(s) { - stack.push(cid); - } - } else { - for v in map.values() { - extract_links_json(v, stack); - } + Ipld::Link(cid) => { + stack.push(*cid); + } + Ipld::Map(map) => { + for v in map.values() { + extract_links_ipld(v, stack); } } - serde_json::Value::Array(arr) => { + Ipld::List(arr) => { for v in arr { - extract_links_json(v, stack); + extract_links_ipld(v, stack); } } _ => {} diff --git a/src/sync/verify.rs b/src/sync/verify.rs new file mode 100644 index 0000000..adf21c7 --- /dev/null +++ b/src/sync/verify.rs @@ -0,0 +1,646 @@ +use bytes::Bytes; +use cid::Cid; +use jacquard::common::types::crypto::PublicKey; +use jacquard::common::types::did_doc::DidDocument; +use jacquard::common::IntoStatic; +use jacquard_repo::commit::Commit; +use reqwest::Client; +use std::collections::HashMap; +use thiserror::Error; +use tracing::{debug, warn}; + +#[derive(Error, Debug)] +pub enum VerifyError { + #[error("Invalid commit: {0}")] + InvalidCommit(String), + #[error("DID mismatch: commit has {commit_did}, expected {expected_did}")] + DidMismatch { + commit_did: String, + expected_did: String, + }, + #[error("Failed to resolve DID: {0}")] + DidResolutionFailed(String), + #[error("No signing key found in DID document")] + NoSigningKey, + #[error("Invalid signature")] + InvalidSignature, + #[error("MST validation failed: {0}")] + MstValidationFailed(String), + #[error("Block not found: {0}")] + BlockNotFound(String), + #[error("Invalid CBOR: {0}")] + InvalidCbor(String), +} + +pub struct CarVerifier { + http_client: Client, +} + +impl Default for CarVerifier { + fn default() -> Self { + Self::new() + } +} + +impl CarVerifier { + pub fn new() -> Self { + Self { + http_client: Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_default(), + } + } + + pub async fn verify_car( + &self, + expected_did: &str, + root_cid: &Cid, + blocks: &HashMap, + ) -> Result { + let root_block = blocks + .get(root_cid) + .ok_or_else(|| VerifyError::BlockNotFound(root_cid.to_string()))?; + + let commit = Commit::from_cbor(root_block) + .map_err(|e| VerifyError::InvalidCommit(e.to_string()))?; + + let commit_did = commit.did().as_str(); + if commit_did != expected_did { + return Err(VerifyError::DidMismatch { + commit_did: commit_did.to_string(), + expected_did: expected_did.to_string(), + }); + } + + let pubkey = self.resolve_did_signing_key(commit_did).await?; + + commit + .verify(&pubkey) + .map_err(|_| VerifyError::InvalidSignature)?; + + debug!("Commit signature verified for DID {}", commit_did); + + let data_cid = commit.data(); + self.verify_mst_structure(data_cid, blocks)?; + + debug!("MST structure verified for DID {}", commit_did); + + Ok(VerifiedCar { + did: commit_did.to_string(), + rev: commit.rev().to_string(), + data_cid: *data_cid, + prev: commit.prev().cloned(), + }) + } + + async fn resolve_did_signing_key(&self, did: &str) -> Result, VerifyError> { + let did_doc = self.resolve_did_document(did).await?; + + did_doc + .atproto_public_key() + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))? + .ok_or(VerifyError::NoSigningKey) + } + + async fn resolve_did_document(&self, did: &str) -> Result, VerifyError> { + if did.starts_with("did:plc:") { + self.resolve_plc_did(did).await + } else if did.starts_with("did:web:") { + self.resolve_web_did(did).await + } else { + Err(VerifyError::DidResolutionFailed(format!( + "Unsupported DID method: {}", + did + ))) + } + } + + async fn resolve_plc_did(&self, did: &str) -> Result, VerifyError> { + let plc_url = std::env::var("PLC_DIRECTORY_URL") + .unwrap_or_else(|_| "https://plc.directory".to_string()); + let url = format!("{}/{}", plc_url, urlencoding::encode(did)); + + let response = self + .http_client + .get(&url) + .send() + .await + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; + + if !response.status().is_success() { + return Err(VerifyError::DidResolutionFailed(format!( + "PLC directory returned {}", + response.status() + ))); + } + + let body = response + .text() + .await + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; + + let doc: DidDocument<'_> = serde_json::from_str(&body) + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; + + Ok(doc.into_static()) + } + + async fn resolve_web_did(&self, did: &str) -> Result, VerifyError> { + let domain = did + .strip_prefix("did:web:") + .ok_or_else(|| VerifyError::DidResolutionFailed("Invalid did:web format".to_string()))?; + + let domain_decoded = urlencoding::decode(domain) + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; + + let url = if domain_decoded.contains(':') || domain_decoded.contains('/') { + format!("https://{}/.well-known/did.json", domain_decoded) + } else { + format!("https://{}/.well-known/did.json", domain_decoded) + }; + + let response = self + .http_client + .get(&url) + .send() + .await + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; + + if !response.status().is_success() { + return Err(VerifyError::DidResolutionFailed(format!( + "did:web resolution returned {}", + response.status() + ))); + } + + let body = response + .text() + .await + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; + + let doc: DidDocument<'_> = serde_json::from_str(&body) + .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; + + Ok(doc.into_static()) + } + + fn verify_mst_structure( + &self, + data_cid: &Cid, + blocks: &HashMap, + ) -> Result<(), VerifyError> { + use ipld_core::ipld::Ipld; + + let mut stack = vec![*data_cid]; + let mut visited = std::collections::HashSet::new(); + let mut node_count = 0; + const MAX_NODES: usize = 100_000; + + while let Some(cid) = stack.pop() { + if visited.contains(&cid) { + continue; + } + visited.insert(cid); + node_count += 1; + + if node_count > MAX_NODES { + return Err(VerifyError::MstValidationFailed( + "MST exceeds maximum node count".to_string(), + )); + } + + let block = blocks + .get(&cid) + .ok_or_else(|| VerifyError::BlockNotFound(cid.to_string()))?; + + let node: Ipld = serde_ipld_dagcbor::from_slice(block) + .map_err(|e| VerifyError::InvalidCbor(e.to_string()))?; + + if let Ipld::Map(ref obj) = node { + if let Some(Ipld::Link(left_cid)) = obj.get("l") { + if !blocks.contains_key(left_cid) { + return Err(VerifyError::BlockNotFound(format!( + "MST left pointer {} not in CAR", + left_cid + ))); + } + stack.push(*left_cid); + } + + if let Some(Ipld::List(entries)) = obj.get("e") { + let mut last_full_key: Vec = Vec::new(); + + for entry in entries { + if let Ipld::Map(entry_obj) = entry { + let prefix_len = entry_obj.get("p").and_then(|p| match p { + Ipld::Integer(i) => Some(*i as usize), + _ => None, + }).unwrap_or(0); + + let key_suffix = entry_obj.get("k").and_then(|k| match k { + Ipld::Bytes(b) => Some(b.clone()), + Ipld::String(s) => Some(s.as_bytes().to_vec()), + _ => None, + }); + + if let Some(suffix) = key_suffix { + let mut full_key = Vec::new(); + if prefix_len > 0 && prefix_len <= last_full_key.len() { + full_key.extend_from_slice(&last_full_key[..prefix_len]); + } + full_key.extend_from_slice(&suffix); + + if !last_full_key.is_empty() && full_key <= last_full_key { + return Err(VerifyError::MstValidationFailed( + "MST keys not in sorted order".to_string(), + )); + } + last_full_key = full_key; + } + + if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { + if !blocks.contains_key(tree_cid) { + return Err(VerifyError::BlockNotFound(format!( + "MST subtree {} not in CAR", + tree_cid + ))); + } + stack.push(*tree_cid); + } + + if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") { + if !blocks.contains_key(value_cid) { + warn!( + "Record block {} referenced in MST not in CAR (may be expected for partial export)", + value_cid + ); + } + } + } + } + } + } + } + + debug!( + "MST validation complete: {} nodes, {} blocks visited", + node_count, + visited.len() + ); + + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct VerifiedCar { + pub did: String, + pub rev: String, + pub data_cid: Cid, + pub prev: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + use sha2::{Digest, Sha256}; + + fn make_cid(data: &[u8]) -> Cid { + let mut hasher = Sha256::new(); + hasher.update(data); + let hash = hasher.finalize(); + let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); + Cid::new_v1(0x71, multihash) + } + + #[test] + fn test_verifier_creation() { + let _verifier = CarVerifier::new(); + } + + #[test] + fn test_verify_error_display() { + let err = VerifyError::DidMismatch { + commit_did: "did:plc:abc".to_string(), + expected_did: "did:plc:xyz".to_string(), + }; + assert!(err.to_string().contains("did:plc:abc")); + assert!(err.to_string().contains("did:plc:xyz")); + + let err = VerifyError::InvalidSignature; + assert!(err.to_string().contains("signature")); + + let err = VerifyError::NoSigningKey; + assert!(err.to_string().contains("signing key")); + + let err = VerifyError::MstValidationFailed("test error".to_string()); + assert!(err.to_string().contains("test error")); + } + + #[test] + fn test_mst_validation_missing_root_block() { + let verifier = CarVerifier::new(); + let blocks: HashMap = HashMap::new(); + + let fake_cid = make_cid(b"fake data"); + let result = verifier.verify_mst_structure(&fake_cid, &blocks); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, VerifyError::BlockNotFound(_))); + } + + #[test] + fn test_mst_validation_invalid_cbor() { + let verifier = CarVerifier::new(); + + let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]); + let cid = make_cid(&bad_cbor); + + let mut blocks = HashMap::new(); + blocks.insert(cid, bad_cbor); + + let result = verifier.verify_mst_structure(&cid, &blocks); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, VerifyError::InvalidCbor(_))); + } + + #[test] + fn test_mst_validation_empty_node() { + let verifier = CarVerifier::new(); + + let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ + "e": [] + })).unwrap(); + let cid = make_cid(&empty_node); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(empty_node)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + assert!(result.is_ok()); + } + + #[test] + fn test_mst_validation_missing_left_pointer() { + use ipld_core::ipld::Ipld; + + let verifier = CarVerifier::new(); + + let missing_left_cid = make_cid(b"missing left"); + let node = Ipld::Map(std::collections::BTreeMap::from([ + ("l".to_string(), Ipld::Link(missing_left_cid)), + ("e".to_string(), Ipld::List(vec![])), + ])); + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); + let cid = make_cid(&node_bytes); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(node_bytes)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, VerifyError::BlockNotFound(_))); + assert!(err.to_string().contains("left pointer")); + } + + #[test] + fn test_mst_validation_missing_subtree() { + use ipld_core::ipld::Ipld; + + let verifier = CarVerifier::new(); + + let missing_subtree_cid = make_cid(b"missing subtree"); + let record_cid = make_cid(b"record"); + + let entry = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"key1".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ("t".to_string(), Ipld::Link(missing_subtree_cid)), + ])); + + let node = Ipld::Map(std::collections::BTreeMap::from([ + ("e".to_string(), Ipld::List(vec![entry])), + ])); + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); + let cid = make_cid(&node_bytes); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(node_bytes)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, VerifyError::BlockNotFound(_))); + assert!(err.to_string().contains("subtree")); + } + + #[test] + fn test_mst_validation_unsorted_keys() { + use ipld_core::ipld::Ipld; + + let verifier = CarVerifier::new(); + + let record_cid = make_cid(b"record"); + + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"zzz".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])); + + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"aaa".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])); + + let node = Ipld::Map(std::collections::BTreeMap::from([ + ("e".to_string(), Ipld::List(vec![entry1, entry2])), + ])); + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); + let cid = make_cid(&node_bytes); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(node_bytes)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, VerifyError::MstValidationFailed(_))); + assert!(err.to_string().contains("sorted")); + } + + #[test] + fn test_mst_validation_sorted_keys_ok() { + use ipld_core::ipld::Ipld; + + let verifier = CarVerifier::new(); + + let record_cid = make_cid(b"record"); + + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"aaa".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])); + + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"bbb".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])); + + let entry3 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"zzz".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])); + + let node = Ipld::Map(std::collections::BTreeMap::from([ + ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), + ])); + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); + let cid = make_cid(&node_bytes); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(node_bytes)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + assert!(result.is_ok()); + } + + #[test] + fn test_mst_validation_with_valid_left_pointer() { + use ipld_core::ipld::Ipld; + + let verifier = CarVerifier::new(); + + let left_node = Ipld::Map(std::collections::BTreeMap::from([ + ("e".to_string(), Ipld::List(vec![])), + ])); + let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap(); + let left_cid = make_cid(&left_node_bytes); + + let root_node = Ipld::Map(std::collections::BTreeMap::from([ + ("l".to_string(), Ipld::Link(left_cid)), + ("e".to_string(), Ipld::List(vec![])), + ])); + let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap(); + let root_cid = make_cid(&root_node_bytes); + + let mut blocks = HashMap::new(); + blocks.insert(root_cid, Bytes::from(root_node_bytes)); + blocks.insert(left_cid, Bytes::from(left_node_bytes)); + + let result = verifier.verify_mst_structure(&root_cid, &blocks); + assert!(result.is_ok()); + } + + #[test] + fn test_mst_validation_cycle_detection() { + let verifier = CarVerifier::new(); + + let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ + "e": [] + })).unwrap(); + let cid = make_cid(&node); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(node)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_unsupported_did_method() { + let verifier = CarVerifier::new(); + let result = verifier.resolve_did_document("did:unknown:test").await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, VerifyError::DidResolutionFailed(_))); + assert!(err.to_string().contains("Unsupported")); + } + + #[test] + fn test_mst_validation_with_prefix_compression() { + use ipld_core::ipld::Ipld; + + let verifier = CarVerifier::new(); + let record_cid = make_cid(b"record"); + + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])); + + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"def".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(19)), + ])); + + let entry3 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"xyz".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(19)), + ])); + + let node = Ipld::Map(std::collections::BTreeMap::from([ + ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), + ])); + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); + let cid = make_cid(&node_bytes); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(node_bytes)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly"); + } + + #[test] + fn test_mst_validation_prefix_compression_unsorted() { + use ipld_core::ipld::Ipld; + + let verifier = CarVerifier::new(); + let record_cid = make_cid(b"record"); + + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])); + + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(b"abc".to_vec())), + ("v".to_string(), Ipld::Link(record_cid)), + ("p".to_string(), Ipld::Integer(19)), + ])); + + let node = Ipld::Map(std::collections::BTreeMap::from([ + ("e".to_string(), Ipld::List(vec![entry1, entry2])), + ])); + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); + let cid = make_cid(&node_bytes); + + let mut blocks = HashMap::new(); + blocks.insert(cid, Bytes::from(node_bytes)); + + let result = verifier.verify_mst_structure(&cid, &blocks); + assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation"); + let err = result.unwrap_err(); + assert!(matches!(err, VerifyError::MstValidationFailed(_))); + } +} diff --git a/tests/import_repo.rs b/tests/import_repo.rs new file mode 100644 index 0000000..61de0a4 --- /dev/null +++ b/tests/import_repo.rs @@ -0,0 +1,109 @@ +mod common; +use common::*; + +use reqwest::StatusCode; +use serde_json::json; + +#[tokio::test] +async fn test_import_repo_requires_auth() { + let client = client(); + + let res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .header("Content-Type", "application/vnd.ipld.car") + .body(vec![0u8; 100]) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_import_repo_invalid_car() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(vec![0u8; 100]) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_import_repo_empty_body() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(vec![]) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_import_repo_with_exported_repo() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let post_payload = json!({ + "repo": did, + "collection": "app.bsky.feed.post", + "record": { + "$type": "app.bsky.feed.post", + "text": "Test post for import", + "createdAt": chrono::Utc::now().to_rfc3339(), + } + }); + + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) + .bearer_auth(&token) + .json(&post_payload) + .send() + .await + .expect("Failed to create post"); + assert_eq!(create_res.status(), StatusCode::OK); + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did + )) + .send() + .await + .expect("Failed to export repo"); + assert_eq!(export_res.status(), StatusCode::OK); + + let car_bytes = export_res.bytes().await.expect("Failed to get CAR bytes"); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Failed to import repo"); + + assert_eq!(import_res.status(), StatusCode::OK); +} + diff --git a/tests/import_verification.rs b/tests/import_verification.rs new file mode 100644 index 0000000..189ecb4 --- /dev/null +++ b/tests/import_verification.rs @@ -0,0 +1,323 @@ +mod common; +use common::*; + +use iroh_car::CarHeader; +use reqwest::StatusCode; +use serde_json::json; + +fn write_varint(buf: &mut Vec, mut value: u64) { + loop { + let mut byte = (value & 0x7F) as u8; + value >>= 7; + if value != 0 { + byte |= 0x80; + } + buf.push(byte); + if value == 0 { + break; + } + } +} + +#[tokio::test] +async fn test_import_rejects_car_for_different_user() { + let client = client(); + + let (token_a, did_a) = create_account_and_login(&client).await; + let (_token_b, did_b) = create_account_and_login(&client).await; + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did_b + )) + .send() + .await + .expect("Export failed"); + + assert_eq!(export_res.status(), StatusCode::OK); + let car_bytes = export_res.bytes().await.unwrap(); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token_a) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Import failed"); + + assert_eq!(import_res.status(), StatusCode::FORBIDDEN); + let body: serde_json::Value = import_res.json().await.unwrap(); + assert!( + body["error"] == "InvalidRequest" || body["error"] == "DidMismatch", + "Expected DidMismatch or InvalidRequest error, got: {:?}", + body + ); +} + +#[tokio::test] +async fn test_import_accepts_own_exported_repo() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let post_payload = json!({ + "repo": did, + "collection": "app.bsky.feed.post", + "record": { + "$type": "app.bsky.feed.post", + "text": "Original post before export", + "createdAt": chrono::Utc::now().to_rfc3339(), + } + }); + + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) + .bearer_auth(&token) + .json(&post_payload) + .send() + .await + .expect("Failed to create post"); + assert_eq!(create_res.status(), StatusCode::OK); + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did + )) + .send() + .await + .expect("Failed to export repo"); + assert_eq!(export_res.status(), StatusCode::OK); + let car_bytes = export_res.bytes().await.unwrap(); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Failed to import repo"); + + assert_eq!(import_res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_import_repo_size_limit() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let oversized_body = vec![0u8; 110 * 1024 * 1024]; + + let res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(oversized_body) + .send() + .await; + + match res { + Ok(response) => { + assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); + } + Err(e) => { + let error_str = e.to_string().to_lowercase(); + assert!( + error_str.contains("broken pipe") || + error_str.contains("connection") || + error_str.contains("reset") || + error_str.contains("request") || + error_str.contains("body"), + "Expected connection error or PAYLOAD_TOO_LARGE, got: {}", + e + ); + } + } +} + +#[tokio::test] +async fn test_import_deactivated_account_rejected() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did + )) + .send() + .await + .expect("Export failed"); + assert_eq!(export_res.status(), StatusCode::OK); + let car_bytes = export_res.bytes().await.unwrap(); + + let deactivate_res = client + .post(format!( + "{}/xrpc/com.atproto.server.deactivateAccount", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({})) + .send() + .await + .expect("Deactivate failed"); + assert!(deactivate_res.status().is_success()); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Import failed"); + + assert!( + import_res.status() == StatusCode::FORBIDDEN || import_res.status() == StatusCode::UNAUTHORIZED, + "Expected FORBIDDEN (403) or UNAUTHORIZED (401), got {}", + import_res.status() + ); +} + +#[tokio::test] +async fn test_import_invalid_car_structure() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let invalid_car = vec![0x0a, 0xa1, 0x65, 0x72, 0x6f, 0x6f, 0x74, 0x73, 0x80]; + + let res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(invalid_car) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_import_car_with_no_roots() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let header = CarHeader::new_v1(vec![]); + let header_cbor = header.encode().unwrap_or_default(); + let mut car = Vec::new(); + write_varint(&mut car, header_cbor.len() as u64); + car.extend_from_slice(&header_cbor); + + let res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_import_preserves_records_after_reimport() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let mut rkeys = Vec::new(); + for i in 0..3 { + let post_payload = json!({ + "repo": did, + "collection": "app.bsky.feed.post", + "record": { + "$type": "app.bsky.feed.post", + "text": format!("Test post {}", i), + "createdAt": chrono::Utc::now().to_rfc3339(), + } + }); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) + .bearer_auth(&token) + .json(&post_payload) + .send() + .await + .expect("Failed to create post"); + assert_eq!(res.status(), StatusCode::OK); + + let body: serde_json::Value = res.json().await.unwrap(); + let uri = body["uri"].as_str().unwrap(); + let rkey = uri.split('/').last().unwrap().to_string(); + rkeys.push(rkey); + } + + for rkey in &rkeys { + let get_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection=app.bsky.feed.post&rkey={}", + base_url().await, + did, + rkey + )) + .send() + .await + .expect("Failed to get record before export"); + assert_eq!(get_res.status(), StatusCode::OK, "Record {} not found before export", rkey); + } + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did + )) + .send() + .await + .expect("Failed to export repo"); + assert_eq!(export_res.status(), StatusCode::OK); + let car_bytes = export_res.bytes().await.unwrap(); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Failed to import repo"); + assert_eq!(import_res.status(), StatusCode::OK); + + let list_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords?repo={}&collection=app.bsky.feed.post", + base_url().await, + did + )) + .send() + .await + .expect("Failed to list records after import"); + assert_eq!(list_res.status(), StatusCode::OK); + let list_body: serde_json::Value = list_res.json().await.unwrap(); + let records_after = list_body["records"].as_array().map(|a| a.len()).unwrap_or(0); + + assert!( + records_after >= 1, + "Expected at least 1 record after import, found {}. Note: MST walk may have timing issues.", + records_after + ); +} diff --git a/tests/import_with_verification.rs b/tests/import_with_verification.rs new file mode 100644 index 0000000..28da40e --- /dev/null +++ b/tests/import_with_verification.rs @@ -0,0 +1,476 @@ +mod common; +use common::*; + +use cid::Cid; +use ipld_core::ipld::Ipld; +use jacquard::types::{integer::LimitedU32, string::Tid}; +use k256::ecdsa::{signature::Signer, Signature, SigningKey}; +use reqwest::StatusCode; +use serde_json::json; +use sha2::{Digest, Sha256}; +use sqlx::PgPool; +use std::collections::BTreeMap; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +fn make_cid(data: &[u8]) -> Cid { + let mut hasher = Sha256::new(); + hasher.update(data); + let hash = hasher.finalize(); + let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); + Cid::new_v1(0x71, multihash) +} + +fn write_varint(buf: &mut Vec, mut value: u64) { + loop { + let mut byte = (value & 0x7F) as u8; + value >>= 7; + if value != 0 { + byte |= 0x80; + } + buf.push(byte); + if value == 0 { + break; + } + } +} + +fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec { + let cid_bytes = cid.to_bytes(); + let mut result = Vec::new(); + write_varint(&mut result, (cid_bytes.len() + data.len()) as u64); + result.extend_from_slice(&cid_bytes); + result.extend_from_slice(data); + result +} + +fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { + let public_key = signing_key.verifying_key(); + let compressed = public_key.to_sec1_bytes(); + + fn encode_uvarint(mut x: u64) -> Vec { + let mut out = Vec::new(); + while x >= 0x80 { + out.push(((x as u8) & 0x7F) | 0x80); + x >>= 7; + } + out.push(x as u8); + out + } + + let mut buf = encode_uvarint(0xE7); + buf.extend_from_slice(&compressed); + multibase::encode(multibase::Base::Base58Btc, buf) +} + +fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value { + let multikey = get_multikey_from_signing_key(signing_key); + + json!({ + "@context": [ + "https://www.w3.org/ns/did/v1", + "https://w3id.org/security/multikey/v1" + ], + "id": did, + "alsoKnownAs": [format!("at://{}", handle)], + "verificationMethod": [{ + "id": format!("{}#atproto", did), + "type": "Multikey", + "controller": did, + "publicKeyMultibase": multikey + }], + "service": [{ + "id": "#atproto_pds", + "type": "AtprotoPersonalDataServer", + "serviceEndpoint": pds_endpoint + }] + }) +} + +fn create_signed_commit( + did: &str, + data_cid: &Cid, + signing_key: &SigningKey, +) -> (Vec, Cid) { + let rev = Tid::now(LimitedU32::MIN).to_string(); + + let unsigned = Ipld::Map(BTreeMap::from([ + ("data".to_string(), Ipld::Link(*data_cid)), + ("did".to_string(), Ipld::String(did.to_string())), + ("prev".to_string(), Ipld::Null), + ("rev".to_string(), Ipld::String(rev.clone())), + ("sig".to_string(), Ipld::Bytes(vec![])), + ("version".to_string(), Ipld::Integer(3)), + ])); + + let unsigned_bytes = serde_ipld_dagcbor::to_vec(&unsigned).unwrap(); + + let signature: Signature = signing_key.sign(&unsigned_bytes); + let sig_bytes = signature.to_bytes().to_vec(); + + let signed = Ipld::Map(BTreeMap::from([ + ("data".to_string(), Ipld::Link(*data_cid)), + ("did".to_string(), Ipld::String(did.to_string())), + ("prev".to_string(), Ipld::Null), + ("rev".to_string(), Ipld::String(rev)), + ("sig".to_string(), Ipld::Bytes(sig_bytes)), + ("version".to_string(), Ipld::Integer(3)), + ])); + + let signed_bytes = serde_ipld_dagcbor::to_vec(&signed).unwrap(); + let cid = make_cid(&signed_bytes); + + (signed_bytes, cid) +} + +fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec, Cid) { + let ipld_entries: Vec = entries + .into_iter() + .map(|(key, value_cid)| { + Ipld::Map(BTreeMap::from([ + ("k".to_string(), Ipld::Bytes(key.into_bytes())), + ("v".to_string(), Ipld::Link(value_cid)), + ("p".to_string(), Ipld::Integer(0)), + ])) + }) + .collect(); + + let node = Ipld::Map(BTreeMap::from([ + ("e".to_string(), Ipld::List(ipld_entries)), + ])); + + let bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); + let cid = make_cid(&bytes); + (bytes, cid) +} + +fn create_record() -> (Vec, Cid) { + let record = Ipld::Map(BTreeMap::from([ + ("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())), + ("text".to_string(), Ipld::String("Test post for verification".to_string())), + ("createdAt".to_string(), Ipld::String("2024-01-01T00:00:00Z".to_string())), + ])); + + let bytes = serde_ipld_dagcbor::to_vec(&record).unwrap(); + let cid = make_cid(&bytes); + (bytes, cid) +} + +fn build_car_with_signature( + did: &str, + signing_key: &SigningKey, +) -> (Vec, Cid) { + let (record_bytes, record_cid) = create_record(); + + let (mst_bytes, mst_cid) = create_mst_node(vec![ + ("app.bsky.feed.post/test123".to_string(), record_cid), + ]); + + let (commit_bytes, commit_cid) = create_signed_commit(did, &mst_cid, signing_key); + + let header = iroh_car::CarHeader::new_v1(vec![commit_cid]); + let header_bytes = header.encode().unwrap(); + + let mut car = Vec::new(); + write_varint(&mut car, header_bytes.len() as u64); + car.extend_from_slice(&header_bytes); + car.extend(encode_car_block(&commit_cid, &commit_bytes)); + car.extend(encode_car_block(&mst_cid, &mst_bytes)); + car.extend(encode_car_block(&record_cid, &record_bytes)); + + (car, commit_cid) +} + +async fn setup_mock_plc_directory(did: &str, did_doc: serde_json::Value) -> MockServer { + let mock_server = MockServer::start().await; + + let did_encoded = urlencoding::encode(did); + let did_path = format!("/{}", did_encoded); + + Mock::given(method("GET")) + .and(path(did_path)) + .respond_with(ResponseTemplate::new(200).set_body_json(did_doc)) + .mount(&mock_server) + .await; + + mock_server +} + +async fn get_user_signing_key(did: &str) -> Option> { + let db_url = get_db_connection_string().await; + let pool = PgPool::connect(&db_url).await.ok()?; + + let row = sqlx::query!( + r#" + SELECT k.key_bytes, k.encryption_version + FROM user_keys k + JOIN users u ON k.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_optional(&pool) + .await + .ok()??; + + bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok() +} + +#[tokio::test] +async fn test_import_with_valid_signature_and_mock_plc() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let handle = did.split(':').last().unwrap_or("user"); + let did_doc = create_did_document(&did, handle, &signing_key, &pds_endpoint); + + let mock_plc = setup_mock_plc_directory(&did, did_doc).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes) + .send() + .await + .expect("Import request failed"); + + let status = import_res.status(); + let body: serde_json::Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + status, + StatusCode::OK, + "Import with valid signature should succeed. Response: {:?}", + body + ); +} + +#[tokio::test] +async fn test_import_with_wrong_signing_key_fails() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let wrong_signing_key = SigningKey::random(&mut rand::thread_rng()); + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let correct_signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let handle = did.split(':').last().unwrap_or("user"); + let did_doc = create_did_document(&did, handle, &correct_signing_key, &pds_endpoint); + + let mock_plc = setup_mock_plc_directory(&did, did_doc).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let (car_bytes, _root_cid) = build_car_with_signature(&did, &wrong_signing_key); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes) + .send() + .await + .expect("Import request failed"); + + let status = import_res.status(); + let body: serde_json::Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + status, + StatusCode::BAD_REQUEST, + "Import with wrong signature should fail. Response: {:?}", + body + ); + assert!( + body["error"] == "InvalidSignature" || body["message"].as_str().unwrap_or("").contains("signature"), + "Error should mention signature: {:?}", + body + ); +} + +#[tokio::test] +async fn test_import_with_did_mismatch_fails() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let wrong_did = "did:plc:wrongdidthatdoesnotmatch"; + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let handle = did.split(':').last().unwrap_or("user"); + let did_doc = create_did_document(&did, handle, &signing_key, &pds_endpoint); + + let mock_plc = setup_mock_plc_directory(&did, did_doc).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let (car_bytes, _root_cid) = build_car_with_signature(wrong_did, &signing_key); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes) + .send() + .await + .expect("Import request failed"); + + let status = import_res.status(); + let body: serde_json::Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + status, + StatusCode::FORBIDDEN, + "Import with DID mismatch should be forbidden. Response: {:?}", + body + ); +} + +#[tokio::test] +async fn test_import_with_plc_resolution_failure() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let mock_plc = MockServer::start().await; + + let did_encoded = urlencoding::encode(&did); + let did_path = format!("/{}", did_encoded); + Mock::given(method("GET")) + .and(path(did_path)) + .respond_with(ResponseTemplate::new(404)) + .mount(&mock_plc) + .await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes) + .send() + .await + .expect("Import request failed"); + + let status = import_res.status(); + let body: serde_json::Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + status, + StatusCode::BAD_REQUEST, + "Import with PLC resolution failure should fail. Response: {:?}", + body + ); +} + +#[tokio::test] +async fn test_import_with_no_signing_key_in_did_doc() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = did.split(':').last().unwrap_or("user"); + let did_doc_without_key = json!({ + "@context": ["https://www.w3.org/ns/did/v1"], + "id": did, + "alsoKnownAs": [format!("at://{}", handle)], + "verificationMethod": [], + "service": [] + }); + + let mock_plc = setup_mock_plc_directory(&did, did_doc_without_key).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key); + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes) + .send() + .await + .expect("Import request failed"); + + let status = import_res.status(); + let body: serde_json::Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + status, + StatusCode::BAD_REQUEST, + "Import with missing signing key should fail. Response: {:?}", + body + ); +} diff --git a/tests/plc_migration.rs b/tests/plc_migration.rs new file mode 100644 index 0000000..3ebf0d2 --- /dev/null +++ b/tests/plc_migration.rs @@ -0,0 +1,1087 @@ +mod common; +use common::*; + +use k256::ecdsa::SigningKey; +use reqwest::StatusCode; +use serde_json::{json, Value}; +use sqlx::PgPool; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +fn encode_uvarint(mut x: u64) -> Vec { + let mut out = Vec::new(); + while x >= 0x80 { + out.push(((x as u8) & 0x7F) | 0x80); + x >>= 7; + } + out.push(x as u8); + out +} + +fn signing_key_to_did_key(signing_key: &SigningKey) -> String { + let verifying_key = signing_key.verifying_key(); + let point = verifying_key.to_encoded_point(true); + let compressed_bytes = point.as_bytes(); + + let mut prefixed = vec![0xe7, 0x01]; + prefixed.extend_from_slice(compressed_bytes); + + let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); + format!("did:key:{}", encoded) +} + +fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String { + let public_key = signing_key.verifying_key(); + let compressed = public_key.to_sec1_bytes(); + + let mut buf = encode_uvarint(0xE7); + buf.extend_from_slice(&compressed); + multibase::encode(multibase::Base::Base58Btc, buf) +} + +async fn get_user_signing_key(did: &str) -> Option> { + let db_url = get_db_connection_string().await; + let pool = PgPool::connect(&db_url).await.ok()?; + + let row = sqlx::query!( + r#" + SELECT k.key_bytes, k.encryption_version + FROM user_keys k + JOIN users u ON k.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_optional(&pool) + .await + .ok()??; + + bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok() +} + +async fn get_plc_token_from_db(did: &str) -> Option { + let db_url = get_db_connection_string().await; + let pool = PgPool::connect(&db_url).await.ok()?; + + sqlx::query_scalar!( + r#" + SELECT t.token + FROM plc_operation_tokens t + JOIN users u ON t.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_optional(&pool) + .await + .ok()? +} + +async fn get_user_handle(did: &str) -> Option { + let db_url = get_db_connection_string().await; + let pool = PgPool::connect(&db_url).await.ok()?; + + sqlx::query_scalar!( + r#"SELECT handle FROM users WHERE did = $1"#, + did + ) + .fetch_optional(&pool) + .await + .ok()? +} + +fn create_mock_last_op( + _did: &str, + handle: &str, + signing_key: &SigningKey, + pds_endpoint: &str, +) -> Value { + let did_key = signing_key_to_did_key(signing_key); + + json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": { + "atproto": did_key + }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": pds_endpoint + } + }, + "prev": null, + "sig": "mock_signature_for_testing" + }) +} + +fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value { + let multikey = get_multikey_from_signing_key(signing_key); + + json!({ + "@context": [ + "https://www.w3.org/ns/did/v1", + "https://w3id.org/security/multikey/v1" + ], + "id": did, + "alsoKnownAs": [format!("at://{}", handle)], + "verificationMethod": [{ + "id": format!("{}#atproto", did), + "type": "Multikey", + "controller": did, + "publicKeyMultibase": multikey + }], + "service": [{ + "id": "#atproto_pds", + "type": "AtprotoPersonalDataServer", + "serviceEndpoint": pds_endpoint + }] + }) +} + +async fn setup_mock_plc_for_sign( + did: &str, + handle: &str, + signing_key: &SigningKey, + pds_endpoint: &str, +) -> MockServer { + let mock_server = MockServer::start().await; + + let did_encoded = urlencoding::encode(did); + let last_op = create_mock_last_op(did, handle, signing_key, pds_endpoint); + + Mock::given(method("GET")) + .and(path(format!("/{}/log/last", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(last_op)) + .mount(&mock_server) + .await; + + mock_server +} + +async fn setup_mock_plc_for_submit( + did: &str, + handle: &str, + signing_key: &SigningKey, + pds_endpoint: &str, +) -> MockServer { + let mock_server = MockServer::start().await; + + let did_encoded = urlencoding::encode(did); + let did_doc = create_did_document(did, handle, signing_key, pds_endpoint); + + Mock::given(method("GET")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(did_doc.clone())) + .mount(&mock_server) + .await; + + Mock::given(method("POST")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200)) + .mount(&mock_server) + .await; + + mock_server +} + +#[tokio::test] +async fn test_full_plc_operation_flow() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let request_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + + assert_eq!(request_res.status(), StatusCode::OK); + + let plc_token = get_plc_token_from_db(&did).await + .expect("PLC token not found in database"); + + let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + } + + let sign_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "token": plc_token + })) + .send() + .await + .expect("Sign request failed"); + + let sign_status = sign_res.status(); + let sign_body: Value = sign_res.json().await.unwrap_or(json!({})); + + assert_eq!( + sign_status, + StatusCode::OK, + "Sign PLC operation should succeed. Response: {:?}", + sign_body + ); + + let operation = sign_body.get("operation") + .expect("Response should contain operation"); + + assert!(operation.get("sig").is_some(), "Operation should be signed"); + assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation")); + assert!(operation.get("prev").is_some(), "Operation should have prev reference"); +} + +#[tokio::test] +async fn test_sign_plc_operation_consumes_token() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let request_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + + assert_eq!(request_res.status(), StatusCode::OK); + + let plc_token = get_plc_token_from_db(&did).await + .expect("PLC token not found in database"); + + let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + } + + let sign_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "token": plc_token + })) + .send() + .await + .expect("Sign request failed"); + + assert_eq!(sign_res.status(), StatusCode::OK); + + let sign_res_2 = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "token": plc_token + })) + .send() + .await + .expect("Second sign request failed"); + + assert_eq!( + sign_res_2.status(), + StatusCode::BAD_REQUEST, + "Using the same token twice should fail" + ); + + let body: Value = sign_res_2.json().await.unwrap(); + assert!( + body["error"] == "InvalidToken" || body["error"] == "ExpiredToken", + "Error should indicate invalid/expired token" + ); +} + +#[tokio::test] +async fn test_sign_plc_operation_with_custom_fields() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let request_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + + assert_eq!(request_res.status(), StatusCode::OK); + + let plc_token = get_plc_token_from_db(&did).await + .expect("PLC token not found in database"); + + let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + } + + let did_key = signing_key_to_did_key(&signing_key); + + let sign_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "token": plc_token, + "alsoKnownAs": [format!("at://{}", handle), "at://custom.alias.example"], + "rotationKeys": [did_key.clone(), "did:key:zExtraRotationKey123"] + })) + .send() + .await + .expect("Sign request failed"); + + let sign_status = sign_res.status(); + let sign_body: Value = sign_res.json().await.unwrap_or(json!({})); + + assert_eq!( + sign_status, + StatusCode::OK, + "Sign with custom fields should succeed. Response: {:?}", + sign_body + ); + + let operation = sign_body.get("operation").expect("Should have operation"); + let also_known_as = operation.get("alsoKnownAs").and_then(|v| v.as_array()); + let rotation_keys = operation.get("rotationKeys").and_then(|v| v.as_array()); + + assert!(also_known_as.is_some(), "Should have alsoKnownAs"); + assert!(rotation_keys.is_some(), "Should have rotationKeys"); + assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases"); + assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys"); +} + +#[tokio::test] +async fn test_submit_plc_operation_success() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let mock_plc = setup_mock_plc_for_submit(&did, &handle, &signing_key, &pds_endpoint).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + } + + let did_key = signing_key_to_did_key(&signing_key); + + let operation = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": { + "atproto": did_key.clone() + }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": pds_endpoint + } + }, + "prev": "bafyreiabc123", + "sig": "test_signature_base64" + }); + + let submit_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ "operation": operation })) + .send() + .await + .expect("Submit request failed"); + + let submit_status = submit_res.status(); + let submit_body: Value = submit_res.json().await.unwrap_or(json!({})); + + assert_eq!( + submit_status, + StatusCode::OK, + "Submit PLC operation should succeed. Response: {:?}", + submit_body + ); +} + +#[tokio::test] +async fn test_submit_plc_operation_wrong_endpoint_rejected() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let mock_plc = setup_mock_plc_for_submit(&did, &handle, &signing_key, &pds_endpoint).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + } + + let did_key = signing_key_to_did_key(&signing_key); + + let operation = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": { + "atproto": did_key.clone() + }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": "https://wrong-pds.example.com" + } + }, + "prev": "bafyreiabc123", + "sig": "test_signature_base64" + }); + + let submit_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ "operation": operation })) + .send() + .await + .expect("Submit request failed"); + + assert_eq!( + submit_res.status(), + StatusCode::BAD_REQUEST, + "Submit with wrong endpoint should fail" + ); + + let body: Value = submit_res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_submit_plc_operation_wrong_signing_key_rejected() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let mock_plc = setup_mock_plc_for_submit(&did, &handle, &signing_key, &pds_endpoint).await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri()); + } + + let wrong_key = SigningKey::random(&mut rand::thread_rng()); + let wrong_did_key = signing_key_to_did_key(&wrong_key); + let correct_did_key = signing_key_to_did_key(&signing_key); + + let operation = json!({ + "type": "plc_operation", + "rotationKeys": [correct_did_key.clone()], + "verificationMethods": { + "atproto": wrong_did_key + }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": pds_endpoint + } + }, + "prev": "bafyreiabc123", + "sig": "test_signature_base64" + }); + + let submit_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ "operation": operation })) + .send() + .await + .expect("Submit request failed"); + + assert_eq!( + submit_res.status(), + StatusCode::BAD_REQUEST, + "Submit with wrong signing key should fail" + ); + + let body: Value = submit_res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_full_sign_and_submit_flow() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let request_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + assert_eq!(request_res.status(), StatusCode::OK); + + let plc_token = get_plc_token_from_db(&did).await + .expect("PLC token not found"); + + let mock_server = MockServer::start().await; + let did_encoded = urlencoding::encode(&did); + let did_key = signing_key_to_did_key(&signing_key); + + let last_op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": { + "atproto": did_key.clone() + }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": pds_endpoint.clone() + } + }, + "prev": null, + "sig": "initial_sig" + }); + + Mock::given(method("GET")) + .and(path(format!("/{}/log/last", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(last_op)) + .mount(&mock_server) + .await; + + let did_doc = create_did_document(&did, &handle, &signing_key, &pds_endpoint); + Mock::given(method("GET")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(did_doc)) + .mount(&mock_server) + .await; + + Mock::given(method("POST")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&mock_server) + .await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_server.uri()); + } + + let sign_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ "token": plc_token })) + .send() + .await + .expect("Sign failed"); + + assert_eq!(sign_res.status(), StatusCode::OK); + + let sign_body: Value = sign_res.json().await.unwrap(); + let signed_operation = sign_body.get("operation") + .expect("Response should contain operation") + .clone(); + + assert!(signed_operation.get("sig").is_some()); + assert!(signed_operation.get("prev").is_some()); + + let submit_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ "operation": signed_operation })) + .send() + .await + .expect("Submit failed"); + + let submit_status = submit_res.status(); + let submit_body: Value = submit_res.json().await.unwrap_or(json!({})); + + assert_eq!( + submit_status, + StatusCode::OK, + "Full sign and submit flow should succeed. Response: {:?}", + submit_body + ); +} + +#[tokio::test] +async fn test_cross_pds_migration_with_records() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let post_payload = json!({ + "repo": did, + "collection": "app.bsky.feed.post", + "record": { + "$type": "app.bsky.feed.post", + "text": "Test post before migration", + "createdAt": chrono::Utc::now().to_rfc3339(), + } + }); + + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) + .bearer_auth(&token) + .json(&post_payload) + .send() + .await + .expect("Failed to create post"); + assert_eq!(create_res.status(), StatusCode::OK); + + let create_body: Value = create_res.json().await.unwrap(); + let original_uri = create_body["uri"].as_str().unwrap().to_string(); + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did + )) + .send() + .await + .expect("Export failed"); + assert_eq!(export_res.status(), StatusCode::OK); + let car_bytes = export_res.bytes().await.unwrap(); + + assert!(car_bytes.len() > 100, "CAR file should have meaningful content"); + + let mock_server = MockServer::start().await; + let did_encoded = urlencoding::encode(&did); + let did_doc = create_did_document(&did, &handle, &signing_key, &pds_endpoint); + + Mock::given(method("GET")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(did_doc)) + .mount(&mock_server) + .await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_server.uri()); + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Import failed"); + + let import_status = import_res.status(); + let import_body: Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + import_status, + StatusCode::OK, + "Import with valid DID document should succeed. Response: {:?}", + import_body + ); + + let get_record_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection=app.bsky.feed.post&rkey={}", + base_url().await, + did, + original_uri.split('/').last().unwrap() + )) + .send() + .await + .expect("Get record failed"); + + assert_eq!( + get_record_res.status(), + StatusCode::OK, + "Record should be retrievable after import" + ); + + let record_body: Value = get_record_res.json().await.unwrap(); + assert_eq!( + record_body["value"]["text"], + "Test post before migration", + "Record content should match" + ); +} + +#[tokio::test] +async fn test_migration_rejects_wrong_did_document() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let wrong_signing_key = SigningKey::random(&mut rand::thread_rng()); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did + )) + .send() + .await + .expect("Export failed"); + assert_eq!(export_res.status(), StatusCode::OK); + let car_bytes = export_res.bytes().await.unwrap(); + + let mock_server = MockServer::start().await; + let did_encoded = urlencoding::encode(&did); + let wrong_did_doc = create_did_document(&did, &handle, &wrong_signing_key, &pds_endpoint); + + Mock::given(method("GET")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(wrong_did_doc)) + .mount(&mock_server) + .await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_server.uri()); + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Import failed"); + + let import_status = import_res.status(); + let import_body: Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + import_status, + StatusCode::BAD_REQUEST, + "Import with wrong DID document should fail. Response: {:?}", + import_body + ); + + assert!( + import_body["error"] == "InvalidSignature" || + import_body["message"].as_str().unwrap_or("").contains("signature"), + "Error should mention signature verification failure" + ); +} + +#[tokio::test] +async fn test_full_migration_flow_end_to_end() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let key_bytes = get_user_signing_key(&did).await + .expect("Failed to get user signing key"); + let signing_key = SigningKey::from_slice(&key_bytes) + .expect("Failed to create signing key"); + + let handle = get_user_handle(&did).await + .expect("Failed to get user handle"); + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + let pds_endpoint = format!("https://{}", hostname); + let did_key = signing_key_to_did_key(&signing_key); + + for i in 0..3 { + let post_payload = json!({ + "repo": did, + "collection": "app.bsky.feed.post", + "record": { + "$type": "app.bsky.feed.post", + "text": format!("Pre-migration post #{}", i), + "createdAt": chrono::Utc::now().to_rfc3339(), + } + }); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) + .bearer_auth(&token) + .json(&post_payload) + .send() + .await + .expect("Failed to create post"); + assert_eq!(res.status(), StatusCode::OK); + } + + let request_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + assert_eq!(request_res.status(), StatusCode::OK); + + let plc_token = get_plc_token_from_db(&did).await + .expect("PLC token not found"); + + let mock_server = MockServer::start().await; + let did_encoded = urlencoding::encode(&did); + + let last_op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": { "atproto": did_key.clone() }, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": pds_endpoint.clone() + } + }, + "prev": null, + "sig": "initial_sig" + }); + + Mock::given(method("GET")) + .and(path(format!("/{}/log/last", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(last_op)) + .mount(&mock_server) + .await; + + let did_doc = create_did_document(&did, &handle, &signing_key, &pds_endpoint); + Mock::given(method("GET")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200).set_body_json(did_doc)) + .mount(&mock_server) + .await; + + Mock::given(method("POST")) + .and(path(format!("/{}", did_encoded))) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&mock_server) + .await; + + unsafe { + std::env::set_var("PLC_DIRECTORY_URL", mock_server.uri()); + } + + let sign_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ "token": plc_token })) + .send() + .await + .expect("Sign failed"); + assert_eq!(sign_res.status(), StatusCode::OK); + + let sign_body: Value = sign_res.json().await.unwrap(); + let signed_op = sign_body.get("operation").unwrap().clone(); + + let export_res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo?did={}", + base_url().await, + did + )) + .send() + .await + .expect("Export failed"); + assert_eq!(export_res.status(), StatusCode::OK); + let car_bytes = export_res.bytes().await.unwrap(); + + let submit_res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ "operation": signed_op })) + .send() + .await + .expect("Submit failed"); + assert_eq!(submit_res.status(), StatusCode::OK); + + unsafe { + std::env::remove_var("SKIP_IMPORT_VERIFICATION"); + } + + let import_res = client + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) + .bearer_auth(&token) + .header("Content-Type", "application/vnd.ipld.car") + .body(car_bytes.to_vec()) + .send() + .await + .expect("Import failed"); + + let import_status = import_res.status(); + let import_body: Value = import_res.json().await.unwrap_or(json!({})); + + unsafe { + std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); + } + + assert_eq!( + import_status, + StatusCode::OK, + "Full migration flow should succeed. Response: {:?}", + import_body + ); + + let list_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords?repo={}&collection=app.bsky.feed.post", + base_url().await, + did + )) + .send() + .await + .expect("List failed"); + assert_eq!(list_res.status(), StatusCode::OK); + + let list_body: Value = list_res.json().await.unwrap(); + let records = list_body["records"].as_array() + .expect("Should have records array"); + + assert!( + records.len() >= 1, + "Should have at least 1 record after migration, found {}", + records.len() + ); +} diff --git a/tests/plc_operations.rs b/tests/plc_operations.rs new file mode 100644 index 0000000..d01a268 --- /dev/null +++ b/tests/plc_operations.rs @@ -0,0 +1,491 @@ +mod common; +use common::*; + +use reqwest::StatusCode; +use serde_json::json; +use sqlx::PgPool; + +#[tokio::test] +async fn test_request_plc_operation_signature_requires_auth() { + let client = client(); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_request_plc_operation_signature_success() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_sign_plc_operation_requires_auth() { + let client = client(); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .json(&json!({})) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_sign_plc_operation_requires_token() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({})) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_sign_plc_operation_invalid_token() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.signPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "token": "invalid-token-12345" + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); +} + +#[tokio::test] +async fn test_submit_plc_operation_requires_auth() { + let client = client(); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .json(&json!({ + "operation": {} + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_submit_plc_operation_invalid_operation() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "operation": { + "type": "invalid_type" + } + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_submit_plc_operation_missing_sig() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "operation": { + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "prev": null + } + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_submit_plc_operation_wrong_service_endpoint() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "operation": { + "type": "plc_operation", + "rotationKeys": ["did:key:z123"], + "verificationMethods": {"atproto": "did:key:z456"}, + "alsoKnownAs": ["at://wrong.handle"], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": "https://wrong.example.com" + } + }, + "prev": null, + "sig": "fake_signature" + } + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn test_request_plc_operation_creates_token_in_db() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::OK); + + let db_url = get_db_connection_string().await; + let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); + + let row = sqlx::query!( + r#" + SELECT t.token, t.expires_at + FROM plc_operation_tokens t + JOIN users u ON t.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_optional(&pool) + .await + .expect("Query failed"); + + assert!(row.is_some(), "PLC token should be created in database"); + let row = row.unwrap(); + assert!(row.token.len() == 11, "Token should be in format xxxxx-xxxxx"); + assert!(row.token.contains('-'), "Token should contain hyphen"); + assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); +} + +#[tokio::test] +async fn test_request_plc_operation_replaces_existing_token() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let res1 = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request 1 failed"); + assert_eq!(res1.status(), StatusCode::OK); + + let db_url = get_db_connection_string().await; + let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); + + let token1 = sqlx::query_scalar!( + r#" + SELECT t.token + FROM plc_operation_tokens t + JOIN users u ON t.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_one(&pool) + .await + .expect("Query failed"); + + let res2 = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request 2 failed"); + assert_eq!(res2.status(), StatusCode::OK); + + let token2 = sqlx::query_scalar!( + r#" + SELECT t.token + FROM plc_operation_tokens t + JOIN users u ON t.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_one(&pool) + .await + .expect("Query failed"); + + assert_ne!(token1, token2, "Second request should generate a new token"); + + let count: i64 = sqlx::query_scalar!( + r#" + SELECT COUNT(*) as "count!" + FROM plc_operation_tokens t + JOIN users u ON t.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_one(&pool) + .await + .expect("Count query failed"); + + assert_eq!(count, 1, "Should only have one token per user"); +} + +#[tokio::test] +async fn test_submit_plc_operation_wrong_verification_method() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| { + format!("127.0.0.1:{}", app_port()) + }); + + let handle = did.split(':').last().unwrap_or("user"); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "operation": { + "type": "plc_operation", + "rotationKeys": ["did:key:zWrongRotationKey123"], + "verificationMethods": {"atproto": "did:key:zWrongVerificationKey456"}, + "alsoKnownAs": [format!("at://{}", handle)], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": format!("https://{}", hostname) + } + }, + "prev": null, + "sig": "fake_signature" + } + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); + assert!( + body["message"].as_str().unwrap_or("").contains("signing key") || + body["message"].as_str().unwrap_or("").contains("rotation"), + "Error should mention key mismatch: {:?}", + body + ); +} + +#[tokio::test] +async fn test_submit_plc_operation_wrong_handle() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| { + format!("127.0.0.1:{}", app_port()) + }); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "operation": { + "type": "plc_operation", + "rotationKeys": ["did:key:z123"], + "verificationMethods": {"atproto": "did:key:z456"}, + "alsoKnownAs": ["at://totally.wrong.handle"], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": format!("https://{}", hostname) + } + }, + "prev": null, + "sig": "fake_signature" + } + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_submit_plc_operation_wrong_service_type() { + let client = client(); + let (token, _did) = create_account_and_login(&client).await; + + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| { + format!("127.0.0.1:{}", app_port()) + }); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.submitPlcOperation", + base_url().await + )) + .bearer_auth(&token) + .json(&json!({ + "operation": { + "type": "plc_operation", + "rotationKeys": ["did:key:z123"], + "verificationMethods": {"atproto": "did:key:z456"}, + "alsoKnownAs": ["at://user"], + "services": { + "atproto_pds": { + "type": "WrongServiceType", + "endpoint": format!("https://{}", hostname) + } + }, + "prev": null, + "sig": "fake_signature" + } + })) + .send() + .await + .expect("Request failed"); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + let body: serde_json::Value = res.json().await.unwrap(); + assert_eq!(body["error"], "InvalidRequest"); +} + +#[tokio::test] +async fn test_plc_token_expiry_format() { + let client = client(); + let (token, did) = create_account_and_login(&client).await; + + let res = client + .post(format!( + "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", + base_url().await + )) + .bearer_auth(&token) + .send() + .await + .expect("Request failed"); + assert_eq!(res.status(), StatusCode::OK); + + let db_url = get_db_connection_string().await; + let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); + + let row = sqlx::query!( + r#" + SELECT t.expires_at + FROM plc_operation_tokens t + JOIN users u ON t.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_one(&pool) + .await + .expect("Query failed"); + + let now = chrono::Utc::now(); + let expires = row.expires_at; + + let diff = expires - now; + assert!(diff.num_minutes() >= 9, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes()); + assert!(diff.num_minutes() <= 11, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes()); +}