diff --git a/src/api/identity.rs b/src/api/identity.rs deleted file mode 100644 index 1c990e7..0000000 --- a/src/api/identity.rs +++ /dev/null @@ -1,424 +0,0 @@ -use axum::{ - extract::{State, Path}, - Json, - response::{IntoResponse, Response}, - http::StatusCode, -}; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use crate::state::AppState; -use sqlx::Row; -use bcrypt::{hash, DEFAULT_COST}; -use tracing::{info, error}; -use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore}; -use jacquard::types::{string::Tid, did::Did, integer::LimitedU32}; -use std::sync::Arc; -use k256::SecretKey; -use rand::rngs::OsRng; -use base64::Engine; -use reqwest; - -#[derive(Deserialize)] -pub struct CreateAccountInput { - pub handle: String, - pub email: String, - pub password: String, - #[serde(rename = "inviteCode")] - pub invite_code: Option, - pub did: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct CreateAccountOutput { - pub access_jwt: String, - pub refresh_jwt: String, - pub handle: String, - pub did: String, -} - -pub async fn create_account( - State(state): State, - Json(input): Json, -) -> Response { - info!("create_account hit: {}", input.handle); - if input.handle.contains('!') || input.handle.contains('@') { - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}))).into_response(); - } - - let did = if let Some(d) = &input.did { - if d.trim().is_empty() { - format!("did:plc:{}", uuid::Uuid::new_v4()) - } else { - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); - if let Err(e) = verify_did_web(d, &hostname, &input.handle).await { - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidDid", "message": e}))).into_response(); - } - d.clone() - } - } else { - format!("did:plc:{}", uuid::Uuid::new_v4()) - }; - - let mut tx = match state.db.begin().await { - Ok(tx) => tx, - Err(e) => { - error!("Error starting transaction: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1") - .bind(&input.handle) - .fetch_optional(&mut *tx) - .await; - - match exists_query { - Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "HandleTaken", "message": "Handle already taken"}))).into_response(), - Err(e) => { - error!("Error checking handle: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - Ok(None) => {} - } - - if let Some(code) = &input.invite_code { - let invite_query = sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE") - .bind(code) - .fetch_optional(&mut *tx) - .await; - - match invite_query { - Ok(Some(row)) => { - let uses: i32 = row.get("available_uses"); - if uses <= 0 { - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response(); - } - - let update_invite = sqlx::query("UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1") - .bind(code) - .execute(&mut *tx) - .await; - - if let Err(e) = update_invite { - error!("Error updating invite code: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }, - Ok(None) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"}))).into_response(), - Err(e) => { - error!("Error checking invite code: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - } - } - - let password_hash = match hash(&input.password, DEFAULT_COST) { - Ok(h) => h, - Err(e) => { - error!("Error hashing password: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id") - .bind(&input.handle) - .bind(&input.email) - .bind(&did) - .bind(&password_hash) - .fetch_one(&mut *tx) - .await; - - let user_id: uuid::Uuid = match user_insert { - Ok(row) => row.get("id"), - Err(e) => { - error!("Error inserting user: {:?}", e); - // TODO: Check for unique constraint violation on email/did specifically - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let secret_key = SecretKey::random(&mut OsRng); - let secret_key_bytes = secret_key.to_bytes(); - - let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)") - .bind(user_id) - .bind(&secret_key_bytes[..]) - .execute(&mut *tx) - .await; - - if let Err(e) = key_insert { - error!("Error inserting user key: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - - let mst = Mst::new(Arc::new(state.block_store.clone())); - let mst_root = match mst.root().await { - Ok(c) => c, - Err(e) => { - error!("Error creating MST root: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let did_obj = match Did::new(&did) { - Ok(d) => d, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(), - }; - - let rev = Tid::now(LimitedU32::MIN); - - let commit = Commit::new_unsigned( - did_obj, - mst_root, - rev, - None - ); - - let commit_bytes = match commit.to_cbor() { - Ok(b) => b, - Err(e) => { - error!("Error serializing genesis commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let commit_cid = match state.block_store.put(&commit_bytes).await { - Ok(c) => c, - Err(e) => { - error!("Error saving genesis commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)") - .bind(user_id) - .bind(commit_cid.to_string()) - .execute(&mut *tx) - .await; - - if let Err(e) = repo_insert { - error!("Error initializing repo: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - - if let Some(code) = &input.invite_code { - let use_insert = sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)") - .bind(code) - .bind(user_id) - .execute(&mut *tx) - .await; - - if let Err(e) = use_insert { - error!("Error recording invite usage: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - } - - let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| { - error!("Error creating access token: {:?}", e); - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() - }); - let access_jwt = match access_jwt { - Ok(t) => t, - Err(r) => return r, - }; - - let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| { - error!("Error creating refresh token: {:?}", e); - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() - }); - let refresh_jwt = match refresh_jwt { - Ok(t) => t, - Err(r) => return r, - }; - - let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)") - .bind(&access_jwt) - .bind(&refresh_jwt) - .bind(&did) - .execute(&mut *tx) - .await; - - if let Err(e) = session_insert { - error!("Error inserting session: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - - if let Err(e) = tx.commit().await { - error!("Error committing transaction: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - - (StatusCode::OK, Json(CreateAccountOutput { - access_jwt, - refresh_jwt, - handle: input.handle, - did, - })).into_response() -} - -fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { - use k256::elliptic_curve::sec1::ToEncodedPoint; - - let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); - let public_key = secret_key.public_key(); - let encoded = public_key.to_encoded_point(false); - let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); - let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); - - json!({ - "kty": "EC", - "crv": "secp256k1", - "x": x, - "y": y - }) -} - -pub async fn well_known_did(State(_state): State) -> impl IntoResponse { - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); - // Kinda for local dev, encode hostname if it contains port - let did = if hostname.contains(':') { - format!("did:web:{}", hostname.replace(':', "%3A")) - } else { - format!("did:web:{}", hostname) - }; - - Json(json!({ - "@context": ["https://www.w3.org/ns/did/v1"], - "id": did, - "service": [{ - "id": "#atproto_pds", - "type": "AtprotoPersonalDataServer", - "serviceEndpoint": format!("https://{}", hostname) - }] - })) -} - -pub async fn user_did_doc( - State(state): State, - Path(handle): Path, -) -> Response { - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); - - let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1") - .bind(&handle) - .fetch_optional(&state.db) - .await; - - let (user_id, did) = match user { - Ok(Some(row)) => { - let id: uuid::Uuid = row.get("id"); - let d: String = row.get("did"); - (id, d) - }, - Ok(None) => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(), - Err(e) => { - error!("DB Error: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() - }, - }; - - if !did.starts_with("did:web:") { - return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "User is not did:web"}))).into_response(); - } - - let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1") - .bind(user_id) - .fetch_optional(&state.db) - .await; - - let key_bytes: Vec = match key_row { - Ok(Some(row)) => row.get("key_bytes"), - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(), - }; - - let jwk = get_jwk(&key_bytes); - - Json(json!({ - "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], - "id": did, - "alsoKnownAs": [format!("at://{}", handle)], - "verificationMethod": [{ - "id": format!("{}#atproto", did), - "type": "JsonWebKey2020", - "controller": did, - "publicKeyJwk": jwk - }], - "service": [{ - "id": "#atproto_pds", - "type": "AtprotoPersonalDataServer", - "serviceEndpoint": format!("https://{}", hostname) - }] - })).into_response() -} - -async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { - let expected_prefix = if hostname.contains(':') { - format!("did:web:{}", hostname.replace(':', "%3A")) - } else { - format!("did:web:{}", hostname) - }; - - if did.starts_with(&expected_prefix) { - let suffix = &did[expected_prefix.len()..]; - let expected_suffix = format!(":u:{}", handle); - if suffix == expected_suffix { - Ok(()) - } else { - Err(format!("Invalid DID path for this PDS. Expected {}", expected_suffix)) - } - } else { - let parts: Vec<&str> = did.split(':').collect(); - if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { - return Err("Invalid did:web format".into()); - } - - let domain_segment = parts[2]; - let domain = domain_segment.replace("%3A", ":"); - - let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { - "http" - } else { - "https" - }; - - let url = if parts.len() == 3 { - format!("{}://{}/.well-known/did.json", scheme, domain) - } else { - let path = parts[3..].join("/"); - format!("{}://{}/{}/did.json", scheme, domain, path) - }; - - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(5)) - .build() - .map_err(|e| format!("Failed to create client: {}", e))?; - - let resp = client.get(&url).send().await - .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; - - if !resp.status().is_success() { - return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); - } - - let doc: serde_json::Value = resp.json().await - .map_err(|e| format!("Failed to parse DID doc: {}", e))?; - - let services = doc["service"].as_array() - .ok_or("No services found in DID doc")?; - - let pds_endpoint = format!("https://{}", hostname); - - let has_valid_service = services.iter().any(|s| { - s["type"] == "AtprotoPersonalDataServer" && - s["serviceEndpoint"] == pds_endpoint - }); - - if has_valid_service { - Ok(()) - } else { - Err(format!("DID document does not list this PDS ({}) as AtprotoPersonalDataServer", pds_endpoint)) - } - } -} diff --git a/src/api/identity/account.rs b/src/api/identity/account.rs new file mode 100644 index 0000000..ba9de90 --- /dev/null +++ b/src/api/identity/account.rs @@ -0,0 +1,355 @@ +use super::did::verify_did_web; +use crate::state::AppState; +use axum::{ + Json, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use bcrypt::{DEFAULT_COST, hash}; +use jacquard::types::{did::Did, integer::LimitedU32, string::Tid}; +use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; +use k256::SecretKey; +use rand::rngs::OsRng; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sqlx::Row; +use std::sync::Arc; +use tracing::{error, info}; + +#[derive(Deserialize)] +pub struct CreateAccountInput { + pub handle: String, + pub email: String, + pub password: String, + #[serde(rename = "inviteCode")] + pub invite_code: Option, + pub did: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateAccountOutput { + pub access_jwt: String, + pub refresh_jwt: String, + pub handle: String, + pub did: String, +} + +pub async fn create_account( + State(state): State, + Json(input): Json, +) -> Response { + info!("create_account hit: {}", input.handle); + if input.handle.contains('!') || input.handle.contains('@') { + return ( + StatusCode::BAD_REQUEST, + Json( + json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), + ), + ) + .into_response(); + } + + let did = if let Some(d) = &input.did { + if d.trim().is_empty() { + format!("did:plc:{}", uuid::Uuid::new_v4()) + } else { + let hostname = + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + if let Err(e) = verify_did_web(d, &hostname, &input.handle).await { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidDid", "message": e})), + ) + .into_response(); + } + d.clone() + } + } else { + format!("did:plc:{}", uuid::Uuid::new_v4()) + }; + + let mut tx = match state.db.begin().await { + Ok(tx) => tx, + Err(e) => { + error!("Error starting transaction: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1") + .bind(&input.handle) + .fetch_optional(&mut *tx) + .await; + + match exists_query { + Ok(Some(_)) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "HandleTaken", "message": "Handle already taken"})), + ) + .into_response(); + } + Err(e) => { + error!("Error checking handle: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + Ok(None) => {} + } + + if let Some(code) = &input.invite_code { + let invite_query = + sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE") + .bind(code) + .fetch_optional(&mut *tx) + .await; + + match invite_query { + Ok(Some(row)) => { + let uses: i32 = row.get("available_uses"); + if uses <= 0 { + return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response(); + } + + let update_invite = sqlx::query( + "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", + ) + .bind(code) + .execute(&mut *tx) + .await; + + if let Err(e) = update_invite { + error!("Error updating invite code: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + } + Ok(None) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})), + ) + .into_response(); + } + Err(e) => { + error!("Error checking invite code: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + } + } + + let password_hash = match hash(&input.password, DEFAULT_COST) { + Ok(h) => h, + Err(e) => { + error!("Error hashing password: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id") + .bind(&input.handle) + .bind(&input.email) + .bind(&did) + .bind(&password_hash) + .fetch_one(&mut *tx) + .await; + + let user_id: uuid::Uuid = match user_insert { + Ok(row) => row.get("id"), + Err(e) => { + error!("Error inserting user: {:?}", e); + // TODO: Check for unique constraint violation on email/did specifically + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let secret_key = SecretKey::random(&mut OsRng); + let secret_key_bytes = secret_key.to_bytes(); + + let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)") + .bind(user_id) + .bind(&secret_key_bytes[..]) + .execute(&mut *tx) + .await; + + if let Err(e) = key_insert { + error!("Error inserting user key: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + let mst = Mst::new(Arc::new(state.block_store.clone())); + let mst_root = match mst.root().await { + Ok(c) => c, + Err(e) => { + error!("Error creating MST root: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let did_obj = match Did::new(&did) { + Ok(d) => d, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Invalid DID"})), + ) + .into_response(); + } + }; + + let rev = Tid::now(LimitedU32::MIN); + + let commit = Commit::new_unsigned(did_obj, mst_root, rev, None); + + let commit_bytes = match commit.to_cbor() { + Ok(b) => b, + Err(e) => { + error!("Error serializing genesis commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let commit_cid = match state.block_store.put(&commit_bytes).await { + Ok(c) => c, + Err(e) => { + error!("Error saving genesis commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)") + .bind(user_id) + .bind(commit_cid.to_string()) + .execute(&mut *tx) + .await; + + if let Err(e) = repo_insert { + error!("Error initializing repo: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + if let Some(code) = &input.invite_code { + let use_insert = + sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)") + .bind(code) + .bind(user_id) + .execute(&mut *tx) + .await; + + if let Err(e) = use_insert { + error!("Error recording invite usage: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + } + + let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| { + error!("Error creating access token: {:?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response() + }); + let access_jwt = match access_jwt { + Ok(t) => t, + Err(r) => return r, + }; + + let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| { + error!("Error creating refresh token: {:?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response() + }); + let refresh_jwt = match refresh_jwt { + Ok(t) => t, + Err(r) => return r, + }; + + let session_insert = + sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)") + .bind(&access_jwt) + .bind(&refresh_jwt) + .bind(&did) + .execute(&mut *tx) + .await; + + if let Err(e) = session_insert { + error!("Error inserting session: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + if let Err(e) = tx.commit().await { + error!("Error committing transaction: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + ( + StatusCode::OK, + Json(CreateAccountOutput { + access_jwt, + refresh_jwt, + handle: input.handle, + did, + }), + ) + .into_response() +} diff --git a/src/api/identity/did.rs b/src/api/identity/did.rs new file mode 100644 index 0000000..cddd9cf --- /dev/null +++ b/src/api/identity/did.rs @@ -0,0 +1,201 @@ +use crate::state::AppState; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use base64::Engine; +use k256::SecretKey; +use k256::elliptic_curve::sec1::ToEncodedPoint; +use reqwest; +use serde_json::json; +use sqlx::Row; +use tracing::error; + +pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { + let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); + let public_key = secret_key.public_key(); + let encoded = public_key.to_encoded_point(false); + let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); + let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); + + json!({ + "kty": "EC", + "crv": "secp256k1", + "x": x, + "y": y + }) +} + +pub async fn well_known_did(State(_state): State) -> impl IntoResponse { + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + // Kinda for local dev, encode hostname if it contains port + let did = if hostname.contains(':') { + format!("did:web:{}", hostname.replace(':', "%3A")) + } else { + format!("did:web:{}", hostname) + }; + + Json(json!({ + "@context": ["https://www.w3.org/ns/did/v1"], + "id": did, + "service": [{ + "id": "#atproto_pds", + "type": "AtprotoPersonalDataServer", + "serviceEndpoint": format!("https://{}", hostname) + }] + })) +} + +pub async fn user_did_doc(State(state): State, Path(handle): Path) -> Response { + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + + let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1") + .bind(&handle) + .fetch_optional(&state.db) + .await; + + let (user_id, did) = match user { + Ok(Some(row)) => { + let id: uuid::Uuid = row.get("id"); + let d: String = row.get("did"); + (id, d) + } + Ok(None) => { + return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); + } + Err(e) => { + error!("DB Error: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + if !did.starts_with("did:web:") { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "User is not did:web"})), + ) + .into_response(); + } + + let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1") + .bind(user_id) + .fetch_optional(&state.db) + .await; + + let key_bytes: Vec = match key_row { + Ok(Some(row)) => row.get("key_bytes"), + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let jwk = get_jwk(&key_bytes); + + Json(json!({ + "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], + "id": did, + "alsoKnownAs": [format!("at://{}", handle)], + "verificationMethod": [{ + "id": format!("{}#atproto", did), + "type": "JsonWebKey2020", + "controller": did, + "publicKeyJwk": jwk + }], + "service": [{ + "id": "#atproto_pds", + "type": "AtprotoPersonalDataServer", + "serviceEndpoint": format!("https://{}", hostname) + }] + })).into_response() +} + +pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { + let expected_prefix = if hostname.contains(':') { + format!("did:web:{}", hostname.replace(':', "%3A")) + } else { + format!("did:web:{}", hostname) + }; + + if did.starts_with(&expected_prefix) { + let suffix = &did[expected_prefix.len()..]; + let expected_suffix = format!(":u:{}", handle); + if suffix == expected_suffix { + Ok(()) + } else { + Err(format!( + "Invalid DID path for this PDS. Expected {}", + expected_suffix + )) + } + } else { + let parts: Vec<&str> = did.split(':').collect(); + if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { + return Err("Invalid did:web format".into()); + } + + let domain_segment = parts[2]; + let domain = domain_segment.replace("%3A", ":"); + + let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { + "http" + } else { + "https" + }; + + let url = if parts.len() == 3 { + format!("{}://{}/.well-known/did.json", scheme, domain) + } else { + let path = parts[3..].join("/"); + format!("{}://{}/{}/did.json", scheme, domain, path) + }; + + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(5)) + .build() + .map_err(|e| format!("Failed to create client: {}", e))?; + + let resp = client + .get(&url) + .send() + .await + .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; + + if !resp.status().is_success() { + return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); + } + + let doc: serde_json::Value = resp + .json() + .await + .map_err(|e| format!("Failed to parse DID doc: {}", e))?; + + let services = doc["service"] + .as_array() + .ok_or("No services found in DID doc")?; + + let pds_endpoint = format!("https://{}", hostname); + + let has_valid_service = services.iter().any(|s| { + s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint + }); + + if has_valid_service { + Ok(()) + } else { + Err(format!( + "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", + pds_endpoint + )) + } + } +} diff --git a/src/api/identity/mod.rs b/src/api/identity/mod.rs new file mode 100644 index 0000000..229e636 --- /dev/null +++ b/src/api/identity/mod.rs @@ -0,0 +1,5 @@ +pub mod account; +pub mod did; + +pub use account::create_account; +pub use did::{user_did_doc, well_known_did}; diff --git a/src/api/mod.rs b/src/api/mod.rs index 049d123..f380623 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,4 @@ -pub mod server; -pub mod repo; -pub mod proxy; pub mod identity; +pub mod proxy; +pub mod repo; +pub mod server; diff --git a/src/api/proxy.rs b/src/api/proxy.rs index e26937c..f3be5b2 100644 --- a/src/api/proxy.rs +++ b/src/api/proxy.rs @@ -1,14 +1,14 @@ +use crate::state::AppState; use axum::{ + body::Bytes, extract::{Path, Query, State}, http::{HeaderMap, Method, StatusCode}, response::{IntoResponse, Response}, - body::Bytes, }; use reqwest::Client; -use tracing::{info, error}; -use std::collections::HashMap; -use crate::state::AppState; use sqlx::Row; +use std::collections::HashMap; +use tracing::{error, info}; pub async fn proxy_handler( State(state): State, @@ -18,8 +18,8 @@ pub async fn proxy_handler( Query(params): Query>, body: Bytes, ) -> Response { - - let proxy_header = headers.get("atproto-proxy") + let proxy_header = headers + .get("atproto-proxy") .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()); @@ -27,7 +27,9 @@ pub async fn proxy_handler( Some(url) => url.clone(), None => match std::env::var("APPVIEW_URL") { Ok(url) => url, - Err(_) => return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response(), + Err(_) => { + return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response(); + } }, }; @@ -37,9 +39,7 @@ pub async fn proxy_handler( let client = Client::new(); - let mut request_builder = client - .request(method_verb, &target_url) - .query(¶ms); + let mut request_builder = client.request(method_verb, &target_url).query(¶ms); let mut auth_header_val = headers.get("Authorization").map(|h| h.clone()); @@ -48,17 +48,21 @@ pub async fn proxy_handler( if let Ok(token) = auth_val.to_str() { let token = token.replace("Bearer ", ""); if let Ok(did) = crate::auth::get_did_from_token(&token) { - let key_row = sqlx::query("SELECT k.key_bytes FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1") + let key_row = sqlx::query("SELECT k.key_bytes FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1") .bind(&did) .fetch_optional(&state.db) .await; if let Ok(Some(row)) = key_row { let key_bytes: Vec = row.get("key_bytes"); - if let Ok(new_token) = crate::auth::create_service_token(&did, aud, &method, &key_bytes) { - if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) { - auth_header_val = Some(val); - } + if let Ok(new_token) = + crate::auth::create_service_token(&did, aud, &method, &key_bytes) + { + if let Ok(val) = + axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) + { + auth_header_val = Some(val); + } } } } @@ -86,7 +90,8 @@ pub async fn proxy_handler( Ok(b) => b, Err(e) => { error!("Error reading proxy response body: {:?}", e); - return (StatusCode::BAD_GATEWAY, "Error reading upstream response").into_response(); + return (StatusCode::BAD_GATEWAY, "Error reading upstream response") + .into_response(); } }; @@ -99,11 +104,11 @@ pub async fn proxy_handler( match response_builder.body(axum::body::Body::from(body)) { Ok(r) => r, Err(e) => { - error!("Error building proxy response: {:?}", e); - (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response() + error!("Error building proxy response: {:?}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response() } } - }, + } Err(e) => { error!("Error sending proxy request: {:?}", e); if e.is_timeout() { diff --git a/src/api/repo.rs b/src/api/repo.rs deleted file mode 100644 index eccbfa0..0000000 --- a/src/api/repo.rs +++ /dev/null @@ -1,889 +0,0 @@ -use axum::{ - extract::{State, Query}, - Json, - response::{IntoResponse, Response}, - http::StatusCode, -}; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use crate::state::AppState; -use chrono::Utc; -use sqlx::Row; -use cid::Cid; -use std::str::FromStr; -use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore}; -use jacquard::types::{string::{Nsid, Tid}, did::Did, integer::LimitedU32}; -use tracing::error; -use std::sync::Arc; -use sha2::{Sha256, Digest}; -use multihash::Multihash; -use axum::body::Bytes; - -#[derive(Deserialize)] -#[allow(dead_code)] -pub struct CreateRecordInput { - pub repo: String, - pub collection: String, - pub rkey: Option, - pub validate: Option, - pub record: serde_json::Value, - #[serde(rename = "swapCommit")] - pub swap_commit: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct CreateRecordOutput { - pub uri: String, - pub cid: String, -} - -pub async fn create_record( - State(state): State, - headers: axum::http::HeaderMap, - Json(input): Json, -) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response(); - } - let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", ""); - - let session = sqlx::query( - "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" - ) - .bind(&token) - .fetch_optional(&state.db) - .await - .unwrap_or(None); - - let (did, key_bytes) = match session { - Some(row) => (row.get::("did"), row.get::, _>("key_bytes")), - None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(), - }; - - if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response(); - } - - if input.repo != did { - return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response(); - } - - let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") - .bind(&did) - .fetch_optional(&state.db) - .await; - - let user_id: uuid::Uuid = match user_query { - Ok(Some(row)) => row.get("id"), - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(), - }; - - let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1") - .bind(user_id) - .fetch_optional(&state.db) - .await; - - let current_root_cid = match repo_root_query { - Ok(Some(row)) => { - let cid_str: String = row.get("repo_root_cid"); - Cid::from_str(&cid_str).ok() - }, - _ => None, - }; - - if current_root_cid.is_none() { - error!("Repo root not found for user {}", did); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response(); - } - let current_root_cid = current_root_cid.unwrap(); - - let commit_bytes = match state.block_store.get(¤t_root_cid).await { - Ok(Some(b)) => b, - Ok(None) => { - error!("Commit block not found: {}", current_root_cid); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - }, - Err(e) => { - error!("Failed to load commit block: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let commit = match Commit::from_cbor(&commit_bytes) { - Ok(c) => c, - Err(e) => { - error!("Failed to parse commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let mst_root = commit.data; - let store = Arc::new(state.block_store.clone()); - let mst = Mst::load(store.clone(), mst_root, None); - - let collection_nsid = match input.collection.parse::() { - Ok(n) => n, - Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(), - }; - - let rkey = input.rkey.unwrap_or_else(|| { - Utc::now().format("%Y%m%d%H%M%S%f").to_string() - }); - - let mut record_bytes = Vec::new(); - if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) { - error!("Error serializing record: {:?}", e); - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); - } - - let record_cid = match state.block_store.put(&record_bytes).await { - Ok(c) => c, - Err(e) => { - error!("Failed to save record block: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let key = format!("{}/{}", collection_nsid, rkey); - if let Err(e) = mst.update(&key, record_cid).await { - error!("Failed to update MST: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - - let new_mst_root = match mst.root().await { - Ok(c) => c, - Err(e) => { - error!("Failed to get new MST root: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let did_obj = match Did::new(&did) { - Ok(d) => d, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(), - }; - - let rev = Tid::now(LimitedU32::MIN); - - let new_commit = Commit::new_unsigned( - did_obj, - new_mst_root, - rev, - Some(current_root_cid) - ); - - let new_commit_bytes = match new_commit.to_cbor() { - Ok(b) => b, - Err(e) => { - error!("Failed to serialize new commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let new_root_cid = match state.block_store.put(&new_commit_bytes).await { - Ok(c) => c, - Err(e) => { - error!("Failed to save new commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2") - .bind(new_root_cid.to_string()) - .bind(user_id) - .execute(&state.db) - .await; - - if let Err(e) = update_repo { - error!("Failed to update repo root in DB: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - - let record_insert = sqlx::query( - "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4) - ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()" - ) - .bind(user_id) - .bind(&input.collection) - .bind(&rkey) - .bind(record_cid.to_string()) - .execute(&state.db) - .await; - - if let Err(e) = record_insert { - error!("Error inserting record index: {:?}", e); - } - - let output = CreateRecordOutput { - uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey), - cid: record_cid.to_string(), - }; - (StatusCode::OK, Json(output)).into_response() -} - -#[derive(Deserialize)] -#[allow(dead_code)] -pub struct PutRecordInput { - pub repo: String, - pub collection: String, - pub rkey: String, - pub validate: Option, - pub record: serde_json::Value, - #[serde(rename = "swapCommit")] - pub swap_commit: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PutRecordOutput { - pub uri: String, - pub cid: String, -} - -pub async fn put_record( - State(state): State, - headers: axum::http::HeaderMap, - Json(input): Json, -) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response(); - } - let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", ""); - - let session = sqlx::query( - "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" - ) - .bind(&token) - .fetch_optional(&state.db) - .await - .unwrap_or(None); - - let (did, key_bytes) = match session { - Some(row) => (row.get::("did"), row.get::, _>("key_bytes")), - None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(), - }; - - if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response(); - } - - if input.repo != did { - return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response(); - } - - let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") - .bind(&did) - .fetch_optional(&state.db) - .await; - - let user_id: uuid::Uuid = match user_query { - Ok(Some(row)) => row.get("id"), - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(), - }; - - let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1") - .bind(user_id) - .fetch_optional(&state.db) - .await; - - let current_root_cid = match repo_root_query { - Ok(Some(row)) => { - let cid_str: String = row.get("repo_root_cid"); - Cid::from_str(&cid_str).ok() - }, - _ => None, - }; - - if current_root_cid.is_none() { - error!("Repo root not found for user {}", did); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response(); - } - let current_root_cid = current_root_cid.unwrap(); - - let commit_bytes = match state.block_store.get(¤t_root_cid).await { - Ok(Some(b)) => b, - Ok(None) => { - error!("Commit block not found: {}", current_root_cid); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(); - }, - Err(e) => { - error!("Failed to load commit block: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to load commit block"}))).into_response(); - } - }; - - let commit = match Commit::from_cbor(&commit_bytes) { - Ok(c) => c, - Err(e) => { - error!("Failed to parse commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response(); - } - }; - - let mst_root = commit.data; - let store = Arc::new(state.block_store.clone()); - let mst = Mst::load(store.clone(), mst_root, None); - - let collection_nsid = match input.collection.parse::() { - Ok(n) => n, - Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(), - }; - - let rkey = input.rkey.clone(); - - let mut record_bytes = Vec::new(); - if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) { - error!("Error serializing record: {:?}", e); - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); - } - - let record_cid = match state.block_store.put(&record_bytes).await { - Ok(c) => c, - Err(e) => { - error!("Failed to save record block: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save record block"}))).into_response(); - } - }; - - let key = format!("{}/{}", collection_nsid, rkey); - if let Err(e) = mst.update(&key, record_cid).await { - error!("Failed to update MST: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response(); - } - - let new_mst_root = match mst.root().await { - Ok(c) => c, - Err(e) => { - error!("Failed to get new MST root: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST root"}))).into_response(); - } - }; - - let did_obj = match Did::new(&did) { - Ok(d) => d, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(), - }; - - let rev = Tid::now(LimitedU32::MIN); - - let new_commit = Commit::new_unsigned( - did_obj, - new_mst_root, - rev, - Some(current_root_cid) - ); - - let new_commit_bytes = match new_commit.to_cbor() { - Ok(b) => b, - Err(e) => { - error!("Failed to serialize new commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to serialize new commit"}))).into_response(); - } - }; - - let new_root_cid = match state.block_store.put(&new_commit_bytes).await { - Ok(c) => c, - Err(e) => { - error!("Failed to save new commit: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save new commit"}))).into_response(); - } - }; - - let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2") - .bind(new_root_cid.to_string()) - .bind(user_id) - .execute(&state.db) - .await; - - if let Err(e) = update_repo { - error!("Failed to update repo root in DB: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"}))).into_response(); - } - - let record_insert = sqlx::query( - "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4) - ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()" - ) - .bind(user_id) - .bind(&input.collection) - .bind(&rkey) - .bind(record_cid.to_string()) - .execute(&state.db) - .await; - - if let Err(e) = record_insert { - error!("Error inserting record index: {:?}", e); - } - - let output = PutRecordOutput { - uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey), - cid: record_cid.to_string(), - }; - (StatusCode::OK, Json(output)).into_response() -} - -#[derive(Deserialize)] -pub struct GetRecordInput { - pub repo: String, - pub collection: String, - pub rkey: String, - pub cid: Option, -} - -pub async fn get_record( - State(state): State, - Query(input): Query, -) -> Response { - let user_row = if input.repo.starts_with("did:") { - sqlx::query("SELECT id FROM users WHERE did = $1") - .bind(&input.repo) - .fetch_optional(&state.db) - .await - } else { - sqlx::query("SELECT id FROM users WHERE handle = $1") - .bind(&input.repo) - .fetch_optional(&state.db) - .await - }; - - let user_id: uuid::Uuid = match user_row { - Ok(Some(row)) => row.get("id"), - _ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(), - }; - - let record_row = sqlx::query("SELECT record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3") - .bind(user_id) - .bind(&input.collection) - .bind(&input.rkey) - .fetch_optional(&state.db) - .await; - - let record_cid_str: String = match record_row { - Ok(Some(row)) => row.get("record_cid"), - _ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Record not found"}))).into_response(), - }; - - if let Some(expected_cid) = &input.cid { - if &record_cid_str != expected_cid { - return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Record CID mismatch"}))).into_response(); - } - } - - let cid = match Cid::from_str(&record_cid_str) { - Ok(c) => c, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid CID in DB"}))).into_response(), - }; - - let block = match state.block_store.get(&cid).await { - Ok(Some(b)) => b, - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Record block not found"}))).into_response(), - }; - - let value: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block) { - Ok(v) => v, - Err(e) => { - error!("Failed to deserialize record: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - Json(json!({ - "uri": format!("at://{}/{}/{}", input.repo, input.collection, input.rkey), - "cid": record_cid_str, - "value": value - })).into_response() -} - -#[derive(Deserialize)] -pub struct DeleteRecordInput { - pub repo: String, - pub collection: String, - pub rkey: String, - #[serde(rename = "swapRecord")] - pub swap_record: Option, - #[serde(rename = "swapCommit")] - pub swap_commit: Option, -} - -pub async fn delete_record( - State(state): State, - headers: axum::http::HeaderMap, - Json(input): Json, -) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response(); - } - let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", ""); - - let session = sqlx::query( - "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" - ) - .bind(&token) - .fetch_optional(&state.db) - .await - .unwrap_or(None); - - let (did, key_bytes) = match session { - Some(row) => (row.get::("did"), row.get::, _>("key_bytes")), - None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(), - }; - - if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response(); - } - - if input.repo != did { - return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response(); - } - - let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") - .bind(&did) - .fetch_optional(&state.db) - .await; - - let user_id: uuid::Uuid = match user_query { - Ok(Some(row)) => row.get("id"), - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(), - }; - - let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1") - .bind(user_id) - .fetch_optional(&state.db) - .await; - - let current_root_cid = match repo_root_query { - Ok(Some(row)) => { - let cid_str: String = row.get("repo_root_cid"); - Cid::from_str(&cid_str).ok() - }, - _ => None, - }; - - if current_root_cid.is_none() { - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response(); - } - let current_root_cid = current_root_cid.unwrap(); - - let commit_bytes = match state.block_store.get(¤t_root_cid).await { - Ok(Some(b)) => b, - Ok(None) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(), - Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to load commit block: {:?}", e)}))).into_response(), - }; - - let commit = match Commit::from_cbor(&commit_bytes) { - Ok(c) => c, - Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to parse commit: {:?}", e)}))).into_response(), - }; - - let mst_root = commit.data; - let store = Arc::new(state.block_store.clone()); - let mst = Mst::load(store.clone(), mst_root, None); - - let collection_nsid = match input.collection.parse::() { - Ok(n) => n, - Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(), - }; - - let key = format!("{}/{}", collection_nsid, input.rkey); - - // TODO: Check swapRecord if provided? Skipping for brevity/robustness - - if let Err(e) = mst.delete(&key).await { - error!("Failed to delete from MST: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response(); - } - - let new_mst_root = match mst.root().await { - Ok(c) => c, - Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST root"}))).into_response(), - }; - - let did_obj = match Did::new(&did) { - Ok(d) => d, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(), - }; - - let rev = Tid::now(LimitedU32::MIN); - - let new_commit = Commit::new_unsigned( - did_obj, - new_mst_root, - rev, - Some(current_root_cid) - ); - - let new_commit_bytes = match new_commit.to_cbor() { - Ok(b) => b, - Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to serialize new commit"}))).into_response(), - }; - - let new_root_cid = match state.block_store.put(&new_commit_bytes).await { - Ok(c) => c, - Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save new commit"}))).into_response(), - }; - - let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2") - .bind(new_root_cid.to_string()) - .bind(user_id) - .execute(&state.db) - .await; - - if let Err(e) = update_repo { - error!("Failed to update repo root in DB: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"}))).into_response(); - } - - let record_delete = sqlx::query("DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3") - .bind(user_id) - .bind(&input.collection) - .bind(&input.rkey) - .execute(&state.db) - .await; - - if let Err(e) = record_delete { - error!("Error deleting record index: {:?}", e); - } - - (StatusCode::OK, Json(json!({}))).into_response() -} - -#[derive(Deserialize)] -pub struct ListRecordsInput { - pub repo: String, - pub collection: String, - pub limit: Option, - pub cursor: Option, - #[serde(rename = "rkeyStart")] - pub rkey_start: Option, - #[serde(rename = "rkeyEnd")] - pub rkey_end: Option, - pub reverse: Option, -} - -#[derive(Serialize)] -pub struct ListRecordsOutput { - pub cursor: Option, - pub records: Vec, -} - -pub async fn list_records( - State(state): State, - Query(input): Query, -) -> Response { - let user_row = if input.repo.starts_with("did:") { - sqlx::query("SELECT id FROM users WHERE did = $1") - .bind(&input.repo) - .fetch_optional(&state.db) - .await - } else { - sqlx::query("SELECT id FROM users WHERE handle = $1") - .bind(&input.repo) - .fetch_optional(&state.db) - .await - }; - - let user_id: uuid::Uuid = match user_row { - Ok(Some(row)) => row.get("id"), - _ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(), - }; - - let limit = input.limit.unwrap_or(50).clamp(1, 100); - let reverse = input.reverse.unwrap_or(false); - - // Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination - // TODO: Implement rkeyStart/End and correct cursor logic - - let query_str = format!( - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 {} ORDER BY rkey {} LIMIT {}", - if let Some(_c) = &input.cursor { - if reverse { "AND rkey < $3" } else { "AND rkey > $3" } - } else { - "" - }, - if reverse { "DESC" } else { "ASC" }, - limit - ); - - let mut query = sqlx::query(&query_str) - .bind(user_id) - .bind(&input.collection); - - if let Some(c) = &input.cursor { - query = query.bind(c); - } - - let rows = match query.fetch_all(&state.db).await { - Ok(r) => r, - Err(e) => { - error!("Error listing records: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - }; - - let mut records = Vec::new(); - let mut last_rkey = None; - - for row in rows { - let rkey: String = row.get("rkey"); - let cid_str: String = row.get("record_cid"); - last_rkey = Some(rkey.clone()); - - if let Ok(cid) = Cid::from_str(&cid_str) { - if let Ok(Some(block)) = state.block_store.get(&cid).await { - if let Ok(value) = serde_ipld_dagcbor::from_slice::(&block) { - records.push(json!({ - "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey), - "cid": cid_str, - "value": value - })); - } - } - } - } - - Json(ListRecordsOutput { - cursor: last_rkey, - records, - }).into_response() -} - -#[derive(Deserialize)] -pub struct DescribeRepoInput { - pub repo: String, -} - -pub async fn describe_repo( - State(state): State, - Query(input): Query, -) -> Response { - let user_row = if input.repo.starts_with("did:") { - sqlx::query("SELECT id, handle, did FROM users WHERE did = $1") - .bind(&input.repo) - .fetch_optional(&state.db) - .await - } else { - sqlx::query("SELECT id, handle, did FROM users WHERE handle = $1") - .bind(&input.repo) - .fetch_optional(&state.db) - .await - }; - - let (user_id, handle, did) = match user_row { - Ok(Some(row)) => (row.get::("id"), row.get::("handle"), row.get::("did")), - _ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(), - }; - - let collections_query = sqlx::query("SELECT DISTINCT collection FROM records WHERE repo_id = $1") - .bind(user_id) - .fetch_all(&state.db) - .await; - - let collections: Vec = match collections_query { - Ok(rows) => rows.iter().map(|r| r.get("collection")).collect(), - Err(_) => Vec::new(), - }; - - let did_doc = json!({ - "id": did, - "alsoKnownAs": [format!("at://{}", handle)] - }); - - Json(json!({ - "handle": handle, - "did": did, - "didDoc": did_doc, - "collections": collections, - "handleIsCorrect": true - })).into_response() -} - -pub async fn upload_blob( - State(state): State, - headers: axum::http::HeaderMap, - body: Bytes, -) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response(); - } - let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", ""); - - let session = sqlx::query( - "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" - ) - .bind(&token) - .fetch_optional(&state.db) - .await - .unwrap_or(None); - - let (did, key_bytes) = match session { - Some(row) => (row.get::("did"), row.get::, _>("key_bytes")), - None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(), - }; - - if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response(); - } - - let mime_type = headers.get("content-type") - .and_then(|h| h.to_str().ok()) - .unwrap_or("application/octet-stream") - .to_string(); - - let size = body.len() as i64; - let data = body.to_vec(); - - let mut hasher = Sha256::new(); - hasher.update(&data); - let hash = hasher.finalize(); - let multihash = Multihash::wrap(0x12, &hash).unwrap(); - let cid = Cid::new_v1(0x55, multihash); - let cid_str = cid.to_string(); - - let storage_key = format!("blobs/{}", cid_str); - - if let Err(e) = state.blob_store.put(&storage_key, &data).await { - error!("Failed to upload blob to storage: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store blob"}))).into_response(); - } - - let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") - .bind(&did) - .fetch_optional(&state.db) - .await; - - let user_id: uuid::Uuid = match user_query { - Ok(Some(row)) => row.get("id"), - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(), - }; - - let insert = sqlx::query( - "INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING" - ) - .bind(&cid_str) - .bind(&mime_type) - .bind(size) - .bind(user_id) - .bind(&storage_key) - .execute(&state.db) - .await; - - if let Err(e) = insert { - error!("Failed to insert blob record: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - } - - Json(json!({ - "blob": { - "ref": { - "$link": cid_str - }, - "mimeType": mime_type, - "size": size - } - })).into_response() -} diff --git a/src/api/repo/blob.rs b/src/api/repo/blob.rs new file mode 100644 index 0000000..f42604b --- /dev/null +++ b/src/api/repo/blob.rs @@ -0,0 +1,138 @@ +use crate::state::AppState; +use axum::body::Bytes; +use axum::{ + Json, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use cid::Cid; +use multihash::Multihash; +use serde_json::json; +use sha2::{Digest, Sha256}; +use sqlx::Row; +use tracing::error; + +pub async fn upload_blob( + State(state): State, + headers: axum::http::HeaderMap, + body: Bytes, +) -> Response { + let auth_header = headers.get("Authorization"); + if auth_header.is_none() { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + let token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); + + let session = sqlx::query( + "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" + ) + .bind(&token) + .fetch_optional(&state.db) + .await + .unwrap_or(None); + + let (did, key_bytes) = match session { + Some(row) => ( + row.get::("did"), + row.get::, _>("key_bytes"), + ), + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed"})), + ) + .into_response(); + } + }; + + if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), + ) + .into_response(); + } + + let mime_type = headers + .get("content-type") + .and_then(|h| h.to_str().ok()) + .unwrap_or("application/octet-stream") + .to_string(); + + let size = body.len() as i64; + let data = body.to_vec(); + + let mut hasher = Sha256::new(); + hasher.update(&data); + let hash = hasher.finalize(); + let multihash = Multihash::wrap(0x12, &hash).unwrap(); + let cid = Cid::new_v1(0x55, multihash); + let cid_str = cid.to_string(); + + let storage_key = format!("blobs/{}", cid_str); + + if let Err(e) = state.blob_store.put(&storage_key, &data).await { + error!("Failed to upload blob to storage: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to store blob"})), + ) + .into_response(); + } + + let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") + .bind(&did) + .fetch_optional(&state.db) + .await; + + let user_id: uuid::Uuid = match user_query { + Ok(Some(row)) => row.get("id"), + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let insert = sqlx::query( + "INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING" + ) + .bind(&cid_str) + .bind(&mime_type) + .bind(size) + .bind(user_id) + .bind(&storage_key) + .execute(&state.db) + .await; + + if let Err(e) = insert { + error!("Failed to insert blob record: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + Json(json!({ + "blob": { + "ref": { + "$link": cid_str + }, + "mimeType": mime_type, + "size": size + } + })) + .into_response() +} diff --git a/src/api/repo/meta.rs b/src/api/repo/meta.rs new file mode 100644 index 0000000..32b041a --- /dev/null +++ b/src/api/repo/meta.rs @@ -0,0 +1,72 @@ +use crate::state::AppState; +use axum::{ + Json, + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use serde::Deserialize; +use serde_json::json; +use sqlx::Row; + +#[derive(Deserialize)] +pub struct DescribeRepoInput { + pub repo: String, +} + +pub async fn describe_repo( + State(state): State, + Query(input): Query, +) -> Response { + let user_row = if input.repo.starts_with("did:") { + sqlx::query("SELECT id, handle, did FROM users WHERE did = $1") + .bind(&input.repo) + .fetch_optional(&state.db) + .await + } else { + sqlx::query("SELECT id, handle, did FROM users WHERE handle = $1") + .bind(&input.repo) + .fetch_optional(&state.db) + .await + }; + + let (user_id, handle, did) = match user_row { + Ok(Some(row)) => ( + row.get::("id"), + row.get::("handle"), + row.get::("did"), + ), + _ => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Repo not found"})), + ) + .into_response(); + } + }; + + let collections_query = + sqlx::query("SELECT DISTINCT collection FROM records WHERE repo_id = $1") + .bind(user_id) + .fetch_all(&state.db) + .await; + + let collections: Vec = match collections_query { + Ok(rows) => rows.iter().map(|r| r.get("collection")).collect(), + Err(_) => Vec::new(), + }; + + let did_doc = json!({ + "id": did, + "alsoKnownAs": [format!("at://{}", handle)] + }); + + Json(json!({ + "handle": handle, + "did": did, + "didDoc": did_doc, + "collections": collections, + "handleIsCorrect": true + })) + .into_response() +} diff --git a/src/api/repo/mod.rs b/src/api/repo/mod.rs new file mode 100644 index 0000000..a61f9bf --- /dev/null +++ b/src/api/repo/mod.rs @@ -0,0 +1,7 @@ +pub mod blob; +pub mod meta; +pub mod record; + +pub use blob::upload_blob; +pub use meta::describe_repo; +pub use record::{create_record, delete_record, get_record, list_records, put_record}; diff --git a/src/api/repo/record/delete.rs b/src/api/repo/record/delete.rs new file mode 100644 index 0000000..ffe555e --- /dev/null +++ b/src/api/repo/record/delete.rs @@ -0,0 +1,236 @@ +use crate::state::AppState; +use axum::{ + Json, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use cid::Cid; +use jacquard::types::{ + did::Did, + integer::LimitedU32, + string::{Nsid, Tid}, +}; +use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; +use serde::Deserialize; +use serde_json::json; +use sqlx::Row; +use std::str::FromStr; +use std::sync::Arc; +use tracing::error; + +#[derive(Deserialize)] +pub struct DeleteRecordInput { + pub repo: String, + pub collection: String, + pub rkey: String, + #[serde(rename = "swapRecord")] + pub swap_record: Option, + #[serde(rename = "swapCommit")] + pub swap_commit: Option, +} + +pub async fn delete_record( + State(state): State, + headers: axum::http::HeaderMap, + Json(input): Json, +) -> Response { + let auth_header = headers.get("Authorization"); + if auth_header.is_none() { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + let token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); + + let session = sqlx::query( + "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" + ) + .bind(&token) + .fetch_optional(&state.db) + .await + .unwrap_or(None); + + let (did, key_bytes) = match session { + Some(row) => ( + row.get::("did"), + row.get::, _>("key_bytes"), + ), + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed"})), + ) + .into_response(); + } + }; + + if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), + ) + .into_response(); + } + + if input.repo != did { + return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response(); + } + + let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") + .bind(&did) + .fetch_optional(&state.db) + .await; + + let user_id: uuid::Uuid = match user_query { + Ok(Some(row)) => row.get("id"), + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "User not found"})), + ) + .into_response(); + } + }; + + let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1") + .bind(user_id) + .fetch_optional(&state.db) + .await; + + let current_root_cid = match repo_root_query { + Ok(Some(row)) => { + let cid_str: String = row.get("repo_root_cid"); + Cid::from_str(&cid_str).ok() + } + _ => None, + }; + + if current_root_cid.is_none() { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Repo root not found"})), + ) + .into_response(); + } + let current_root_cid = current_root_cid.unwrap(); + + let commit_bytes = match state.block_store.get(¤t_root_cid).await { + Ok(Some(b)) => b, + Ok(None) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(), + Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to load commit block: {:?}", e)}))).into_response(), + }; + + let commit = match Commit::from_cbor(&commit_bytes) { + Ok(c) => c, + Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to parse commit: {:?}", e)}))).into_response(), + }; + + let mst_root = commit.data; + let store = Arc::new(state.block_store.clone()); + let mst = Mst::load(store.clone(), mst_root, None); + + let collection_nsid = match input.collection.parse::() { + Ok(n) => n, + Err(_) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidCollection"})), + ) + .into_response(); + } + }; + + let key = format!("{}/{}", collection_nsid, input.rkey); + + // TODO: Check swapRecord if provided? Skipping for brevity/robustness + + if let Err(e) = mst.delete(&key).await { + error!("Failed to delete from MST: {:?}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response(); + } + + let new_mst_root = match mst.root().await { + Ok(c) => c, + Err(_e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})), + ) + .into_response(); + } + }; + + let did_obj = match Did::new(&did) { + Ok(d) => d, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Invalid DID"})), + ) + .into_response(); + } + }; + + let rev = Tid::now(LimitedU32::MIN); + + let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid)); + + let new_commit_bytes = + match new_commit.to_cbor() { + Ok(b) => b, + Err(_e) => return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + json!({"error": "InternalError", "message": "Failed to serialize new commit"}), + ), + ) + .into_response(), + }; + + let new_root_cid = match state.block_store.put(&new_commit_bytes).await { + Ok(c) => c, + Err(_e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to save new commit"})), + ) + .into_response(); + } + }; + + let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2") + .bind(new_root_cid.to_string()) + .bind(user_id) + .execute(&state.db) + .await; + + if let Err(e) = update_repo { + error!("Failed to update repo root in DB: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"})), + ) + .into_response(); + } + + let record_delete = + sqlx::query("DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3") + .bind(user_id) + .bind(&input.collection) + .bind(&input.rkey) + .execute(&state.db) + .await; + + if let Err(e) = record_delete { + error!("Error deleting record index: {:?}", e); + } + + (StatusCode::OK, Json(json!({}))).into_response() +} diff --git a/src/api/repo/record/mod.rs b/src/api/repo/record/mod.rs new file mode 100644 index 0000000..df71716 --- /dev/null +++ b/src/api/repo/record/mod.rs @@ -0,0 +1,10 @@ +pub mod delete; +pub mod read; +pub mod write; + +pub use delete::{DeleteRecordInput, delete_record}; +pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records}; +pub use write::{ + CreateRecordInput, CreateRecordOutput, PutRecordInput, PutRecordOutput, create_record, + put_record, +}; diff --git a/src/api/repo/record/read.rs b/src/api/repo/record/read.rs new file mode 100644 index 0000000..7f3acb5 --- /dev/null +++ b/src/api/repo/record/read.rs @@ -0,0 +1,236 @@ +use crate::state::AppState; +use axum::{ + Json, + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use cid::Cid; +use jacquard_repo::storage::BlockStore; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sqlx::Row; +use std::str::FromStr; +use tracing::error; + +#[derive(Deserialize)] +pub struct GetRecordInput { + pub repo: String, + pub collection: String, + pub rkey: String, + pub cid: Option, +} + +pub async fn get_record( + State(state): State, + Query(input): Query, +) -> Response { + let user_row = if input.repo.starts_with("did:") { + sqlx::query("SELECT id FROM users WHERE did = $1") + .bind(&input.repo) + .fetch_optional(&state.db) + .await + } else { + sqlx::query("SELECT id FROM users WHERE handle = $1") + .bind(&input.repo) + .fetch_optional(&state.db) + .await + }; + + let user_id: uuid::Uuid = match user_row { + Ok(Some(row)) => row.get("id"), + _ => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Repo not found"})), + ) + .into_response(); + } + }; + + let record_row = sqlx::query( + "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3", + ) + .bind(user_id) + .bind(&input.collection) + .bind(&input.rkey) + .fetch_optional(&state.db) + .await; + + let record_cid_str: String = match record_row { + Ok(Some(row)) => row.get("record_cid"), + _ => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Record not found"})), + ) + .into_response(); + } + }; + + if let Some(expected_cid) = &input.cid { + if &record_cid_str != expected_cid { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Record CID mismatch"})), + ) + .into_response(); + } + } + + let cid = match Cid::from_str(&record_cid_str) { + Ok(c) => c, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Invalid CID in DB"})), + ) + .into_response(); + } + }; + + let block = match state.block_store.get(&cid).await { + Ok(Some(b)) => b, + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Record block not found"})), + ) + .into_response(); + } + }; + + let value: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block) { + Ok(v) => v, + Err(e) => { + error!("Failed to deserialize record: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + Json(json!({ + "uri": format!("at://{}/{}/{}", input.repo, input.collection, input.rkey), + "cid": record_cid_str, + "value": value + })) + .into_response() +} + +#[derive(Deserialize)] +pub struct ListRecordsInput { + pub repo: String, + pub collection: String, + pub limit: Option, + pub cursor: Option, + #[serde(rename = "rkeyStart")] + pub rkey_start: Option, + #[serde(rename = "rkeyEnd")] + pub rkey_end: Option, + pub reverse: Option, +} + +#[derive(Serialize)] +pub struct ListRecordsOutput { + pub cursor: Option, + pub records: Vec, +} + +pub async fn list_records( + State(state): State, + Query(input): Query, +) -> Response { + let user_row = if input.repo.starts_with("did:") { + sqlx::query("SELECT id FROM users WHERE did = $1") + .bind(&input.repo) + .fetch_optional(&state.db) + .await + } else { + sqlx::query("SELECT id FROM users WHERE handle = $1") + .bind(&input.repo) + .fetch_optional(&state.db) + .await + }; + + let user_id: uuid::Uuid = match user_row { + Ok(Some(row)) => row.get("id"), + _ => { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "NotFound", "message": "Repo not found"})), + ) + .into_response(); + } + }; + + let limit = input.limit.unwrap_or(50).clamp(1, 100); + let reverse = input.reverse.unwrap_or(false); + + // Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination + // TODO: Implement rkeyStart/End and correct cursor logic + + let query_str = format!( + "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 {} ORDER BY rkey {} LIMIT {}", + if let Some(_c) = &input.cursor { + if reverse { + "AND rkey < $3" + } else { + "AND rkey > $3" + } + } else { + "" + }, + if reverse { "DESC" } else { "ASC" }, + limit + ); + + let mut query = sqlx::query(&query_str) + .bind(user_id) + .bind(&input.collection); + + if let Some(c) = &input.cursor { + query = query.bind(c); + } + + let rows = match query.fetch_all(&state.db).await { + Ok(r) => r, + Err(e) => { + error!("Error listing records: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let mut records = Vec::new(); + let mut last_rkey = None; + + for row in rows { + let rkey: String = row.get("rkey"); + let cid_str: String = row.get("record_cid"); + last_rkey = Some(rkey.clone()); + + if let Ok(cid) = Cid::from_str(&cid_str) { + if let Ok(Some(block)) = state.block_store.get(&cid).await { + if let Ok(value) = serde_ipld_dagcbor::from_slice::(&block) { + records.push(json!({ + "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey), + "cid": cid_str, + "value": value + })); + } + } + } + } + + Json(ListRecordsOutput { + cursor: last_rkey, + records, + }) + .into_response() +} diff --git a/src/api/repo/record/write.rs b/src/api/repo/record/write.rs new file mode 100644 index 0000000..7c5117e --- /dev/null +++ b/src/api/repo/record/write.rs @@ -0,0 +1,591 @@ +use crate::state::AppState; +use axum::{ + Json, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use chrono::Utc; +use cid::Cid; +use jacquard::types::{ + did::Did, + integer::LimitedU32, + string::{Nsid, Tid}, +}; +use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sqlx::Row; +use std::str::FromStr; +use std::sync::Arc; +use tracing::error; + +#[derive(Deserialize)] +#[allow(dead_code)] +pub struct CreateRecordInput { + pub repo: String, + pub collection: String, + pub rkey: Option, + pub validate: Option, + pub record: serde_json::Value, + #[serde(rename = "swapCommit")] + pub swap_commit: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateRecordOutput { + pub uri: String, + pub cid: String, +} + +pub async fn create_record( + State(state): State, + headers: axum::http::HeaderMap, + Json(input): Json, +) -> Response { + let auth_header = headers.get("Authorization"); + if auth_header.is_none() { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + let token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); + + let session = sqlx::query( + "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" + ) + .bind(&token) + .fetch_optional(&state.db) + .await + .unwrap_or(None); + + let (did, key_bytes) = match session { + Some(row) => ( + row.get::("did"), + row.get::, _>("key_bytes"), + ), + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed"})), + ) + .into_response(); + } + }; + + if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), + ) + .into_response(); + } + + if input.repo != did { + return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response(); + } + + let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") + .bind(&did) + .fetch_optional(&state.db) + .await; + + let user_id: uuid::Uuid = match user_query { + Ok(Some(row)) => row.get("id"), + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "User not found"})), + ) + .into_response(); + } + }; + + let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1") + .bind(user_id) + .fetch_optional(&state.db) + .await; + + let current_root_cid = match repo_root_query { + Ok(Some(row)) => { + let cid_str: String = row.get("repo_root_cid"); + Cid::from_str(&cid_str).ok() + } + _ => None, + }; + + if current_root_cid.is_none() { + error!("Repo root not found for user {}", did); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Repo root not found"})), + ) + .into_response(); + } + let current_root_cid = current_root_cid.unwrap(); + + let commit_bytes = match state.block_store.get(¤t_root_cid).await { + Ok(Some(b)) => b, + Ok(None) => { + error!("Commit block not found: {}", current_root_cid); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + Err(e) => { + error!("Failed to load commit block: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let commit = match Commit::from_cbor(&commit_bytes) { + Ok(c) => c, + Err(e) => { + error!("Failed to parse commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let mst_root = commit.data; + let store = Arc::new(state.block_store.clone()); + let mst = Mst::load(store.clone(), mst_root, None); + + let collection_nsid = match input.collection.parse::() { + Ok(n) => n, + Err(_) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidCollection"})), + ) + .into_response(); + } + }; + + let rkey = input + .rkey + .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string()); + + let mut record_bytes = Vec::new(); + if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) { + error!("Error serializing record: {:?}", e); + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})), + ) + .into_response(); + } + + let record_cid = match state.block_store.put(&record_bytes).await { + Ok(c) => c, + Err(e) => { + error!("Failed to save record block: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let key = format!("{}/{}", collection_nsid, rkey); + if let Err(e) = mst.update(&key, record_cid).await { + error!("Failed to update MST: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + let new_mst_root = match mst.root().await { + Ok(c) => c, + Err(e) => { + error!("Failed to get new MST root: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let did_obj = match Did::new(&did) { + Ok(d) => d, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Invalid DID"})), + ) + .into_response(); + } + }; + + let rev = Tid::now(LimitedU32::MIN); + + let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid)); + + let new_commit_bytes = match new_commit.to_cbor() { + Ok(b) => b, + Err(e) => { + error!("Failed to serialize new commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let new_root_cid = match state.block_store.put(&new_commit_bytes).await { + Ok(c) => c, + Err(e) => { + error!("Failed to save new commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + }; + + let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2") + .bind(new_root_cid.to_string()) + .bind(user_id) + .execute(&state.db) + .await; + + if let Err(e) = update_repo { + error!("Failed to update repo root in DB: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } + + let record_insert = sqlx::query( + "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4) + ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()", + ) + .bind(user_id) + .bind(&input.collection) + .bind(&rkey) + .bind(record_cid.to_string()) + .execute(&state.db) + .await; + + if let Err(e) = record_insert { + error!("Error inserting record index: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to index record"})), + ) + .into_response(); + } + + let output = CreateRecordOutput { + uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey), + cid: record_cid.to_string(), + }; + (StatusCode::OK, Json(output)).into_response() +} + +#[derive(Deserialize)] +#[allow(dead_code)] +pub struct PutRecordInput { + pub repo: String, + pub collection: String, + pub rkey: String, + pub validate: Option, + pub record: serde_json::Value, + #[serde(rename = "swapCommit")] + pub swap_commit: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PutRecordOutput { + pub uri: String, + pub cid: String, +} + +pub async fn put_record( + State(state): State, + headers: axum::http::HeaderMap, + Json(input): Json, +) -> Response { + let auth_header = headers.get("Authorization"); + if auth_header.is_none() { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); + } + let token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); + + let session = sqlx::query( + "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1" + ) + .bind(&token) + .fetch_optional(&state.db) + .await + .unwrap_or(None); + + let (did, key_bytes) = match session { + Some(row) => ( + row.get::("did"), + row.get::, _>("key_bytes"), + ), + None => { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed"})), + ) + .into_response(); + } + }; + + if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), + ) + .into_response(); + } + + if input.repo != did { + return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response(); + } + + let user_query = sqlx::query("SELECT id FROM users WHERE did = $1") + .bind(&did) + .fetch_optional(&state.db) + .await; + + let user_id: uuid::Uuid = match user_query { + Ok(Some(row)) => row.get("id"), + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "User not found"})), + ) + .into_response(); + } + }; + + let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1") + .bind(user_id) + .fetch_optional(&state.db) + .await; + + let current_root_cid = match repo_root_query { + Ok(Some(row)) => { + let cid_str: String = row.get("repo_root_cid"); + Cid::from_str(&cid_str).ok() + } + _ => None, + }; + + if current_root_cid.is_none() { + error!("Repo root not found for user {}", did); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Repo root not found"})), + ) + .into_response(); + } + let current_root_cid = current_root_cid.unwrap(); + + let commit_bytes = match state.block_store.get(¤t_root_cid).await { + Ok(Some(b)) => b, + Ok(None) => { + error!("Commit block not found: {}", current_root_cid); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Commit block not found"})), + ) + .into_response(); + } + Err(e) => { + error!("Failed to load commit block: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to load commit block"})), + ) + .into_response(); + } + }; + + let commit = match Commit::from_cbor(&commit_bytes) { + Ok(c) => c, + Err(e) => { + error!("Failed to parse commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), + ) + .into_response(); + } + }; + + let mst_root = commit.data; + let store = Arc::new(state.block_store.clone()); + let mst = Mst::load(store.clone(), mst_root, None); + + let collection_nsid = match input.collection.parse::() { + Ok(n) => n, + Err(_) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidCollection"})), + ) + .into_response(); + } + }; + + let rkey = input.rkey.clone(); + + let mut record_bytes = Vec::new(); + if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) { + error!("Error serializing record: {:?}", e); + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})), + ) + .into_response(); + } + + let record_cid = match state.block_store.put(&record_bytes).await { + Ok(c) => c, + Err(e) => { + error!("Failed to save record block: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to save record block"})), + ) + .into_response(); + } + }; + + let key = format!("{}/{}", collection_nsid, rkey); + if let Err(e) = mst.update(&key, record_cid).await { + error!("Failed to update MST: {:?}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response(); + } + + let new_mst_root = match mst.root().await { + Ok(c) => c, + Err(e) => { + error!("Failed to get new MST root: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})), + ) + .into_response(); + } + }; + + let did_obj = match Did::new(&did) { + Ok(d) => d, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Invalid DID"})), + ) + .into_response(); + } + }; + + let rev = Tid::now(LimitedU32::MIN); + + let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid)); + + let new_commit_bytes = match new_commit.to_cbor() { + Ok(b) => b, + Err(e) => { + error!("Failed to serialize new commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + json!({"error": "InternalError", "message": "Failed to serialize new commit"}), + ), + ) + .into_response(); + } + }; + + let new_root_cid = match state.block_store.put(&new_commit_bytes).await { + Ok(c) => c, + Err(e) => { + error!("Failed to save new commit: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to save new commit"})), + ) + .into_response(); + } + }; + + let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2") + .bind(new_root_cid.to_string()) + .bind(user_id) + .execute(&state.db) + .await; + + if let Err(e) = update_repo { + error!("Failed to update repo root in DB: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"})), + ) + .into_response(); + } + + let record_insert = sqlx::query( + "INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4) + ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()", + ) + .bind(user_id) + .bind(&input.collection) + .bind(&rkey) + .bind(record_cid.to_string()) + .execute(&state.db) + .await; + + if let Err(e) = record_insert { + error!("Error inserting record index: {:?}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Failed to index record"})), + ) + .into_response(); + } + + let output = PutRecordOutput { + uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey), + cid: record_cid.to_string(), + }; + (StatusCode::OK, Json(output)).into_response() +} diff --git a/src/api/server/meta.rs b/src/api/server/meta.rs new file mode 100644 index 0000000..8c3c190 --- /dev/null +++ b/src/api/server/meta.rs @@ -0,0 +1,25 @@ +use crate::state::AppState; +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use serde_json::json; + +use tracing::error; + +pub async fn describe_server() -> impl IntoResponse { + let domains_str = + std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string()); + let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect(); + + Json(json!({ + "availableUserDomains": domains + })) +} + +pub async fn health(State(state): State) -> impl IntoResponse { + match sqlx::query("SELECT 1").execute(&state.db).await { + Ok(_) => (StatusCode::OK, "OK"), + Err(e) => { + error!("Health check failed: {:?}", e); + (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable") + } + } +} diff --git a/src/api/server/mod.rs b/src/api/server/mod.rs new file mode 100644 index 0000000..24d909b --- /dev/null +++ b/src/api/server/mod.rs @@ -0,0 +1,5 @@ +pub mod meta; +pub mod session; + +pub use meta::{describe_server, health}; +pub use session::{create_session, delete_session, get_session, refresh_session}; diff --git a/src/api/server.rs b/src/api/server/session.rs similarity index 50% rename from src/api/server.rs rename to src/api/server/session.rs index a7234cd..5f9488f 100644 --- a/src/api/server.rs +++ b/src/api/server/session.rs @@ -1,34 +1,15 @@ +use crate::state::AppState; use axum::{ - extract::State, Json, - response::{IntoResponse, Response}, + extract::State, http::StatusCode, + response::{IntoResponse, Response}, }; +use bcrypt::verify; use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::state::AppState; use sqlx::Row; -use bcrypt::verify; -use tracing::{info, error, warn}; - -pub async fn describe_server() -> impl IntoResponse { - let domains_str = std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string()); - let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect(); - - Json(json!({ - "availableUserDomains": domains - })) -} - -pub async fn health(State(state): State) -> impl IntoResponse { - match sqlx::query("SELECT 1").execute(&state.db).await { - Ok(_) => (StatusCode::OK, "OK"), - Err(e) => { - error!("Health check failed: {:?}", e); - (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable") - } - } -} +use tracing::{error, info, warn}; #[derive(Deserialize)] pub struct CreateSessionInput { @@ -69,7 +50,11 @@ pub async fn create_session( Ok(t) => t, Err(e) => { error!("Failed to create access token: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } }; @@ -77,45 +62,70 @@ pub async fn create_session( Ok(t) => t, Err(e) => { error!("Failed to create refresh token: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } }; - let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)") - .bind(&access_jwt) - .bind(&refresh_jwt) - .bind(&did) - .execute(&state.db) - .await; + let session_insert = sqlx::query( + "INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)", + ) + .bind(&access_jwt) + .bind(&refresh_jwt) + .bind(&did) + .execute(&state.db) + .await; match session_insert { Ok(_) => { - return (StatusCode::OK, Json(CreateSessionOutput { - access_jwt, - refresh_jwt, - handle, - did, - })).into_response(); - }, + return ( + StatusCode::OK, + Json(CreateSessionOutput { + access_jwt, + refresh_jwt, + handle, + did, + }), + ) + .into_response(); + } Err(e) => { error!("Failed to insert session: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } } } else { - warn!("Password verification failed for identifier: {}", input.identifier); + warn!( + "Password verification failed for identifier: {}", + input.identifier + ); } - }, + } Ok(None) => { warn!("User not found for identifier: {}", input.identifier); - }, + } Err(e) => { error!("Database error fetching user: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } } - (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"}))).into_response() + ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"})), + ) + .into_response() } pub async fn get_session( @@ -124,10 +134,18 @@ pub async fn get_session( ) -> Response { let auth_header = headers.get("Authorization"); if auth_header.is_none() { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); } - let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", ""); + let token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); let result = sqlx::query( r#" @@ -136,7 +154,7 @@ pub async fn get_session( JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1 - "# + "#, ) .bind(&token) .fetch_optional(&state.db) @@ -150,22 +168,34 @@ pub async fn get_session( let key_bytes: Vec = row.get("key_bytes"); if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response(); + return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response(); } - return (StatusCode::OK, Json(json!({ - "handle": handle, - "did": did, - "email": email, - "didDoc": {} - }))).into_response(); - }, + return ( + StatusCode::OK, + Json(json!({ + "handle": handle, + "did": did, + "email": email, + "didDoc": {} + })), + ) + .into_response(); + } Ok(None) => { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(); - }, + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed"})), + ) + .into_response(); + } Err(e) => { error!("Database error in get_session: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } } } @@ -176,10 +206,18 @@ pub async fn delete_session( ) -> Response { let auth_header = headers.get("Authorization"); if auth_header.is_none() { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); } - let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", ""); + let token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); let result = sqlx::query("DELETE FROM sessions WHERE access_jwt = $1") .bind(token) @@ -191,13 +229,17 @@ pub async fn delete_session( if res.rows_affected() > 0 { return (StatusCode::OK, Json(json!({}))).into_response(); } - }, + } Err(e) => { error!("Database error in delete_session: {:?}", e); } } - (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response() + ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed"})), + ) + .into_response() } pub async fn refresh_session( @@ -206,10 +248,18 @@ pub async fn refresh_session( ) -> Response { let auth_header = headers.get("Authorization"); if auth_header.is_none() { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response(); + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationRequired"})), + ) + .into_response(); } - let refresh_token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", ""); + let refresh_token = auth_header + .unwrap() + .to_str() + .unwrap_or("") + .replace("Bearer ", ""); let session = sqlx::query( "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.refresh_jwt = $1" @@ -231,27 +281,37 @@ pub async fn refresh_session( Ok(t) => t, Err(e) => { error!("Failed to create access token: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } }; let new_refresh_jwt = match crate::auth::create_refresh_token(&did, &key_bytes) { Ok(t) => t, Err(e) => { error!("Failed to create refresh token: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } }; - let update = sqlx::query("UPDATE sessions SET access_jwt = $1, refresh_jwt = $2 WHERE refresh_jwt = $3") - .bind(&new_access_jwt) - .bind(&new_refresh_jwt) - .bind(&refresh_token) - .execute(&state.db) - .await; + let update = sqlx::query( + "UPDATE sessions SET access_jwt = $1, refresh_jwt = $2 WHERE refresh_jwt = $3", + ) + .bind(&new_access_jwt) + .bind(&new_refresh_jwt) + .bind(&refresh_token) + .execute(&state.db) + .await; match update { Ok(_) => { - let user = sqlx::query("SELECT handle FROM users WHERE did = $1") + let user = sqlx::query("SELECT handle FROM users WHERE did = $1") .bind(&did) .fetch_optional(&state.db) .await; @@ -259,36 +319,59 @@ pub async fn refresh_session( match user { Ok(Some(u)) => { let handle: String = u.get("handle"); - return (StatusCode::OK, Json(json!({ - "accessJwt": new_access_jwt, - "refreshJwt": new_refresh_jwt, - "handle": handle, - "did": did - }))).into_response(); - }, + return ( + StatusCode::OK, + Json(json!({ + "accessJwt": new_access_jwt, + "refreshJwt": new_refresh_jwt, + "handle": handle, + "did": did + })), + ) + .into_response(); + } Ok(None) => { error!("User not found for existing session: {}", did); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); - }, + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); + } Err(e) => { error!("Database error fetching user: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } } - }, + } Err(e) => { error!("Database error updating session: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } } - }, + } Ok(None) => { - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"}))).into_response(); - }, + return ( + StatusCode::UNAUTHORIZED, + Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"})), + ) + .into_response(); + } Err(e) => { error!("Database error fetching session: {:?}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response(); } } } - diff --git a/src/auth.rs b/src/auth.rs deleted file mode 100644 index be9f8e4..0000000 --- a/src/auth.rs +++ /dev/null @@ -1,157 +0,0 @@ -use serde::{Deserialize, Serialize}; -use chrono::{Utc, Duration}; -use k256::ecdsa::{SigningKey, VerifyingKey, signature::Signer, signature::Verifier, Signature}; -use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; -use anyhow::{Context, Result, anyhow}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct Claims { - pub iss: String, - pub sub: String, - pub aud: String, - pub exp: usize, - pub iat: usize, - #[serde(skip_serializing_if = "Option::is_none")] - pub scope: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub lxm: Option, - pub jti: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct Header { - alg: String, - typ: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct UnsafeClaims { - iss: String, - sub: Option, -} - -// fancy boy TokenData equivalent for compatibility/structure -pub struct TokenData { - pub claims: T, -} - -pub fn get_did_from_token(token: &str) -> Result { - let parts: Vec<&str> = token.split('.').collect(); - if parts.len() != 3 { - return Err("Invalid token format".to_string()); - } - - let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]) - .map_err(|e| format!("Base64 decode failed: {}", e))?; - - let claims: UnsafeClaims = serde_json::from_slice(&payload_bytes) - .map_err(|e| format!("JSON decode failed: {}", e))?; - - Ok(claims.sub.unwrap_or(claims.iss)) -} - -pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result { - create_signed_token(did, "access", key_bytes, Duration::minutes(15)) -} - -pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result { - create_signed_token(did, "refresh", key_bytes, Duration::days(7)) -} - -pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result { - let signing_key = SigningKey::from_slice(key_bytes)?; - - let expiration = Utc::now() - .checked_add_signed(Duration::seconds(60)) - .expect("valid timestamp") - .timestamp(); - - let claims = Claims { - iss: did.to_owned(), - sub: did.to_owned(), - aud: aud.to_owned(), - exp: expiration as usize, - iat: Utc::now().timestamp() as usize, - scope: None, - lxm: Some(lxm.to_string()), - jti: uuid::Uuid::new_v4().to_string(), - }; - - sign_claims(claims, &signing_key) -} - -fn create_signed_token(did: &str, scope: &str, key_bytes: &[u8], duration: Duration) -> Result { - let signing_key = SigningKey::from_slice(key_bytes)?; - - let expiration = Utc::now() - .checked_add_signed(duration) - .expect("valid timestamp") - .timestamp(); - - let claims = Claims { - iss: did.to_owned(), - sub: did.to_owned(), - aud: format!("did:web:{}", std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())), - exp: expiration as usize, - iat: Utc::now().timestamp() as usize, - scope: Some(scope.to_string()), - lxm: None, - jti: uuid::Uuid::new_v4().to_string(), - }; - - sign_claims(claims, &signing_key) -} - -fn sign_claims(claims: Claims, key: &SigningKey) -> Result { - let header = Header { - alg: "ES256K".to_string(), - typ: "JWT".to_string(), - }; - - let header_json = serde_json::to_string(&header)?; - let claims_json = serde_json::to_string(&claims)?; - - let header_b64 = URL_SAFE_NO_PAD.encode(header_json); - let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); - - let message = format!("{}.{}", header_b64, claims_b64); - let signature: Signature = key.sign(message.as_bytes()); - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); - - Ok(format!("{}.{}", message, signature_b64)) -} - -pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result, anyhow::Error> { - let parts: Vec<&str> = token.split('.').collect(); - if parts.len() != 3 { - return Err(anyhow!("Invalid token format")); - } - - let header_b64 = parts[0]; - let claims_b64 = parts[1]; - let signature_b64 = parts[2]; - - let signature_bytes = URL_SAFE_NO_PAD.decode(signature_b64) - .context("Base64 decode of signature failed")?; - let signature = Signature::from_slice(&signature_bytes) - .map_err(|e| anyhow!("Invalid signature format: {}", e))?; - - let signing_key = SigningKey::from_slice(key_bytes)?; - let verifying_key = VerifyingKey::from(&signing_key); - - let message = format!("{}.{}", header_b64, claims_b64); - verifying_key.verify(message.as_bytes(), &signature) - .map_err(|e| anyhow!("Signature verification failed: {}", e))?; - - let claims_bytes = URL_SAFE_NO_PAD.decode(claims_b64) - .context("Base64 decode of claims failed")?; - let claims: Claims = serde_json::from_slice(&claims_bytes) - .context("JSON decode of claims failed")?; - - let now = Utc::now().timestamp() as usize; - if claims.exp < now { - return Err(anyhow!("Token expired")); - } - - Ok(TokenData { claims }) -} diff --git a/src/auth/mod.rs b/src/auth/mod.rs new file mode 100644 index 0000000..790d81b --- /dev/null +++ b/src/auth/mod.rs @@ -0,0 +1,38 @@ +use serde::{Deserialize, Serialize}; + +pub mod token; +pub mod verify; + +pub use token::{create_access_token, create_refresh_token, create_service_token}; +pub use verify::{get_did_from_token, verify_token}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Claims { + pub iss: String, + pub sub: String, + pub aud: String, + pub exp: usize, + pub iat: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub lxm: Option, + pub jti: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Header { + pub alg: String, + pub typ: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UnsafeClaims { + pub iss: String, + pub sub: Option, +} + +// fancy boy TokenData equivalent for compatibility/structure +pub struct TokenData { + pub claims: T, +} diff --git a/src/auth/token.rs b/src/auth/token.rs new file mode 100644 index 0000000..9568918 --- /dev/null +++ b/src/auth/token.rs @@ -0,0 +1,86 @@ +use super::{Claims, Header}; +use anyhow::Result; +use base64::Engine as _; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use chrono::{Duration, Utc}; +use k256::ecdsa::{Signature, SigningKey, signature::Signer}; +use uuid; + +pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result { + create_signed_token(did, "access", key_bytes, Duration::minutes(15)) +} + +pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result { + create_signed_token(did, "refresh", key_bytes, Duration::days(7)) +} + +pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result { + let signing_key = SigningKey::from_slice(key_bytes)?; + + let expiration = Utc::now() + .checked_add_signed(Duration::seconds(60)) + .expect("valid timestamp") + .timestamp(); + + let claims = Claims { + iss: did.to_owned(), + sub: did.to_owned(), + aud: aud.to_owned(), + exp: expiration as usize, + iat: Utc::now().timestamp() as usize, + scope: None, + lxm: Some(lxm.to_string()), + jti: uuid::Uuid::new_v4().to_string(), + }; + + sign_claims(claims, &signing_key) +} + +fn create_signed_token( + did: &str, + scope: &str, + key_bytes: &[u8], + duration: Duration, +) -> Result { + let signing_key = SigningKey::from_slice(key_bytes)?; + + let expiration = Utc::now() + .checked_add_signed(duration) + .expect("valid timestamp") + .timestamp(); + + let claims = Claims { + iss: did.to_owned(), + sub: did.to_owned(), + aud: format!( + "did:web:{}", + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) + ), + exp: expiration as usize, + iat: Utc::now().timestamp() as usize, + scope: Some(scope.to_string()), + lxm: None, + jti: uuid::Uuid::new_v4().to_string(), + }; + + sign_claims(claims, &signing_key) +} + +fn sign_claims(claims: Claims, key: &SigningKey) -> Result { + let header = Header { + alg: "ES256K".to_string(), + typ: "JWT".to_string(), + }; + + let header_json = serde_json::to_string(&header)?; + let claims_json = serde_json::to_string(&claims)?; + + let header_b64 = URL_SAFE_NO_PAD.encode(header_json); + let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); + + let message = format!("{}.{}", header_b64, claims_b64); + let signature: Signature = key.sign(message.as_bytes()); + let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); + + Ok(format!("{}.{}", message, signature_b64)) +} diff --git a/src/auth/verify.rs b/src/auth/verify.rs new file mode 100644 index 0000000..5956984 --- /dev/null +++ b/src/auth/verify.rs @@ -0,0 +1,60 @@ +use super::{Claims, TokenData, UnsafeClaims}; +use anyhow::{Context, Result, anyhow}; +use base64::Engine as _; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use chrono::Utc; +use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; + +pub fn get_did_from_token(token: &str) -> Result { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return Err("Invalid token format".to_string()); + } + + let payload_bytes = URL_SAFE_NO_PAD + .decode(parts[1]) + .map_err(|e| format!("Base64 decode failed: {}", e))?; + + let claims: UnsafeClaims = + serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; + + Ok(claims.sub.unwrap_or(claims.iss)) +} + +pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result> { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return Err(anyhow!("Invalid token format")); + } + + let header_b64 = parts[0]; + let claims_b64 = parts[1]; + let signature_b64 = parts[2]; + + let signature_bytes = URL_SAFE_NO_PAD + .decode(signature_b64) + .context("Base64 decode of signature failed")?; + let signature = Signature::from_slice(&signature_bytes) + .map_err(|e| anyhow!("Invalid signature format: {}", e))?; + + let signing_key = SigningKey::from_slice(key_bytes)?; + let verifying_key = VerifyingKey::from(&signing_key); + + let message = format!("{}.{}", header_b64, claims_b64); + verifying_key + .verify(message.as_bytes(), &signature) + .map_err(|e| anyhow!("Signature verification failed: {}", e))?; + + let claims_bytes = URL_SAFE_NO_PAD + .decode(claims_b64) + .context("Base64 decode of claims failed")?; + let claims: Claims = + serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; + + let now = Utc::now().timestamp() as usize; + if claims.exp < now { + return Err(anyhow!("Token expired")); + } + + Ok(TokenData { claims }) +} diff --git a/src/lib.rs b/src/lib.rs index 6c953f0..8a3f3ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,31 +1,70 @@ pub mod api; -pub mod state; pub mod auth; pub mod repo; +pub mod state; pub mod storage; use axum::{ - routing::{get, post, any}, Router, + routing::{any, get, post}, }; use state::AppState; pub fn app(state: AppState) -> Router { Router::new() .route("/health", get(api::server::health)) - .route("/xrpc/com.atproto.server.describeServer", get(api::server::describe_server)) - .route("/xrpc/com.atproto.server.createAccount", post(api::identity::create_account)) - .route("/xrpc/com.atproto.server.createSession", post(api::server::create_session)) - .route("/xrpc/com.atproto.server.getSession", get(api::server::get_session)) - .route("/xrpc/com.atproto.server.deleteSession", post(api::server::delete_session)) - .route("/xrpc/com.atproto.server.refreshSession", post(api::server::refresh_session)) - .route("/xrpc/com.atproto.repo.createRecord", post(api::repo::create_record)) - .route("/xrpc/com.atproto.repo.putRecord", post(api::repo::put_record)) - .route("/xrpc/com.atproto.repo.getRecord", get(api::repo::get_record)) - .route("/xrpc/com.atproto.repo.deleteRecord", post(api::repo::delete_record)) - .route("/xrpc/com.atproto.repo.listRecords", get(api::repo::list_records)) - .route("/xrpc/com.atproto.repo.describeRepo", get(api::repo::describe_repo)) - .route("/xrpc/com.atproto.repo.uploadBlob", post(api::repo::upload_blob)) + .route( + "/xrpc/com.atproto.server.describeServer", + get(api::server::describe_server), + ) + .route( + "/xrpc/com.atproto.server.createAccount", + post(api::identity::create_account), + ) + .route( + "/xrpc/com.atproto.server.createSession", + post(api::server::create_session), + ) + .route( + "/xrpc/com.atproto.server.getSession", + get(api::server::get_session), + ) + .route( + "/xrpc/com.atproto.server.deleteSession", + post(api::server::delete_session), + ) + .route( + "/xrpc/com.atproto.server.refreshSession", + post(api::server::refresh_session), + ) + .route( + "/xrpc/com.atproto.repo.createRecord", + post(api::repo::create_record), + ) + .route( + "/xrpc/com.atproto.repo.putRecord", + post(api::repo::put_record), + ) + .route( + "/xrpc/com.atproto.repo.getRecord", + get(api::repo::get_record), + ) + .route( + "/xrpc/com.atproto.repo.deleteRecord", + post(api::repo::delete_record), + ) + .route( + "/xrpc/com.atproto.repo.listRecords", + get(api::repo::list_records), + ) + .route( + "/xrpc/com.atproto.repo.describeRepo", + get(api::repo::describe_repo), + ) + .route( + "/xrpc/com.atproto.repo.uploadBlob", + post(api::repo::upload_blob), + ) .route("/.well-known/did.json", get(api::identity::well_known_did)) .route("/u/{handle}/did.json", get(api::identity::user_did_doc)) .route("/xrpc/{*method}", any(api::proxy::proxy_handler)) diff --git a/src/main.rs b/src/main.rs index ca696ae..7db7e6d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ -use std::net::SocketAddr; use bspds::state::AppState; +use std::net::SocketAddr; use tracing::info; #[tokio::main] diff --git a/src/repo/mod.rs b/src/repo/mod.rs index e3b3c88..e27913b 100644 --- a/src/repo/mod.rs +++ b/src/repo/mod.rs @@ -1,11 +1,11 @@ -use jacquard_repo::storage::BlockStore; +use bytes::Bytes; +use cid::Cid; use jacquard_repo::error::RepoError; use jacquard_repo::repo::CommitData; -use cid::Cid; -use sqlx::{PgPool, Row}; -use bytes::Bytes; -use sha2::{Sha256, Digest}; +use jacquard_repo::storage::BlockStore; use multihash::Multihash; +use sha2::{Digest, Sha256}; +use sqlx::{PgPool, Row}; #[derive(Clone)] pub struct PostgresBlockStore { @@ -31,7 +31,7 @@ impl BlockStore for PostgresBlockStore { Some(row) => { let data: Vec = row.get("data"); Ok(Some(Bytes::from(data))) - }, + } None => Ok(None), } } @@ -65,16 +65,21 @@ impl BlockStore for PostgresBlockStore { Ok(row.is_some()) } - async fn put_many(&self, blocks: impl IntoIterator + Send) -> Result<(), RepoError> { + async fn put_many( + &self, + blocks: impl IntoIterator + Send, + ) -> Result<(), RepoError> { let blocks: Vec<_> = blocks.into_iter().collect(); for (cid, data) in blocks { let cid_bytes = cid.to_bytes(); - sqlx::query("INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING") - .bind(cid_bytes) - .bind(data.as_ref()) - .execute(&self.pool) - .await - .map_err(|e| RepoError::storage(e))?; + sqlx::query( + "INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING", + ) + .bind(cid_bytes) + .bind(data.as_ref()) + .execute(&self.pool) + .await + .map_err(|e| RepoError::storage(e))?; } Ok(()) } diff --git a/src/state.rs b/src/state.rs index 40bdab4..3bc2a46 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,6 @@ -use sqlx::PgPool; use crate::repo::PostgresBlockStore; use crate::storage::{BlobStorage, S3BlobStorage}; +use sqlx::PgPool; use std::sync::Arc; #[derive(Clone)] @@ -14,6 +14,10 @@ impl AppState { pub async fn new(db: PgPool) -> Self { let block_store = PostgresBlockStore::new(db.clone()); let blob_store = S3BlobStorage::new().await; - Self { db, block_store, blob_store: Arc::new(blob_store) } + Self { + db, + block_store, + blob_store: Arc::new(blob_store), + } } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index d83b25f..20fde38 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; -use thiserror::Error; +use aws_config::BehaviorVersion; +use aws_config::meta::region::RegionProviderChain; use aws_sdk_s3::Client; use aws_sdk_s3::primitives::ByteStream; -use aws_config::meta::region::RegionProviderChain; -use aws_config::BehaviorVersion; +use thiserror::Error; #[derive(Error, Debug)] pub enum StorageError { @@ -55,7 +55,8 @@ impl S3BlobStorage { #[async_trait] impl BlobStorage for S3BlobStorage { async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { - self.client.put_object() + self.client + .put_object() .bucket(&self.bucket) .key(key) .body(ByteStream::from(data.to_vec())) @@ -66,14 +67,19 @@ impl BlobStorage for S3BlobStorage { } async fn get(&self, key: &str) -> Result, StorageError> { - let resp = self.client.get_object() + let resp = self + .client + .get_object() .bucket(&self.bucket) .key(key) .send() .await .map_err(|e| StorageError::S3(e.to_string()))?; - let data = resp.body.collect().await + let data = resp + .body + .collect() + .await .map_err(|e| StorageError::S3(e.to_string()))? .into_bytes(); @@ -81,7 +87,8 @@ impl BlobStorage for S3BlobStorage { } async fn delete(&self, key: &str) -> Result<(), StorageError> { - self.client.delete_object() + self.client + .delete_object() .bucket(&self.bucket) .key(key) .send() diff --git a/tests/auth.rs b/tests/auth.rs index 250436c..a6040d2 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -1,10 +1,10 @@ -use bspds::auth; -use k256::SecretKey; -use rand::rngs::OsRng; -use chrono::{Utc, Duration}; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; -use serde_json::json; +use bspds::auth; +use chrono::{Duration, Utc}; +use k256::SecretKey; use k256::ecdsa::{SigningKey, signature::Signer}; +use rand::rngs::OsRng; +use serde_json::json; #[test] fn test_jwt_flow() { @@ -24,7 +24,8 @@ fn test_jwt_flow() { let aud = "did:web:service"; let lxm = "com.example.test"; - let s_token = auth::create_service_token(did, aud, lxm, &key_bytes).expect("create service token"); + let s_token = + auth::create_service_token(did, aud, lxm, &key_bytes).expect("create service token"); let s_data = auth::verify_token(&s_token, &key_bytes).expect("verify service token"); assert_eq!(s_data.claims.aud, aud); assert_eq!(s_data.claims.lxm, Some(lxm.to_string())); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 51955e4..2cef84f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,22 +1,22 @@ -use reqwest::{header, Client, StatusCode}; -use serde_json::{json, Value}; +use aws_config::BehaviorVersion; +use aws_sdk_s3::Client as S3Client; +use aws_sdk_s3::config::Credentials; +use bspds::state::AppState; use chrono::Utc; +use reqwest::{Client, StatusCode, header}; +use serde_json::{Value, json}; +use sqlx::postgres::PgPoolOptions; #[allow(unused_imports)] use std::collections::HashMap; +use std::sync::OnceLock; #[allow(unused_imports)] use std::time::Duration; -use std::sync::OnceLock; -use bspds::state::AppState; -use sqlx::postgres::PgPoolOptions; -use tokio::net::TcpListener; -use testcontainers::{runners::AsyncRunner, ContainerAsync, ImageExt, GenericImage}; use testcontainers::core::ContainerPort; +use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner}; use testcontainers_modules::postgres::Postgres; -use aws_sdk_s3::Client as S3Client; -use aws_config::BehaviorVersion; -use aws_sdk_s3::config::Credentials; -use wiremock::{MockServer, Mock, ResponseTemplate}; +use tokio::net::TcpListener; use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; static SERVER_URL: OnceLock = OnceLock::new(); static DB_CONTAINER: OnceLock> = OnceLock::new(); @@ -46,7 +46,12 @@ pub async fn base_url() -> &'static str { if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") { let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock"); if podman_sock.exists() { - unsafe { std::env::set_var("DOCKER_HOST", format!("unix://{}", podman_sock.display())); } + unsafe { + std::env::set_var( + "DOCKER_HOST", + format!("unix://{}", podman_sock.display()), + ); + } } } } @@ -62,7 +67,10 @@ pub async fn base_url() -> &'static str { .await .expect("Failed to start MinIO"); - let s3_port = s3_container.get_host_port_ipv4(9000).await.expect("Failed to get S3 port"); + let s3_port = s3_container + .get_host_port_ipv4(9000) + .await + .expect("Failed to get S3 port"); let s3_endpoint = format!("http://127.0.0.1:{}", s3_port); unsafe { @@ -76,7 +84,13 @@ pub async fn base_url() -> &'static str { let sdk_config = aws_config::defaults(BehaviorVersion::latest()) .region("us-east-1") .endpoint_url(&s3_endpoint) - .credentials_provider(Credentials::new("minioadmin", "minioadmin", None, None, "test")) + .credentials_provider(Credentials::new( + "minioadmin", + "minioadmin", + None, + None, + "test", + )) .load() .await; @@ -108,15 +122,24 @@ pub async fn base_url() -> &'static str { .mount(&mock_server) .await; - unsafe { std::env::set_var("APPVIEW_URL", mock_server.uri()); } + unsafe { + std::env::set_var("APPVIEW_URL", mock_server.uri()); + } MOCK_APPVIEW.set(mock_server).ok(); S3_CONTAINER.set(s3_container).ok(); - let container = Postgres::default().with_tag("18-alpine").start().await.expect("Failed to start Postgres"); + let container = Postgres::default() + .with_tag("18-alpine") + .start() + .await + .expect("Failed to start Postgres"); let connection_string = format!( "postgres://postgres:postgres@127.0.0.1:{}/postgres", - container.get_host_port_ipv4(5432).await.expect("Failed to get port") + container + .get_host_port_ipv4(5432) + .await + .expect("Failed to get port") ); DB_CONTAINER.set(container).ok(); @@ -157,7 +180,11 @@ async fn spawn_app(database_url: String) -> String { #[allow(dead_code)] pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value { - let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.uploadBlob", + base_url().await + )) .header(header::CONTENT_TYPE, mime) .bearer_auth(AUTH_TOKEN) .body(data) @@ -170,12 +197,11 @@ pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'stati body["blob"].clone() } - #[allow(dead_code)] pub async fn create_test_post( client: &Client, text: &str, - reply_to: Option + reply_to: Option, ) -> (String, String, String) { let collection = "app.bsky.feed.post"; let mut record = json!({ @@ -194,7 +220,11 @@ pub async fn create_test_post( "record": record }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) .bearer_auth(AUTH_TOKEN) .json(&payload) .send() @@ -202,11 +232,24 @@ pub async fn create_test_post( .expect("Failed to send createRecord"); assert_eq!(res.status(), StatusCode::OK, "Failed to create post record"); - let body: Value = res.json().await.expect("createRecord response was not JSON"); + let body: Value = res + .json() + .await + .expect("createRecord response was not JSON"); - let uri = body["uri"].as_str().expect("Response had no URI").to_string(); - let cid = body["cid"].as_str().expect("Response had no CID").to_string(); - let rkey = uri.split('/').last().expect("URI was malformed").to_string(); + let uri = body["uri"] + .as_str() + .expect("Response had no URI") + .to_string(); + let cid = body["cid"] + .as_str() + .expect("Response had no CID") + .to_string(); + let rkey = uri + .split('/') + .last() + .expect("URI was malformed") + .to_string(); (uri, cid, rkey) } @@ -220,7 +263,11 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) { "password": "password" }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&payload) .send() .await @@ -231,7 +278,10 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) { } let body: Value = res.json().await.expect("Invalid JSON"); - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt").to_string(); + let access_jwt = body["accessJwt"] + .as_str() + .expect("No accessJwt") + .to_string(); let did = body["did"].as_str().expect("No did").to_string(); (access_jwt, did) } diff --git a/tests/identity.rs b/tests/identity.rs index 6b45c35..02b9771 100644 --- a/tests/identity.rs +++ b/tests/identity.rs @@ -1,9 +1,9 @@ mod common; use common::*; use reqwest::StatusCode; -use serde_json::{json, Value}; -use wiremock::{MockServer, Mock, ResponseTemplate}; +use serde_json::{Value, json}; use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; // #[tokio::test] // async fn test_resolve_handle() { @@ -23,7 +23,8 @@ use wiremock::matchers::{method, path}; #[tokio::test] async fn test_well_known_did() { let client = client(); - let res = client.get(format!("{}/.well-known/did.json", base_url().await)) + let res = client + .get(format!("{}/.well-known/did.json", base_url().await)) .send() .await .expect("Failed to send request"); @@ -71,7 +72,11 @@ async fn test_create_did_web_account_and_resolve() { "did": did }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&payload) .send() .await @@ -79,13 +84,20 @@ async fn test_create_did_web_account_and_resolve() { if res.status() != StatusCode::OK { let status = res.status(); - let body: Value = res.json().await.unwrap_or(json!({"error": "could not parse body"})); + let body: Value = res + .json() + .await + .unwrap_or(json!({"error": "could not parse body"})); panic!("createAccount failed with status {}: {:?}", status, body); } - let body: Value = res.json().await.expect("createAccount response was not JSON"); + let body: Value = res + .json() + .await + .expect("createAccount response was not JSON"); assert_eq!(body["did"], did); - let res = client.get(format!("{}/u/{}/did.json", base_url().await, handle)) + let res = client + .get(format!("{}/u/{}/did.json", base_url().await, handle)) .send() .await .expect("Failed to fetch DID doc"); @@ -111,14 +123,22 @@ async fn test_create_account_duplicate_handle() { "password": "password" }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&payload) .send() .await .expect("Failed to send request"); assert_eq!(res.status(), StatusCode::OK); - let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&payload) .send() .await @@ -143,7 +163,11 @@ async fn test_did_web_lifecycle() { "did": did }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&create_payload) .send() .await @@ -162,7 +186,11 @@ async fn test_did_web_lifecycle() { "identifier": handle, "password": "password" }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createSession", + base_url().await + )) .json(&login_payload) .send() .await diff --git a/tests/lifecycle.rs b/tests/lifecycle.rs index 5bb1406..f53cb44 100644 --- a/tests/lifecycle.rs +++ b/tests/lifecycle.rs @@ -1,10 +1,9 @@ mod common; use common::*; -use reqwest::{Client, StatusCode}; -use serde_json::{json, Value}; use chrono::Utc; -#[allow(unused_imports)] +use reqwest; +use serde_json::{Value, json}; use std::time::Duration; async fn setup_new_user(handle_prefix: &str) -> (String, String) { @@ -19,20 +18,36 @@ async fn setup_new_user(handle_prefix: &str) -> (String, String) { "email": email, "password": password }); - let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&create_account_payload) .send() .await .expect("setup_new_user: Failed to send createAccount"); - if create_res.status() != StatusCode::OK { - panic!("setup_new_user: Failed to create account: {:?}", create_res.text().await); + if create_res.status() != reqwest::StatusCode::OK { + panic!( + "setup_new_user: Failed to create account: {:?}", + create_res.text().await + ); } - let create_body: Value = create_res.json().await.expect("setup_new_user: createAccount response was not JSON"); + let create_body: Value = create_res + .json() + .await + .expect("setup_new_user: createAccount response was not JSON"); - let new_did = create_body["did"].as_str().expect("setup_new_user: Response had no DID").to_string(); - let new_jwt = create_body["accessJwt"].as_str().expect("setup_new_user: Response had no accessJwt").to_string(); + let new_did = create_body["did"] + .as_str() + .expect("setup_new_user: Response had no DID") + .to_string(); + let new_jwt = create_body["accessJwt"] + .as_str() + .expect("setup_new_user: Response had no accessJwt") + .to_string(); (new_did, new_jwt) } @@ -59,35 +74,59 @@ async fn test_post_crud_lifecycle() { } }); - let create_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(&jwt) .json(&create_payload) .send() .await .expect("Failed to send create request"); - assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record"); - let create_body: Value = create_res.json().await.expect("create response was not JSON"); - let uri = create_body["uri"].as_str().unwrap(); + if create_res.status() != reqwest::StatusCode::OK { + let status = create_res.status(); + let body = create_res + .text() + .await + .unwrap_or_else(|_| "Could not get body".to_string()); + panic!( + "Failed to create record. Status: {}, Body: {}", + status, body + ); + } + let create_body: Value = create_res + .json() + .await + .expect("create response was not JSON"); + let uri = create_body["uri"].as_str().unwrap(); let params = [ ("repo", did.as_str()), ("collection", collection), ("rkey", &rkey), ]; - let get_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + let get_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord", + base_url().await + )) .query(¶ms) .send() .await .expect("Failed to send get request"); - assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create"); + assert_eq!( + get_res.status(), + reqwest::StatusCode::OK, + "Failed to get record after create" + ); let get_body: Value = get_res.json().await.expect("get response was not JSON"); assert_eq!(get_body["uri"], uri); assert_eq!(get_body["value"]["text"], original_text); - let updated_text = "This post has been updated."; let update_payload = json!({ "repo": did, @@ -100,26 +139,46 @@ async fn test_post_crud_lifecycle() { } }); - let update_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let update_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(&jwt) .json(&update_payload) .send() .await .expect("Failed to send update request"); - assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record"); + assert_eq!( + update_res.status(), + reqwest::StatusCode::OK, + "Failed to update record" + ); - - let get_updated_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + let get_updated_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord", + base_url().await + )) .query(¶ms) .send() .await .expect("Failed to send get-after-update request"); - assert_eq!(get_updated_res.status(), StatusCode::OK, "Failed to get record after update"); - let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON"); - assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated"); - + assert_eq!( + get_updated_res.status(), + reqwest::StatusCode::OK, + "Failed to get record after update" + ); + let get_updated_body: Value = get_updated_res + .json() + .await + .expect("get-updated response was not JSON"); + assert_eq!( + get_updated_body["value"]["text"], updated_text, + "Text was not updated" + ); let delete_payload = json!({ "repo": did, @@ -127,23 +186,38 @@ async fn test_post_crud_lifecycle() { "rkey": rkey }); - let delete_res = client.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) + let delete_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.deleteRecord", + base_url().await + )) .bearer_auth(&jwt) .json(&delete_payload) .send() .await .expect("Failed to send delete request"); - assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record"); + assert_eq!( + delete_res.status(), + reqwest::StatusCode::OK, + "Failed to delete record" + ); - - let get_deleted_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + let get_deleted_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord", + base_url().await + )) .query(¶ms) .send() .await .expect("Failed to send get-after-delete request"); - assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record was found, but it should be deleted"); + assert_eq!( + get_deleted_res.status(), + reqwest::StatusCode::NOT_FOUND, + "Record was found, but it should be deleted" + ); } #[tokio::test] @@ -161,24 +235,39 @@ async fn test_record_update_conflict_lifecycle() { "displayName": "Original Name" } }); - let create_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(&user_jwt) .json(&profile_payload) - .send().await.expect("create profile failed"); + .send() + .await + .expect("create profile failed"); - if create_res.status() != StatusCode::OK { + if create_res.status() != reqwest::StatusCode::OK { return; } - let get_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + let get_res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord", + base_url().await + )) .query(&[ ("repo", &user_did), ("collection", &"app.bsky.actor.profile".to_string()), ("rkey", &"self".to_string()), ]) - .send().await.expect("getRecord failed"); + .send() + .await + .expect("getRecord failed"); let get_body: Value = get_res.json().await.expect("getRecord not json"); - let cid_v1 = get_body["cid"].as_str().expect("Profile v1 had no CID").to_string(); + let cid_v1 = get_body["cid"] + .as_str() + .expect("Profile v1 had no CID") + .to_string(); let update_payload_v2 = json!({ "repo": user_did, @@ -190,13 +279,26 @@ async fn test_record_update_conflict_lifecycle() { }, "swapCommit": cid_v1 // <-- Correctly point to v1 }); - let update_res_v2 = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let update_res_v2 = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(&user_jwt) .json(&update_payload_v2) - .send().await.expect("putRecord v2 failed"); - assert_eq!(update_res_v2.status(), StatusCode::OK, "v2 update failed"); + .send() + .await + .expect("putRecord v2 failed"); + assert_eq!( + update_res_v2.status(), + reqwest::StatusCode::OK, + "v2 update failed" + ); let update_body_v2: Value = update_res_v2.json().await.expect("v2 body not json"); - let cid_v2 = update_body_v2["cid"].as_str().expect("v2 response had no CID").to_string(); + let cid_v2 = update_body_v2["cid"] + .as_str() + .expect("v2 response had no CID") + .to_string(); let update_payload_v3_stale = json!({ "repo": user_did, @@ -208,14 +310,20 @@ async fn test_record_update_conflict_lifecycle() { }, "swapCommit": cid_v1 }); - let update_res_v3_stale = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let update_res_v3_stale = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(&user_jwt) .json(&update_payload_v3_stale) - .send().await.expect("putRecord v3 (stale) failed"); + .send() + .await + .expect("putRecord v3 (stale) failed"); assert_eq!( update_res_v3_stale.status(), - StatusCode::CONFLICT, + reqwest::StatusCode::CONFLICT, "Stale update did not cause a 409 Conflict" ); @@ -229,10 +337,233 @@ async fn test_record_update_conflict_lifecycle() { }, "swapCommit": cid_v2 // <-- Correct }); - let update_res_v3_good = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let update_res_v3_good = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(&user_jwt) .json(&update_payload_v3_good) - .send().await.expect("putRecord v3 (good) failed"); + .send() + .await + .expect("putRecord v3 (good) failed"); - assert_eq!(update_res_v3_good.status(), StatusCode::OK, "v3 (good) update failed"); + assert_eq!( + update_res_v3_good.status(), + reqwest::StatusCode::OK, + "v3 (good) update failed" + ); +} + +async fn create_post( + client: &reqwest::Client, + did: &str, + jwt: &str, + text: &str, +) -> (String, String) { + let collection = "app.bsky.feed.post"; + let rkey = format!("e2e_social_{}", Utc::now().timestamp_millis()); + let now = Utc::now().to_rfc3339(); + + let create_payload = json!({ + "repo": did, + "collection": collection, + "rkey": rkey, + "record": { + "$type": collection, + "text": text, + "createdAt": now + } + }); + + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) + .bearer_auth(jwt) + .json(&create_payload) + .send() + .await + .expect("Failed to send create post request"); + + assert_eq!( + create_res.status(), + reqwest::StatusCode::OK, + "Failed to create post record" + ); + let create_body: Value = create_res + .json() + .await + .expect("create post response was not JSON"); + let uri = create_body["uri"].as_str().unwrap().to_string(); + let cid = create_body["cid"].as_str().unwrap().to_string(); + (uri, cid) +} + +async fn create_follow( + client: &reqwest::Client, + follower_did: &str, + follower_jwt: &str, + followee_did: &str, +) -> (String, String) { + let collection = "app.bsky.graph.follow"; + let rkey = format!("e2e_follow_{}", Utc::now().timestamp_millis()); + let now = Utc::now().to_rfc3339(); + + let create_payload = json!({ + "repo": follower_did, + "collection": collection, + "rkey": rkey, + "record": { + "$type": collection, + "subject": followee_did, + "createdAt": now + } + }); + + let create_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) + .bearer_auth(follower_jwt) + .json(&create_payload) + .send() + .await + .expect("Failed to send create follow request"); + + assert_eq!( + create_res.status(), + reqwest::StatusCode::OK, + "Failed to create follow record" + ); + let create_body: Value = create_res + .json() + .await + .expect("create follow response was not JSON"); + let uri = create_body["uri"].as_str().unwrap().to_string(); + let cid = create_body["cid"].as_str().unwrap().to_string(); + (uri, cid) +} + +#[tokio::test] +#[ignore] +async fn test_social_flow_lifecycle() { + let client = client(); + + let (alice_did, alice_jwt) = setup_new_user("alice-social").await; + let (bob_did, bob_jwt) = setup_new_user("bob-social").await; + + let (post1_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's first post!").await; + + create_follow(&client, &bob_did, &bob_jwt, &alice_did).await; + + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeline_res_1 = client + .get(format!( + "{}/xrpc/app.bsky.feed.getTimeline", + base_url().await + )) + .bearer_auth(&bob_jwt) + .send() + .await + .expect("Failed to get timeline (1)"); + + assert_eq!( + timeline_res_1.status(), + reqwest::StatusCode::OK, + "Failed to get timeline (1)" + ); + let timeline_body_1: Value = timeline_res_1.json().await.expect("Timeline (1) not JSON"); + let feed_1 = timeline_body_1["feed"].as_array().unwrap(); + assert_eq!(feed_1.len(), 1, "Timeline should have 1 post"); + assert_eq!( + feed_1[0]["post"]["uri"], post1_uri, + "Post URI mismatch in timeline (1)" + ); + + let (post2_uri, _) = create_post( + &client, + &alice_did, + &alice_jwt, + "Alice's second post, so exciting!", + ) + .await; + + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeline_res_2 = client + .get(format!( + "{}/xrpc/app.bsky.feed.getTimeline", + base_url().await + )) + .bearer_auth(&bob_jwt) + .send() + .await + .expect("Failed to get timeline (2)"); + + assert_eq!( + timeline_res_2.status(), + reqwest::StatusCode::OK, + "Failed to get timeline (2)" + ); + let timeline_body_2: Value = timeline_res_2.json().await.expect("Timeline (2) not JSON"); + let feed_2 = timeline_body_2["feed"].as_array().unwrap(); + assert_eq!(feed_2.len(), 2, "Timeline should have 2 posts"); + assert_eq!( + feed_2[0]["post"]["uri"], post2_uri, + "Post 2 should be first" + ); + assert_eq!( + feed_2[1]["post"]["uri"], post1_uri, + "Post 1 should be second" + ); + + let delete_payload = json!({ + "repo": alice_did, + "collection": "app.bsky.feed.post", + "rkey": post1_uri.split('/').last().unwrap() + }); + let delete_res = client + .post(format!( + "{}/xrpc/com.atproto.repo.deleteRecord", + base_url().await + )) + .bearer_auth(&alice_jwt) + .json(&delete_payload) + .send() + .await + .expect("Failed to send delete request"); + assert_eq!( + delete_res.status(), + reqwest::StatusCode::OK, + "Failed to delete record" + ); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeline_res_3 = client + .get(format!( + "{}/xrpc/app.bsky.feed.getTimeline", + base_url().await + )) + .bearer_auth(&bob_jwt) + .send() + .await + .expect("Failed to get timeline (3)"); + + assert_eq!( + timeline_res_3.status(), + reqwest::StatusCode::OK, + "Failed to get timeline (3)" + ); + let timeline_body_3: Value = timeline_res_3.json().await.expect("Timeline (3) not JSON"); + let feed_3 = timeline_body_3["feed"].as_array().unwrap(); + assert_eq!(feed_3.len(), 1, "Timeline should have 1 post after delete"); + assert_eq!( + feed_3[0]["post"]["uri"], post2_uri, + "Only post 2 should remain" + ); } diff --git a/tests/proxy.rs b/tests/proxy.rs index 37596a0..02d4c87 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -1,17 +1,15 @@ mod common; -use axum::{ - routing::any, - Router, - extract::Request, - http::StatusCode, -}; -use tokio::net::TcpListener; +use axum::{Router, extract::Request, http::StatusCode, routing::any}; +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; use reqwest::Client; use std::sync::Arc; -use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use tokio::net::TcpListener; -async fn spawn_mock_upstream() -> (String, tokio::sync::mpsc::Receiver<(String, String, Option)>) { +async fn spawn_mock_upstream() -> ( + String, + tokio::sync::mpsc::Receiver<(String, String, Option)>, +) { let (tx, rx) = tokio::sync::mpsc::channel(10); let tx = Arc::new(tx); @@ -20,7 +18,9 @@ async fn spawn_mock_upstream() -> (String, tokio::sync::mpsc::Receiver<(String, async move { let method = req.method().to_string(); let uri = req.uri().to_string(); - let auth = req.headers().get("Authorization") + let auth = req + .headers() + .get("Authorization") .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()); @@ -45,7 +45,8 @@ async fn test_proxy_via_header() { let (upstream_url, mut rx) = spawn_mock_upstream().await; let client = Client::new(); - let res = client.get(format!("{}/xrpc/com.example.test", app_url)) + let res = client + .get(format!("{}/xrpc/com.example.test", app_url)) .header("atproto-proxy", &upstream_url) .header("Authorization", "Bearer test-token") .send() @@ -65,12 +66,15 @@ async fn test_proxy_via_header() { async fn test_proxy_via_env_var() { let (upstream_url, mut rx) = spawn_mock_upstream().await; - unsafe { std::env::set_var("APPVIEW_URL", &upstream_url); } + unsafe { + std::env::set_var("APPVIEW_URL", &upstream_url); + } let app_url = common::base_url().await; let client = Client::new(); - let res = client.get(format!("{}/xrpc/com.example.envtest", app_url)) + let res = client + .get(format!("{}/xrpc/com.example.envtest", app_url)) .send() .await .unwrap(); @@ -85,12 +89,15 @@ async fn test_proxy_via_env_var() { #[tokio::test] #[ignore] async fn test_proxy_missing_config() { - unsafe { std::env::remove_var("APPVIEW_URL"); } + unsafe { + std::env::remove_var("APPVIEW_URL"); + } let app_url = common::base_url().await; let client = Client::new(); - let res = client.get(format!("{}/xrpc/com.example.fail", app_url)) + let res = client + .get(format!("{}/xrpc/com.example.fail", app_url)) .send() .await .unwrap(); @@ -106,7 +113,8 @@ async fn test_proxy_auth_signing() { let (access_jwt, did) = common::create_account_and_login(&client).await; - let res = client.get(format!("{}/xrpc/com.example.signed", app_url)) + let res = client + .get(format!("{}/xrpc/com.example.signed", app_url)) .header("atproto-proxy", &upstream_url) .header("Authorization", format!("Bearer {}", access_jwt)) .send() diff --git a/tests/repo.rs b/tests/repo.rs index df945a2..ead484b 100644 --- a/tests/repo.rs +++ b/tests/repo.rs @@ -1,9 +1,9 @@ mod common; use common::*; -use reqwest::{header, StatusCode}; -use serde_json::{json, Value}; use chrono::Utc; +use reqwest::{StatusCode, header}; +use serde_json::{Value, json}; #[tokio::test] #[ignore] @@ -15,7 +15,11 @@ async fn test_get_record() { ("rkey", "self"), ]; - let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord", + base_url().await + )) .query(¶ms) .send() .await @@ -36,7 +40,11 @@ async fn test_get_record_not_found() { ("rkey", "nonexistent"), ]; - let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord", + base_url().await + )) .query(¶ms) .send() .await @@ -50,7 +58,11 @@ async fn test_get_record_not_found() { #[tokio::test] async fn test_upload_blob_no_auth() { let client = client(); - let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.uploadBlob", + base_url().await + )) .header(header::CONTENT_TYPE, "text/plain") .body("no auth") .send() @@ -66,7 +78,11 @@ async fn test_upload_blob_no_auth() { async fn test_upload_blob_success() { let client = client(); let (token, _) = create_account_and_login(&client).await; - let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.uploadBlob", + base_url().await + )) .header(header::CONTENT_TYPE, "text/plain") .bearer_auth(token) .body("This is our blob data") @@ -90,7 +106,11 @@ async fn test_put_record_no_auth() { "record": {} }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .json(&payload) .send() .await @@ -118,7 +138,11 @@ async fn test_put_record_success() { } }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(token) .json(&payload) .send() @@ -135,23 +159,33 @@ async fn test_put_record_success() { #[ignore] async fn test_get_record_missing_params() { let client = client(); - let params = [ - ("repo", "did:plc:12345"), - ]; + let params = [("repo", "did:plc:12345")]; - let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.getRecord", + base_url().await + )) .query(¶ms) .send() .await .expect("Failed to send request"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for missing params"); + assert_eq!( + res.status(), + StatusCode::BAD_REQUEST, + "Expected 400 for missing params" + ); } #[tokio::test] async fn test_upload_blob_bad_token() { let client = client(); - let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.uploadBlob", + base_url().await + )) .header(header::CONTENT_TYPE, "text/plain") .bearer_auth(BAD_AUTH_TOKEN) .body("This is our blob data") @@ -181,14 +215,22 @@ async fn test_put_record_mismatched_repo() { } }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(token) .json(&payload) .send() .await .expect("Failed to send request"); - assert_eq!(res.status(), StatusCode::FORBIDDEN, "Expected 403 for mismatched repo and auth"); + assert_eq!( + res.status(), + StatusCode::FORBIDDEN, + "Expected 403 for mismatched repo and auth" + ); } #[tokio::test] @@ -207,21 +249,33 @@ async fn test_put_record_invalid_schema() { } }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) .bearer_auth(token) .json(&payload) .send() .await .expect("Failed to send request"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for invalid record schema"); + assert_eq!( + res.status(), + StatusCode::BAD_REQUEST, + "Expected 400 for invalid record schema" + ); } #[tokio::test] async fn test_upload_blob_unsupported_mime_type() { let client = client(); let (token, _) = create_account_and_login(&client).await; - let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.uploadBlob", + base_url().await + )) .header(header::CONTENT_TYPE, "application/xml") .bearer_auth(token) .body("not an image") @@ -242,7 +296,11 @@ async fn test_list_records() { ("collection", "app.bsky.feed.post"), ("limit", "10"), ]; - let res = client.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) .query(¶ms) .send() .await @@ -255,10 +313,12 @@ async fn test_list_records() { async fn test_describe_repo() { let client = client(); let (_, did) = create_account_and_login(&client).await; - let params = [ - ("repo", did.as_str()), - ]; - let res = client.get(format!("{}/xrpc/com.atproto.repo.describeRepo", base_url().await)) + let params = [("repo", did.as_str())]; + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.describeRepo", + base_url().await + )) .query(¶ms) .send() .await @@ -282,7 +342,11 @@ async fn test_create_record_success_with_generated_rkey() { } }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) .json(&payload) .bearer_auth(token) .send() @@ -313,7 +377,11 @@ async fn test_create_record_success_with_provided_rkey() { } }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.createRecord", + base_url().await + )) .json(&payload) .bearer_auth(token) .send() @@ -322,7 +390,10 @@ async fn test_create_record_success_with_provided_rkey() { assert_eq!(res.status(), StatusCode::OK); let body: Value = res.json().await.expect("Response was not valid JSON"); - assert_eq!(body["uri"], format!("at://{}/app.bsky.feed.post/{}", did, rkey)); + assert_eq!( + body["uri"], + format!("at://{}/app.bsky.feed.post/{}", did, rkey) + ); // assert_eq!(body["cid"], "bafyreihy"); } @@ -336,7 +407,11 @@ async fn test_delete_record() { "collection": "app.bsky.feed.post", "rkey": "some_post_to_delete" }); - let res = client.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.deleteRecord", + base_url().await + )) .bearer_auth(token) .json(&payload) .send() diff --git a/tests/server.rs b/tests/server.rs index e7fbca3..d36bac3 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -2,12 +2,13 @@ mod common; use common::*; use reqwest::StatusCode; -use serde_json::{json, Value}; +use serde_json::{Value, json}; #[tokio::test] async fn test_health() { let client = client(); - let res = client.get(format!("{}/health", base_url().await)) + let res = client + .get(format!("{}/health", base_url().await)) .send() .await .expect("Failed to send request"); @@ -19,7 +20,11 @@ async fn test_health() { #[tokio::test] async fn test_describe_server() { let client = client(); - let res = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base_url().await)) + let res = client + .get(format!( + "{}/xrpc/com.atproto.server.describeServer", + base_url().await + )) .send() .await .expect("Failed to send request"); @@ -39,7 +44,11 @@ async fn test_create_session() { "email": format!("{}@example.com", handle), "password": "password" }); - let _ = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let _ = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&payload) .send() .await; @@ -49,7 +58,11 @@ async fn test_create_session() { "password": "password" }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createSession", + base_url().await + )) .json(&payload) .send() .await @@ -67,14 +80,21 @@ async fn test_create_session_missing_identifier() { "password": "password" }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createSession", + base_url().await + )) .json(&payload) .send() .await .expect("Failed to send request"); - assert!(res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY, - "Expected 400 or 422 for missing identifier, got {}", res.status()); + assert!( + res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY, + "Expected 400 or 422 for missing identifier, got {}", + res.status() + ); } #[tokio::test] @@ -86,19 +106,31 @@ async fn test_create_account_invalid_handle() { "password": "password" }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&payload) .send() .await .expect("Failed to send request"); - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for invalid handle chars"); + assert_eq!( + res.status(), + StatusCode::BAD_REQUEST, + "Expected 400 for invalid handle chars" + ); } #[tokio::test] async fn test_get_session() { let client = client(); - let res = client.get(format!("{}/xrpc/com.atproto.server.getSession", base_url().await)) + let res = client + .get(format!( + "{}/xrpc/com.atproto.server.getSession", + base_url().await + )) .bearer_auth(AUTH_TOKEN) .send() .await @@ -117,7 +149,11 @@ async fn test_refresh_session() { "email": format!("{}@example.com", handle), "password": "password" }); - let _ = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await)) + let _ = client + .post(format!( + "{}/xrpc/com.atproto.server.createAccount", + base_url().await + )) .json(&payload) .send() .await; @@ -126,7 +162,11 @@ async fn test_refresh_session() { "identifier": handle, "password": "password" }); - let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.createSession", + base_url().await + )) .json(&login_payload) .send() .await @@ -134,10 +174,20 @@ async fn test_refresh_session() { assert_eq!(res.status(), StatusCode::OK); let body: Value = res.json().await.expect("Invalid JSON"); - let refresh_jwt = body["refreshJwt"].as_str().expect("No refreshJwt").to_string(); - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt").to_string(); + let refresh_jwt = body["refreshJwt"] + .as_str() + .expect("No refreshJwt") + .to_string(); + let access_jwt = body["accessJwt"] + .as_str() + .expect("No accessJwt") + .to_string(); - let res = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.refreshSession", + base_url().await + )) .bearer_auth(&refresh_jwt) .send() .await @@ -154,7 +204,11 @@ async fn test_refresh_session() { #[tokio::test] async fn test_delete_session() { let client = client(); - let res = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base_url().await)) + let res = client + .post(format!( + "{}/xrpc/com.atproto.server.deleteSession", + base_url().await + )) .bearer_auth(AUTH_TOKEN) .send() .await diff --git a/tests/sync.rs b/tests/sync.rs index 1a334ec..a58c125 100644 --- a/tests/sync.rs +++ b/tests/sync.rs @@ -6,10 +6,12 @@ use reqwest::StatusCode; #[ignore] async fn test_get_repo() { let client = client(); - let params = [ - ("did", AUTH_DID), - ]; - let res = client.get(format!("{}/xrpc/com.atproto.sync.getRepo", base_url().await)) + let params = [("did", AUTH_DID)]; + let res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getRepo", + base_url().await + )) .query(¶ms) .send() .await @@ -26,7 +28,11 @@ async fn test_get_blocks() { ("did", AUTH_DID), // "cids" would be a list of CIDs ]; - let res = client.get(format!("{}/xrpc/com.atproto.sync.getBlocks", base_url().await)) + let res = client + .get(format!( + "{}/xrpc/com.atproto.sync.getBlocks", + base_url().await + )) .query(¶ms) .send() .await