First version of pds migration

This commit is contained in:
lewis
2025-12-11 17:10:19 +02:00
parent 2eb67eb688
commit ea7837fcec
40 changed files with 5332 additions and 37 deletions

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text"
]
},
"nullable": []
},
"hash": "076cbf7f32c5f0103207a8e0e73dd5768681ff2520682edda8f2977dcae7cd62"
}

View File

@@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "repo_root_cid",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false
]
},
"hash": "1ed53dde97706d6da36a49d2a8d39f14da4a8dbfe54c9f1ee70c970adde80be8"
}

View File

@@ -35,7 +35,8 @@
"password_reset",
"email_update",
"account_deletion",
"admin_email"
"admin_email",
"plc_operation"
]
}
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "DELETE FROM plc_operation_tokens WHERE id = $1",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "402ecd9f1531f5756dd6873f7f4d59b4bf2113f69d493cde07f4a861a8b3567c"
}

View File

@@ -0,0 +1,17 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO records (repo_id, collection, rkey, record_cid)\n VALUES ($1, $2, $3, $4)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text"
]
},
"nullable": []
},
"hash": "5d1f9275037dd0cb03cefe1e4bbbf7dfaeecb1cc8469b4f0836fe5e52e046839"
}

View File

@@ -35,7 +35,8 @@
"password_reset",
"email_update",
"account_deletion",
"admin_email"
"admin_email",
"plc_operation"
]
}
}

View File

@@ -0,0 +1,29 @@
{
"db_name": "PostgreSQL",
"query": "SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "expires_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": [
false,
false
]
},
"hash": "84e5abf0f7fab44731b1d69658e99018936f8a346bbff91b23a7731b973633cc"
}

View File

@@ -0,0 +1,19 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids)\n VALUES ($1, 'commit', $2, $3, $4, $5, $6)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text",
"Text",
"Text",
"Jsonb",
"TextArray",
"TextArray"
]
},
"nullable": []
},
"hash": "aadc1f8c79d79e9a32fe6f4bf7e901076532fa2bf8f0b4d0f1bae7aa0f792183"
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "ac8c260666ab6d1e7103e08e15bc1341694fb453a65c26b4f0bfb07d9b74ebd4"
}

View File

@@ -0,0 +1,34 @@
{
"db_name": "PostgreSQL",
"query": "SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 2,
"name": "takedown_ref",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true,
true
]
},
"hash": "c47715c259bb7b56b576d9719f8facb87a9e9b6b530ca6f81ce308a4c584c002"
}

View File

@@ -43,7 +43,8 @@
"password_reset",
"email_update",
"account_deletion",
"admin_email"
"admin_email",
"plc_operation"
]
}
}

View File

@@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO plc_operation_tokens (user_id, token, expires_at)\n VALUES ($1, $2, $3)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Timestamptz"
]
},
"nullable": []
},
"hash": "d981225224ea8e4db25c53566032c8ac81335d05ff5b91cfb20da805e735aea3"
}

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "UPDATE repos SET repo_root_cid = $1, updated_at = NOW() WHERE user_id = $2",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text",
"Uuid"
]
},
"nullable": []
},
"hash": "f1e88d447915b116f887c378253388654a783bddb111b1f9aa04507f176980d3"
}

1
Cargo.lock generated
View File

@@ -917,6 +917,7 @@ dependencies = [
"futures",
"hkdf",
"hmac",
"ipld-core",
"iroh-car",
"jacquard",
"jacquard-axum",

View File

@@ -31,6 +31,7 @@ reqwest = { version = "0.12.24", features = ["json"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_bytes = "0.11.14"
serde_ipld_dagcbor = "0.6.4"
ipld-core = "0.4.2"
serde_json = "1.0.145"
sha2 = "0.10.9"
subtle = "2.5"
@@ -45,13 +46,13 @@ tracing-subscriber = "0.3.22"
tokio-tungstenite = { version = "0.28.0", features = ["native-tls"] }
urlencoding = "2.1"
uuid = { version = "1.19.0", features = ["v4", "fast-rng"] }
iroh-car = "0.5.1"
[features]
external-infra = []
[dev-dependencies]
ctor = "0.6.3"
iroh-car = "0.5.1"
testcontainers = "0.26.0"
testcontainers-modules = { version = "0.14.0", features = ["postgres"] }
wiremock = "0.6.5"

View File

@@ -56,7 +56,7 @@ Lewis' corrected big boy todofile
- [x] Implement `com.atproto.repo.listRecords`.
- [x] Implement `com.atproto.repo.describeRepo`.
- [x] Implement `com.atproto.repo.applyWrites` (Batch writes).
- [ ] Implement `com.atproto.repo.importRepo` (Migration).
- [x] Implement `com.atproto.repo.importRepo` (Migration).
- [x] Implement `com.atproto.repo.listMissingBlobs`.
- [x] Blob Management
- [x] Implement `com.atproto.repo.uploadBlob`.
@@ -83,10 +83,10 @@ Lewis' corrected big boy todofile
- [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us).
## Identity (`com.atproto.identity`)
- [ ] Resolution
- [x] Resolution
- [x] Implement `com.atproto.identity.resolveHandle` (Can be internal or proxy to PLC).
- [x] Implement `com.atproto.identity.updateHandle`.
- [ ] Implement `com.atproto.identity.submitPlcOperation` / `signPlcOperation` / `requestPlcOperationSignature`.
- [x] Implement `com.atproto.identity.submitPlcOperation` / `signPlcOperation` / `requestPlcOperationSignature`.
- [x] Implement `com.atproto.identity.getRecommendedDidCredentials`.
- [x] Implement `/.well-known/did.json` (Depends on supporting did:web).

View File

@@ -0,0 +1,10 @@
CREATE TABLE plc_operation_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_plc_op_tokens_user ON plc_operation_tokens(user_id);
CREATE INDEX idx_plc_op_tokens_expires ON plc_operation_tokens(expires_at);

View File

@@ -0,0 +1 @@
ALTER TYPE notification_type ADD VALUE 'plc_operation';

View File

@@ -96,6 +96,7 @@ export AWS_SECRET_ACCESS_KEY="minioadmin"
export AWS_REGION="us-east-1"
export BSPDS_TEST_INFRA_READY="1"
export BSPDS_ALLOW_INSECURE_SECRETS="1"
export SKIP_IMPORT_VERIFICATION="true"
EOF
echo ""

View File

@@ -9,7 +9,7 @@ use axum::{
use bcrypt::{DEFAULT_COST, hash};
use jacquard::types::{did::Did, integer::LimitedU32, string::Tid};
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
use k256::SecretKey;
use k256::{ecdsa::SigningKey, SecretKey};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -302,9 +302,33 @@ pub async fn create_account(
let rev = Tid::now(LimitedU32::MIN);
let commit = Commit::new_unsigned(did_obj, mst_root, rev, None);
let unsigned_commit = Commit::new_unsigned(did_obj, mst_root, rev, None);
let commit_bytes = match commit.to_cbor() {
let signing_key = match SigningKey::from_slice(&secret_key_bytes) {
Ok(k) => k,
Err(e) => {
error!("Error creating signing key: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let signed_commit = match unsigned_commit.sign(&signing_key) {
Ok(c) => c,
Err(e) => {
error!("Error signing genesis commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let commit_bytes = match signed_commit.to_cbor() {
Ok(b) => b,
Err(e) => {
error!("Error serializing genesis commit: {:?}", e);

View File

@@ -1,7 +1,9 @@
pub mod account;
pub mod did;
pub mod plc;
pub use account::create_account;
pub use did::{
get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, well_known_did,
};
pub use plc::{request_plc_operation_signature, sign_plc_operation, submit_plc_operation};

618
src/api/identity/plc.rs Normal file
View File

@@ -0,0 +1,618 @@
use crate::plc::{
create_update_op, sign_operation, signing_key_to_did_key, validate_plc_operation,
PlcClient, PlcError, PlcService,
};
use crate::state::AppState;
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use chrono::{Duration, Utc};
use k256::ecdsa::SigningKey;
use rand::Rng;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use tracing::{error, info, warn};
fn generate_plc_token() -> String {
let mut rng = rand::thread_rng();
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
format!("{}-{}", part1, part2)
}
pub async fn request_plc_operation_signature(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> Response {
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => user,
Err(e) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": e})),
)
.into_response();
}
};
let did = &auth_user.did;
let user = match sqlx::query!(
"SELECT id FROM users WHERE did = $1",
did
)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "AccountNotFound"})),
)
.into_response();
}
Err(e) => {
error!("DB error: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let _ = sqlx::query!(
"DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()",
user.id
)
.execute(&state.db)
.await;
let plc_token = generate_plc_token();
let expires_at = Utc::now() + Duration::minutes(10);
if let Err(e) = sqlx::query!(
r#"
INSERT INTO plc_operation_tokens (user_id, token, expires_at)
VALUES ($1, $2, $3)
"#,
user.id,
plc_token,
expires_at
)
.execute(&state.db)
.await
{
error!("Failed to create PLC token: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
if let Err(e) = crate::notifications::enqueue_plc_operation(
&state.db,
user.id,
&plc_token,
&hostname,
)
.await
{
warn!("Failed to enqueue PLC operation notification: {:?}", e);
}
info!("PLC operation signature requested for user {}", did);
(StatusCode::OK, Json(json!({}))).into_response()
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SignPlcOperationInput {
pub token: Option<String>,
pub rotation_keys: Option<Vec<String>>,
pub also_known_as: Option<Vec<String>>,
pub verification_methods: Option<HashMap<String, String>>,
pub services: Option<HashMap<String, ServiceInput>>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ServiceInput {
#[serde(rename = "type")]
pub service_type: String,
pub endpoint: String,
}
#[derive(Debug, Serialize)]
pub struct SignPlcOperationOutput {
pub operation: Value,
}
pub async fn sign_plc_operation(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<SignPlcOperationInput>,
) -> Response {
let bearer = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
Ok(user) => user,
Err(e) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": e})),
)
.into_response();
}
};
let did = &auth_user.did;
let token = match &input.token {
Some(t) => t,
None => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Email confirmation token required to sign PLC operations"
})),
)
.into_response();
}
};
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
_ => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "AccountNotFound"})),
)
.into_response();
}
};
let token_row = match sqlx::query!(
"SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2",
user.id,
token
)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
Ok(None) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidToken",
"message": "Invalid or expired token"
})),
)
.into_response();
}
Err(e) => {
error!("DB error: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
if Utc::now() > token_row.expires_at {
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
.execute(&state.db)
.await;
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "ExpiredToken",
"message": "Token has expired"
})),
)
.into_response();
}
let key_row = match sqlx::query!(
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
user.id
)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
)
.into_response();
}
};
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
{
Ok(k) => k,
Err(e) => {
error!("Failed to decrypt user key: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let signing_key = match SigningKey::from_slice(&key_bytes) {
Ok(k) => k,
Err(e) => {
error!("Failed to create signing key: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let plc_client = PlcClient::new(None);
let last_op = match plc_client.get_last_op(did).await {
Ok(op) => op,
Err(PlcError::NotFound) => {
return (
StatusCode::NOT_FOUND,
Json(json!({
"error": "NotFound",
"message": "DID not found in PLC directory"
})),
)
.into_response();
}
Err(e) => {
error!("Failed to fetch PLC operation: {:?}", e);
return (
StatusCode::BAD_GATEWAY,
Json(json!({
"error": "UpstreamError",
"message": "Failed to communicate with PLC directory"
})),
)
.into_response();
}
};
if last_op.is_tombstone() {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "DID is tombstoned"
})),
)
.into_response();
}
let services = input.services.map(|s| {
s.into_iter()
.map(|(k, v)| {
(
k,
PlcService {
service_type: v.service_type,
endpoint: v.endpoint,
},
)
})
.collect()
});
let unsigned_op = match create_update_op(
&last_op,
input.rotation_keys,
input.verification_methods,
input.also_known_as,
services,
) {
Ok(op) => op,
Err(PlcError::Tombstoned) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Cannot update tombstoned DID"
})),
)
.into_response();
}
Err(e) => {
error!("Failed to create PLC operation: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let signed_op = match sign_operation(&unsigned_op, &signing_key) {
Ok(op) => op,
Err(e) => {
error!("Failed to sign PLC operation: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
.execute(&state.db)
.await;
info!("Signed PLC operation for user {}", did);
(
StatusCode::OK,
Json(SignPlcOperationOutput {
operation: signed_op,
}),
)
.into_response()
}
#[derive(Debug, Deserialize)]
pub struct SubmitPlcOperationInput {
pub operation: Value,
}
pub async fn submit_plc_operation(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<SubmitPlcOperationInput>,
) -> Response {
let bearer = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
Ok(user) => user,
Err(e) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": e})),
)
.into_response();
}
};
let did = &auth_user.did;
if let Err(e) = validate_plc_operation(&input.operation) {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Invalid operation: {}", e)
})),
)
.into_response();
}
let op = &input.operation;
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let public_url = format!("https://{}", hostname);
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
_ => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "AccountNotFound"})),
)
.into_response();
}
};
let key_row = match sqlx::query!(
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
user.id
)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
)
.into_response();
}
};
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
{
Ok(k) => k,
Err(e) => {
error!("Failed to decrypt user key: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let signing_key = match SigningKey::from_slice(&key_bytes) {
Ok(k) => k,
Err(e) => {
error!("Failed to create signing key: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let user_did_key = signing_key_to_did_key(&signing_key);
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
let server_rotation_key =
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
let has_server_key = rotation_keys
.iter()
.any(|k| k.as_str() == Some(&server_rotation_key));
if !has_server_key {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Rotation keys do not include server's rotation key"
})),
)
.into_response();
}
}
if let Some(services) = op.get("services").and_then(|v| v.as_object()) {
if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
let service_type = pds.get("type").and_then(|v| v.as_str());
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
if service_type != Some("AtprotoPersonalDataServer") {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Incorrect type on atproto_pds service"
})),
)
.into_response();
}
if endpoint != Some(&public_url) {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Incorrect endpoint on atproto_pds service"
})),
)
.into_response();
}
}
}
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) {
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
if atproto_key != user_did_key {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Incorrect signing key in verificationMethods"
})),
)
.into_response();
}
}
}
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
let expected_handle = format!("at://{}", user.handle);
let first_aka = also_known_as.first().and_then(|v| v.as_str());
if first_aka != Some(&expected_handle) {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Incorrect handle in alsoKnownAs"
})),
)
.into_response();
}
}
let plc_client = PlcClient::new(None);
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
error!("Failed to submit PLC operation: {:?}", e);
return (
StatusCode::BAD_GATEWAY,
Json(json!({
"error": "UpstreamError",
"message": format!("Failed to submit to PLC directory: {}", e)
})),
)
.into_response();
}
if let Err(e) = sqlx::query!(
"INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')",
did
)
.execute(&state.db)
.await
{
warn!("Failed to sequence identity event: {:?}", e);
}
info!("Submitted PLC operation for user {}", did);
(StatusCode::OK, Json(json!({}))).into_response()
}

420
src/api/repo/import.rs Normal file
View File

@@ -0,0 +1,420 @@
use crate::state::AppState;
use crate::sync::import::{apply_import, parse_car, ImportError};
use crate::sync::verify::CarVerifier;
use axum::{
body::Bytes,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use tracing::{debug, error, info, warn};
const DEFAULT_MAX_IMPORT_SIZE: usize = 100 * 1024 * 1024;
const DEFAULT_MAX_BLOCKS: usize = 50000;
pub async fn import_repo(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
body: Bytes,
) -> Response {
let accepting_imports = std::env::var("ACCEPTING_REPO_IMPORTS")
.map(|v| v != "false" && v != "0")
.unwrap_or(true);
if !accepting_imports {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Service is not accepting repo imports"
})),
)
.into_response();
}
let max_size: usize = std::env::var("MAX_IMPORT_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_MAX_IMPORT_SIZE);
if body.len() > max_size {
return (
StatusCode::PAYLOAD_TOO_LARGE,
Json(json!({
"error": "InvalidRequest",
"message": format!("Import size exceeds limit of {} bytes", max_size)
})),
)
.into_response();
}
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
Some(t) => t,
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => user,
Err(e) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": e})),
)
.into_response();
}
};
let did = &auth_user.did;
let user = match sqlx::query!(
"SELECT id, deactivated_at, takedown_ref FROM users WHERE did = $1",
did
)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => row,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "AccountNotFound"})),
)
.into_response();
}
Err(e) => {
error!("DB error fetching user: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
if user.deactivated_at.is_some() {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "AccountDeactivated",
"message": "Account is deactivated"
})),
)
.into_response();
}
if user.takedown_ref.is_some() {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "AccountTakenDown",
"message": "Account has been taken down"
})),
)
.into_response();
}
let user_id = user.id;
let (root, blocks) = match parse_car(&body).await {
Ok((r, b)) => (r, b),
Err(ImportError::InvalidRootCount) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Expected exactly one root in CAR file"
})),
)
.into_response();
}
Err(ImportError::CarParse(msg)) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Failed to parse CAR file: {}", msg)
})),
)
.into_response();
}
Err(e) => {
error!("CAR parsing error: {:?}", e);
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Invalid CAR file: {}", e)
})),
)
.into_response();
}
};
info!(
"Importing repo for user {}: {} blocks, root {}",
did,
blocks.len(),
root
);
let root_block = match blocks.get(&root) {
Some(b) => b,
None => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Root block not found in CAR file"
})),
)
.into_response();
}
};
let commit_did = match jacquard_repo::commit::Commit::from_cbor(root_block) {
Ok(commit) => commit.did().to_string(),
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Invalid commit: {}", e)
})),
)
.into_response();
}
};
if commit_did != *did {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "InvalidRequest",
"message": format!(
"CAR file is for DID {} but you are authenticated as {}",
commit_did, did
)
})),
)
.into_response();
}
let skip_verification = std::env::var("SKIP_IMPORT_VERIFICATION")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !skip_verification {
debug!("Verifying CAR file signature and structure for DID {}", did);
let verifier = CarVerifier::new();
match verifier.verify_car(did, &root, &blocks).await {
Ok(verified) => {
debug!(
"CAR verification successful: rev={}, data_cid={}",
verified.rev, verified.data_cid
);
}
Err(crate::sync::verify::VerifyError::DidMismatch {
commit_did,
expected_did,
}) => {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "InvalidRequest",
"message": format!(
"CAR file is for DID {} but you are authenticated as {}",
commit_did, expected_did
)
})),
)
.into_response();
}
Err(crate::sync::verify::VerifyError::InvalidSignature) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidSignature",
"message": "CAR file commit signature verification failed"
})),
)
.into_response();
}
Err(crate::sync::verify::VerifyError::DidResolutionFailed(msg)) => {
warn!("DID resolution failed during import verification: {}", msg);
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Failed to verify DID: {}", msg)
})),
)
.into_response();
}
Err(crate::sync::verify::VerifyError::NoSigningKey) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "DID document does not contain a signing key"
})),
)
.into_response();
}
Err(crate::sync::verify::VerifyError::MstValidationFailed(msg)) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("MST validation failed: {}", msg)
})),
)
.into_response();
}
Err(e) => {
error!("CAR verification error: {:?}", e);
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("CAR verification failed: {}", e)
})),
)
.into_response();
}
}
} else {
warn!("Skipping CAR signature verification for import (SKIP_IMPORT_VERIFICATION=true)");
}
let max_blocks: usize = std::env::var("MAX_IMPORT_BLOCKS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_MAX_BLOCKS);
match apply_import(&state.db, user_id, root, blocks, max_blocks).await {
Ok(records) => {
info!(
"Successfully imported {} records for user {}",
records.len(),
did
);
if let Err(e) = sequence_import_event(&state, did, &root.to_string()).await {
warn!("Failed to sequence import event: {:?}", e);
}
(StatusCode::OK, Json(json!({}))).into_response()
}
Err(ImportError::SizeLimitExceeded) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Import exceeds block limit of {}", max_blocks)
})),
)
.into_response(),
Err(ImportError::RepoNotFound) => (
StatusCode::NOT_FOUND,
Json(json!({
"error": "RepoNotFound",
"message": "Repository not initialized for this account"
})),
)
.into_response(),
Err(ImportError::InvalidCbor(msg)) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Invalid CBOR data: {}", msg)
})),
)
.into_response(),
Err(ImportError::InvalidCommit(msg)) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Invalid commit structure: {}", msg)
})),
)
.into_response(),
Err(ImportError::BlockNotFound(cid)) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("Referenced block not found in CAR: {}", cid)
})),
)
.into_response(),
Err(ImportError::ConcurrentModification) => (
StatusCode::CONFLICT,
Json(json!({
"error": "ConcurrentModification",
"message": "Repository is being modified by another operation, please retry"
})),
)
.into_response(),
Err(ImportError::VerificationFailed(ve)) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "VerificationFailed",
"message": format!("CAR verification failed: {}", ve)
})),
)
.into_response(),
Err(ImportError::DidMismatch { car_did, auth_did }) => (
StatusCode::FORBIDDEN,
Json(json!({
"error": "DidMismatch",
"message": format!("CAR is for {} but authenticated as {}", car_did, auth_did)
})),
)
.into_response(),
Err(e) => {
error!("Import error: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response()
}
}
}
async fn sequence_import_event(
state: &AppState,
did: &str,
commit_cid: &str,
) -> Result<(), sqlx::Error> {
let prev_cid: Option<String> = None;
let ops = serde_json::json!([]);
let blobs: Vec<String> = vec![];
let blocks_cids: Vec<String> = vec![];
sqlx::query!(
r#"
INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids)
VALUES ($1, 'commit', $2, $3, $4, $5, $6)
"#,
did,
commit_cid,
prev_cid,
ops,
&blobs,
&blocks_cids
)
.execute(&state.db)
.await?;
Ok(())
}

View File

@@ -1,7 +1,9 @@
pub mod blob;
pub mod import;
pub mod meta;
pub mod record;
pub use blob::{list_missing_blobs, upload_blob};
pub use import::import_repo;
pub use meta::describe_repo;
pub use record::{apply_writes, create_record, delete_record, get_record, list_records, put_record};

View File

@@ -3,6 +3,7 @@ use cid::Cid;
use jacquard::types::{did::Did, integer::LimitedU32, string::Tid};
use jacquard_repo::commit::Commit;
use jacquard_repo::storage::BlockStore;
use k256::ecdsa::SigningKey;
use serde_json::json;
use uuid::Uuid;
@@ -26,12 +27,30 @@ pub async fn commit_and_log(
ops: Vec<RecordOp>,
blocks_cids: &Vec<String>,
) -> Result<CommitResult, String> {
let key_row = sqlx::query!(
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
user_id
)
.fetch_one(&state.db)
.await
.map_err(|e| format!("Failed to fetch signing key: {}", e))?;
let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
.map_err(|e| format!("Failed to decrypt signing key: {}", e))?;
let signing_key = SigningKey::from_slice(&key_bytes)
.map_err(|e| format!("Invalid signing key: {}", e))?;
let did_obj = Did::new(did).map_err(|e| format!("Invalid DID: {}", e))?;
let rev = Tid::now(LimitedU32::MIN);
let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev.clone(), current_root_cid);
let unsigned_commit = Commit::new_unsigned(did_obj, new_mst_root, rev.clone(), current_root_cid);
let new_commit_bytes = new_commit.to_cbor().map_err(|e| format!("Failed to serialize commit: {:?}", e))?;
let signed_commit = unsigned_commit
.sign(&signing_key)
.map_err(|e| format!("Failed to sign commit: {:?}", e))?;
let new_commit_bytes = signed_commit.to_cbor().map_err(|e| format!("Failed to serialize commit: {:?}", e))?;
let new_root_cid = state.block_store.put(&new_commit_bytes).await
.map_err(|e| format!("Failed to save commit block: {:?}", e))?;

View File

@@ -3,6 +3,7 @@ pub mod auth;
pub mod config;
pub mod notifications;
pub mod oauth;
pub mod plc;
pub mod repo;
pub mod state;
pub mod storage;
@@ -194,6 +195,22 @@ pub fn app(state: AppState) -> Router {
"/xrpc/com.atproto.identity.updateHandle",
post(api::identity::update_handle),
)
.route(
"/xrpc/com.atproto.identity.requestPlcOperationSignature",
post(api::identity::request_plc_operation_signature),
)
.route(
"/xrpc/com.atproto.identity.signPlcOperation",
post(api::identity::sign_plc_operation),
)
.route(
"/xrpc/com.atproto.identity.submitPlcOperation",
post(api::identity::submit_plc_operation),
)
.route(
"/xrpc/com.atproto.repo.importRepo",
post(api::repo::import_repo),
)
.route(
"/xrpc/com.atproto.admin.deleteAccount",
post(api::admin::delete_account),

View File

@@ -5,7 +5,8 @@ mod types;
pub use sender::{EmailSender, NotificationSender};
pub use service::{
enqueue_account_deletion, enqueue_email_update, enqueue_email_verification,
enqueue_notification, enqueue_password_reset, enqueue_welcome, NotificationService,
enqueue_notification, enqueue_password_reset, enqueue_plc_operation, enqueue_welcome,
NotificationService,
};
pub use types::{
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,

View File

@@ -416,3 +416,30 @@ pub async fn enqueue_account_deletion(
)
.await
}
pub async fn enqueue_plc_operation(
db: &PgPool,
user_id: Uuid,
token: &str,
hostname: &str,
) -> Result<Uuid, sqlx::Error> {
let prefs = get_user_notification_prefs(db, user_id).await?;
let body = format!(
"Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.",
prefs.handle, token
);
enqueue_notification(
db,
NewNotification::new(
user_id,
prefs.channel,
super::types::NotificationType::PlcOperation,
prefs.email.clone(),
Some(format!("{} - PLC Operation Token", hostname)),
body,
),
)
.await
}

View File

@@ -30,6 +30,7 @@ pub enum NotificationType {
EmailUpdate,
AccountDeletion,
AdminEmail,
PlcOperation,
}
#[derive(Debug, Clone, FromRow)]

358
src/plc/mod.rs Normal file
View File

@@ -0,0 +1,358 @@
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use k256::ecdsa::{SigningKey, Signature, signature::Signer};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum PlcError {
#[error("HTTP request failed: {0}")]
Http(#[from] reqwest::Error),
#[error("Invalid response: {0}")]
InvalidResponse(String),
#[error("DID not found")]
NotFound,
#[error("DID is tombstoned")]
Tombstoned,
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Signing error: {0}")]
Signing(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlcOperation {
#[serde(rename = "type")]
pub op_type: String,
#[serde(rename = "rotationKeys")]
pub rotation_keys: Vec<String>,
#[serde(rename = "verificationMethods")]
pub verification_methods: HashMap<String, String>,
#[serde(rename = "alsoKnownAs")]
pub also_known_as: Vec<String>,
pub services: HashMap<String, PlcService>,
pub prev: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sig: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlcService {
#[serde(rename = "type")]
pub service_type: String,
pub endpoint: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlcTombstone {
#[serde(rename = "type")]
pub op_type: String,
pub prev: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub sig: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PlcOpOrTombstone {
Operation(PlcOperation),
Tombstone(PlcTombstone),
}
impl PlcOpOrTombstone {
pub fn is_tombstone(&self) -> bool {
match self {
PlcOpOrTombstone::Tombstone(_) => true,
PlcOpOrTombstone::Operation(op) => op.op_type == "plc_tombstone",
}
}
}
pub struct PlcClient {
base_url: String,
client: Client,
}
impl PlcClient {
pub fn new(base_url: Option<String>) -> Self {
let base_url = base_url.unwrap_or_else(|| {
std::env::var("PLC_DIRECTORY_URL")
.unwrap_or_else(|_| "https://plc.directory".to_string())
});
Self {
base_url,
client: Client::new(),
}
}
fn encode_did(did: &str) -> String {
urlencoding::encode(did).to_string()
}
pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> {
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
let response = self.client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(PlcError::NotFound);
}
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(PlcError::InvalidResponse(format!(
"HTTP {}: {}",
status, body
)));
}
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
}
pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> {
let url = format!("{}/{}/data", self.base_url, Self::encode_did(did));
let response = self.client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(PlcError::NotFound);
}
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(PlcError::InvalidResponse(format!(
"HTTP {}: {}",
status, body
)));
}
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
}
pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> {
let url = format!("{}/{}/log/last", self.base_url, Self::encode_did(did));
let response = self.client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(PlcError::NotFound);
}
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(PlcError::InvalidResponse(format!(
"HTTP {}: {}",
status, body
)));
}
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
}
pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> {
let url = format!("{}/{}/log/audit", self.base_url, Self::encode_did(did));
let response = self.client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(PlcError::NotFound);
}
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(PlcError::InvalidResponse(format!(
"HTTP {}: {}",
status, body
)));
}
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
}
pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> {
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
let response = self.client
.post(&url)
.json(operation)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(PlcError::InvalidResponse(format!(
"HTTP {}: {}",
status, body
)));
}
Ok(())
}
}
pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> {
let cbor_bytes = serde_ipld_dagcbor::to_vec(value)
.map_err(|e| PlcError::Serialization(e.to_string()))?;
let mut hasher = Sha256::new();
hasher.update(&cbor_bytes);
let hash = hasher.finalize();
let multihash = multihash::Multihash::wrap(0x12, &hash)
.map_err(|e| PlcError::Serialization(e.to_string()))?;
let cid = cid::Cid::new_v1(0x71, multihash);
Ok(cid.to_string())
}
pub fn sign_operation(
operation: &Value,
signing_key: &SigningKey,
) -> Result<Value, PlcError> {
let mut op = operation.clone();
if let Some(obj) = op.as_object_mut() {
obj.remove("sig");
}
let cbor_bytes = serde_ipld_dagcbor::to_vec(&op)
.map_err(|e| PlcError::Serialization(e.to_string()))?;
let signature: Signature = signing_key.sign(&cbor_bytes);
let sig_bytes = signature.to_bytes();
let sig_b64 = URL_SAFE_NO_PAD.encode(sig_bytes);
if let Some(obj) = op.as_object_mut() {
obj.insert("sig".to_string(), json!(sig_b64));
}
Ok(op)
}
pub fn create_update_op(
last_op: &PlcOpOrTombstone,
rotation_keys: Option<Vec<String>>,
verification_methods: Option<HashMap<String, String>>,
also_known_as: Option<Vec<String>>,
services: Option<HashMap<String, PlcService>>,
) -> Result<Value, PlcError> {
let prev_value = match last_op {
PlcOpOrTombstone::Operation(op) => serde_json::to_value(op)
.map_err(|e| PlcError::Serialization(e.to_string()))?,
PlcOpOrTombstone::Tombstone(t) => serde_json::to_value(t)
.map_err(|e| PlcError::Serialization(e.to_string()))?,
};
let prev_cid = cid_for_cbor(&prev_value)?;
let (base_rotation_keys, base_verification_methods, base_also_known_as, base_services) =
match last_op {
PlcOpOrTombstone::Operation(op) => (
op.rotation_keys.clone(),
op.verification_methods.clone(),
op.also_known_as.clone(),
op.services.clone(),
),
PlcOpOrTombstone::Tombstone(_) => {
return Err(PlcError::Tombstoned);
}
};
let new_op = PlcOperation {
op_type: "plc_operation".to_string(),
rotation_keys: rotation_keys.unwrap_or(base_rotation_keys),
verification_methods: verification_methods.unwrap_or(base_verification_methods),
also_known_as: also_known_as.unwrap_or(base_also_known_as),
services: services.unwrap_or(base_services),
prev: Some(prev_cid),
sig: None,
};
serde_json::to_value(new_op).map_err(|e| PlcError::Serialization(e.to_string()))
}
pub fn signing_key_to_did_key(signing_key: &SigningKey) -> String {
let verifying_key = signing_key.verifying_key();
let point = verifying_key.to_encoded_point(true);
let compressed_bytes = point.as_bytes();
let mut prefixed = vec![0xe7, 0x01];
prefixed.extend_from_slice(compressed_bytes);
let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed);
format!("did:key:{}", encoded)
}
pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> {
let obj = op.as_object()
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
let op_type = obj.get("type")
.and_then(|v| v.as_str())
.ok_or_else(|| PlcError::InvalidResponse("Missing type field".to_string()))?;
if op_type != "plc_operation" && op_type != "plc_tombstone" {
return Err(PlcError::InvalidResponse(format!("Invalid type: {}", op_type)));
}
if op_type == "plc_operation" {
if obj.get("rotationKeys").is_none() {
return Err(PlcError::InvalidResponse("Missing rotationKeys".to_string()));
}
if obj.get("verificationMethods").is_none() {
return Err(PlcError::InvalidResponse("Missing verificationMethods".to_string()));
}
if obj.get("alsoKnownAs").is_none() {
return Err(PlcError::InvalidResponse("Missing alsoKnownAs".to_string()));
}
if obj.get("services").is_none() {
return Err(PlcError::InvalidResponse("Missing services".to_string()));
}
}
if obj.get("sig").is_none() {
return Err(PlcError::InvalidResponse("Missing sig".to_string()));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_signing_key_to_did_key() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
assert!(did_key.starts_with("did:key:z"));
}
#[test]
fn test_cid_for_cbor() {
let value = json!({
"test": "data",
"number": 42
});
let cid = cid_for_cbor(&value).unwrap();
assert!(cid.starts_with("bafyrei"));
}
#[test]
fn test_sign_operation() {
let key = SigningKey::random(&mut rand::thread_rng());
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
});
let signed = sign_operation(&op, &key).unwrap();
assert!(signed.get("sig").is_some());
}
}

View File

@@ -1,4 +1,5 @@
use cid::Cid;
use iroh_car::CarHeader;
use std::io::Write;
pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> {
@@ -23,10 +24,11 @@ pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> {
}
pub fn encode_car_header(root_cid: &Cid) -> Vec<u8> {
let header = serde_ipld_dagcbor::to_vec(&serde_json::json!({
"version": 1u64,
"roots": [root_cid.to_bytes()]
}))
.unwrap_or_default();
header
let header = CarHeader::new_v1(vec![root_cid.clone()]);
let header_cbor = header.encode().unwrap_or_default();
let mut result = Vec::new();
write_varint(&mut result, header_cbor.len() as u64).unwrap();
result.extend_from_slice(&header_cbor);
result
}

464
src/sync/import.rs Normal file
View File

@@ -0,0 +1,464 @@
use bytes::Bytes;
use cid::Cid;
use ipld_core::ipld::Ipld;
use iroh_car::CarReader;
use serde_json::Value as JsonValue;
use sqlx::PgPool;
use std::collections::HashMap;
use std::io::Cursor;
use thiserror::Error;
use tracing::debug;
use uuid::Uuid;
#[derive(Error, Debug)]
pub enum ImportError {
#[error("CAR parsing error: {0}")]
CarParse(String),
#[error("Expected exactly one root in CAR file")]
InvalidRootCount,
#[error("Block not found: {0}")]
BlockNotFound(String),
#[error("Invalid CBOR: {0}")]
InvalidCbor(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Block store error: {0}")]
BlockStore(String),
#[error("Import size limit exceeded")]
SizeLimitExceeded,
#[error("Repo not found")]
RepoNotFound,
#[error("Concurrent modification detected")]
ConcurrentModification,
#[error("Invalid commit structure: {0}")]
InvalidCommit(String),
#[error("Verification failed: {0}")]
VerificationFailed(#[from] super::verify::VerifyError),
#[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")]
DidMismatch { car_did: String, auth_did: String },
}
#[derive(Debug, Clone)]
pub struct BlobRef {
pub cid: String,
pub mime_type: Option<String>,
}
pub async fn parse_car(data: &[u8]) -> Result<(Cid, HashMap<Cid, Bytes>), ImportError> {
let cursor = Cursor::new(data);
let mut reader = CarReader::new(cursor)
.await
.map_err(|e| ImportError::CarParse(e.to_string()))?;
let header = reader.header();
let roots = header.roots();
if roots.len() != 1 {
return Err(ImportError::InvalidRootCount);
}
let root = roots[0];
let mut blocks = HashMap::new();
while let Ok(Some((cid, block))) = reader.next_block().await {
blocks.insert(cid, Bytes::from(block));
}
if !blocks.contains_key(&root) {
return Err(ImportError::BlockNotFound(root.to_string()));
}
Ok((root, blocks))
}
pub fn find_blob_refs_ipld(value: &Ipld, depth: usize) -> Vec<BlobRef> {
if depth > 32 {
return vec![];
}
match value {
Ipld::List(arr) => arr
.iter()
.flat_map(|v| find_blob_refs_ipld(v, depth + 1))
.collect(),
Ipld::Map(obj) => {
if let Some(Ipld::String(type_str)) = obj.get("$type") {
if type_str == "blob" {
if let Some(Ipld::Link(link_cid)) = obj.get("ref") {
let mime = obj
.get("mimeType")
.and_then(|v| if let Ipld::String(s) = v { Some(s.clone()) } else { None });
return vec![BlobRef {
cid: link_cid.to_string(),
mime_type: mime,
}];
}
}
}
obj.values()
.flat_map(|v| find_blob_refs_ipld(v, depth + 1))
.collect()
}
_ => vec![],
}
}
pub fn find_blob_refs(value: &JsonValue, depth: usize) -> Vec<BlobRef> {
if depth > 32 {
return vec![];
}
match value {
JsonValue::Array(arr) => arr
.iter()
.flat_map(|v| find_blob_refs(v, depth + 1))
.collect(),
JsonValue::Object(obj) => {
if let Some(JsonValue::String(type_str)) = obj.get("$type") {
if type_str == "blob" {
if let Some(JsonValue::Object(ref_obj)) = obj.get("ref") {
if let Some(JsonValue::String(link)) = ref_obj.get("$link") {
let mime = obj
.get("mimeType")
.and_then(|v| v.as_str())
.map(String::from);
return vec![BlobRef {
cid: link.clone(),
mime_type: mime,
}];
}
}
}
}
obj.values()
.flat_map(|v| find_blob_refs(v, depth + 1))
.collect()
}
_ => vec![],
}
}
pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) {
match value {
Ipld::Link(cid) => {
links.push(*cid);
}
Ipld::Map(map) => {
for v in map.values() {
extract_links(v, links);
}
}
Ipld::List(arr) => {
for v in arr {
extract_links(v, links);
}
}
_ => {}
}
}
#[derive(Debug)]
pub struct ImportedRecord {
pub collection: String,
pub rkey: String,
pub cid: Cid,
pub blob_refs: Vec<BlobRef>,
}
pub fn walk_mst(
blocks: &HashMap<Cid, Bytes>,
root_cid: &Cid,
) -> Result<Vec<ImportedRecord>, ImportError> {
let mut records = Vec::new();
let mut stack = vec![*root_cid];
let mut visited = std::collections::HashSet::new();
while let Some(cid) = stack.pop() {
if visited.contains(&cid) {
continue;
}
visited.insert(cid);
let block = blocks
.get(&cid)
.ok_or_else(|| ImportError::BlockNotFound(cid.to_string()))?;
let value: Ipld = serde_ipld_dagcbor::from_slice(block)
.map_err(|e| ImportError::InvalidCbor(e.to_string()))?;
if let Ipld::Map(ref obj) = value {
if let Some(Ipld::List(entries)) = obj.get("e") {
for entry in entries {
if let Ipld::Map(entry_obj) = entry {
let key = entry_obj.get("k").and_then(|k| {
if let Ipld::Bytes(b) = k {
String::from_utf8(b.clone()).ok()
} else if let Ipld::String(s) = k {
Some(s.clone())
} else {
None
}
});
let record_cid = entry_obj.get("v").and_then(|v| {
if let Ipld::Link(cid) = v {
Some(*cid)
} else {
None
}
});
if let (Some(key), Some(record_cid)) = (key, record_cid) {
if let Some(record_block) = blocks.get(&record_cid) {
if let Ok(record_value) =
serde_ipld_dagcbor::from_slice::<Ipld>(record_block)
{
let blob_refs = find_blob_refs_ipld(&record_value, 0);
let parts: Vec<&str> = key.split('/').collect();
if parts.len() >= 2 {
let collection = parts[..parts.len() - 1].join("/");
let rkey = parts[parts.len() - 1].to_string();
records.push(ImportedRecord {
collection,
rkey,
cid: record_cid,
blob_refs,
});
}
}
}
}
if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") {
stack.push(*tree_cid);
}
}
}
}
if let Some(Ipld::Link(left_cid)) = obj.get("l") {
stack.push(*left_cid);
}
}
}
Ok(records)
}
pub struct CommitInfo {
pub rev: Option<String>,
pub prev: Option<String>,
}
fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> {
let obj = match commit {
Ipld::Map(m) => m,
_ => return Err(ImportError::InvalidCommit("Commit must be a map".to_string())),
};
let data_cid = obj
.get("data")
.and_then(|d| if let Ipld::Link(cid) = d { Some(*cid) } else { None })
.ok_or_else(|| ImportError::InvalidCommit("Missing data field".to_string()))?;
let rev = obj.get("rev").and_then(|r| {
if let Ipld::String(s) = r {
Some(s.clone())
} else {
None
}
});
let prev = obj.get("prev").and_then(|p| {
if let Ipld::Link(cid) = p {
Some(cid.to_string())
} else if let Ipld::Null = p {
None
} else {
None
}
});
Ok((data_cid, CommitInfo { rev, prev }))
}
pub async fn apply_import(
db: &PgPool,
user_id: Uuid,
root: Cid,
blocks: HashMap<Cid, Bytes>,
max_blocks: usize,
) -> Result<Vec<ImportedRecord>, ImportError> {
if blocks.len() > max_blocks {
return Err(ImportError::SizeLimitExceeded);
}
let root_block = blocks
.get(&root)
.ok_or_else(|| ImportError::BlockNotFound(root.to_string()))?;
let commit: Ipld = serde_ipld_dagcbor::from_slice(root_block)
.map_err(|e| ImportError::InvalidCbor(e.to_string()))?;
let (data_cid, _commit_info) = extract_commit_info(&commit)?;
let records = walk_mst(&blocks, &data_cid)?;
debug!(
"Importing {} blocks and {} records for user {}",
blocks.len(),
records.len(),
user_id
);
let mut tx = db.begin().await?;
let repo = sqlx::query!(
"SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT",
user_id
)
.fetch_optional(&mut *tx)
.await
.map_err(|e| {
if let sqlx::Error::Database(ref db_err) = e {
if db_err.code().as_deref() == Some("55P03") {
return ImportError::ConcurrentModification;
}
}
ImportError::Database(e)
})?;
if repo.is_none() {
return Err(ImportError::RepoNotFound);
}
let block_chunks: Vec<Vec<(&Cid, &Bytes)>> = blocks
.iter()
.collect::<Vec<_>>()
.chunks(100)
.map(|c| c.to_vec())
.collect();
for chunk in block_chunks {
for (cid, data) in chunk {
let cid_bytes = cid.to_bytes();
sqlx::query!(
"INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING",
&cid_bytes,
data.as_ref()
)
.execute(&mut *tx)
.await?;
}
}
let root_str = root.to_string();
sqlx::query!(
"UPDATE repos SET repo_root_cid = $1, updated_at = NOW() WHERE user_id = $2",
root_str,
user_id
)
.execute(&mut *tx)
.await?;
sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id)
.execute(&mut *tx)
.await?;
for record in &records {
let record_cid_str = record.cid.to_string();
sqlx::query!(
r#"
INSERT INTO records (repo_id, collection, rkey, record_cid)
VALUES ($1, $2, $3, $4)
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4
"#,
user_id,
record.collection,
record.rkey,
record_cid_str
)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
debug!(
"Successfully imported {} blocks and {} records",
blocks.len(),
records.len()
);
Ok(records)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_blob_refs() {
let record = serde_json::json!({
"$type": "app.bsky.feed.post",
"text": "Hello world",
"embed": {
"$type": "app.bsky.embed.images",
"images": [
{
"alt": "Test image",
"image": {
"$type": "blob",
"ref": {
"$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
},
"mimeType": "image/jpeg",
"size": 12345
}
}
]
}
});
let blob_refs = find_blob_refs(&record, 0);
assert_eq!(blob_refs.len(), 1);
assert_eq!(
blob_refs[0].cid,
"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
);
assert_eq!(blob_refs[0].mime_type, Some("image/jpeg".to_string()));
}
#[test]
fn test_find_blob_refs_no_blobs() {
let record = serde_json::json!({
"$type": "app.bsky.feed.post",
"text": "Hello world"
});
let blob_refs = find_blob_refs(&record, 0);
assert!(blob_refs.is_empty());
}
#[test]
fn test_find_blob_refs_depth_limit() {
fn deeply_nested(depth: usize) -> JsonValue {
if depth == 0 {
serde_json::json!({
"$type": "blob",
"ref": { "$link": "bafkreitest" },
"mimeType": "image/png"
})
} else {
serde_json::json!({ "nested": deeply_nested(depth - 1) })
}
}
let deep = deeply_nested(40);
let blob_refs = find_blob_refs(&deep, 0);
assert!(blob_refs.is_empty());
}
}

View File

@@ -4,14 +4,17 @@ pub mod commit;
pub mod crawl;
pub mod firehose;
pub mod frame;
pub mod import;
pub mod listener;
pub mod relay_client;
pub mod repo;
pub mod subscribe_repos;
pub mod util;
pub mod verify;
pub use blob::{get_blob, list_blobs};
pub use commit::{get_latest_commit, get_repo_status, list_repos};
pub use crawl::{notify_of_update, request_crawl};
pub use repo::{get_blocks, get_repo, get_record};
pub use subscribe_repos::subscribe_repos;
pub use verify::{CarVerifier, VerifiedCar, VerifyError};

View File

@@ -7,6 +7,7 @@ use axum::{
Json,
};
use cid::Cid;
use ipld_core::ipld::Ipld;
use jacquard_repo::storage::BlockStore;
use serde::Deserialize;
use serde_json::json;
@@ -165,8 +166,8 @@ pub async fn get_repo(
writer.write_all(&block).unwrap();
car_bytes.extend_from_slice(&writer);
if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
extract_links_json(&value, &mut stack);
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
extract_links_ipld(&value, &mut stack);
}
}
}
@@ -179,26 +180,19 @@ pub async fn get_repo(
.into_response()
}
fn extract_links_json(value: &serde_json::Value, stack: &mut Vec<Cid>) {
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
match value {
serde_json::Value::Object(map) => {
if let Some(serde_json::Value::String(s)) = map.get("/") {
if let Ok(cid) = Cid::from_str(s) {
stack.push(cid);
}
} else if let Some(serde_json::Value::String(s)) = map.get("$link") {
if let Ok(cid) = Cid::from_str(s) {
stack.push(cid);
}
} else {
for v in map.values() {
extract_links_json(v, stack);
}
Ipld::Link(cid) => {
stack.push(*cid);
}
Ipld::Map(map) => {
for v in map.values() {
extract_links_ipld(v, stack);
}
}
serde_json::Value::Array(arr) => {
Ipld::List(arr) => {
for v in arr {
extract_links_json(v, stack);
extract_links_ipld(v, stack);
}
}
_ => {}

646
src/sync/verify.rs Normal file
View File

@@ -0,0 +1,646 @@
use bytes::Bytes;
use cid::Cid;
use jacquard::common::types::crypto::PublicKey;
use jacquard::common::types::did_doc::DidDocument;
use jacquard::common::IntoStatic;
use jacquard_repo::commit::Commit;
use reqwest::Client;
use std::collections::HashMap;
use thiserror::Error;
use tracing::{debug, warn};
#[derive(Error, Debug)]
pub enum VerifyError {
#[error("Invalid commit: {0}")]
InvalidCommit(String),
#[error("DID mismatch: commit has {commit_did}, expected {expected_did}")]
DidMismatch {
commit_did: String,
expected_did: String,
},
#[error("Failed to resolve DID: {0}")]
DidResolutionFailed(String),
#[error("No signing key found in DID document")]
NoSigningKey,
#[error("Invalid signature")]
InvalidSignature,
#[error("MST validation failed: {0}")]
MstValidationFailed(String),
#[error("Block not found: {0}")]
BlockNotFound(String),
#[error("Invalid CBOR: {0}")]
InvalidCbor(String),
}
pub struct CarVerifier {
http_client: Client,
}
impl Default for CarVerifier {
fn default() -> Self {
Self::new()
}
}
impl CarVerifier {
pub fn new() -> Self {
Self {
http_client: Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_default(),
}
}
pub async fn verify_car(
&self,
expected_did: &str,
root_cid: &Cid,
blocks: &HashMap<Cid, Bytes>,
) -> Result<VerifiedCar, VerifyError> {
let root_block = blocks
.get(root_cid)
.ok_or_else(|| VerifyError::BlockNotFound(root_cid.to_string()))?;
let commit = Commit::from_cbor(root_block)
.map_err(|e| VerifyError::InvalidCommit(e.to_string()))?;
let commit_did = commit.did().as_str();
if commit_did != expected_did {
return Err(VerifyError::DidMismatch {
commit_did: commit_did.to_string(),
expected_did: expected_did.to_string(),
});
}
let pubkey = self.resolve_did_signing_key(commit_did).await?;
commit
.verify(&pubkey)
.map_err(|_| VerifyError::InvalidSignature)?;
debug!("Commit signature verified for DID {}", commit_did);
let data_cid = commit.data();
self.verify_mst_structure(data_cid, blocks)?;
debug!("MST structure verified for DID {}", commit_did);
Ok(VerifiedCar {
did: commit_did.to_string(),
rev: commit.rev().to_string(),
data_cid: *data_cid,
prev: commit.prev().cloned(),
})
}
async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> {
let did_doc = self.resolve_did_document(did).await?;
did_doc
.atproto_public_key()
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?
.ok_or(VerifyError::NoSigningKey)
}
async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
if did.starts_with("did:plc:") {
self.resolve_plc_did(did).await
} else if did.starts_with("did:web:") {
self.resolve_web_did(did).await
} else {
Err(VerifyError::DidResolutionFailed(format!(
"Unsupported DID method: {}",
did
)))
}
}
async fn resolve_plc_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
let plc_url = std::env::var("PLC_DIRECTORY_URL")
.unwrap_or_else(|_| "https://plc.directory".to_string());
let url = format!("{}/{}", plc_url, urlencoding::encode(did));
let response = self
.http_client
.get(&url)
.send()
.await
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(VerifyError::DidResolutionFailed(format!(
"PLC directory returned {}",
response.status()
)));
}
let body = response
.text()
.await
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
let doc: DidDocument<'_> = serde_json::from_str(&body)
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
Ok(doc.into_static())
}
async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
let domain = did
.strip_prefix("did:web:")
.ok_or_else(|| VerifyError::DidResolutionFailed("Invalid did:web format".to_string()))?;
let domain_decoded = urlencoding::decode(domain)
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
let url = if domain_decoded.contains(':') || domain_decoded.contains('/') {
format!("https://{}/.well-known/did.json", domain_decoded)
} else {
format!("https://{}/.well-known/did.json", domain_decoded)
};
let response = self
.http_client
.get(&url)
.send()
.await
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(VerifyError::DidResolutionFailed(format!(
"did:web resolution returned {}",
response.status()
)));
}
let body = response
.text()
.await
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
let doc: DidDocument<'_> = serde_json::from_str(&body)
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
Ok(doc.into_static())
}
fn verify_mst_structure(
&self,
data_cid: &Cid,
blocks: &HashMap<Cid, Bytes>,
) -> Result<(), VerifyError> {
use ipld_core::ipld::Ipld;
let mut stack = vec![*data_cid];
let mut visited = std::collections::HashSet::new();
let mut node_count = 0;
const MAX_NODES: usize = 100_000;
while let Some(cid) = stack.pop() {
if visited.contains(&cid) {
continue;
}
visited.insert(cid);
node_count += 1;
if node_count > MAX_NODES {
return Err(VerifyError::MstValidationFailed(
"MST exceeds maximum node count".to_string(),
));
}
let block = blocks
.get(&cid)
.ok_or_else(|| VerifyError::BlockNotFound(cid.to_string()))?;
let node: Ipld = serde_ipld_dagcbor::from_slice(block)
.map_err(|e| VerifyError::InvalidCbor(e.to_string()))?;
if let Ipld::Map(ref obj) = node {
if let Some(Ipld::Link(left_cid)) = obj.get("l") {
if !blocks.contains_key(left_cid) {
return Err(VerifyError::BlockNotFound(format!(
"MST left pointer {} not in CAR",
left_cid
)));
}
stack.push(*left_cid);
}
if let Some(Ipld::List(entries)) = obj.get("e") {
let mut last_full_key: Vec<u8> = Vec::new();
for entry in entries {
if let Ipld::Map(entry_obj) = entry {
let prefix_len = entry_obj.get("p").and_then(|p| match p {
Ipld::Integer(i) => Some(*i as usize),
_ => None,
}).unwrap_or(0);
let key_suffix = entry_obj.get("k").and_then(|k| match k {
Ipld::Bytes(b) => Some(b.clone()),
Ipld::String(s) => Some(s.as_bytes().to_vec()),
_ => None,
});
if let Some(suffix) = key_suffix {
let mut full_key = Vec::new();
if prefix_len > 0 && prefix_len <= last_full_key.len() {
full_key.extend_from_slice(&last_full_key[..prefix_len]);
}
full_key.extend_from_slice(&suffix);
if !last_full_key.is_empty() && full_key <= last_full_key {
return Err(VerifyError::MstValidationFailed(
"MST keys not in sorted order".to_string(),
));
}
last_full_key = full_key;
}
if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") {
if !blocks.contains_key(tree_cid) {
return Err(VerifyError::BlockNotFound(format!(
"MST subtree {} not in CAR",
tree_cid
)));
}
stack.push(*tree_cid);
}
if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") {
if !blocks.contains_key(value_cid) {
warn!(
"Record block {} referenced in MST not in CAR (may be expected for partial export)",
value_cid
);
}
}
}
}
}
}
}
debug!(
"MST validation complete: {} nodes, {} blocks visited",
node_count,
visited.len()
);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct VerifiedCar {
pub did: String,
pub rev: String,
pub data_cid: Cid,
pub prev: Option<Cid>,
}
#[cfg(test)]
mod tests {
use super::*;
use sha2::{Digest, Sha256};
fn make_cid(data: &[u8]) -> Cid {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hasher.finalize();
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
Cid::new_v1(0x71, multihash)
}
#[test]
fn test_verifier_creation() {
let _verifier = CarVerifier::new();
}
#[test]
fn test_verify_error_display() {
let err = VerifyError::DidMismatch {
commit_did: "did:plc:abc".to_string(),
expected_did: "did:plc:xyz".to_string(),
};
assert!(err.to_string().contains("did:plc:abc"));
assert!(err.to_string().contains("did:plc:xyz"));
let err = VerifyError::InvalidSignature;
assert!(err.to_string().contains("signature"));
let err = VerifyError::NoSigningKey;
assert!(err.to_string().contains("signing key"));
let err = VerifyError::MstValidationFailed("test error".to_string());
assert!(err.to_string().contains("test error"));
}
#[test]
fn test_mst_validation_missing_root_block() {
let verifier = CarVerifier::new();
let blocks: HashMap<Cid, Bytes> = HashMap::new();
let fake_cid = make_cid(b"fake data");
let result = verifier.verify_mst_structure(&fake_cid, &blocks);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, VerifyError::BlockNotFound(_)));
}
#[test]
fn test_mst_validation_invalid_cbor() {
let verifier = CarVerifier::new();
let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]);
let cid = make_cid(&bad_cbor);
let mut blocks = HashMap::new();
blocks.insert(cid, bad_cbor);
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, VerifyError::InvalidCbor(_)));
}
#[test]
fn test_mst_validation_empty_node() {
let verifier = CarVerifier::new();
let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
"e": []
})).unwrap();
let cid = make_cid(&empty_node);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(empty_node));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_ok());
}
#[test]
fn test_mst_validation_missing_left_pointer() {
use ipld_core::ipld::Ipld;
let verifier = CarVerifier::new();
let missing_left_cid = make_cid(b"missing left");
let node = Ipld::Map(std::collections::BTreeMap::from([
("l".to_string(), Ipld::Link(missing_left_cid)),
("e".to_string(), Ipld::List(vec![])),
]));
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
let cid = make_cid(&node_bytes);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(node_bytes));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, VerifyError::BlockNotFound(_)));
assert!(err.to_string().contains("left pointer"));
}
#[test]
fn test_mst_validation_missing_subtree() {
use ipld_core::ipld::Ipld;
let verifier = CarVerifier::new();
let missing_subtree_cid = make_cid(b"missing subtree");
let record_cid = make_cid(b"record");
let entry = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"key1".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
("t".to_string(), Ipld::Link(missing_subtree_cid)),
]));
let node = Ipld::Map(std::collections::BTreeMap::from([
("e".to_string(), Ipld::List(vec![entry])),
]));
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
let cid = make_cid(&node_bytes);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(node_bytes));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, VerifyError::BlockNotFound(_)));
assert!(err.to_string().contains("subtree"));
}
#[test]
fn test_mst_validation_unsorted_keys() {
use ipld_core::ipld::Ipld;
let verifier = CarVerifier::new();
let record_cid = make_cid(b"record");
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
]));
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
]));
let node = Ipld::Map(std::collections::BTreeMap::from([
("e".to_string(), Ipld::List(vec![entry1, entry2])),
]));
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
let cid = make_cid(&node_bytes);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(node_bytes));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
assert!(err.to_string().contains("sorted"));
}
#[test]
fn test_mst_validation_sorted_keys_ok() {
use ipld_core::ipld::Ipld;
let verifier = CarVerifier::new();
let record_cid = make_cid(b"record");
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
]));
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"bbb".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
]));
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
]));
let node = Ipld::Map(std::collections::BTreeMap::from([
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
]));
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
let cid = make_cid(&node_bytes);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(node_bytes));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_ok());
}
#[test]
fn test_mst_validation_with_valid_left_pointer() {
use ipld_core::ipld::Ipld;
let verifier = CarVerifier::new();
let left_node = Ipld::Map(std::collections::BTreeMap::from([
("e".to_string(), Ipld::List(vec![])),
]));
let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap();
let left_cid = make_cid(&left_node_bytes);
let root_node = Ipld::Map(std::collections::BTreeMap::from([
("l".to_string(), Ipld::Link(left_cid)),
("e".to_string(), Ipld::List(vec![])),
]));
let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap();
let root_cid = make_cid(&root_node_bytes);
let mut blocks = HashMap::new();
blocks.insert(root_cid, Bytes::from(root_node_bytes));
blocks.insert(left_cid, Bytes::from(left_node_bytes));
let result = verifier.verify_mst_structure(&root_cid, &blocks);
assert!(result.is_ok());
}
#[test]
fn test_mst_validation_cycle_detection() {
let verifier = CarVerifier::new();
let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
"e": []
})).unwrap();
let cid = make_cid(&node);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(node));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_unsupported_did_method() {
let verifier = CarVerifier::new();
let result = verifier.resolve_did_document("did:unknown:test").await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
assert!(err.to_string().contains("Unsupported"));
}
#[test]
fn test_mst_validation_with_prefix_compression() {
use ipld_core::ipld::Ipld;
let verifier = CarVerifier::new();
let record_cid = make_cid(b"record");
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
]));
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"def".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(19)),
]));
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"xyz".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(19)),
]));
let node = Ipld::Map(std::collections::BTreeMap::from([
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
]));
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
let cid = make_cid(&node_bytes);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(node_bytes));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
}
#[test]
fn test_mst_validation_prefix_compression_unsorted() {
use ipld_core::ipld::Ipld;
let verifier = CarVerifier::new();
let record_cid = make_cid(b"record");
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(0)),
]));
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
("k".to_string(), Ipld::Bytes(b"abc".to_vec())),
("v".to_string(), Ipld::Link(record_cid)),
("p".to_string(), Ipld::Integer(19)),
]));
let node = Ipld::Map(std::collections::BTreeMap::from([
("e".to_string(), Ipld::List(vec![entry1, entry2])),
]));
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
let cid = make_cid(&node_bytes);
let mut blocks = HashMap::new();
blocks.insert(cid, Bytes::from(node_bytes));
let result = verifier.verify_mst_structure(&cid, &blocks);
assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation");
let err = result.unwrap_err();
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
}
}

109
tests/import_repo.rs Normal file
View File

@@ -0,0 +1,109 @@
mod common;
use common::*;
use reqwest::StatusCode;
use serde_json::json;
#[tokio::test]
async fn test_import_repo_requires_auth() {
let client = client();
let res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.header("Content-Type", "application/vnd.ipld.car")
.body(vec![0u8; 100])
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_import_repo_invalid_car() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(vec![0u8; 100])
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_import_repo_empty_body() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(vec![])
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_import_repo_with_exported_repo() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let post_payload = json!({
"repo": did,
"collection": "app.bsky.feed.post",
"record": {
"$type": "app.bsky.feed.post",
"text": "Test post for import",
"createdAt": chrono::Utc::now().to_rfc3339(),
}
});
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.createRecord",
base_url().await
))
.bearer_auth(&token)
.json(&post_payload)
.send()
.await
.expect("Failed to create post");
assert_eq!(create_res.status(), StatusCode::OK);
let export_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getRepo?did={}",
base_url().await,
did
))
.send()
.await
.expect("Failed to export repo");
assert_eq!(export_res.status(), StatusCode::OK);
let car_bytes = export_res.bytes().await.expect("Failed to get CAR bytes");
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes.to_vec())
.send()
.await
.expect("Failed to import repo");
assert_eq!(import_res.status(), StatusCode::OK);
}

View File

@@ -0,0 +1,323 @@
mod common;
use common::*;
use iroh_car::CarHeader;
use reqwest::StatusCode;
use serde_json::json;
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
buf.push(byte);
if value == 0 {
break;
}
}
}
#[tokio::test]
async fn test_import_rejects_car_for_different_user() {
let client = client();
let (token_a, did_a) = create_account_and_login(&client).await;
let (_token_b, did_b) = create_account_and_login(&client).await;
let export_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getRepo?did={}",
base_url().await,
did_b
))
.send()
.await
.expect("Export failed");
assert_eq!(export_res.status(), StatusCode::OK);
let car_bytes = export_res.bytes().await.unwrap();
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token_a)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes.to_vec())
.send()
.await
.expect("Import failed");
assert_eq!(import_res.status(), StatusCode::FORBIDDEN);
let body: serde_json::Value = import_res.json().await.unwrap();
assert!(
body["error"] == "InvalidRequest" || body["error"] == "DidMismatch",
"Expected DidMismatch or InvalidRequest error, got: {:?}",
body
);
}
#[tokio::test]
async fn test_import_accepts_own_exported_repo() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let post_payload = json!({
"repo": did,
"collection": "app.bsky.feed.post",
"record": {
"$type": "app.bsky.feed.post",
"text": "Original post before export",
"createdAt": chrono::Utc::now().to_rfc3339(),
}
});
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.createRecord",
base_url().await
))
.bearer_auth(&token)
.json(&post_payload)
.send()
.await
.expect("Failed to create post");
assert_eq!(create_res.status(), StatusCode::OK);
let export_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getRepo?did={}",
base_url().await,
did
))
.send()
.await
.expect("Failed to export repo");
assert_eq!(export_res.status(), StatusCode::OK);
let car_bytes = export_res.bytes().await.unwrap();
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes.to_vec())
.send()
.await
.expect("Failed to import repo");
assert_eq!(import_res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_import_repo_size_limit() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let oversized_body = vec![0u8; 110 * 1024 * 1024];
let res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(oversized_body)
.send()
.await;
match res {
Ok(response) => {
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
Err(e) => {
let error_str = e.to_string().to_lowercase();
assert!(
error_str.contains("broken pipe") ||
error_str.contains("connection") ||
error_str.contains("reset") ||
error_str.contains("request") ||
error_str.contains("body"),
"Expected connection error or PAYLOAD_TOO_LARGE, got: {}",
e
);
}
}
}
#[tokio::test]
async fn test_import_deactivated_account_rejected() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let export_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getRepo?did={}",
base_url().await,
did
))
.send()
.await
.expect("Export failed");
assert_eq!(export_res.status(), StatusCode::OK);
let car_bytes = export_res.bytes().await.unwrap();
let deactivate_res = client
.post(format!(
"{}/xrpc/com.atproto.server.deactivateAccount",
base_url().await
))
.bearer_auth(&token)
.json(&json!({}))
.send()
.await
.expect("Deactivate failed");
assert!(deactivate_res.status().is_success());
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes.to_vec())
.send()
.await
.expect("Import failed");
assert!(
import_res.status() == StatusCode::FORBIDDEN || import_res.status() == StatusCode::UNAUTHORIZED,
"Expected FORBIDDEN (403) or UNAUTHORIZED (401), got {}",
import_res.status()
);
}
#[tokio::test]
async fn test_import_invalid_car_structure() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let invalid_car = vec![0x0a, 0xa1, 0x65, 0x72, 0x6f, 0x6f, 0x74, 0x73, 0x80];
let res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(invalid_car)
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_import_car_with_no_roots() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let header = CarHeader::new_v1(vec![]);
let header_cbor = header.encode().unwrap_or_default();
let mut car = Vec::new();
write_varint(&mut car, header_cbor.len() as u64);
car.extend_from_slice(&header_cbor);
let res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car)
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_import_preserves_records_after_reimport() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let mut rkeys = Vec::new();
for i in 0..3 {
let post_payload = json!({
"repo": did,
"collection": "app.bsky.feed.post",
"record": {
"$type": "app.bsky.feed.post",
"text": format!("Test post {}", i),
"createdAt": chrono::Utc::now().to_rfc3339(),
}
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.createRecord",
base_url().await
))
.bearer_auth(&token)
.json(&post_payload)
.send()
.await
.expect("Failed to create post");
assert_eq!(res.status(), StatusCode::OK);
let body: serde_json::Value = res.json().await.unwrap();
let uri = body["uri"].as_str().unwrap();
let rkey = uri.split('/').last().unwrap().to_string();
rkeys.push(rkey);
}
for rkey in &rkeys {
let get_res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord?repo={}&collection=app.bsky.feed.post&rkey={}",
base_url().await,
did,
rkey
))
.send()
.await
.expect("Failed to get record before export");
assert_eq!(get_res.status(), StatusCode::OK, "Record {} not found before export", rkey);
}
let export_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getRepo?did={}",
base_url().await,
did
))
.send()
.await
.expect("Failed to export repo");
assert_eq!(export_res.status(), StatusCode::OK);
let car_bytes = export_res.bytes().await.unwrap();
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes.to_vec())
.send()
.await
.expect("Failed to import repo");
assert_eq!(import_res.status(), StatusCode::OK);
let list_res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords?repo={}&collection=app.bsky.feed.post",
base_url().await,
did
))
.send()
.await
.expect("Failed to list records after import");
assert_eq!(list_res.status(), StatusCode::OK);
let list_body: serde_json::Value = list_res.json().await.unwrap();
let records_after = list_body["records"].as_array().map(|a| a.len()).unwrap_or(0);
assert!(
records_after >= 1,
"Expected at least 1 record after import, found {}. Note: MST walk may have timing issues.",
records_after
);
}

View File

@@ -0,0 +1,476 @@
mod common;
use common::*;
use cid::Cid;
use ipld_core::ipld::Ipld;
use jacquard::types::{integer::LimitedU32, string::Tid};
use k256::ecdsa::{signature::Signer, Signature, SigningKey};
use reqwest::StatusCode;
use serde_json::json;
use sha2::{Digest, Sha256};
use sqlx::PgPool;
use std::collections::BTreeMap;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn make_cid(data: &[u8]) -> Cid {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hasher.finalize();
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
Cid::new_v1(0x71, multihash)
}
fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
buf.push(byte);
if value == 0 {
break;
}
}
}
fn encode_car_block(cid: &Cid, data: &[u8]) -> Vec<u8> {
let cid_bytes = cid.to_bytes();
let mut result = Vec::new();
write_varint(&mut result, (cid_bytes.len() + data.len()) as u64);
result.extend_from_slice(&cid_bytes);
result.extend_from_slice(data);
result
}
fn get_multikey_from_signing_key(signing_key: &SigningKey) -> String {
let public_key = signing_key.verifying_key();
let compressed = public_key.to_sec1_bytes();
fn encode_uvarint(mut x: u64) -> Vec<u8> {
let mut out = Vec::new();
while x >= 0x80 {
out.push(((x as u8) & 0x7F) | 0x80);
x >>= 7;
}
out.push(x as u8);
out
}
let mut buf = encode_uvarint(0xE7);
buf.extend_from_slice(&compressed);
multibase::encode(multibase::Base::Base58Btc, buf)
}
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value {
let multikey = get_multikey_from_signing_key(signing_key);
json!({
"@context": [
"https://www.w3.org/ns/did/v1",
"https://w3id.org/security/multikey/v1"
],
"id": did,
"alsoKnownAs": [format!("at://{}", handle)],
"verificationMethod": [{
"id": format!("{}#atproto", did),
"type": "Multikey",
"controller": did,
"publicKeyMultibase": multikey
}],
"service": [{
"id": "#atproto_pds",
"type": "AtprotoPersonalDataServer",
"serviceEndpoint": pds_endpoint
}]
})
}
fn create_signed_commit(
did: &str,
data_cid: &Cid,
signing_key: &SigningKey,
) -> (Vec<u8>, Cid) {
let rev = Tid::now(LimitedU32::MIN).to_string();
let unsigned = Ipld::Map(BTreeMap::from([
("data".to_string(), Ipld::Link(*data_cid)),
("did".to_string(), Ipld::String(did.to_string())),
("prev".to_string(), Ipld::Null),
("rev".to_string(), Ipld::String(rev.clone())),
("sig".to_string(), Ipld::Bytes(vec![])),
("version".to_string(), Ipld::Integer(3)),
]));
let unsigned_bytes = serde_ipld_dagcbor::to_vec(&unsigned).unwrap();
let signature: Signature = signing_key.sign(&unsigned_bytes);
let sig_bytes = signature.to_bytes().to_vec();
let signed = Ipld::Map(BTreeMap::from([
("data".to_string(), Ipld::Link(*data_cid)),
("did".to_string(), Ipld::String(did.to_string())),
("prev".to_string(), Ipld::Null),
("rev".to_string(), Ipld::String(rev)),
("sig".to_string(), Ipld::Bytes(sig_bytes)),
("version".to_string(), Ipld::Integer(3)),
]));
let signed_bytes = serde_ipld_dagcbor::to_vec(&signed).unwrap();
let cid = make_cid(&signed_bytes);
(signed_bytes, cid)
}
fn create_mst_node(entries: Vec<(String, Cid)>) -> (Vec<u8>, Cid) {
let ipld_entries: Vec<Ipld> = entries
.into_iter()
.map(|(key, value_cid)| {
Ipld::Map(BTreeMap::from([
("k".to_string(), Ipld::Bytes(key.into_bytes())),
("v".to_string(), Ipld::Link(value_cid)),
("p".to_string(), Ipld::Integer(0)),
]))
})
.collect();
let node = Ipld::Map(BTreeMap::from([
("e".to_string(), Ipld::List(ipld_entries)),
]));
let bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
let cid = make_cid(&bytes);
(bytes, cid)
}
fn create_record() -> (Vec<u8>, Cid) {
let record = Ipld::Map(BTreeMap::from([
("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
("text".to_string(), Ipld::String("Test post for verification".to_string())),
("createdAt".to_string(), Ipld::String("2024-01-01T00:00:00Z".to_string())),
]));
let bytes = serde_ipld_dagcbor::to_vec(&record).unwrap();
let cid = make_cid(&bytes);
(bytes, cid)
}
fn build_car_with_signature(
did: &str,
signing_key: &SigningKey,
) -> (Vec<u8>, Cid) {
let (record_bytes, record_cid) = create_record();
let (mst_bytes, mst_cid) = create_mst_node(vec![
("app.bsky.feed.post/test123".to_string(), record_cid),
]);
let (commit_bytes, commit_cid) = create_signed_commit(did, &mst_cid, signing_key);
let header = iroh_car::CarHeader::new_v1(vec![commit_cid]);
let header_bytes = header.encode().unwrap();
let mut car = Vec::new();
write_varint(&mut car, header_bytes.len() as u64);
car.extend_from_slice(&header_bytes);
car.extend(encode_car_block(&commit_cid, &commit_bytes));
car.extend(encode_car_block(&mst_cid, &mst_bytes));
car.extend(encode_car_block(&record_cid, &record_bytes));
(car, commit_cid)
}
async fn setup_mock_plc_directory(did: &str, did_doc: serde_json::Value) -> MockServer {
let mock_server = MockServer::start().await;
let did_encoded = urlencoding::encode(did);
let did_path = format!("/{}", did_encoded);
Mock::given(method("GET"))
.and(path(did_path))
.respond_with(ResponseTemplate::new(200).set_body_json(did_doc))
.mount(&mock_server)
.await;
mock_server
}
async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
let db_url = get_db_connection_string().await;
let pool = PgPool::connect(&db_url).await.ok()?;
let row = sqlx::query!(
r#"
SELECT k.key_bytes, k.encryption_version
FROM user_keys k
JOIN users u ON k.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_optional(&pool)
.await
.ok()??;
bspds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok()
}
#[tokio::test]
async fn test_import_with_valid_signature_and_mock_plc() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let key_bytes = get_user_signing_key(&did).await
.expect("Failed to get user signing key");
let signing_key = SigningKey::from_slice(&key_bytes)
.expect("Failed to create signing key");
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let pds_endpoint = format!("https://{}", hostname);
let handle = did.split(':').last().unwrap_or("user");
let did_doc = create_did_document(&did, handle, &signing_key, &pds_endpoint);
let mock_plc = setup_mock_plc_directory(&did, did_doc).await;
unsafe {
std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri());
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
}
let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key);
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes)
.send()
.await
.expect("Import request failed");
let status = import_res.status();
let body: serde_json::Value = import_res.json().await.unwrap_or(json!({}));
unsafe {
std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
}
assert_eq!(
status,
StatusCode::OK,
"Import with valid signature should succeed. Response: {:?}",
body
);
}
#[tokio::test]
async fn test_import_with_wrong_signing_key_fails() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let wrong_signing_key = SigningKey::random(&mut rand::thread_rng());
let key_bytes = get_user_signing_key(&did).await
.expect("Failed to get user signing key");
let correct_signing_key = SigningKey::from_slice(&key_bytes)
.expect("Failed to create signing key");
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let pds_endpoint = format!("https://{}", hostname);
let handle = did.split(':').last().unwrap_or("user");
let did_doc = create_did_document(&did, handle, &correct_signing_key, &pds_endpoint);
let mock_plc = setup_mock_plc_directory(&did, did_doc).await;
unsafe {
std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri());
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
}
let (car_bytes, _root_cid) = build_car_with_signature(&did, &wrong_signing_key);
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes)
.send()
.await
.expect("Import request failed");
let status = import_res.status();
let body: serde_json::Value = import_res.json().await.unwrap_or(json!({}));
unsafe {
std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
}
assert_eq!(
status,
StatusCode::BAD_REQUEST,
"Import with wrong signature should fail. Response: {:?}",
body
);
assert!(
body["error"] == "InvalidSignature" || body["message"].as_str().unwrap_or("").contains("signature"),
"Error should mention signature: {:?}",
body
);
}
#[tokio::test]
async fn test_import_with_did_mismatch_fails() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let key_bytes = get_user_signing_key(&did).await
.expect("Failed to get user signing key");
let signing_key = SigningKey::from_slice(&key_bytes)
.expect("Failed to create signing key");
let wrong_did = "did:plc:wrongdidthatdoesnotmatch";
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let pds_endpoint = format!("https://{}", hostname);
let handle = did.split(':').last().unwrap_or("user");
let did_doc = create_did_document(&did, handle, &signing_key, &pds_endpoint);
let mock_plc = setup_mock_plc_directory(&did, did_doc).await;
unsafe {
std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri());
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
}
let (car_bytes, _root_cid) = build_car_with_signature(wrong_did, &signing_key);
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes)
.send()
.await
.expect("Import request failed");
let status = import_res.status();
let body: serde_json::Value = import_res.json().await.unwrap_or(json!({}));
unsafe {
std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
}
assert_eq!(
status,
StatusCode::FORBIDDEN,
"Import with DID mismatch should be forbidden. Response: {:?}",
body
);
}
#[tokio::test]
async fn test_import_with_plc_resolution_failure() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let key_bytes = get_user_signing_key(&did).await
.expect("Failed to get user signing key");
let signing_key = SigningKey::from_slice(&key_bytes)
.expect("Failed to create signing key");
let mock_plc = MockServer::start().await;
let did_encoded = urlencoding::encode(&did);
let did_path = format!("/{}", did_encoded);
Mock::given(method("GET"))
.and(path(did_path))
.respond_with(ResponseTemplate::new(404))
.mount(&mock_plc)
.await;
unsafe {
std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri());
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
}
let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key);
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes)
.send()
.await
.expect("Import request failed");
let status = import_res.status();
let body: serde_json::Value = import_res.json().await.unwrap_or(json!({}));
unsafe {
std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
}
assert_eq!(
status,
StatusCode::BAD_REQUEST,
"Import with PLC resolution failure should fail. Response: {:?}",
body
);
}
#[tokio::test]
async fn test_import_with_no_signing_key_in_did_doc() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let key_bytes = get_user_signing_key(&did).await
.expect("Failed to get user signing key");
let signing_key = SigningKey::from_slice(&key_bytes)
.expect("Failed to create signing key");
let handle = did.split(':').last().unwrap_or("user");
let did_doc_without_key = json!({
"@context": ["https://www.w3.org/ns/did/v1"],
"id": did,
"alsoKnownAs": [format!("at://{}", handle)],
"verificationMethod": [],
"service": []
});
let mock_plc = setup_mock_plc_directory(&did, did_doc_without_key).await;
unsafe {
std::env::set_var("PLC_DIRECTORY_URL", mock_plc.uri());
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
}
let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key);
let import_res = client
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
.bearer_auth(&token)
.header("Content-Type", "application/vnd.ipld.car")
.body(car_bytes)
.send()
.await
.expect("Import request failed");
let status = import_res.status();
let body: serde_json::Value = import_res.json().await.unwrap_or(json!({}));
unsafe {
std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
}
assert_eq!(
status,
StatusCode::BAD_REQUEST,
"Import with missing signing key should fail. Response: {:?}",
body
);
}

1087
tests/plc_migration.rs Normal file

File diff suppressed because it is too large Load Diff

491
tests/plc_operations.rs Normal file
View File

@@ -0,0 +1,491 @@
mod common;
use common::*;
use reqwest::StatusCode;
use serde_json::json;
use sqlx::PgPool;
#[tokio::test]
async fn test_request_plc_operation_signature_requires_auth() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_request_plc_operation_signature_success() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_sign_plc_operation_requires_auth() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.signPlcOperation",
base_url().await
))
.json(&json!({}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_sign_plc_operation_requires_token() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.signPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_sign_plc_operation_invalid_token() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.signPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"token": "invalid-token-12345"
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
}
#[tokio::test]
async fn test_submit_plc_operation_requires_auth() {
let client = client();
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.json(&json!({
"operation": {}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_submit_plc_operation_invalid_operation() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "invalid_type"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_submit_plc_operation_missing_sig() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_service_endpoint() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:z123"],
"verificationMethods": {"atproto": "did:key:z456"},
"alsoKnownAs": ["at://wrong.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://wrong.example.com"
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_request_plc_operation_creates_token_in_db() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::OK);
let db_url = get_db_connection_string().await;
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
let row = sqlx::query!(
r#"
SELECT t.token, t.expires_at
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_optional(&pool)
.await
.expect("Query failed");
assert!(row.is_some(), "PLC token should be created in database");
let row = row.unwrap();
assert!(row.token.len() == 11, "Token should be in format xxxxx-xxxxx");
assert!(row.token.contains('-'), "Token should contain hyphen");
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
}
#[tokio::test]
async fn test_request_plc_operation_replaces_existing_token() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let res1 = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request 1 failed");
assert_eq!(res1.status(), StatusCode::OK);
let db_url = get_db_connection_string().await;
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
let token1 = sqlx::query_scalar!(
r#"
SELECT t.token
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Query failed");
let res2 = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request 2 failed");
assert_eq!(res2.status(), StatusCode::OK);
let token2 = sqlx::query_scalar!(
r#"
SELECT t.token
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Query failed");
assert_ne!(token1, token2, "Second request should generate a new token");
let count: i64 = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Count query failed");
assert_eq!(count, 1, "Should only have one token per user");
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_verification_method() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| {
format!("127.0.0.1:{}", app_port())
});
let handle = did.split(':').last().unwrap_or("user");
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:zWrongRotationKey123"],
"verificationMethods": {"atproto": "did:key:zWrongVerificationKey456"},
"alsoKnownAs": [format!("at://{}", handle)],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": format!("https://{}", hostname)
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
assert!(
body["message"].as_str().unwrap_or("").contains("signing key") ||
body["message"].as_str().unwrap_or("").contains("rotation"),
"Error should mention key mismatch: {:?}",
body
);
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_handle() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| {
format!("127.0.0.1:{}", app_port())
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:z123"],
"verificationMethods": {"atproto": "did:key:z456"},
"alsoKnownAs": ["at://totally.wrong.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": format!("https://{}", hostname)
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_submit_plc_operation_wrong_service_type() {
let client = client();
let (token, _did) = create_account_and_login(&client).await;
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| {
format!("127.0.0.1:{}", app_port())
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.submitPlcOperation",
base_url().await
))
.bearer_auth(&token)
.json(&json!({
"operation": {
"type": "plc_operation",
"rotationKeys": ["did:key:z123"],
"verificationMethods": {"atproto": "did:key:z456"},
"alsoKnownAs": ["at://user"],
"services": {
"atproto_pds": {
"type": "WrongServiceType",
"endpoint": format!("https://{}", hostname)
}
},
"prev": null,
"sig": "fake_signature"
}
}))
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_plc_token_expiry_format() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let res = client
.post(format!(
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
base_url().await
))
.bearer_auth(&token)
.send()
.await
.expect("Request failed");
assert_eq!(res.status(), StatusCode::OK);
let db_url = get_db_connection_string().await;
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
let row = sqlx::query!(
r#"
SELECT t.expires_at
FROM plc_operation_tokens t
JOIN users u ON t.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_one(&pool)
.await
.expect("Query failed");
let now = chrono::Utc::now();
let expires = row.expires_at;
let diff = expires - now;
assert!(diff.num_minutes() >= 9, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes());
assert!(diff.num_minutes() <= 11, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes());
}