Format and split big files into smaller ones

This commit is contained in:
Lewis
2025-12-07 11:47:38 +02:00
parent 7b90694066
commit e2cc51f0b1
34 changed files with 3067 additions and 1776 deletions

View File

@@ -1,424 +0,0 @@
use axum::{
extract::{State, Path},
Json,
response::{IntoResponse, Response},
http::StatusCode,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::state::AppState;
use sqlx::Row;
use bcrypt::{hash, DEFAULT_COST};
use tracing::{info, error};
use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore};
use jacquard::types::{string::Tid, did::Did, integer::LimitedU32};
use std::sync::Arc;
use k256::SecretKey;
use rand::rngs::OsRng;
use base64::Engine;
use reqwest;
#[derive(Deserialize)]
pub struct CreateAccountInput {
pub handle: String,
pub email: String,
pub password: String,
#[serde(rename = "inviteCode")]
pub invite_code: Option<String>,
pub did: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateAccountOutput {
pub access_jwt: String,
pub refresh_jwt: String,
pub handle: String,
pub did: String,
}
pub async fn create_account(
State(state): State<AppState>,
Json(input): Json<CreateAccountInput>,
) -> Response {
info!("create_account hit: {}", input.handle);
if input.handle.contains('!') || input.handle.contains('@') {
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}))).into_response();
}
let did = if let Some(d) = &input.did {
if d.trim().is_empty() {
format!("did:plc:{}", uuid::Uuid::new_v4())
} else {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
if let Err(e) = verify_did_web(d, &hostname, &input.handle).await {
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidDid", "message": e}))).into_response();
}
d.clone()
}
} else {
format!("did:plc:{}", uuid::Uuid::new_v4())
};
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(e) => {
error!("Error starting transaction: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1")
.bind(&input.handle)
.fetch_optional(&mut *tx)
.await;
match exists_query {
Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "HandleTaken", "message": "Handle already taken"}))).into_response(),
Err(e) => {
error!("Error checking handle: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
Ok(None) => {}
}
if let Some(code) = &input.invite_code {
let invite_query = sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE")
.bind(code)
.fetch_optional(&mut *tx)
.await;
match invite_query {
Ok(Some(row)) => {
let uses: i32 = row.get("available_uses");
if uses <= 0 {
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
}
let update_invite = sqlx::query("UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1")
.bind(code)
.execute(&mut *tx)
.await;
if let Err(e) = update_invite {
error!("Error updating invite code: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
},
Ok(None) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"}))).into_response(),
Err(e) => {
error!("Error checking invite code: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
}
}
let password_hash = match hash(&input.password, DEFAULT_COST) {
Ok(h) => h,
Err(e) => {
error!("Error hashing password: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id")
.bind(&input.handle)
.bind(&input.email)
.bind(&did)
.bind(&password_hash)
.fetch_one(&mut *tx)
.await;
let user_id: uuid::Uuid = match user_insert {
Ok(row) => row.get("id"),
Err(e) => {
error!("Error inserting user: {:?}", e);
// TODO: Check for unique constraint violation on email/did specifically
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let secret_key = SecretKey::random(&mut OsRng);
let secret_key_bytes = secret_key.to_bytes();
let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)")
.bind(user_id)
.bind(&secret_key_bytes[..])
.execute(&mut *tx)
.await;
if let Err(e) = key_insert {
error!("Error inserting user key: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
let mst = Mst::new(Arc::new(state.block_store.clone()));
let mst_root = match mst.root().await {
Ok(c) => c,
Err(e) => {
error!("Error creating MST root: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
};
let rev = Tid::now(LimitedU32::MIN);
let commit = Commit::new_unsigned(
did_obj,
mst_root,
rev,
None
);
let commit_bytes = match commit.to_cbor() {
Ok(b) => b,
Err(e) => {
error!("Error serializing genesis commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let commit_cid = match state.block_store.put(&commit_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Error saving genesis commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)")
.bind(user_id)
.bind(commit_cid.to_string())
.execute(&mut *tx)
.await;
if let Err(e) = repo_insert {
error!("Error initializing repo: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
if let Some(code) = &input.invite_code {
let use_insert = sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)")
.bind(code)
.bind(user_id)
.execute(&mut *tx)
.await;
if let Err(e) = use_insert {
error!("Error recording invite usage: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
}
let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| {
error!("Error creating access token: {:?}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
});
let access_jwt = match access_jwt {
Ok(t) => t,
Err(r) => return r,
};
let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| {
error!("Error creating refresh token: {:?}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
});
let refresh_jwt = match refresh_jwt {
Ok(t) => t,
Err(r) => return r,
};
let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
.bind(&access_jwt)
.bind(&refresh_jwt)
.bind(&did)
.execute(&mut *tx)
.await;
if let Err(e) = session_insert {
error!("Error inserting session: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
if let Err(e) = tx.commit().await {
error!("Error committing transaction: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
(StatusCode::OK, Json(CreateAccountOutput {
access_jwt,
refresh_jwt,
handle: input.handle,
did,
})).into_response()
}
fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
use k256::elliptic_curve::sec1::ToEncodedPoint;
let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
let public_key = secret_key.public_key();
let encoded = public_key.to_encoded_point(false);
let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
json!({
"kty": "EC",
"crv": "secp256k1",
"x": x,
"y": y
})
}
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
// Kinda for local dev, encode hostname if it contains port
let did = if hostname.contains(':') {
format!("did:web:{}", hostname.replace(':', "%3A"))
} else {
format!("did:web:{}", hostname)
};
Json(json!({
"@context": ["https://www.w3.org/ns/did/v1"],
"id": did,
"service": [{
"id": "#atproto_pds",
"type": "AtprotoPersonalDataServer",
"serviceEndpoint": format!("https://{}", hostname)
}]
}))
}
pub async fn user_did_doc(
State(state): State<AppState>,
Path(handle): Path<String>,
) -> Response {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
.bind(&handle)
.fetch_optional(&state.db)
.await;
let (user_id, did) = match user {
Ok(Some(row)) => {
let id: uuid::Uuid = row.get("id");
let d: String = row.get("did");
(id, d)
},
Ok(None) => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(),
Err(e) => {
error!("DB Error: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
},
};
if !did.starts_with("did:web:") {
return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "User is not did:web"}))).into_response();
}
let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let key_bytes: Vec<u8> = match key_row {
Ok(Some(row)) => row.get("key_bytes"),
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(),
};
let jwk = get_jwk(&key_bytes);
Json(json!({
"@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
"id": did,
"alsoKnownAs": [format!("at://{}", handle)],
"verificationMethod": [{
"id": format!("{}#atproto", did),
"type": "JsonWebKey2020",
"controller": did,
"publicKeyJwk": jwk
}],
"service": [{
"id": "#atproto_pds",
"type": "AtprotoPersonalDataServer",
"serviceEndpoint": format!("https://{}", hostname)
}]
})).into_response()
}
async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
let expected_prefix = if hostname.contains(':') {
format!("did:web:{}", hostname.replace(':', "%3A"))
} else {
format!("did:web:{}", hostname)
};
if did.starts_with(&expected_prefix) {
let suffix = &did[expected_prefix.len()..];
let expected_suffix = format!(":u:{}", handle);
if suffix == expected_suffix {
Ok(())
} else {
Err(format!("Invalid DID path for this PDS. Expected {}", expected_suffix))
}
} else {
let parts: Vec<&str> = did.split(':').collect();
if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
return Err("Invalid did:web format".into());
}
let domain_segment = parts[2];
let domain = domain_segment.replace("%3A", ":");
let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
"http"
} else {
"https"
};
let url = if parts.len() == 3 {
format!("{}://{}/.well-known/did.json", scheme, domain)
} else {
let path = parts[3..].join("/");
format!("{}://{}/{}/did.json", scheme, domain, path)
};
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.map_err(|e| format!("Failed to create client: {}", e))?;
let resp = client.get(&url).send().await
.map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
if !resp.status().is_success() {
return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
}
let doc: serde_json::Value = resp.json().await
.map_err(|e| format!("Failed to parse DID doc: {}", e))?;
let services = doc["service"].as_array()
.ok_or("No services found in DID doc")?;
let pds_endpoint = format!("https://{}", hostname);
let has_valid_service = services.iter().any(|s| {
s["type"] == "AtprotoPersonalDataServer" &&
s["serviceEndpoint"] == pds_endpoint
});
if has_valid_service {
Ok(())
} else {
Err(format!("DID document does not list this PDS ({}) as AtprotoPersonalDataServer", pds_endpoint))
}
}
}

355
src/api/identity/account.rs Normal file
View File

@@ -0,0 +1,355 @@
use super::did::verify_did_web;
use crate::state::AppState;
use axum::{
Json,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use bcrypt::{DEFAULT_COST, hash};
use jacquard::types::{did::Did, integer::LimitedU32, string::Tid};
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
use k256::SecretKey;
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::Row;
use std::sync::Arc;
use tracing::{error, info};
#[derive(Deserialize)]
pub struct CreateAccountInput {
pub handle: String,
pub email: String,
pub password: String,
#[serde(rename = "inviteCode")]
pub invite_code: Option<String>,
pub did: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateAccountOutput {
pub access_jwt: String,
pub refresh_jwt: String,
pub handle: String,
pub did: String,
}
pub async fn create_account(
State(state): State<AppState>,
Json(input): Json<CreateAccountInput>,
) -> Response {
info!("create_account hit: {}", input.handle);
if input.handle.contains('!') || input.handle.contains('@') {
return (
StatusCode::BAD_REQUEST,
Json(
json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
),
)
.into_response();
}
let did = if let Some(d) = &input.did {
if d.trim().is_empty() {
format!("did:plc:{}", uuid::Uuid::new_v4())
} else {
let hostname =
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
if let Err(e) = verify_did_web(d, &hostname, &input.handle).await {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidDid", "message": e})),
)
.into_response();
}
d.clone()
}
} else {
format!("did:plc:{}", uuid::Uuid::new_v4())
};
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(e) => {
error!("Error starting transaction: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1")
.bind(&input.handle)
.fetch_optional(&mut *tx)
.await;
match exists_query {
Ok(Some(_)) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "HandleTaken", "message": "Handle already taken"})),
)
.into_response();
}
Err(e) => {
error!("Error checking handle: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
Ok(None) => {}
}
if let Some(code) = &input.invite_code {
let invite_query =
sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE")
.bind(code)
.fetch_optional(&mut *tx)
.await;
match invite_query {
Ok(Some(row)) => {
let uses: i32 = row.get("available_uses");
if uses <= 0 {
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
}
let update_invite = sqlx::query(
"UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1",
)
.bind(code)
.execute(&mut *tx)
.await;
if let Err(e) = update_invite {
error!("Error updating invite code: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
Ok(None) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})),
)
.into_response();
}
Err(e) => {
error!("Error checking invite code: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
}
let password_hash = match hash(&input.password, DEFAULT_COST) {
Ok(h) => h,
Err(e) => {
error!("Error hashing password: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id")
.bind(&input.handle)
.bind(&input.email)
.bind(&did)
.bind(&password_hash)
.fetch_one(&mut *tx)
.await;
let user_id: uuid::Uuid = match user_insert {
Ok(row) => row.get("id"),
Err(e) => {
error!("Error inserting user: {:?}", e);
// TODO: Check for unique constraint violation on email/did specifically
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let secret_key = SecretKey::random(&mut OsRng);
let secret_key_bytes = secret_key.to_bytes();
let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)")
.bind(user_id)
.bind(&secret_key_bytes[..])
.execute(&mut *tx)
.await;
if let Err(e) = key_insert {
error!("Error inserting user key: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
let mst = Mst::new(Arc::new(state.block_store.clone()));
let mst_root = match mst.root().await {
Ok(c) => c,
Err(e) => {
error!("Error creating MST root: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
)
.into_response();
}
};
let rev = Tid::now(LimitedU32::MIN);
let commit = Commit::new_unsigned(did_obj, mst_root, rev, None);
let commit_bytes = match commit.to_cbor() {
Ok(b) => b,
Err(e) => {
error!("Error serializing genesis commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let commit_cid = match state.block_store.put(&commit_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Error saving genesis commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)")
.bind(user_id)
.bind(commit_cid.to_string())
.execute(&mut *tx)
.await;
if let Err(e) = repo_insert {
error!("Error initializing repo: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
if let Some(code) = &input.invite_code {
let use_insert =
sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)")
.bind(code)
.bind(user_id)
.execute(&mut *tx)
.await;
if let Err(e) = use_insert {
error!("Error recording invite usage: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| {
error!("Error creating access token: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response()
});
let access_jwt = match access_jwt {
Ok(t) => t,
Err(r) => return r,
};
let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| {
error!("Error creating refresh token: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response()
});
let refresh_jwt = match refresh_jwt {
Ok(t) => t,
Err(r) => return r,
};
let session_insert =
sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
.bind(&access_jwt)
.bind(&refresh_jwt)
.bind(&did)
.execute(&mut *tx)
.await;
if let Err(e) = session_insert {
error!("Error inserting session: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
if let Err(e) = tx.commit().await {
error!("Error committing transaction: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
(
StatusCode::OK,
Json(CreateAccountOutput {
access_jwt,
refresh_jwt,
handle: input.handle,
did,
}),
)
.into_response()
}

201
src/api/identity/did.rs Normal file
View File

@@ -0,0 +1,201 @@
use crate::state::AppState;
use axum::{
Json,
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use base64::Engine;
use k256::SecretKey;
use k256::elliptic_curve::sec1::ToEncodedPoint;
use reqwest;
use serde_json::json;
use sqlx::Row;
use tracing::error;
pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
let public_key = secret_key.public_key();
let encoded = public_key.to_encoded_point(false);
let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
json!({
"kty": "EC",
"crv": "secp256k1",
"x": x,
"y": y
})
}
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
// Kinda for local dev, encode hostname if it contains port
let did = if hostname.contains(':') {
format!("did:web:{}", hostname.replace(':', "%3A"))
} else {
format!("did:web:{}", hostname)
};
Json(json!({
"@context": ["https://www.w3.org/ns/did/v1"],
"id": did,
"service": [{
"id": "#atproto_pds",
"type": "AtprotoPersonalDataServer",
"serviceEndpoint": format!("https://{}", hostname)
}]
}))
}
pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
.bind(&handle)
.fetch_optional(&state.db)
.await;
let (user_id, did) = match user {
Ok(Some(row)) => {
let id: uuid::Uuid = row.get("id");
let d: String = row.get("did");
(id, d)
}
Ok(None) => {
return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
}
Err(e) => {
error!("DB Error: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
if !did.starts_with("did:web:") {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "User is not did:web"})),
)
.into_response();
}
let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let key_bytes: Vec<u8> = match key_row {
Ok(Some(row)) => row.get("key_bytes"),
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let jwk = get_jwk(&key_bytes);
Json(json!({
"@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
"id": did,
"alsoKnownAs": [format!("at://{}", handle)],
"verificationMethod": [{
"id": format!("{}#atproto", did),
"type": "JsonWebKey2020",
"controller": did,
"publicKeyJwk": jwk
}],
"service": [{
"id": "#atproto_pds",
"type": "AtprotoPersonalDataServer",
"serviceEndpoint": format!("https://{}", hostname)
}]
})).into_response()
}
pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
let expected_prefix = if hostname.contains(':') {
format!("did:web:{}", hostname.replace(':', "%3A"))
} else {
format!("did:web:{}", hostname)
};
if did.starts_with(&expected_prefix) {
let suffix = &did[expected_prefix.len()..];
let expected_suffix = format!(":u:{}", handle);
if suffix == expected_suffix {
Ok(())
} else {
Err(format!(
"Invalid DID path for this PDS. Expected {}",
expected_suffix
))
}
} else {
let parts: Vec<&str> = did.split(':').collect();
if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
return Err("Invalid did:web format".into());
}
let domain_segment = parts[2];
let domain = domain_segment.replace("%3A", ":");
let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
"http"
} else {
"https"
};
let url = if parts.len() == 3 {
format!("{}://{}/.well-known/did.json", scheme, domain)
} else {
let path = parts[3..].join("/");
format!("{}://{}/{}/did.json", scheme, domain, path)
};
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.map_err(|e| format!("Failed to create client: {}", e))?;
let resp = client
.get(&url)
.send()
.await
.map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
if !resp.status().is_success() {
return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
}
let doc: serde_json::Value = resp
.json()
.await
.map_err(|e| format!("Failed to parse DID doc: {}", e))?;
let services = doc["service"]
.as_array()
.ok_or("No services found in DID doc")?;
let pds_endpoint = format!("https://{}", hostname);
let has_valid_service = services.iter().any(|s| {
s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
});
if has_valid_service {
Ok(())
} else {
Err(format!(
"DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
pds_endpoint
))
}
}
}

5
src/api/identity/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod account;
pub mod did;
pub use account::create_account;
pub use did::{user_did_doc, well_known_did};

View File

@@ -1,4 +1,4 @@
pub mod server;
pub mod repo;
pub mod proxy;
pub mod identity;
pub mod proxy;
pub mod repo;
pub mod server;

View File

@@ -1,14 +1,14 @@
use crate::state::AppState;
use axum::{
body::Bytes,
extract::{Path, Query, State},
http::{HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
body::Bytes,
};
use reqwest::Client;
use tracing::{info, error};
use std::collections::HashMap;
use crate::state::AppState;
use sqlx::Row;
use std::collections::HashMap;
use tracing::{error, info};
pub async fn proxy_handler(
State(state): State<AppState>,
@@ -18,8 +18,8 @@ pub async fn proxy_handler(
Query(params): Query<HashMap<String, String>>,
body: Bytes,
) -> Response {
let proxy_header = headers.get("atproto-proxy")
let proxy_header = headers
.get("atproto-proxy")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
@@ -27,7 +27,9 @@ pub async fn proxy_handler(
Some(url) => url.clone(),
None => match std::env::var("APPVIEW_URL") {
Ok(url) => url,
Err(_) => return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response(),
Err(_) => {
return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response();
}
},
};
@@ -37,9 +39,7 @@ pub async fn proxy_handler(
let client = Client::new();
let mut request_builder = client
.request(method_verb, &target_url)
.query(&params);
let mut request_builder = client.request(method_verb, &target_url).query(&params);
let mut auth_header_val = headers.get("Authorization").map(|h| h.clone());
@@ -48,17 +48,21 @@ pub async fn proxy_handler(
if let Ok(token) = auth_val.to_str() {
let token = token.replace("Bearer ", "");
if let Ok(did) = crate::auth::get_did_from_token(&token) {
let key_row = sqlx::query("SELECT k.key_bytes FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1")
let key_row = sqlx::query("SELECT k.key_bytes FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
if let Ok(Some(row)) = key_row {
let key_bytes: Vec<u8> = row.get("key_bytes");
if let Ok(new_token) = crate::auth::create_service_token(&did, aud, &method, &key_bytes) {
if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) {
auth_header_val = Some(val);
}
if let Ok(new_token) =
crate::auth::create_service_token(&did, aud, &method, &key_bytes)
{
if let Ok(val) =
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
{
auth_header_val = Some(val);
}
}
}
}
@@ -86,7 +90,8 @@ pub async fn proxy_handler(
Ok(b) => b,
Err(e) => {
error!("Error reading proxy response body: {:?}", e);
return (StatusCode::BAD_GATEWAY, "Error reading upstream response").into_response();
return (StatusCode::BAD_GATEWAY, "Error reading upstream response")
.into_response();
}
};
@@ -99,11 +104,11 @@ pub async fn proxy_handler(
match response_builder.body(axum::body::Body::from(body)) {
Ok(r) => r,
Err(e) => {
error!("Error building proxy response: {:?}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
error!("Error building proxy response: {:?}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
}
}
},
}
Err(e) => {
error!("Error sending proxy request: {:?}", e);
if e.is_timeout() {

View File

@@ -1,889 +0,0 @@
use axum::{
extract::{State, Query},
Json,
response::{IntoResponse, Response},
http::StatusCode,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::state::AppState;
use chrono::Utc;
use sqlx::Row;
use cid::Cid;
use std::str::FromStr;
use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore};
use jacquard::types::{string::{Nsid, Tid}, did::Did, integer::LimitedU32};
use tracing::error;
use std::sync::Arc;
use sha2::{Sha256, Digest};
use multihash::Multihash;
use axum::body::Bytes;
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct CreateRecordInput {
pub repo: String,
pub collection: String,
pub rkey: Option<String>,
pub validate: Option<bool>,
pub record: serde_json::Value,
#[serde(rename = "swapCommit")]
pub swap_commit: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateRecordOutput {
pub uri: String,
pub cid: String,
}
pub async fn create_record(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<CreateRecordInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
}
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
}
if input.repo != did {
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(),
};
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let current_root_cid = match repo_root_query {
Ok(Some(row)) => {
let cid_str: String = row.get("repo_root_cid");
Cid::from_str(&cid_str).ok()
},
_ => None,
};
if current_root_cid.is_none() {
error!("Repo root not found for user {}", did);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response();
}
let current_root_cid = current_root_cid.unwrap();
let commit_bytes = match state.block_store.get(&current_root_cid).await {
Ok(Some(b)) => b,
Ok(None) => {
error!("Commit block not found: {}", current_root_cid);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
},
Err(e) => {
error!("Failed to load commit block: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
Err(e) => {
error!("Failed to parse commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let mst_root = commit.data;
let store = Arc::new(state.block_store.clone());
let mst = Mst::load(store.clone(), mst_root, None);
let collection_nsid = match input.collection.parse::<Nsid>() {
Ok(n) => n,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
};
let rkey = input.rkey.unwrap_or_else(|| {
Utc::now().format("%Y%m%d%H%M%S%f").to_string()
});
let mut record_bytes = Vec::new();
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
error!("Error serializing record: {:?}", e);
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
}
let record_cid = match state.block_store.put(&record_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save record block: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let key = format!("{}/{}", collection_nsid, rkey);
if let Err(e) = mst.update(&key, record_cid).await {
error!("Failed to update MST: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
let new_mst_root = match mst.root().await {
Ok(c) => c,
Err(e) => {
error!("Failed to get new MST root: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
};
let rev = Tid::now(LimitedU32::MIN);
let new_commit = Commit::new_unsigned(
did_obj,
new_mst_root,
rev,
Some(current_root_cid)
);
let new_commit_bytes = match new_commit.to_cbor() {
Ok(b) => b,
Err(e) => {
error!("Failed to serialize new commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save new commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
.bind(new_root_cid.to_string())
.bind(user_id)
.execute(&state.db)
.await;
if let Err(e) = update_repo {
error!("Failed to update repo root in DB: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
let record_insert = sqlx::query(
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()"
)
.bind(user_id)
.bind(&input.collection)
.bind(&rkey)
.bind(record_cid.to_string())
.execute(&state.db)
.await;
if let Err(e) = record_insert {
error!("Error inserting record index: {:?}", e);
}
let output = CreateRecordOutput {
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
cid: record_cid.to_string(),
};
(StatusCode::OK, Json(output)).into_response()
}
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct PutRecordInput {
pub repo: String,
pub collection: String,
pub rkey: String,
pub validate: Option<bool>,
pub record: serde_json::Value,
#[serde(rename = "swapCommit")]
pub swap_commit: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PutRecordOutput {
pub uri: String,
pub cid: String,
}
pub async fn put_record(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<PutRecordInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
}
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
}
if input.repo != did {
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(),
};
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let current_root_cid = match repo_root_query {
Ok(Some(row)) => {
let cid_str: String = row.get("repo_root_cid");
Cid::from_str(&cid_str).ok()
},
_ => None,
};
if current_root_cid.is_none() {
error!("Repo root not found for user {}", did);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response();
}
let current_root_cid = current_root_cid.unwrap();
let commit_bytes = match state.block_store.get(&current_root_cid).await {
Ok(Some(b)) => b,
Ok(None) => {
error!("Commit block not found: {}", current_root_cid);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response();
},
Err(e) => {
error!("Failed to load commit block: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to load commit block"}))).into_response();
}
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
Err(e) => {
error!("Failed to parse commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response();
}
};
let mst_root = commit.data;
let store = Arc::new(state.block_store.clone());
let mst = Mst::load(store.clone(), mst_root, None);
let collection_nsid = match input.collection.parse::<Nsid>() {
Ok(n) => n,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
};
let rkey = input.rkey.clone();
let mut record_bytes = Vec::new();
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
error!("Error serializing record: {:?}", e);
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
}
let record_cid = match state.block_store.put(&record_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save record block: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save record block"}))).into_response();
}
};
let key = format!("{}/{}", collection_nsid, rkey);
if let Err(e) = mst.update(&key, record_cid).await {
error!("Failed to update MST: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response();
}
let new_mst_root = match mst.root().await {
Ok(c) => c,
Err(e) => {
error!("Failed to get new MST root: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST root"}))).into_response();
}
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
};
let rev = Tid::now(LimitedU32::MIN);
let new_commit = Commit::new_unsigned(
did_obj,
new_mst_root,
rev,
Some(current_root_cid)
);
let new_commit_bytes = match new_commit.to_cbor() {
Ok(b) => b,
Err(e) => {
error!("Failed to serialize new commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to serialize new commit"}))).into_response();
}
};
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save new commit: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save new commit"}))).into_response();
}
};
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
.bind(new_root_cid.to_string())
.bind(user_id)
.execute(&state.db)
.await;
if let Err(e) = update_repo {
error!("Failed to update repo root in DB: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"}))).into_response();
}
let record_insert = sqlx::query(
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()"
)
.bind(user_id)
.bind(&input.collection)
.bind(&rkey)
.bind(record_cid.to_string())
.execute(&state.db)
.await;
if let Err(e) = record_insert {
error!("Error inserting record index: {:?}", e);
}
let output = PutRecordOutput {
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
cid: record_cid.to_string(),
};
(StatusCode::OK, Json(output)).into_response()
}
#[derive(Deserialize)]
pub struct GetRecordInput {
pub repo: String,
pub collection: String,
pub rkey: String,
pub cid: Option<String>,
}
pub async fn get_record(
State(state): State<AppState>,
Query(input): Query<GetRecordInput>,
) -> Response {
let user_row = if input.repo.starts_with("did:") {
sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
} else {
sqlx::query("SELECT id FROM users WHERE handle = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
};
let user_id: uuid::Uuid = match user_row {
Ok(Some(row)) => row.get("id"),
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(),
};
let record_row = sqlx::query("SELECT record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3")
.bind(user_id)
.bind(&input.collection)
.bind(&input.rkey)
.fetch_optional(&state.db)
.await;
let record_cid_str: String = match record_row {
Ok(Some(row)) => row.get("record_cid"),
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Record not found"}))).into_response(),
};
if let Some(expected_cid) = &input.cid {
if &record_cid_str != expected_cid {
return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Record CID mismatch"}))).into_response();
}
}
let cid = match Cid::from_str(&record_cid_str) {
Ok(c) => c,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid CID in DB"}))).into_response(),
};
let block = match state.block_store.get(&cid).await {
Ok(Some(b)) => b,
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Record block not found"}))).into_response(),
};
let value: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block) {
Ok(v) => v,
Err(e) => {
error!("Failed to deserialize record: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
Json(json!({
"uri": format!("at://{}/{}/{}", input.repo, input.collection, input.rkey),
"cid": record_cid_str,
"value": value
})).into_response()
}
#[derive(Deserialize)]
pub struct DeleteRecordInput {
pub repo: String,
pub collection: String,
pub rkey: String,
#[serde(rename = "swapRecord")]
pub swap_record: Option<String>,
#[serde(rename = "swapCommit")]
pub swap_commit: Option<String>,
}
pub async fn delete_record(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<DeleteRecordInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
}
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
}
if input.repo != did {
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(),
};
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let current_root_cid = match repo_root_query {
Ok(Some(row)) => {
let cid_str: String = row.get("repo_root_cid");
Cid::from_str(&cid_str).ok()
},
_ => None,
};
if current_root_cid.is_none() {
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response();
}
let current_root_cid = current_root_cid.unwrap();
let commit_bytes = match state.block_store.get(&current_root_cid).await {
Ok(Some(b)) => b,
Ok(None) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(),
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to load commit block: {:?}", e)}))).into_response(),
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to parse commit: {:?}", e)}))).into_response(),
};
let mst_root = commit.data;
let store = Arc::new(state.block_store.clone());
let mst = Mst::load(store.clone(), mst_root, None);
let collection_nsid = match input.collection.parse::<Nsid>() {
Ok(n) => n,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
};
let key = format!("{}/{}", collection_nsid, input.rkey);
// TODO: Check swapRecord if provided? Skipping for brevity/robustness
if let Err(e) = mst.delete(&key).await {
error!("Failed to delete from MST: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response();
}
let new_mst_root = match mst.root().await {
Ok(c) => c,
Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST root"}))).into_response(),
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
};
let rev = Tid::now(LimitedU32::MIN);
let new_commit = Commit::new_unsigned(
did_obj,
new_mst_root,
rev,
Some(current_root_cid)
);
let new_commit_bytes = match new_commit.to_cbor() {
Ok(b) => b,
Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to serialize new commit"}))).into_response(),
};
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
Ok(c) => c,
Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save new commit"}))).into_response(),
};
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
.bind(new_root_cid.to_string())
.bind(user_id)
.execute(&state.db)
.await;
if let Err(e) = update_repo {
error!("Failed to update repo root in DB: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"}))).into_response();
}
let record_delete = sqlx::query("DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3")
.bind(user_id)
.bind(&input.collection)
.bind(&input.rkey)
.execute(&state.db)
.await;
if let Err(e) = record_delete {
error!("Error deleting record index: {:?}", e);
}
(StatusCode::OK, Json(json!({}))).into_response()
}
#[derive(Deserialize)]
pub struct ListRecordsInput {
pub repo: String,
pub collection: String,
pub limit: Option<i32>,
pub cursor: Option<String>,
#[serde(rename = "rkeyStart")]
pub rkey_start: Option<String>,
#[serde(rename = "rkeyEnd")]
pub rkey_end: Option<String>,
pub reverse: Option<bool>,
}
#[derive(Serialize)]
pub struct ListRecordsOutput {
pub cursor: Option<String>,
pub records: Vec<serde_json::Value>,
}
pub async fn list_records(
State(state): State<AppState>,
Query(input): Query<ListRecordsInput>,
) -> Response {
let user_row = if input.repo.starts_with("did:") {
sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
} else {
sqlx::query("SELECT id FROM users WHERE handle = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
};
let user_id: uuid::Uuid = match user_row {
Ok(Some(row)) => row.get("id"),
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(),
};
let limit = input.limit.unwrap_or(50).clamp(1, 100);
let reverse = input.reverse.unwrap_or(false);
// Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination
// TODO: Implement rkeyStart/End and correct cursor logic
let query_str = format!(
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 {} ORDER BY rkey {} LIMIT {}",
if let Some(_c) = &input.cursor {
if reverse { "AND rkey < $3" } else { "AND rkey > $3" }
} else {
""
},
if reverse { "DESC" } else { "ASC" },
limit
);
let mut query = sqlx::query(&query_str)
.bind(user_id)
.bind(&input.collection);
if let Some(c) = &input.cursor {
query = query.bind(c);
}
let rows = match query.fetch_all(&state.db).await {
Ok(r) => r,
Err(e) => {
error!("Error listing records: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
};
let mut records = Vec::new();
let mut last_rkey = None;
for row in rows {
let rkey: String = row.get("rkey");
let cid_str: String = row.get("record_cid");
last_rkey = Some(rkey.clone());
if let Ok(cid) = Cid::from_str(&cid_str) {
if let Ok(Some(block)) = state.block_store.get(&cid).await {
if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
records.push(json!({
"uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
"cid": cid_str,
"value": value
}));
}
}
}
}
Json(ListRecordsOutput {
cursor: last_rkey,
records,
}).into_response()
}
#[derive(Deserialize)]
pub struct DescribeRepoInput {
pub repo: String,
}
pub async fn describe_repo(
State(state): State<AppState>,
Query(input): Query<DescribeRepoInput>,
) -> Response {
let user_row = if input.repo.starts_with("did:") {
sqlx::query("SELECT id, handle, did FROM users WHERE did = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
} else {
sqlx::query("SELECT id, handle, did FROM users WHERE handle = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
};
let (user_id, handle, did) = match user_row {
Ok(Some(row)) => (row.get::<uuid::Uuid, _>("id"), row.get::<String, _>("handle"), row.get::<String, _>("did")),
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(),
};
let collections_query = sqlx::query("SELECT DISTINCT collection FROM records WHERE repo_id = $1")
.bind(user_id)
.fetch_all(&state.db)
.await;
let collections: Vec<String> = match collections_query {
Ok(rows) => rows.iter().map(|r| r.get("collection")).collect(),
Err(_) => Vec::new(),
};
let did_doc = json!({
"id": did,
"alsoKnownAs": [format!("at://{}", handle)]
});
Json(json!({
"handle": handle,
"did": did,
"didDoc": did_doc,
"collections": collections,
"handleIsCorrect": true
})).into_response()
}
pub async fn upload_blob(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
body: Bytes,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
}
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
}
let mime_type = headers.get("content-type")
.and_then(|h| h.to_str().ok())
.unwrap_or("application/octet-stream")
.to_string();
let size = body.len() as i64;
let data = body.to_vec();
let mut hasher = Sha256::new();
hasher.update(&data);
let hash = hasher.finalize();
let multihash = Multihash::wrap(0x12, &hash).unwrap();
let cid = Cid::new_v1(0x55, multihash);
let cid_str = cid.to_string();
let storage_key = format!("blobs/{}", cid_str);
if let Err(e) = state.blob_store.put(&storage_key, &data).await {
error!("Failed to upload blob to storage: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store blob"}))).into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(),
};
let insert = sqlx::query(
"INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING"
)
.bind(&cid_str)
.bind(&mime_type)
.bind(size)
.bind(user_id)
.bind(&storage_key)
.execute(&state.db)
.await;
if let Err(e) = insert {
error!("Failed to insert blob record: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
}
Json(json!({
"blob": {
"ref": {
"$link": cid_str
},
"mimeType": mime_type,
"size": size
}
})).into_response()
}

138
src/api/repo/blob.rs Normal file
View File

@@ -0,0 +1,138 @@
use crate::state::AppState;
use axum::body::Bytes;
use axum::{
Json,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use cid::Cid;
use multihash::Multihash;
use serde_json::json;
use sha2::{Digest, Sha256};
use sqlx::Row;
use tracing::error;
pub async fn upload_blob(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
body: Bytes,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (
row.get::<String, _>("did"),
row.get::<Vec<u8>, _>("key_bytes"),
),
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
)
.into_response();
}
let mime_type = headers
.get("content-type")
.and_then(|h| h.to_str().ok())
.unwrap_or("application/octet-stream")
.to_string();
let size = body.len() as i64;
let data = body.to_vec();
let mut hasher = Sha256::new();
hasher.update(&data);
let hash = hasher.finalize();
let multihash = Multihash::wrap(0x12, &hash).unwrap();
let cid = Cid::new_v1(0x55, multihash);
let cid_str = cid.to_string();
let storage_key = format!("blobs/{}", cid_str);
if let Err(e) = state.blob_store.put(&storage_key, &data).await {
error!("Failed to upload blob to storage: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to store blob"})),
)
.into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let insert = sqlx::query(
"INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING"
)
.bind(&cid_str)
.bind(&mime_type)
.bind(size)
.bind(user_id)
.bind(&storage_key)
.execute(&state.db)
.await;
if let Err(e) = insert {
error!("Failed to insert blob record: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
Json(json!({
"blob": {
"ref": {
"$link": cid_str
},
"mimeType": mime_type,
"size": size
}
}))
.into_response()
}

72
src/api/repo/meta.rs Normal file
View File

@@ -0,0 +1,72 @@
use crate::state::AppState;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Deserialize;
use serde_json::json;
use sqlx::Row;
#[derive(Deserialize)]
pub struct DescribeRepoInput {
pub repo: String,
}
pub async fn describe_repo(
State(state): State<AppState>,
Query(input): Query<DescribeRepoInput>,
) -> Response {
let user_row = if input.repo.starts_with("did:") {
sqlx::query("SELECT id, handle, did FROM users WHERE did = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
} else {
sqlx::query("SELECT id, handle, did FROM users WHERE handle = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
};
let (user_id, handle, did) = match user_row {
Ok(Some(row)) => (
row.get::<uuid::Uuid, _>("id"),
row.get::<String, _>("handle"),
row.get::<String, _>("did"),
),
_ => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Repo not found"})),
)
.into_response();
}
};
let collections_query =
sqlx::query("SELECT DISTINCT collection FROM records WHERE repo_id = $1")
.bind(user_id)
.fetch_all(&state.db)
.await;
let collections: Vec<String> = match collections_query {
Ok(rows) => rows.iter().map(|r| r.get("collection")).collect(),
Err(_) => Vec::new(),
};
let did_doc = json!({
"id": did,
"alsoKnownAs": [format!("at://{}", handle)]
});
Json(json!({
"handle": handle,
"did": did,
"didDoc": did_doc,
"collections": collections,
"handleIsCorrect": true
}))
.into_response()
}

7
src/api/repo/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
pub mod blob;
pub mod meta;
pub mod record;
pub use blob::upload_blob;
pub use meta::describe_repo;
pub use record::{create_record, delete_record, get_record, list_records, put_record};

View File

@@ -0,0 +1,236 @@
use crate::state::AppState;
use axum::{
Json,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use cid::Cid;
use jacquard::types::{
did::Did,
integer::LimitedU32,
string::{Nsid, Tid},
};
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
use serde::Deserialize;
use serde_json::json;
use sqlx::Row;
use std::str::FromStr;
use std::sync::Arc;
use tracing::error;
#[derive(Deserialize)]
pub struct DeleteRecordInput {
pub repo: String,
pub collection: String,
pub rkey: String,
#[serde(rename = "swapRecord")]
pub swap_record: Option<String>,
#[serde(rename = "swapCommit")]
pub swap_commit: Option<String>,
}
pub async fn delete_record(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<DeleteRecordInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (
row.get::<String, _>("did"),
row.get::<Vec<u8>, _>("key_bytes"),
),
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
)
.into_response();
}
if input.repo != did {
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "User not found"})),
)
.into_response();
}
};
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let current_root_cid = match repo_root_query {
Ok(Some(row)) => {
let cid_str: String = row.get("repo_root_cid");
Cid::from_str(&cid_str).ok()
}
_ => None,
};
if current_root_cid.is_none() {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
)
.into_response();
}
let current_root_cid = current_root_cid.unwrap();
let commit_bytes = match state.block_store.get(&current_root_cid).await {
Ok(Some(b)) => b,
Ok(None) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(),
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to load commit block: {:?}", e)}))).into_response(),
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to parse commit: {:?}", e)}))).into_response(),
};
let mst_root = commit.data;
let store = Arc::new(state.block_store.clone());
let mst = Mst::load(store.clone(), mst_root, None);
let collection_nsid = match input.collection.parse::<Nsid>() {
Ok(n) => n,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidCollection"})),
)
.into_response();
}
};
let key = format!("{}/{}", collection_nsid, input.rkey);
// TODO: Check swapRecord if provided? Skipping for brevity/robustness
if let Err(e) = mst.delete(&key).await {
error!("Failed to delete from MST: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response();
}
let new_mst_root = match mst.root().await {
Ok(c) => c,
Err(_e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})),
)
.into_response();
}
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
)
.into_response();
}
};
let rev = Tid::now(LimitedU32::MIN);
let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid));
let new_commit_bytes =
match new_commit.to_cbor() {
Ok(b) => b,
Err(_e) => return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(
json!({"error": "InternalError", "message": "Failed to serialize new commit"}),
),
)
.into_response(),
};
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
Ok(c) => c,
Err(_e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to save new commit"})),
)
.into_response();
}
};
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
.bind(new_root_cid.to_string())
.bind(user_id)
.execute(&state.db)
.await;
if let Err(e) = update_repo {
error!("Failed to update repo root in DB: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"})),
)
.into_response();
}
let record_delete =
sqlx::query("DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3")
.bind(user_id)
.bind(&input.collection)
.bind(&input.rkey)
.execute(&state.db)
.await;
if let Err(e) = record_delete {
error!("Error deleting record index: {:?}", e);
}
(StatusCode::OK, Json(json!({}))).into_response()
}

View File

@@ -0,0 +1,10 @@
pub mod delete;
pub mod read;
pub mod write;
pub use delete::{DeleteRecordInput, delete_record};
pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
pub use write::{
CreateRecordInput, CreateRecordOutput, PutRecordInput, PutRecordOutput, create_record,
put_record,
};

236
src/api/repo/record/read.rs Normal file
View File

@@ -0,0 +1,236 @@
use crate::state::AppState;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use cid::Cid;
use jacquard_repo::storage::BlockStore;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::Row;
use std::str::FromStr;
use tracing::error;
#[derive(Deserialize)]
pub struct GetRecordInput {
pub repo: String,
pub collection: String,
pub rkey: String,
pub cid: Option<String>,
}
pub async fn get_record(
State(state): State<AppState>,
Query(input): Query<GetRecordInput>,
) -> Response {
let user_row = if input.repo.starts_with("did:") {
sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
} else {
sqlx::query("SELECT id FROM users WHERE handle = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
};
let user_id: uuid::Uuid = match user_row {
Ok(Some(row)) => row.get("id"),
_ => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Repo not found"})),
)
.into_response();
}
};
let record_row = sqlx::query(
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3",
)
.bind(user_id)
.bind(&input.collection)
.bind(&input.rkey)
.fetch_optional(&state.db)
.await;
let record_cid_str: String = match record_row {
Ok(Some(row)) => row.get("record_cid"),
_ => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Record not found"})),
)
.into_response();
}
};
if let Some(expected_cid) = &input.cid {
if &record_cid_str != expected_cid {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Record CID mismatch"})),
)
.into_response();
}
}
let cid = match Cid::from_str(&record_cid_str) {
Ok(c) => c,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Invalid CID in DB"})),
)
.into_response();
}
};
let block = match state.block_store.get(&cid).await {
Ok(Some(b)) => b,
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Record block not found"})),
)
.into_response();
}
};
let value: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block) {
Ok(v) => v,
Err(e) => {
error!("Failed to deserialize record: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
Json(json!({
"uri": format!("at://{}/{}/{}", input.repo, input.collection, input.rkey),
"cid": record_cid_str,
"value": value
}))
.into_response()
}
#[derive(Deserialize)]
pub struct ListRecordsInput {
pub repo: String,
pub collection: String,
pub limit: Option<i32>,
pub cursor: Option<String>,
#[serde(rename = "rkeyStart")]
pub rkey_start: Option<String>,
#[serde(rename = "rkeyEnd")]
pub rkey_end: Option<String>,
pub reverse: Option<bool>,
}
#[derive(Serialize)]
pub struct ListRecordsOutput {
pub cursor: Option<String>,
pub records: Vec<serde_json::Value>,
}
pub async fn list_records(
State(state): State<AppState>,
Query(input): Query<ListRecordsInput>,
) -> Response {
let user_row = if input.repo.starts_with("did:") {
sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
} else {
sqlx::query("SELECT id FROM users WHERE handle = $1")
.bind(&input.repo)
.fetch_optional(&state.db)
.await
};
let user_id: uuid::Uuid = match user_row {
Ok(Some(row)) => row.get("id"),
_ => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "NotFound", "message": "Repo not found"})),
)
.into_response();
}
};
let limit = input.limit.unwrap_or(50).clamp(1, 100);
let reverse = input.reverse.unwrap_or(false);
// Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination
// TODO: Implement rkeyStart/End and correct cursor logic
let query_str = format!(
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 {} ORDER BY rkey {} LIMIT {}",
if let Some(_c) = &input.cursor {
if reverse {
"AND rkey < $3"
} else {
"AND rkey > $3"
}
} else {
""
},
if reverse { "DESC" } else { "ASC" },
limit
);
let mut query = sqlx::query(&query_str)
.bind(user_id)
.bind(&input.collection);
if let Some(c) = &input.cursor {
query = query.bind(c);
}
let rows = match query.fetch_all(&state.db).await {
Ok(r) => r,
Err(e) => {
error!("Error listing records: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let mut records = Vec::new();
let mut last_rkey = None;
for row in rows {
let rkey: String = row.get("rkey");
let cid_str: String = row.get("record_cid");
last_rkey = Some(rkey.clone());
if let Ok(cid) = Cid::from_str(&cid_str) {
if let Ok(Some(block)) = state.block_store.get(&cid).await {
if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
records.push(json!({
"uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
"cid": cid_str,
"value": value
}));
}
}
}
}
Json(ListRecordsOutput {
cursor: last_rkey,
records,
})
.into_response()
}

View File

@@ -0,0 +1,591 @@
use crate::state::AppState;
use axum::{
Json,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use chrono::Utc;
use cid::Cid;
use jacquard::types::{
did::Did,
integer::LimitedU32,
string::{Nsid, Tid},
};
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::Row;
use std::str::FromStr;
use std::sync::Arc;
use tracing::error;
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct CreateRecordInput {
pub repo: String,
pub collection: String,
pub rkey: Option<String>,
pub validate: Option<bool>,
pub record: serde_json::Value,
#[serde(rename = "swapCommit")]
pub swap_commit: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateRecordOutput {
pub uri: String,
pub cid: String,
}
pub async fn create_record(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<CreateRecordInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (
row.get::<String, _>("did"),
row.get::<Vec<u8>, _>("key_bytes"),
),
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
)
.into_response();
}
if input.repo != did {
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "User not found"})),
)
.into_response();
}
};
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let current_root_cid = match repo_root_query {
Ok(Some(row)) => {
let cid_str: String = row.get("repo_root_cid");
Cid::from_str(&cid_str).ok()
}
_ => None,
};
if current_root_cid.is_none() {
error!("Repo root not found for user {}", did);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
)
.into_response();
}
let current_root_cid = current_root_cid.unwrap();
let commit_bytes = match state.block_store.get(&current_root_cid).await {
Ok(Some(b)) => b,
Ok(None) => {
error!("Commit block not found: {}", current_root_cid);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
Err(e) => {
error!("Failed to load commit block: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
Err(e) => {
error!("Failed to parse commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let mst_root = commit.data;
let store = Arc::new(state.block_store.clone());
let mst = Mst::load(store.clone(), mst_root, None);
let collection_nsid = match input.collection.parse::<Nsid>() {
Ok(n) => n,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidCollection"})),
)
.into_response();
}
};
let rkey = input
.rkey
.unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
let mut record_bytes = Vec::new();
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
error!("Error serializing record: {:?}", e);
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})),
)
.into_response();
}
let record_cid = match state.block_store.put(&record_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save record block: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let key = format!("{}/{}", collection_nsid, rkey);
if let Err(e) = mst.update(&key, record_cid).await {
error!("Failed to update MST: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
let new_mst_root = match mst.root().await {
Ok(c) => c,
Err(e) => {
error!("Failed to get new MST root: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
)
.into_response();
}
};
let rev = Tid::now(LimitedU32::MIN);
let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid));
let new_commit_bytes = match new_commit.to_cbor() {
Ok(b) => b,
Err(e) => {
error!("Failed to serialize new commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save new commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
.bind(new_root_cid.to_string())
.bind(user_id)
.execute(&state.db)
.await;
if let Err(e) = update_repo {
error!("Failed to update repo root in DB: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
let record_insert = sqlx::query(
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
)
.bind(user_id)
.bind(&input.collection)
.bind(&rkey)
.bind(record_cid.to_string())
.execute(&state.db)
.await;
if let Err(e) = record_insert {
error!("Error inserting record index: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to index record"})),
)
.into_response();
}
let output = CreateRecordOutput {
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
cid: record_cid.to_string(),
};
(StatusCode::OK, Json(output)).into_response()
}
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct PutRecordInput {
pub repo: String,
pub collection: String,
pub rkey: String,
pub validate: Option<bool>,
pub record: serde_json::Value,
#[serde(rename = "swapCommit")]
pub swap_commit: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PutRecordOutput {
pub uri: String,
pub cid: String,
}
pub async fn put_record(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<PutRecordInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
)
.bind(&token)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let (did, key_bytes) = match session {
Some(row) => (
row.get::<String, _>("did"),
row.get::<Vec<u8>, _>("key_bytes"),
),
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
)
.into_response();
}
if input.repo != did {
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
}
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
let user_id: uuid::Uuid = match user_query {
Ok(Some(row)) => row.get("id"),
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "User not found"})),
)
.into_response();
}
};
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
.bind(user_id)
.fetch_optional(&state.db)
.await;
let current_root_cid = match repo_root_query {
Ok(Some(row)) => {
let cid_str: String = row.get("repo_root_cid");
Cid::from_str(&cid_str).ok()
}
_ => None,
};
if current_root_cid.is_none() {
error!("Repo root not found for user {}", did);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
)
.into_response();
}
let current_root_cid = current_root_cid.unwrap();
let commit_bytes = match state.block_store.get(&current_root_cid).await {
Ok(Some(b)) => b,
Ok(None) => {
error!("Commit block not found: {}", current_root_cid);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Commit block not found"})),
)
.into_response();
}
Err(e) => {
error!("Failed to load commit block: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to load commit block"})),
)
.into_response();
}
};
let commit = match Commit::from_cbor(&commit_bytes) {
Ok(c) => c,
Err(e) => {
error!("Failed to parse commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
)
.into_response();
}
};
let mst_root = commit.data;
let store = Arc::new(state.block_store.clone());
let mst = Mst::load(store.clone(), mst_root, None);
let collection_nsid = match input.collection.parse::<Nsid>() {
Ok(n) => n,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidCollection"})),
)
.into_response();
}
};
let rkey = input.rkey.clone();
let mut record_bytes = Vec::new();
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
error!("Error serializing record: {:?}", e);
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})),
)
.into_response();
}
let record_cid = match state.block_store.put(&record_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save record block: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to save record block"})),
)
.into_response();
}
};
let key = format!("{}/{}", collection_nsid, rkey);
if let Err(e) = mst.update(&key, record_cid).await {
error!("Failed to update MST: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response();
}
let new_mst_root = match mst.root().await {
Ok(c) => c,
Err(e) => {
error!("Failed to get new MST root: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})),
)
.into_response();
}
};
let did_obj = match Did::new(&did) {
Ok(d) => d,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
)
.into_response();
}
};
let rev = Tid::now(LimitedU32::MIN);
let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid));
let new_commit_bytes = match new_commit.to_cbor() {
Ok(b) => b,
Err(e) => {
error!("Failed to serialize new commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(
json!({"error": "InternalError", "message": "Failed to serialize new commit"}),
),
)
.into_response();
}
};
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
Ok(c) => c,
Err(e) => {
error!("Failed to save new commit: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to save new commit"})),
)
.into_response();
}
};
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
.bind(new_root_cid.to_string())
.bind(user_id)
.execute(&state.db)
.await;
if let Err(e) = update_repo {
error!("Failed to update repo root in DB: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"})),
)
.into_response();
}
let record_insert = sqlx::query(
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
)
.bind(user_id)
.bind(&input.collection)
.bind(&rkey)
.bind(record_cid.to_string())
.execute(&state.db)
.await;
if let Err(e) = record_insert {
error!("Error inserting record index: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Failed to index record"})),
)
.into_response();
}
let output = PutRecordOutput {
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
cid: record_cid.to_string(),
};
(StatusCode::OK, Json(output)).into_response()
}

25
src/api/server/meta.rs Normal file
View File

@@ -0,0 +1,25 @@
use crate::state::AppState;
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
use serde_json::json;
use tracing::error;
pub async fn describe_server() -> impl IntoResponse {
let domains_str =
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect();
Json(json!({
"availableUserDomains": domains
}))
}
pub async fn health(State(state): State<AppState>) -> impl IntoResponse {
match sqlx::query("SELECT 1").execute(&state.db).await {
Ok(_) => (StatusCode::OK, "OK"),
Err(e) => {
error!("Health check failed: {:?}", e);
(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
}
}
}

5
src/api/server/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod meta;
pub mod session;
pub use meta::{describe_server, health};
pub use session::{create_session, delete_session, get_session, refresh_session};

View File

@@ -1,34 +1,15 @@
use crate::state::AppState;
use axum::{
extract::State,
Json,
response::{IntoResponse, Response},
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use bcrypt::verify;
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::state::AppState;
use sqlx::Row;
use bcrypt::verify;
use tracing::{info, error, warn};
pub async fn describe_server() -> impl IntoResponse {
let domains_str = std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect();
Json(json!({
"availableUserDomains": domains
}))
}
pub async fn health(State(state): State<AppState>) -> impl IntoResponse {
match sqlx::query("SELECT 1").execute(&state.db).await {
Ok(_) => (StatusCode::OK, "OK"),
Err(e) => {
error!("Health check failed: {:?}", e);
(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
}
}
}
use tracing::{error, info, warn};
#[derive(Deserialize)]
pub struct CreateSessionInput {
@@ -69,7 +50,11 @@ pub async fn create_session(
Ok(t) => t,
Err(e) => {
error!("Failed to create access token: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
@@ -77,45 +62,70 @@ pub async fn create_session(
Ok(t) => t,
Err(e) => {
error!("Failed to create refresh token: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
.bind(&access_jwt)
.bind(&refresh_jwt)
.bind(&did)
.execute(&state.db)
.await;
let session_insert = sqlx::query(
"INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)",
)
.bind(&access_jwt)
.bind(&refresh_jwt)
.bind(&did)
.execute(&state.db)
.await;
match session_insert {
Ok(_) => {
return (StatusCode::OK, Json(CreateSessionOutput {
access_jwt,
refresh_jwt,
handle,
did,
})).into_response();
},
return (
StatusCode::OK,
Json(CreateSessionOutput {
access_jwt,
refresh_jwt,
handle,
did,
}),
)
.into_response();
}
Err(e) => {
error!("Failed to insert session: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
} else {
warn!("Password verification failed for identifier: {}", input.identifier);
warn!(
"Password verification failed for identifier: {}",
input.identifier
);
}
},
}
Ok(None) => {
warn!("User not found for identifier: {}", input.identifier);
},
}
Err(e) => {
error!("Database error fetching user: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
(StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"}))).into_response()
(
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"})),
)
.into_response()
}
pub async fn get_session(
@@ -124,10 +134,18 @@ pub async fn get_session(
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let result = sqlx::query(
r#"
@@ -136,7 +154,7 @@ pub async fn get_session(
JOIN users u ON s.did = u.did
JOIN user_keys k ON u.id = k.user_id
WHERE s.access_jwt = $1
"#
"#,
)
.bind(&token)
.fetch_optional(&state.db)
@@ -150,22 +168,34 @@ pub async fn get_session(
let key_bytes: Vec<u8> = row.get("key_bytes");
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
}
return (StatusCode::OK, Json(json!({
"handle": handle,
"did": did,
"email": email,
"didDoc": {}
}))).into_response();
},
return (
StatusCode::OK,
Json(json!({
"handle": handle,
"did": did,
"email": email,
"didDoc": {}
})),
)
.into_response();
}
Ok(None) => {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response();
},
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
Err(e) => {
error!("Database error in get_session: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
}
@@ -176,10 +206,18 @@ pub async fn delete_session(
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let result = sqlx::query("DELETE FROM sessions WHERE access_jwt = $1")
.bind(token)
@@ -191,13 +229,17 @@ pub async fn delete_session(
if res.rows_affected() > 0 {
return (StatusCode::OK, Json(json!({}))).into_response();
}
},
}
Err(e) => {
error!("Database error in delete_session: {:?}", e);
}
}
(StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response()
(
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response()
}
pub async fn refresh_session(
@@ -206,10 +248,18 @@ pub async fn refresh_session(
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let refresh_token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
let refresh_token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query(
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.refresh_jwt = $1"
@@ -231,27 +281,37 @@ pub async fn refresh_session(
Ok(t) => t,
Err(e) => {
error!("Failed to create access token: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let new_refresh_jwt = match crate::auth::create_refresh_token(&did, &key_bytes) {
Ok(t) => t,
Err(e) => {
error!("Failed to create refresh token: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let update = sqlx::query("UPDATE sessions SET access_jwt = $1, refresh_jwt = $2 WHERE refresh_jwt = $3")
.bind(&new_access_jwt)
.bind(&new_refresh_jwt)
.bind(&refresh_token)
.execute(&state.db)
.await;
let update = sqlx::query(
"UPDATE sessions SET access_jwt = $1, refresh_jwt = $2 WHERE refresh_jwt = $3",
)
.bind(&new_access_jwt)
.bind(&new_refresh_jwt)
.bind(&refresh_token)
.execute(&state.db)
.await;
match update {
Ok(_) => {
let user = sqlx::query("SELECT handle FROM users WHERE did = $1")
let user = sqlx::query("SELECT handle FROM users WHERE did = $1")
.bind(&did)
.fetch_optional(&state.db)
.await;
@@ -259,36 +319,59 @@ pub async fn refresh_session(
match user {
Ok(Some(u)) => {
let handle: String = u.get("handle");
return (StatusCode::OK, Json(json!({
"accessJwt": new_access_jwt,
"refreshJwt": new_refresh_jwt,
"handle": handle,
"did": did
}))).into_response();
},
return (
StatusCode::OK,
Json(json!({
"accessJwt": new_access_jwt,
"refreshJwt": new_refresh_jwt,
"handle": handle,
"did": did
})),
)
.into_response();
}
Ok(None) => {
error!("User not found for existing session: {}", did);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
},
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
Err(e) => {
error!("Database error fetching user: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
},
}
Err(e) => {
error!("Database error updating session: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
},
}
Ok(None) => {
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"}))).into_response();
},
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"})),
)
.into_response();
}
Err(e) => {
error!("Database error fetching session: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
}

View File

@@ -1,157 +0,0 @@
use serde::{Deserialize, Serialize};
use chrono::{Utc, Duration};
use k256::ecdsa::{SigningKey, VerifyingKey, signature::Signer, signature::Verifier, Signature};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use anyhow::{Context, Result, anyhow};
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub iss: String,
pub sub: String,
pub aud: String,
pub exp: usize,
pub iat: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lxm: Option<String>,
pub jti: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct Header {
alg: String,
typ: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct UnsafeClaims {
iss: String,
sub: Option<String>,
}
// fancy boy TokenData equivalent for compatibility/structure
pub struct TokenData<T> {
pub claims: T,
}
pub fn get_did_from_token(token: &str) -> Result<String, String> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err("Invalid token format".to_string());
}
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1])
.map_err(|e| format!("Base64 decode failed: {}", e))?;
let claims: UnsafeClaims = serde_json::from_slice(&payload_bytes)
.map_err(|e| format!("JSON decode failed: {}", e))?;
Ok(claims.sub.unwrap_or(claims.iss))
}
pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String, anyhow::Error> {
create_signed_token(did, "access", key_bytes, Duration::minutes(15))
}
pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String, anyhow::Error> {
create_signed_token(did, "refresh", key_bytes, Duration::days(7))
}
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String, anyhow::Error> {
let signing_key = SigningKey::from_slice(key_bytes)?;
let expiration = Utc::now()
.checked_add_signed(Duration::seconds(60))
.expect("valid timestamp")
.timestamp();
let claims = Claims {
iss: did.to_owned(),
sub: did.to_owned(),
aud: aud.to_owned(),
exp: expiration as usize,
iat: Utc::now().timestamp() as usize,
scope: None,
lxm: Some(lxm.to_string()),
jti: uuid::Uuid::new_v4().to_string(),
};
sign_claims(claims, &signing_key)
}
fn create_signed_token(did: &str, scope: &str, key_bytes: &[u8], duration: Duration) -> Result<String, anyhow::Error> {
let signing_key = SigningKey::from_slice(key_bytes)?;
let expiration = Utc::now()
.checked_add_signed(duration)
.expect("valid timestamp")
.timestamp();
let claims = Claims {
iss: did.to_owned(),
sub: did.to_owned(),
aud: format!("did:web:{}", std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())),
exp: expiration as usize,
iat: Utc::now().timestamp() as usize,
scope: Some(scope.to_string()),
lxm: None,
jti: uuid::Uuid::new_v4().to_string(),
};
sign_claims(claims, &signing_key)
}
fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String, anyhow::Error> {
let header = Header {
alg: "ES256K".to_string(),
typ: "JWT".to_string(),
};
let header_json = serde_json::to_string(&header)?;
let claims_json = serde_json::to_string(&claims)?;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
let message = format!("{}.{}", header_b64, claims_b64);
let signature: Signature = key.sign(message.as_bytes());
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
Ok(format!("{}.{}", message, signature_b64))
}
pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>, anyhow::Error> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(anyhow!("Invalid token format"));
}
let header_b64 = parts[0];
let claims_b64 = parts[1];
let signature_b64 = parts[2];
let signature_bytes = URL_SAFE_NO_PAD.decode(signature_b64)
.context("Base64 decode of signature failed")?;
let signature = Signature::from_slice(&signature_bytes)
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
let signing_key = SigningKey::from_slice(key_bytes)?;
let verifying_key = VerifyingKey::from(&signing_key);
let message = format!("{}.{}", header_b64, claims_b64);
verifying_key.verify(message.as_bytes(), &signature)
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
let claims_bytes = URL_SAFE_NO_PAD.decode(claims_b64)
.context("Base64 decode of claims failed")?;
let claims: Claims = serde_json::from_slice(&claims_bytes)
.context("JSON decode of claims failed")?;
let now = Utc::now().timestamp() as usize;
if claims.exp < now {
return Err(anyhow!("Token expired"));
}
Ok(TokenData { claims })
}

38
src/auth/mod.rs Normal file
View File

@@ -0,0 +1,38 @@
use serde::{Deserialize, Serialize};
pub mod token;
pub mod verify;
pub use token::{create_access_token, create_refresh_token, create_service_token};
pub use verify::{get_did_from_token, verify_token};
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub iss: String,
pub sub: String,
pub aud: String,
pub exp: usize,
pub iat: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lxm: Option<String>,
pub jti: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Header {
pub alg: String,
pub typ: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UnsafeClaims {
pub iss: String,
pub sub: Option<String>,
}
// fancy boy TokenData equivalent for compatibility/structure
pub struct TokenData<T> {
pub claims: T,
}

86
src/auth/token.rs Normal file
View File

@@ -0,0 +1,86 @@
use super::{Claims, Header};
use anyhow::Result;
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::{Duration, Utc};
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
use uuid;
pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> {
create_signed_token(did, "access", key_bytes, Duration::minutes(15))
}
pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> {
create_signed_token(did, "refresh", key_bytes, Duration::days(7))
}
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> {
let signing_key = SigningKey::from_slice(key_bytes)?;
let expiration = Utc::now()
.checked_add_signed(Duration::seconds(60))
.expect("valid timestamp")
.timestamp();
let claims = Claims {
iss: did.to_owned(),
sub: did.to_owned(),
aud: aud.to_owned(),
exp: expiration as usize,
iat: Utc::now().timestamp() as usize,
scope: None,
lxm: Some(lxm.to_string()),
jti: uuid::Uuid::new_v4().to_string(),
};
sign_claims(claims, &signing_key)
}
fn create_signed_token(
did: &str,
scope: &str,
key_bytes: &[u8],
duration: Duration,
) -> Result<String> {
let signing_key = SigningKey::from_slice(key_bytes)?;
let expiration = Utc::now()
.checked_add_signed(duration)
.expect("valid timestamp")
.timestamp();
let claims = Claims {
iss: did.to_owned(),
sub: did.to_owned(),
aud: format!(
"did:web:{}",
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
),
exp: expiration as usize,
iat: Utc::now().timestamp() as usize,
scope: Some(scope.to_string()),
lxm: None,
jti: uuid::Uuid::new_v4().to_string(),
};
sign_claims(claims, &signing_key)
}
fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> {
let header = Header {
alg: "ES256K".to_string(),
typ: "JWT".to_string(),
};
let header_json = serde_json::to_string(&header)?;
let claims_json = serde_json::to_string(&claims)?;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
let message = format!("{}.{}", header_b64, claims_b64);
let signature: Signature = key.sign(message.as_bytes());
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
Ok(format!("{}.{}", message, signature_b64))
}

60
src/auth/verify.rs Normal file
View File

@@ -0,0 +1,60 @@
use super::{Claims, TokenData, UnsafeClaims};
use anyhow::{Context, Result, anyhow};
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::Utc;
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
pub fn get_did_from_token(token: &str) -> Result<String, String> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err("Invalid token format".to_string());
}
let payload_bytes = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| format!("Base64 decode failed: {}", e))?;
let claims: UnsafeClaims =
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
Ok(claims.sub.unwrap_or(claims.iss))
}
pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(anyhow!("Invalid token format"));
}
let header_b64 = parts[0];
let claims_b64 = parts[1];
let signature_b64 = parts[2];
let signature_bytes = URL_SAFE_NO_PAD
.decode(signature_b64)
.context("Base64 decode of signature failed")?;
let signature = Signature::from_slice(&signature_bytes)
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
let signing_key = SigningKey::from_slice(key_bytes)?;
let verifying_key = VerifyingKey::from(&signing_key);
let message = format!("{}.{}", header_b64, claims_b64);
verifying_key
.verify(message.as_bytes(), &signature)
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
let claims_bytes = URL_SAFE_NO_PAD
.decode(claims_b64)
.context("Base64 decode of claims failed")?;
let claims: Claims =
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
let now = Utc::now().timestamp() as usize;
if claims.exp < now {
return Err(anyhow!("Token expired"));
}
Ok(TokenData { claims })
}

View File

@@ -1,31 +1,70 @@
pub mod api;
pub mod state;
pub mod auth;
pub mod repo;
pub mod state;
pub mod storage;
use axum::{
routing::{get, post, any},
Router,
routing::{any, get, post},
};
use state::AppState;
pub fn app(state: AppState) -> Router {
Router::new()
.route("/health", get(api::server::health))
.route("/xrpc/com.atproto.server.describeServer", get(api::server::describe_server))
.route("/xrpc/com.atproto.server.createAccount", post(api::identity::create_account))
.route("/xrpc/com.atproto.server.createSession", post(api::server::create_session))
.route("/xrpc/com.atproto.server.getSession", get(api::server::get_session))
.route("/xrpc/com.atproto.server.deleteSession", post(api::server::delete_session))
.route("/xrpc/com.atproto.server.refreshSession", post(api::server::refresh_session))
.route("/xrpc/com.atproto.repo.createRecord", post(api::repo::create_record))
.route("/xrpc/com.atproto.repo.putRecord", post(api::repo::put_record))
.route("/xrpc/com.atproto.repo.getRecord", get(api::repo::get_record))
.route("/xrpc/com.atproto.repo.deleteRecord", post(api::repo::delete_record))
.route("/xrpc/com.atproto.repo.listRecords", get(api::repo::list_records))
.route("/xrpc/com.atproto.repo.describeRepo", get(api::repo::describe_repo))
.route("/xrpc/com.atproto.repo.uploadBlob", post(api::repo::upload_blob))
.route(
"/xrpc/com.atproto.server.describeServer",
get(api::server::describe_server),
)
.route(
"/xrpc/com.atproto.server.createAccount",
post(api::identity::create_account),
)
.route(
"/xrpc/com.atproto.server.createSession",
post(api::server::create_session),
)
.route(
"/xrpc/com.atproto.server.getSession",
get(api::server::get_session),
)
.route(
"/xrpc/com.atproto.server.deleteSession",
post(api::server::delete_session),
)
.route(
"/xrpc/com.atproto.server.refreshSession",
post(api::server::refresh_session),
)
.route(
"/xrpc/com.atproto.repo.createRecord",
post(api::repo::create_record),
)
.route(
"/xrpc/com.atproto.repo.putRecord",
post(api::repo::put_record),
)
.route(
"/xrpc/com.atproto.repo.getRecord",
get(api::repo::get_record),
)
.route(
"/xrpc/com.atproto.repo.deleteRecord",
post(api::repo::delete_record),
)
.route(
"/xrpc/com.atproto.repo.listRecords",
get(api::repo::list_records),
)
.route(
"/xrpc/com.atproto.repo.describeRepo",
get(api::repo::describe_repo),
)
.route(
"/xrpc/com.atproto.repo.uploadBlob",
post(api::repo::upload_blob),
)
.route("/.well-known/did.json", get(api::identity::well_known_did))
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
.route("/xrpc/{*method}", any(api::proxy::proxy_handler))

View File

@@ -1,5 +1,5 @@
use std::net::SocketAddr;
use bspds::state::AppState;
use std::net::SocketAddr;
use tracing::info;
#[tokio::main]

View File

@@ -1,11 +1,11 @@
use jacquard_repo::storage::BlockStore;
use bytes::Bytes;
use cid::Cid;
use jacquard_repo::error::RepoError;
use jacquard_repo::repo::CommitData;
use cid::Cid;
use sqlx::{PgPool, Row};
use bytes::Bytes;
use sha2::{Sha256, Digest};
use jacquard_repo::storage::BlockStore;
use multihash::Multihash;
use sha2::{Digest, Sha256};
use sqlx::{PgPool, Row};
#[derive(Clone)]
pub struct PostgresBlockStore {
@@ -31,7 +31,7 @@ impl BlockStore for PostgresBlockStore {
Some(row) => {
let data: Vec<u8> = row.get("data");
Ok(Some(Bytes::from(data)))
},
}
None => Ok(None),
}
}
@@ -65,16 +65,21 @@ impl BlockStore for PostgresBlockStore {
Ok(row.is_some())
}
async fn put_many(&self, blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send) -> Result<(), RepoError> {
async fn put_many(
&self,
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
) -> Result<(), RepoError> {
let blocks: Vec<_> = blocks.into_iter().collect();
for (cid, data) in blocks {
let cid_bytes = cid.to_bytes();
sqlx::query("INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING")
.bind(cid_bytes)
.bind(data.as_ref())
.execute(&self.pool)
.await
.map_err(|e| RepoError::storage(e))?;
sqlx::query(
"INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING",
)
.bind(cid_bytes)
.bind(data.as_ref())
.execute(&self.pool)
.await
.map_err(|e| RepoError::storage(e))?;
}
Ok(())
}

View File

@@ -1,6 +1,6 @@
use sqlx::PgPool;
use crate::repo::PostgresBlockStore;
use crate::storage::{BlobStorage, S3BlobStorage};
use sqlx::PgPool;
use std::sync::Arc;
#[derive(Clone)]
@@ -14,6 +14,10 @@ impl AppState {
pub async fn new(db: PgPool) -> Self {
let block_store = PostgresBlockStore::new(db.clone());
let blob_store = S3BlobStorage::new().await;
Self { db, block_store, blob_store: Arc::new(blob_store) }
Self {
db,
block_store,
blob_store: Arc::new(blob_store),
}
}
}

View File

@@ -1,9 +1,9 @@
use async_trait::async_trait;
use thiserror::Error;
use aws_config::BehaviorVersion;
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_s3::Client;
use aws_sdk_s3::primitives::ByteStream;
use aws_config::meta::region::RegionProviderChain;
use aws_config::BehaviorVersion;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum StorageError {
@@ -55,7 +55,8 @@ impl S3BlobStorage {
#[async_trait]
impl BlobStorage for S3BlobStorage {
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> {
self.client.put_object()
self.client
.put_object()
.bucket(&self.bucket)
.key(key)
.body(ByteStream::from(data.to_vec()))
@@ -66,14 +67,19 @@ impl BlobStorage for S3BlobStorage {
}
async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> {
let resp = self.client.get_object()
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| StorageError::S3(e.to_string()))?;
let data = resp.body.collect().await
let data = resp
.body
.collect()
.await
.map_err(|e| StorageError::S3(e.to_string()))?
.into_bytes();
@@ -81,7 +87,8 @@ impl BlobStorage for S3BlobStorage {
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
self.client.delete_object()
self.client
.delete_object()
.bucket(&self.bucket)
.key(key)
.send()

View File

@@ -1,10 +1,10 @@
use bspds::auth;
use k256::SecretKey;
use rand::rngs::OsRng;
use chrono::{Utc, Duration};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use serde_json::json;
use bspds::auth;
use chrono::{Duration, Utc};
use k256::SecretKey;
use k256::ecdsa::{SigningKey, signature::Signer};
use rand::rngs::OsRng;
use serde_json::json;
#[test]
fn test_jwt_flow() {
@@ -24,7 +24,8 @@ fn test_jwt_flow() {
let aud = "did:web:service";
let lxm = "com.example.test";
let s_token = auth::create_service_token(did, aud, lxm, &key_bytes).expect("create service token");
let s_token =
auth::create_service_token(did, aud, lxm, &key_bytes).expect("create service token");
let s_data = auth::verify_token(&s_token, &key_bytes).expect("verify service token");
assert_eq!(s_data.claims.aud, aud);
assert_eq!(s_data.claims.lxm, Some(lxm.to_string()));

View File

@@ -1,22 +1,22 @@
use reqwest::{header, Client, StatusCode};
use serde_json::{json, Value};
use aws_config::BehaviorVersion;
use aws_sdk_s3::Client as S3Client;
use aws_sdk_s3::config::Credentials;
use bspds::state::AppState;
use chrono::Utc;
use reqwest::{Client, StatusCode, header};
use serde_json::{Value, json};
use sqlx::postgres::PgPoolOptions;
#[allow(unused_imports)]
use std::collections::HashMap;
use std::sync::OnceLock;
#[allow(unused_imports)]
use std::time::Duration;
use std::sync::OnceLock;
use bspds::state::AppState;
use sqlx::postgres::PgPoolOptions;
use tokio::net::TcpListener;
use testcontainers::{runners::AsyncRunner, ContainerAsync, ImageExt, GenericImage};
use testcontainers::core::ContainerPort;
use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner};
use testcontainers_modules::postgres::Postgres;
use aws_sdk_s3::Client as S3Client;
use aws_config::BehaviorVersion;
use aws_sdk_s3::config::Credentials;
use wiremock::{MockServer, Mock, ResponseTemplate};
use tokio::net::TcpListener;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
static SERVER_URL: OnceLock<String> = OnceLock::new();
static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
@@ -46,7 +46,12 @@ pub async fn base_url() -> &'static str {
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
if podman_sock.exists() {
unsafe { std::env::set_var("DOCKER_HOST", format!("unix://{}", podman_sock.display())); }
unsafe {
std::env::set_var(
"DOCKER_HOST",
format!("unix://{}", podman_sock.display()),
);
}
}
}
}
@@ -62,7 +67,10 @@ pub async fn base_url() -> &'static str {
.await
.expect("Failed to start MinIO");
let s3_port = s3_container.get_host_port_ipv4(9000).await.expect("Failed to get S3 port");
let s3_port = s3_container
.get_host_port_ipv4(9000)
.await
.expect("Failed to get S3 port");
let s3_endpoint = format!("http://127.0.0.1:{}", s3_port);
unsafe {
@@ -76,7 +84,13 @@ pub async fn base_url() -> &'static str {
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
.region("us-east-1")
.endpoint_url(&s3_endpoint)
.credentials_provider(Credentials::new("minioadmin", "minioadmin", None, None, "test"))
.credentials_provider(Credentials::new(
"minioadmin",
"minioadmin",
None,
None,
"test",
))
.load()
.await;
@@ -108,15 +122,24 @@ pub async fn base_url() -> &'static str {
.mount(&mock_server)
.await;
unsafe { std::env::set_var("APPVIEW_URL", mock_server.uri()); }
unsafe {
std::env::set_var("APPVIEW_URL", mock_server.uri());
}
MOCK_APPVIEW.set(mock_server).ok();
S3_CONTAINER.set(s3_container).ok();
let container = Postgres::default().with_tag("18-alpine").start().await.expect("Failed to start Postgres");
let container = Postgres::default()
.with_tag("18-alpine")
.start()
.await
.expect("Failed to start Postgres");
let connection_string = format!(
"postgres://postgres:postgres@127.0.0.1:{}/postgres",
container.get_host_port_ipv4(5432).await.expect("Failed to get port")
container
.get_host_port_ipv4(5432)
.await
.expect("Failed to get port")
);
DB_CONTAINER.set(container).ok();
@@ -157,7 +180,11 @@ async fn spawn_app(database_url: String) -> String {
#[allow(dead_code)]
pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.uploadBlob",
base_url().await
))
.header(header::CONTENT_TYPE, mime)
.bearer_auth(AUTH_TOKEN)
.body(data)
@@ -170,12 +197,11 @@ pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'stati
body["blob"].clone()
}
#[allow(dead_code)]
pub async fn create_test_post(
client: &Client,
text: &str,
reply_to: Option<Value>
reply_to: Option<Value>,
) -> (String, String, String) {
let collection = "app.bsky.feed.post";
let mut record = json!({
@@ -194,7 +220,11 @@ pub async fn create_test_post(
"record": record
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.createRecord",
base_url().await
))
.bearer_auth(AUTH_TOKEN)
.json(&payload)
.send()
@@ -202,11 +232,24 @@ pub async fn create_test_post(
.expect("Failed to send createRecord");
assert_eq!(res.status(), StatusCode::OK, "Failed to create post record");
let body: Value = res.json().await.expect("createRecord response was not JSON");
let body: Value = res
.json()
.await
.expect("createRecord response was not JSON");
let uri = body["uri"].as_str().expect("Response had no URI").to_string();
let cid = body["cid"].as_str().expect("Response had no CID").to_string();
let rkey = uri.split('/').last().expect("URI was malformed").to_string();
let uri = body["uri"]
.as_str()
.expect("Response had no URI")
.to_string();
let cid = body["cid"]
.as_str()
.expect("Response had no CID")
.to_string();
let rkey = uri
.split('/')
.last()
.expect("URI was malformed")
.to_string();
(uri, cid, rkey)
}
@@ -220,7 +263,11 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) {
"password": "password"
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
@@ -231,7 +278,10 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) {
}
let body: Value = res.json().await.expect("Invalid JSON");
let access_jwt = body["accessJwt"].as_str().expect("No accessJwt").to_string();
let access_jwt = body["accessJwt"]
.as_str()
.expect("No accessJwt")
.to_string();
let did = body["did"].as_str().expect("No did").to_string();
(access_jwt, did)
}

View File

@@ -1,9 +1,9 @@
mod common;
use common::*;
use reqwest::StatusCode;
use serde_json::{json, Value};
use wiremock::{MockServer, Mock, ResponseTemplate};
use serde_json::{Value, json};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
// #[tokio::test]
// async fn test_resolve_handle() {
@@ -23,7 +23,8 @@ use wiremock::matchers::{method, path};
#[tokio::test]
async fn test_well_known_did() {
let client = client();
let res = client.get(format!("{}/.well-known/did.json", base_url().await))
let res = client
.get(format!("{}/.well-known/did.json", base_url().await))
.send()
.await
.expect("Failed to send request");
@@ -71,7 +72,11 @@ async fn test_create_did_web_account_and_resolve() {
"did": did
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
@@ -79,13 +84,20 @@ async fn test_create_did_web_account_and_resolve() {
if res.status() != StatusCode::OK {
let status = res.status();
let body: Value = res.json().await.unwrap_or(json!({"error": "could not parse body"}));
let body: Value = res
.json()
.await
.unwrap_or(json!({"error": "could not parse body"}));
panic!("createAccount failed with status {}: {:?}", status, body);
}
let body: Value = res.json().await.expect("createAccount response was not JSON");
let body: Value = res
.json()
.await
.expect("createAccount response was not JSON");
assert_eq!(body["did"], did);
let res = client.get(format!("{}/u/{}/did.json", base_url().await, handle))
let res = client
.get(format!("{}/u/{}/did.json", base_url().await, handle))
.send()
.await
.expect("Failed to fetch DID doc");
@@ -111,14 +123,22 @@ async fn test_create_account_duplicate_handle() {
"password": "password"
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
@@ -143,7 +163,11 @@ async fn test_did_web_lifecycle() {
"did": did
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&create_payload)
.send()
.await
@@ -162,7 +186,11 @@ async fn test_did_web_lifecycle() {
"identifier": handle,
"password": "password"
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createSession",
base_url().await
))
.json(&login_payload)
.send()
.await

View File

@@ -1,10 +1,9 @@
mod common;
use common::*;
use reqwest::{Client, StatusCode};
use serde_json::{json, Value};
use chrono::Utc;
#[allow(unused_imports)]
use reqwest;
use serde_json::{Value, json};
use std::time::Duration;
async fn setup_new_user(handle_prefix: &str) -> (String, String) {
@@ -19,20 +18,36 @@ async fn setup_new_user(handle_prefix: &str) -> (String, String) {
"email": email,
"password": password
});
let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&create_account_payload)
.send()
.await
.expect("setup_new_user: Failed to send createAccount");
if create_res.status() != StatusCode::OK {
panic!("setup_new_user: Failed to create account: {:?}", create_res.text().await);
if create_res.status() != reqwest::StatusCode::OK {
panic!(
"setup_new_user: Failed to create account: {:?}",
create_res.text().await
);
}
let create_body: Value = create_res.json().await.expect("setup_new_user: createAccount response was not JSON");
let create_body: Value = create_res
.json()
.await
.expect("setup_new_user: createAccount response was not JSON");
let new_did = create_body["did"].as_str().expect("setup_new_user: Response had no DID").to_string();
let new_jwt = create_body["accessJwt"].as_str().expect("setup_new_user: Response had no accessJwt").to_string();
let new_did = create_body["did"]
.as_str()
.expect("setup_new_user: Response had no DID")
.to_string();
let new_jwt = create_body["accessJwt"]
.as_str()
.expect("setup_new_user: Response had no accessJwt")
.to_string();
(new_did, new_jwt)
}
@@ -59,35 +74,59 @@ async fn test_post_crud_lifecycle() {
}
});
let create_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(&jwt)
.json(&create_payload)
.send()
.await
.expect("Failed to send create request");
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record");
let create_body: Value = create_res.json().await.expect("create response was not JSON");
let uri = create_body["uri"].as_str().unwrap();
if create_res.status() != reqwest::StatusCode::OK {
let status = create_res.status();
let body = create_res
.text()
.await
.unwrap_or_else(|_| "Could not get body".to_string());
panic!(
"Failed to create record. Status: {}, Body: {}",
status, body
);
}
let create_body: Value = create_res
.json()
.await
.expect("create response was not JSON");
let uri = create_body["uri"].as_str().unwrap();
let params = [
("repo", did.as_str()),
("collection", collection),
("rkey", &rkey),
];
let get_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
let get_res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord",
base_url().await
))
.query(&params)
.send()
.await
.expect("Failed to send get request");
assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create");
assert_eq!(
get_res.status(),
reqwest::StatusCode::OK,
"Failed to get record after create"
);
let get_body: Value = get_res.json().await.expect("get response was not JSON");
assert_eq!(get_body["uri"], uri);
assert_eq!(get_body["value"]["text"], original_text);
let updated_text = "This post has been updated.";
let update_payload = json!({
"repo": did,
@@ -100,26 +139,46 @@ async fn test_post_crud_lifecycle() {
}
});
let update_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let update_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(&jwt)
.json(&update_payload)
.send()
.await
.expect("Failed to send update request");
assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record");
assert_eq!(
update_res.status(),
reqwest::StatusCode::OK,
"Failed to update record"
);
let get_updated_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
let get_updated_res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord",
base_url().await
))
.query(&params)
.send()
.await
.expect("Failed to send get-after-update request");
assert_eq!(get_updated_res.status(), StatusCode::OK, "Failed to get record after update");
let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON");
assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated");
assert_eq!(
get_updated_res.status(),
reqwest::StatusCode::OK,
"Failed to get record after update"
);
let get_updated_body: Value = get_updated_res
.json()
.await
.expect("get-updated response was not JSON");
assert_eq!(
get_updated_body["value"]["text"], updated_text,
"Text was not updated"
);
let delete_payload = json!({
"repo": did,
@@ -127,23 +186,38 @@ async fn test_post_crud_lifecycle() {
"rkey": rkey
});
let delete_res = client.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
let delete_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.deleteRecord",
base_url().await
))
.bearer_auth(&jwt)
.json(&delete_payload)
.send()
.await
.expect("Failed to send delete request");
assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record");
assert_eq!(
delete_res.status(),
reqwest::StatusCode::OK,
"Failed to delete record"
);
let get_deleted_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
let get_deleted_res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord",
base_url().await
))
.query(&params)
.send()
.await
.expect("Failed to send get-after-delete request");
assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record was found, but it should be deleted");
assert_eq!(
get_deleted_res.status(),
reqwest::StatusCode::NOT_FOUND,
"Record was found, but it should be deleted"
);
}
#[tokio::test]
@@ -161,24 +235,39 @@ async fn test_record_update_conflict_lifecycle() {
"displayName": "Original Name"
}
});
let create_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(&user_jwt)
.json(&profile_payload)
.send().await.expect("create profile failed");
.send()
.await
.expect("create profile failed");
if create_res.status() != StatusCode::OK {
if create_res.status() != reqwest::StatusCode::OK {
return;
}
let get_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
let get_res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord",
base_url().await
))
.query(&[
("repo", &user_did),
("collection", &"app.bsky.actor.profile".to_string()),
("rkey", &"self".to_string()),
])
.send().await.expect("getRecord failed");
.send()
.await
.expect("getRecord failed");
let get_body: Value = get_res.json().await.expect("getRecord not json");
let cid_v1 = get_body["cid"].as_str().expect("Profile v1 had no CID").to_string();
let cid_v1 = get_body["cid"]
.as_str()
.expect("Profile v1 had no CID")
.to_string();
let update_payload_v2 = json!({
"repo": user_did,
@@ -190,13 +279,26 @@ async fn test_record_update_conflict_lifecycle() {
},
"swapCommit": cid_v1 // <-- Correctly point to v1
});
let update_res_v2 = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let update_res_v2 = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(&user_jwt)
.json(&update_payload_v2)
.send().await.expect("putRecord v2 failed");
assert_eq!(update_res_v2.status(), StatusCode::OK, "v2 update failed");
.send()
.await
.expect("putRecord v2 failed");
assert_eq!(
update_res_v2.status(),
reqwest::StatusCode::OK,
"v2 update failed"
);
let update_body_v2: Value = update_res_v2.json().await.expect("v2 body not json");
let cid_v2 = update_body_v2["cid"].as_str().expect("v2 response had no CID").to_string();
let cid_v2 = update_body_v2["cid"]
.as_str()
.expect("v2 response had no CID")
.to_string();
let update_payload_v3_stale = json!({
"repo": user_did,
@@ -208,14 +310,20 @@ async fn test_record_update_conflict_lifecycle() {
},
"swapCommit": cid_v1
});
let update_res_v3_stale = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let update_res_v3_stale = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(&user_jwt)
.json(&update_payload_v3_stale)
.send().await.expect("putRecord v3 (stale) failed");
.send()
.await
.expect("putRecord v3 (stale) failed");
assert_eq!(
update_res_v3_stale.status(),
StatusCode::CONFLICT,
reqwest::StatusCode::CONFLICT,
"Stale update did not cause a 409 Conflict"
);
@@ -229,10 +337,233 @@ async fn test_record_update_conflict_lifecycle() {
},
"swapCommit": cid_v2 // <-- Correct
});
let update_res_v3_good = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let update_res_v3_good = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(&user_jwt)
.json(&update_payload_v3_good)
.send().await.expect("putRecord v3 (good) failed");
.send()
.await
.expect("putRecord v3 (good) failed");
assert_eq!(update_res_v3_good.status(), StatusCode::OK, "v3 (good) update failed");
assert_eq!(
update_res_v3_good.status(),
reqwest::StatusCode::OK,
"v3 (good) update failed"
);
}
async fn create_post(
client: &reqwest::Client,
did: &str,
jwt: &str,
text: &str,
) -> (String, String) {
let collection = "app.bsky.feed.post";
let rkey = format!("e2e_social_{}", Utc::now().timestamp_millis());
let now = Utc::now().to_rfc3339();
let create_payload = json!({
"repo": did,
"collection": collection,
"rkey": rkey,
"record": {
"$type": collection,
"text": text,
"createdAt": now
}
});
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(jwt)
.json(&create_payload)
.send()
.await
.expect("Failed to send create post request");
assert_eq!(
create_res.status(),
reqwest::StatusCode::OK,
"Failed to create post record"
);
let create_body: Value = create_res
.json()
.await
.expect("create post response was not JSON");
let uri = create_body["uri"].as_str().unwrap().to_string();
let cid = create_body["cid"].as_str().unwrap().to_string();
(uri, cid)
}
async fn create_follow(
client: &reqwest::Client,
follower_did: &str,
follower_jwt: &str,
followee_did: &str,
) -> (String, String) {
let collection = "app.bsky.graph.follow";
let rkey = format!("e2e_follow_{}", Utc::now().timestamp_millis());
let now = Utc::now().to_rfc3339();
let create_payload = json!({
"repo": follower_did,
"collection": collection,
"rkey": rkey,
"record": {
"$type": collection,
"subject": followee_did,
"createdAt": now
}
});
let create_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(follower_jwt)
.json(&create_payload)
.send()
.await
.expect("Failed to send create follow request");
assert_eq!(
create_res.status(),
reqwest::StatusCode::OK,
"Failed to create follow record"
);
let create_body: Value = create_res
.json()
.await
.expect("create follow response was not JSON");
let uri = create_body["uri"].as_str().unwrap().to_string();
let cid = create_body["cid"].as_str().unwrap().to_string();
(uri, cid)
}
#[tokio::test]
#[ignore]
async fn test_social_flow_lifecycle() {
let client = client();
let (alice_did, alice_jwt) = setup_new_user("alice-social").await;
let (bob_did, bob_jwt) = setup_new_user("bob-social").await;
let (post1_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's first post!").await;
create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_res_1 = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline (1)");
assert_eq!(
timeline_res_1.status(),
reqwest::StatusCode::OK,
"Failed to get timeline (1)"
);
let timeline_body_1: Value = timeline_res_1.json().await.expect("Timeline (1) not JSON");
let feed_1 = timeline_body_1["feed"].as_array().unwrap();
assert_eq!(feed_1.len(), 1, "Timeline should have 1 post");
assert_eq!(
feed_1[0]["post"]["uri"], post1_uri,
"Post URI mismatch in timeline (1)"
);
let (post2_uri, _) = create_post(
&client,
&alice_did,
&alice_jwt,
"Alice's second post, so exciting!",
)
.await;
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_res_2 = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline (2)");
assert_eq!(
timeline_res_2.status(),
reqwest::StatusCode::OK,
"Failed to get timeline (2)"
);
let timeline_body_2: Value = timeline_res_2.json().await.expect("Timeline (2) not JSON");
let feed_2 = timeline_body_2["feed"].as_array().unwrap();
assert_eq!(feed_2.len(), 2, "Timeline should have 2 posts");
assert_eq!(
feed_2[0]["post"]["uri"], post2_uri,
"Post 2 should be first"
);
assert_eq!(
feed_2[1]["post"]["uri"], post1_uri,
"Post 1 should be second"
);
let delete_payload = json!({
"repo": alice_did,
"collection": "app.bsky.feed.post",
"rkey": post1_uri.split('/').last().unwrap()
});
let delete_res = client
.post(format!(
"{}/xrpc/com.atproto.repo.deleteRecord",
base_url().await
))
.bearer_auth(&alice_jwt)
.json(&delete_payload)
.send()
.await
.expect("Failed to send delete request");
assert_eq!(
delete_res.status(),
reqwest::StatusCode::OK,
"Failed to delete record"
);
tokio::time::sleep(Duration::from_secs(1)).await;
let timeline_res_3 = client
.get(format!(
"{}/xrpc/app.bsky.feed.getTimeline",
base_url().await
))
.bearer_auth(&bob_jwt)
.send()
.await
.expect("Failed to get timeline (3)");
assert_eq!(
timeline_res_3.status(),
reqwest::StatusCode::OK,
"Failed to get timeline (3)"
);
let timeline_body_3: Value = timeline_res_3.json().await.expect("Timeline (3) not JSON");
let feed_3 = timeline_body_3["feed"].as_array().unwrap();
assert_eq!(feed_3.len(), 1, "Timeline should have 1 post after delete");
assert_eq!(
feed_3[0]["post"]["uri"], post2_uri,
"Only post 2 should remain"
);
}

View File

@@ -1,17 +1,15 @@
mod common;
use axum::{
routing::any,
Router,
extract::Request,
http::StatusCode,
};
use tokio::net::TcpListener;
use axum::{Router, extract::Request, http::StatusCode, routing::any};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use reqwest::Client;
use std::sync::Arc;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use tokio::net::TcpListener;
async fn spawn_mock_upstream() -> (String, tokio::sync::mpsc::Receiver<(String, String, Option<String>)>) {
async fn spawn_mock_upstream() -> (
String,
tokio::sync::mpsc::Receiver<(String, String, Option<String>)>,
) {
let (tx, rx) = tokio::sync::mpsc::channel(10);
let tx = Arc::new(tx);
@@ -20,7 +18,9 @@ async fn spawn_mock_upstream() -> (String, tokio::sync::mpsc::Receiver<(String,
async move {
let method = req.method().to_string();
let uri = req.uri().to_string();
let auth = req.headers().get("Authorization")
let auth = req
.headers()
.get("Authorization")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
@@ -45,7 +45,8 @@ async fn test_proxy_via_header() {
let (upstream_url, mut rx) = spawn_mock_upstream().await;
let client = Client::new();
let res = client.get(format!("{}/xrpc/com.example.test", app_url))
let res = client
.get(format!("{}/xrpc/com.example.test", app_url))
.header("atproto-proxy", &upstream_url)
.header("Authorization", "Bearer test-token")
.send()
@@ -65,12 +66,15 @@ async fn test_proxy_via_header() {
async fn test_proxy_via_env_var() {
let (upstream_url, mut rx) = spawn_mock_upstream().await;
unsafe { std::env::set_var("APPVIEW_URL", &upstream_url); }
unsafe {
std::env::set_var("APPVIEW_URL", &upstream_url);
}
let app_url = common::base_url().await;
let client = Client::new();
let res = client.get(format!("{}/xrpc/com.example.envtest", app_url))
let res = client
.get(format!("{}/xrpc/com.example.envtest", app_url))
.send()
.await
.unwrap();
@@ -85,12 +89,15 @@ async fn test_proxy_via_env_var() {
#[tokio::test]
#[ignore]
async fn test_proxy_missing_config() {
unsafe { std::env::remove_var("APPVIEW_URL"); }
unsafe {
std::env::remove_var("APPVIEW_URL");
}
let app_url = common::base_url().await;
let client = Client::new();
let res = client.get(format!("{}/xrpc/com.example.fail", app_url))
let res = client
.get(format!("{}/xrpc/com.example.fail", app_url))
.send()
.await
.unwrap();
@@ -106,7 +113,8 @@ async fn test_proxy_auth_signing() {
let (access_jwt, did) = common::create_account_and_login(&client).await;
let res = client.get(format!("{}/xrpc/com.example.signed", app_url))
let res = client
.get(format!("{}/xrpc/com.example.signed", app_url))
.header("atproto-proxy", &upstream_url)
.header("Authorization", format!("Bearer {}", access_jwt))
.send()

View File

@@ -1,9 +1,9 @@
mod common;
use common::*;
use reqwest::{header, StatusCode};
use serde_json::{json, Value};
use chrono::Utc;
use reqwest::{StatusCode, header};
use serde_json::{Value, json};
#[tokio::test]
#[ignore]
@@ -15,7 +15,11 @@ async fn test_get_record() {
("rkey", "self"),
];
let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord",
base_url().await
))
.query(&params)
.send()
.await
@@ -36,7 +40,11 @@ async fn test_get_record_not_found() {
("rkey", "nonexistent"),
];
let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord",
base_url().await
))
.query(&params)
.send()
.await
@@ -50,7 +58,11 @@ async fn test_get_record_not_found() {
#[tokio::test]
async fn test_upload_blob_no_auth() {
let client = client();
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.uploadBlob",
base_url().await
))
.header(header::CONTENT_TYPE, "text/plain")
.body("no auth")
.send()
@@ -66,7 +78,11 @@ async fn test_upload_blob_no_auth() {
async fn test_upload_blob_success() {
let client = client();
let (token, _) = create_account_and_login(&client).await;
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.uploadBlob",
base_url().await
))
.header(header::CONTENT_TYPE, "text/plain")
.bearer_auth(token)
.body("This is our blob data")
@@ -90,7 +106,11 @@ async fn test_put_record_no_auth() {
"record": {}
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.json(&payload)
.send()
.await
@@ -118,7 +138,11 @@ async fn test_put_record_success() {
}
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(token)
.json(&payload)
.send()
@@ -135,23 +159,33 @@ async fn test_put_record_success() {
#[ignore]
async fn test_get_record_missing_params() {
let client = client();
let params = [
("repo", "did:plc:12345"),
];
let params = [("repo", "did:plc:12345")];
let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.getRecord",
base_url().await
))
.query(&params)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for missing params");
assert_eq!(
res.status(),
StatusCode::BAD_REQUEST,
"Expected 400 for missing params"
);
}
#[tokio::test]
async fn test_upload_blob_bad_token() {
let client = client();
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.uploadBlob",
base_url().await
))
.header(header::CONTENT_TYPE, "text/plain")
.bearer_auth(BAD_AUTH_TOKEN)
.body("This is our blob data")
@@ -181,14 +215,22 @@ async fn test_put_record_mismatched_repo() {
}
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(token)
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::FORBIDDEN, "Expected 403 for mismatched repo and auth");
assert_eq!(
res.status(),
StatusCode::FORBIDDEN,
"Expected 403 for mismatched repo and auth"
);
}
#[tokio::test]
@@ -207,21 +249,33 @@ async fn test_put_record_invalid_schema() {
}
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(token)
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for invalid record schema");
assert_eq!(
res.status(),
StatusCode::BAD_REQUEST,
"Expected 400 for invalid record schema"
);
}
#[tokio::test]
async fn test_upload_blob_unsupported_mime_type() {
let client = client();
let (token, _) = create_account_and_login(&client).await;
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.uploadBlob",
base_url().await
))
.header(header::CONTENT_TYPE, "application/xml")
.bearer_auth(token)
.body("<xml>not an image</xml>")
@@ -242,7 +296,11 @@ async fn test_list_records() {
("collection", "app.bsky.feed.post"),
("limit", "10"),
];
let res = client.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&params)
.send()
.await
@@ -255,10 +313,12 @@ async fn test_list_records() {
async fn test_describe_repo() {
let client = client();
let (_, did) = create_account_and_login(&client).await;
let params = [
("repo", did.as_str()),
];
let res = client.get(format!("{}/xrpc/com.atproto.repo.describeRepo", base_url().await))
let params = [("repo", did.as_str())];
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.describeRepo",
base_url().await
))
.query(&params)
.send()
.await
@@ -282,7 +342,11 @@ async fn test_create_record_success_with_generated_rkey() {
}
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.createRecord",
base_url().await
))
.json(&payload)
.bearer_auth(token)
.send()
@@ -313,7 +377,11 @@ async fn test_create_record_success_with_provided_rkey() {
}
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.createRecord",
base_url().await
))
.json(&payload)
.bearer_auth(token)
.send()
@@ -322,7 +390,10 @@ async fn test_create_record_success_with_provided_rkey() {
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["uri"], format!("at://{}/app.bsky.feed.post/{}", did, rkey));
assert_eq!(
body["uri"],
format!("at://{}/app.bsky.feed.post/{}", did, rkey)
);
// assert_eq!(body["cid"], "bafyreihy");
}
@@ -336,7 +407,11 @@ async fn test_delete_record() {
"collection": "app.bsky.feed.post",
"rkey": "some_post_to_delete"
});
let res = client.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.deleteRecord",
base_url().await
))
.bearer_auth(token)
.json(&payload)
.send()

View File

@@ -2,12 +2,13 @@ mod common;
use common::*;
use reqwest::StatusCode;
use serde_json::{json, Value};
use serde_json::{Value, json};
#[tokio::test]
async fn test_health() {
let client = client();
let res = client.get(format!("{}/health", base_url().await))
let res = client
.get(format!("{}/health", base_url().await))
.send()
.await
.expect("Failed to send request");
@@ -19,7 +20,11 @@ async fn test_health() {
#[tokio::test]
async fn test_describe_server() {
let client = client();
let res = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base_url().await))
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.describeServer",
base_url().await
))
.send()
.await
.expect("Failed to send request");
@@ -39,7 +44,11 @@ async fn test_create_session() {
"email": format!("{}@example.com", handle),
"password": "password"
});
let _ = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let _ = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await;
@@ -49,7 +58,11 @@ async fn test_create_session() {
"password": "password"
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createSession",
base_url().await
))
.json(&payload)
.send()
.await
@@ -67,14 +80,21 @@ async fn test_create_session_missing_identifier() {
"password": "password"
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createSession",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert!(res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY,
"Expected 400 or 422 for missing identifier, got {}", res.status());
assert!(
res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY,
"Expected 400 or 422 for missing identifier, got {}",
res.status()
);
}
#[tokio::test]
@@ -86,19 +106,31 @@ async fn test_create_account_invalid_handle() {
"password": "password"
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for invalid handle chars");
assert_eq!(
res.status(),
StatusCode::BAD_REQUEST,
"Expected 400 for invalid handle chars"
);
}
#[tokio::test]
async fn test_get_session() {
let client = client();
let res = client.get(format!("{}/xrpc/com.atproto.server.getSession", base_url().await))
let res = client
.get(format!(
"{}/xrpc/com.atproto.server.getSession",
base_url().await
))
.bearer_auth(AUTH_TOKEN)
.send()
.await
@@ -117,7 +149,11 @@ async fn test_refresh_session() {
"email": format!("{}@example.com", handle),
"password": "password"
});
let _ = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
let _ = client
.post(format!(
"{}/xrpc/com.atproto.server.createAccount",
base_url().await
))
.json(&payload)
.send()
.await;
@@ -126,7 +162,11 @@ async fn test_refresh_session() {
"identifier": handle,
"password": "password"
});
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.createSession",
base_url().await
))
.json(&login_payload)
.send()
.await
@@ -134,10 +174,20 @@ async fn test_refresh_session() {
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Invalid JSON");
let refresh_jwt = body["refreshJwt"].as_str().expect("No refreshJwt").to_string();
let access_jwt = body["accessJwt"].as_str().expect("No accessJwt").to_string();
let refresh_jwt = body["refreshJwt"]
.as_str()
.expect("No refreshJwt")
.to_string();
let access_jwt = body["accessJwt"]
.as_str()
.expect("No accessJwt")
.to_string();
let res = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.refreshSession",
base_url().await
))
.bearer_auth(&refresh_jwt)
.send()
.await
@@ -154,7 +204,11 @@ async fn test_refresh_session() {
#[tokio::test]
async fn test_delete_session() {
let client = client();
let res = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base_url().await))
let res = client
.post(format!(
"{}/xrpc/com.atproto.server.deleteSession",
base_url().await
))
.bearer_auth(AUTH_TOKEN)
.send()
.await

View File

@@ -6,10 +6,12 @@ use reqwest::StatusCode;
#[ignore]
async fn test_get_repo() {
let client = client();
let params = [
("did", AUTH_DID),
];
let res = client.get(format!("{}/xrpc/com.atproto.sync.getRepo", base_url().await))
let params = [("did", AUTH_DID)];
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getRepo",
base_url().await
))
.query(&params)
.send()
.await
@@ -26,7 +28,11 @@ async fn test_get_blocks() {
("did", AUTH_DID),
// "cids" would be a list of CIDs
];
let res = client.get(format!("{}/xrpc/com.atproto.sync.getBlocks", base_url().await))
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getBlocks",
base_url().await
))
.query(&params)
.send()
.await