mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-09 05:40:09 +00:00
Format and split big files into smaller ones
This commit is contained in:
@@ -1,424 +0,0 @@
|
||||
use axum::{
|
||||
extract::{State, Path},
|
||||
Json,
|
||||
response::{IntoResponse, Response},
|
||||
http::StatusCode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use crate::state::AppState;
|
||||
use sqlx::Row;
|
||||
use bcrypt::{hash, DEFAULT_COST};
|
||||
use tracing::{info, error};
|
||||
use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore};
|
||||
use jacquard::types::{string::Tid, did::Did, integer::LimitedU32};
|
||||
use std::sync::Arc;
|
||||
use k256::SecretKey;
|
||||
use rand::rngs::OsRng;
|
||||
use base64::Engine;
|
||||
use reqwest;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CreateAccountInput {
|
||||
pub handle: String,
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
#[serde(rename = "inviteCode")]
|
||||
pub invite_code: Option<String>,
|
||||
pub did: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateAccountOutput {
|
||||
pub access_jwt: String,
|
||||
pub refresh_jwt: String,
|
||||
pub handle: String,
|
||||
pub did: String,
|
||||
}
|
||||
|
||||
pub async fn create_account(
|
||||
State(state): State<AppState>,
|
||||
Json(input): Json<CreateAccountInput>,
|
||||
) -> Response {
|
||||
info!("create_account hit: {}", input.handle);
|
||||
if input.handle.contains('!') || input.handle.contains('@') {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}))).into_response();
|
||||
}
|
||||
|
||||
let did = if let Some(d) = &input.did {
|
||||
if d.trim().is_empty() {
|
||||
format!("did:plc:{}", uuid::Uuid::new_v4())
|
||||
} else {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
if let Err(e) = verify_did_web(d, &hostname, &input.handle).await {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidDid", "message": e}))).into_response();
|
||||
}
|
||||
d.clone()
|
||||
}
|
||||
} else {
|
||||
format!("did:plc:{}", uuid::Uuid::new_v4())
|
||||
};
|
||||
|
||||
let mut tx = match state.db.begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(e) => {
|
||||
error!("Error starting transaction: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1")
|
||||
.bind(&input.handle)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await;
|
||||
|
||||
match exists_query {
|
||||
Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "HandleTaken", "message": "Handle already taken"}))).into_response(),
|
||||
Err(e) => {
|
||||
error!("Error checking handle: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
Ok(None) => {}
|
||||
}
|
||||
|
||||
if let Some(code) = &input.invite_code {
|
||||
let invite_query = sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE")
|
||||
.bind(code)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await;
|
||||
|
||||
match invite_query {
|
||||
Ok(Some(row)) => {
|
||||
let uses: i32 = row.get("available_uses");
|
||||
if uses <= 0 {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
|
||||
}
|
||||
|
||||
let update_invite = sqlx::query("UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1")
|
||||
.bind(code)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_invite {
|
||||
error!("Error updating invite code: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
},
|
||||
Ok(None) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"}))).into_response(),
|
||||
Err(e) => {
|
||||
error!("Error checking invite code: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let password_hash = match hash(&input.password, DEFAULT_COST) {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("Error hashing password: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id")
|
||||
.bind(&input.handle)
|
||||
.bind(&input.email)
|
||||
.bind(&did)
|
||||
.bind(&password_hash)
|
||||
.fetch_one(&mut *tx)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_insert {
|
||||
Ok(row) => row.get("id"),
|
||||
Err(e) => {
|
||||
error!("Error inserting user: {:?}", e);
|
||||
// TODO: Check for unique constraint violation on email/did specifically
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let secret_key = SecretKey::random(&mut OsRng);
|
||||
let secret_key_bytes = secret_key.to_bytes();
|
||||
|
||||
let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)")
|
||||
.bind(user_id)
|
||||
.bind(&secret_key_bytes[..])
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = key_insert {
|
||||
error!("Error inserting user key: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
|
||||
let mst = Mst::new(Arc::new(state.block_store.clone()));
|
||||
let mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Error creating MST root: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let commit = Commit::new_unsigned(
|
||||
did_obj,
|
||||
mst_root,
|
||||
rev,
|
||||
None
|
||||
);
|
||||
|
||||
let commit_bytes = match commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Error serializing genesis commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let commit_cid = match state.block_store.put(&commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Error saving genesis commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)")
|
||||
.bind(user_id)
|
||||
.bind(commit_cid.to_string())
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = repo_insert {
|
||||
error!("Error initializing repo: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
|
||||
if let Some(code) = &input.invite_code {
|
||||
let use_insert = sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)")
|
||||
.bind(code)
|
||||
.bind(user_id)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = use_insert {
|
||||
error!("Error recording invite usage: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| {
|
||||
error!("Error creating access token: {:?}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
|
||||
});
|
||||
let access_jwt = match access_jwt {
|
||||
Ok(t) => t,
|
||||
Err(r) => return r,
|
||||
};
|
||||
|
||||
let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| {
|
||||
error!("Error creating refresh token: {:?}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
|
||||
});
|
||||
let refresh_jwt = match refresh_jwt {
|
||||
Ok(t) => t,
|
||||
Err(r) => return r,
|
||||
};
|
||||
|
||||
let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
|
||||
.bind(&access_jwt)
|
||||
.bind(&refresh_jwt)
|
||||
.bind(&did)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = session_insert {
|
||||
error!("Error inserting session: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
|
||||
if let Err(e) = tx.commit().await {
|
||||
error!("Error committing transaction: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(CreateAccountOutput {
|
||||
access_jwt,
|
||||
refresh_jwt,
|
||||
handle: input.handle,
|
||||
did,
|
||||
})).into_response()
|
||||
}
|
||||
|
||||
fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
|
||||
use k256::elliptic_curve::sec1::ToEncodedPoint;
|
||||
|
||||
let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
|
||||
let public_key = secret_key.public_key();
|
||||
let encoded = public_key.to_encoded_point(false);
|
||||
let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
|
||||
let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
|
||||
|
||||
json!({
|
||||
"kty": "EC",
|
||||
"crv": "secp256k1",
|
||||
"x": x,
|
||||
"y": y
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
// Kinda for local dev, encode hostname if it contains port
|
||||
let did = if hostname.contains(':') {
|
||||
format!("did:web:{}", hostname.replace(':', "%3A"))
|
||||
} else {
|
||||
format!("did:web:{}", hostname)
|
||||
};
|
||||
|
||||
Json(json!({
|
||||
"@context": ["https://www.w3.org/ns/did/v1"],
|
||||
"id": did,
|
||||
"service": [{
|
||||
"id": "#atproto_pds",
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"serviceEndpoint": format!("https://{}", hostname)
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn user_did_doc(
|
||||
State(state): State<AppState>,
|
||||
Path(handle): Path<String>,
|
||||
) -> Response {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
|
||||
let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
|
||||
.bind(&handle)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let (user_id, did) = match user {
|
||||
Ok(Some(row)) => {
|
||||
let id: uuid::Uuid = row.get("id");
|
||||
let d: String = row.get("did");
|
||||
(id, d)
|
||||
},
|
||||
Ok(None) => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(),
|
||||
Err(e) => {
|
||||
error!("DB Error: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
|
||||
},
|
||||
};
|
||||
|
||||
if !did.starts_with("did:web:") {
|
||||
return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "User is not did:web"}))).into_response();
|
||||
}
|
||||
|
||||
let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let key_bytes: Vec<u8> = match key_row {
|
||||
Ok(Some(row)) => row.get("key_bytes"),
|
||||
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(),
|
||||
};
|
||||
|
||||
let jwk = get_jwk(&key_bytes);
|
||||
|
||||
Json(json!({
|
||||
"@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
|
||||
"id": did,
|
||||
"alsoKnownAs": [format!("at://{}", handle)],
|
||||
"verificationMethod": [{
|
||||
"id": format!("{}#atproto", did),
|
||||
"type": "JsonWebKey2020",
|
||||
"controller": did,
|
||||
"publicKeyJwk": jwk
|
||||
}],
|
||||
"service": [{
|
||||
"id": "#atproto_pds",
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"serviceEndpoint": format!("https://{}", hostname)
|
||||
}]
|
||||
})).into_response()
|
||||
}
|
||||
|
||||
async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
|
||||
let expected_prefix = if hostname.contains(':') {
|
||||
format!("did:web:{}", hostname.replace(':', "%3A"))
|
||||
} else {
|
||||
format!("did:web:{}", hostname)
|
||||
};
|
||||
|
||||
if did.starts_with(&expected_prefix) {
|
||||
let suffix = &did[expected_prefix.len()..];
|
||||
let expected_suffix = format!(":u:{}", handle);
|
||||
if suffix == expected_suffix {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Invalid DID path for this PDS. Expected {}", expected_suffix))
|
||||
}
|
||||
} else {
|
||||
let parts: Vec<&str> = did.split(':').collect();
|
||||
if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
|
||||
return Err("Invalid did:web format".into());
|
||||
}
|
||||
|
||||
let domain_segment = parts[2];
|
||||
let domain = domain_segment.replace("%3A", ":");
|
||||
|
||||
let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
|
||||
"http"
|
||||
} else {
|
||||
"https"
|
||||
};
|
||||
|
||||
let url = if parts.len() == 3 {
|
||||
format!("{}://{}/.well-known/did.json", scheme, domain)
|
||||
} else {
|
||||
let path = parts[3..].join("/");
|
||||
format!("{}://{}/{}/did.json", scheme, domain, path)
|
||||
};
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create client: {}", e))?;
|
||||
|
||||
let resp = client.get(&url).send().await
|
||||
.map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
|
||||
}
|
||||
|
||||
let doc: serde_json::Value = resp.json().await
|
||||
.map_err(|e| format!("Failed to parse DID doc: {}", e))?;
|
||||
|
||||
let services = doc["service"].as_array()
|
||||
.ok_or("No services found in DID doc")?;
|
||||
|
||||
let pds_endpoint = format!("https://{}", hostname);
|
||||
|
||||
let has_valid_service = services.iter().any(|s| {
|
||||
s["type"] == "AtprotoPersonalDataServer" &&
|
||||
s["serviceEndpoint"] == pds_endpoint
|
||||
});
|
||||
|
||||
if has_valid_service {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("DID document does not list this PDS ({}) as AtprotoPersonalDataServer", pds_endpoint))
|
||||
}
|
||||
}
|
||||
}
|
||||
355
src/api/identity/account.rs
Normal file
355
src/api/identity/account.rs
Normal file
@@ -0,0 +1,355 @@
|
||||
use super::did::verify_did_web;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bcrypt::{DEFAULT_COST, hash};
|
||||
use jacquard::types::{did::Did, integer::LimitedU32, string::Tid};
|
||||
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
|
||||
use k256::SecretKey;
|
||||
use rand::rngs::OsRng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CreateAccountInput {
|
||||
pub handle: String,
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
#[serde(rename = "inviteCode")]
|
||||
pub invite_code: Option<String>,
|
||||
pub did: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateAccountOutput {
|
||||
pub access_jwt: String,
|
||||
pub refresh_jwt: String,
|
||||
pub handle: String,
|
||||
pub did: String,
|
||||
}
|
||||
|
||||
pub async fn create_account(
|
||||
State(state): State<AppState>,
|
||||
Json(input): Json<CreateAccountInput>,
|
||||
) -> Response {
|
||||
info!("create_account hit: {}", input.handle);
|
||||
if input.handle.contains('!') || input.handle.contains('@') {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
|
||||
),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let did = if let Some(d) = &input.did {
|
||||
if d.trim().is_empty() {
|
||||
format!("did:plc:{}", uuid::Uuid::new_v4())
|
||||
} else {
|
||||
let hostname =
|
||||
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
if let Err(e) = verify_did_web(d, &hostname, &input.handle).await {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidDid", "message": e})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
d.clone()
|
||||
}
|
||||
} else {
|
||||
format!("did:plc:{}", uuid::Uuid::new_v4())
|
||||
};
|
||||
|
||||
let mut tx = match state.db.begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(e) => {
|
||||
error!("Error starting transaction: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1")
|
||||
.bind(&input.handle)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await;
|
||||
|
||||
match exists_query {
|
||||
Ok(Some(_)) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "HandleTaken", "message": "Handle already taken"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error checking handle: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Ok(None) => {}
|
||||
}
|
||||
|
||||
if let Some(code) = &input.invite_code {
|
||||
let invite_query =
|
||||
sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE")
|
||||
.bind(code)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await;
|
||||
|
||||
match invite_query {
|
||||
Ok(Some(row)) => {
|
||||
let uses: i32 = row.get("available_uses");
|
||||
if uses <= 0 {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
|
||||
}
|
||||
|
||||
let update_invite = sqlx::query(
|
||||
"UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1",
|
||||
)
|
||||
.bind(code)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_invite {
|
||||
error!("Error updating invite code: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error checking invite code: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let password_hash = match hash(&input.password, DEFAULT_COST) {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("Error hashing password: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id")
|
||||
.bind(&input.handle)
|
||||
.bind(&input.email)
|
||||
.bind(&did)
|
||||
.bind(&password_hash)
|
||||
.fetch_one(&mut *tx)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_insert {
|
||||
Ok(row) => row.get("id"),
|
||||
Err(e) => {
|
||||
error!("Error inserting user: {:?}", e);
|
||||
// TODO: Check for unique constraint violation on email/did specifically
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let secret_key = SecretKey::random(&mut OsRng);
|
||||
let secret_key_bytes = secret_key.to_bytes();
|
||||
|
||||
let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)")
|
||||
.bind(user_id)
|
||||
.bind(&secret_key_bytes[..])
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = key_insert {
|
||||
error!("Error inserting user key: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let mst = Mst::new(Arc::new(state.block_store.clone()));
|
||||
let mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Error creating MST root: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let commit = Commit::new_unsigned(did_obj, mst_root, rev, None);
|
||||
|
||||
let commit_bytes = match commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Error serializing genesis commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let commit_cid = match state.block_store.put(&commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Error saving genesis commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)")
|
||||
.bind(user_id)
|
||||
.bind(commit_cid.to_string())
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = repo_insert {
|
||||
error!("Error initializing repo: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if let Some(code) = &input.invite_code {
|
||||
let use_insert =
|
||||
sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)")
|
||||
.bind(code)
|
||||
.bind(user_id)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = use_insert {
|
||||
error!("Error recording invite usage: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| {
|
||||
error!("Error creating access token: {:?}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response()
|
||||
});
|
||||
let access_jwt = match access_jwt {
|
||||
Ok(t) => t,
|
||||
Err(r) => return r,
|
||||
};
|
||||
|
||||
let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| {
|
||||
error!("Error creating refresh token: {:?}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response()
|
||||
});
|
||||
let refresh_jwt = match refresh_jwt {
|
||||
Ok(t) => t,
|
||||
Err(r) => return r,
|
||||
};
|
||||
|
||||
let session_insert =
|
||||
sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
|
||||
.bind(&access_jwt)
|
||||
.bind(&refresh_jwt)
|
||||
.bind(&did)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
if let Err(e) = session_insert {
|
||||
error!("Error inserting session: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if let Err(e) = tx.commit().await {
|
||||
error!("Error committing transaction: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(CreateAccountOutput {
|
||||
access_jwt,
|
||||
refresh_jwt,
|
||||
handle: input.handle,
|
||||
did,
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
201
src/api/identity/did.rs
Normal file
201
src/api/identity/did.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use base64::Engine;
|
||||
use k256::SecretKey;
|
||||
use k256::elliptic_curve::sec1::ToEncodedPoint;
|
||||
use reqwest;
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
use tracing::error;
|
||||
|
||||
pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
|
||||
let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
|
||||
let public_key = secret_key.public_key();
|
||||
let encoded = public_key.to_encoded_point(false);
|
||||
let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
|
||||
let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
|
||||
|
||||
json!({
|
||||
"kty": "EC",
|
||||
"crv": "secp256k1",
|
||||
"x": x,
|
||||
"y": y
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
// Kinda for local dev, encode hostname if it contains port
|
||||
let did = if hostname.contains(':') {
|
||||
format!("did:web:{}", hostname.replace(':', "%3A"))
|
||||
} else {
|
||||
format!("did:web:{}", hostname)
|
||||
};
|
||||
|
||||
Json(json!({
|
||||
"@context": ["https://www.w3.org/ns/did/v1"],
|
||||
"id": did,
|
||||
"service": [{
|
||||
"id": "#atproto_pds",
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"serviceEndpoint": format!("https://{}", hostname)
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
|
||||
let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
|
||||
.bind(&handle)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let (user_id, did) = match user {
|
||||
Ok(Some(row)) => {
|
||||
let id: uuid::Uuid = row.get("id");
|
||||
let d: String = row.get("did");
|
||||
(id, d)
|
||||
}
|
||||
Ok(None) => {
|
||||
return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("DB Error: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if !did.starts_with("did:web:") {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "User is not did:web"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let key_bytes: Vec<u8> = match key_row {
|
||||
Ok(Some(row)) => row.get("key_bytes"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let jwk = get_jwk(&key_bytes);
|
||||
|
||||
Json(json!({
|
||||
"@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
|
||||
"id": did,
|
||||
"alsoKnownAs": [format!("at://{}", handle)],
|
||||
"verificationMethod": [{
|
||||
"id": format!("{}#atproto", did),
|
||||
"type": "JsonWebKey2020",
|
||||
"controller": did,
|
||||
"publicKeyJwk": jwk
|
||||
}],
|
||||
"service": [{
|
||||
"id": "#atproto_pds",
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"serviceEndpoint": format!("https://{}", hostname)
|
||||
}]
|
||||
})).into_response()
|
||||
}
|
||||
|
||||
pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
|
||||
let expected_prefix = if hostname.contains(':') {
|
||||
format!("did:web:{}", hostname.replace(':', "%3A"))
|
||||
} else {
|
||||
format!("did:web:{}", hostname)
|
||||
};
|
||||
|
||||
if did.starts_with(&expected_prefix) {
|
||||
let suffix = &did[expected_prefix.len()..];
|
||||
let expected_suffix = format!(":u:{}", handle);
|
||||
if suffix == expected_suffix {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"Invalid DID path for this PDS. Expected {}",
|
||||
expected_suffix
|
||||
))
|
||||
}
|
||||
} else {
|
||||
let parts: Vec<&str> = did.split(':').collect();
|
||||
if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
|
||||
return Err("Invalid did:web format".into());
|
||||
}
|
||||
|
||||
let domain_segment = parts[2];
|
||||
let domain = domain_segment.replace("%3A", ":");
|
||||
|
||||
let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
|
||||
"http"
|
||||
} else {
|
||||
"https"
|
||||
};
|
||||
|
||||
let url = if parts.len() == 3 {
|
||||
format!("{}://{}/.well-known/did.json", scheme, domain)
|
||||
} else {
|
||||
let path = parts[3..].join("/");
|
||||
format!("{}://{}/{}/did.json", scheme, domain, path)
|
||||
};
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create client: {}", e))?;
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
|
||||
}
|
||||
|
||||
let doc: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse DID doc: {}", e))?;
|
||||
|
||||
let services = doc["service"]
|
||||
.as_array()
|
||||
.ok_or("No services found in DID doc")?;
|
||||
|
||||
let pds_endpoint = format!("https://{}", hostname);
|
||||
|
||||
let has_valid_service = services.iter().any(|s| {
|
||||
s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
|
||||
});
|
||||
|
||||
if has_valid_service {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
|
||||
pds_endpoint
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
5
src/api/identity/mod.rs
Normal file
5
src/api/identity/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod account;
|
||||
pub mod did;
|
||||
|
||||
pub use account::create_account;
|
||||
pub use did::{user_did_doc, well_known_did};
|
||||
@@ -1,4 +1,4 @@
|
||||
pub mod server;
|
||||
pub mod repo;
|
||||
pub mod proxy;
|
||||
pub mod identity;
|
||||
pub mod proxy;
|
||||
pub mod repo;
|
||||
pub mod server;
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
body::Bytes,
|
||||
extract::{Path, Query, State},
|
||||
http::{HeaderMap, Method, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
body::Bytes,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use tracing::{info, error};
|
||||
use std::collections::HashMap;
|
||||
use crate::state::AppState;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub async fn proxy_handler(
|
||||
State(state): State<AppState>,
|
||||
@@ -18,8 +18,8 @@ pub async fn proxy_handler(
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
|
||||
let proxy_header = headers.get("atproto-proxy")
|
||||
let proxy_header = headers
|
||||
.get("atproto-proxy")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
@@ -27,7 +27,9 @@ pub async fn proxy_handler(
|
||||
Some(url) => url.clone(),
|
||||
None => match std::env::var("APPVIEW_URL") {
|
||||
Ok(url) => url,
|
||||
Err(_) => return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response(),
|
||||
Err(_) => {
|
||||
return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response();
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -37,9 +39,7 @@ pub async fn proxy_handler(
|
||||
|
||||
let client = Client::new();
|
||||
|
||||
let mut request_builder = client
|
||||
.request(method_verb, &target_url)
|
||||
.query(¶ms);
|
||||
let mut request_builder = client.request(method_verb, &target_url).query(¶ms);
|
||||
|
||||
let mut auth_header_val = headers.get("Authorization").map(|h| h.clone());
|
||||
|
||||
@@ -48,17 +48,21 @@ pub async fn proxy_handler(
|
||||
if let Ok(token) = auth_val.to_str() {
|
||||
let token = token.replace("Bearer ", "");
|
||||
if let Ok(did) = crate::auth::get_did_from_token(&token) {
|
||||
let key_row = sqlx::query("SELECT k.key_bytes FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1")
|
||||
let key_row = sqlx::query("SELECT k.key_bytes FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
if let Ok(Some(row)) = key_row {
|
||||
let key_bytes: Vec<u8> = row.get("key_bytes");
|
||||
if let Ok(new_token) = crate::auth::create_service_token(&did, aud, &method, &key_bytes) {
|
||||
if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) {
|
||||
auth_header_val = Some(val);
|
||||
}
|
||||
if let Ok(new_token) =
|
||||
crate::auth::create_service_token(&did, aud, &method, &key_bytes)
|
||||
{
|
||||
if let Ok(val) =
|
||||
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
|
||||
{
|
||||
auth_header_val = Some(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,7 +90,8 @@ pub async fn proxy_handler(
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Error reading proxy response body: {:?}", e);
|
||||
return (StatusCode::BAD_GATEWAY, "Error reading upstream response").into_response();
|
||||
return (StatusCode::BAD_GATEWAY, "Error reading upstream response")
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -99,11 +104,11 @@ pub async fn proxy_handler(
|
||||
match response_builder.body(axum::body::Body::from(body)) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Error building proxy response: {:?}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
|
||||
error!("Error building proxy response: {:?}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error sending proxy request: {:?}", e);
|
||||
if e.is_timeout() {
|
||||
|
||||
889
src/api/repo.rs
889
src/api/repo.rs
@@ -1,889 +0,0 @@
|
||||
use axum::{
|
||||
extract::{State, Query},
|
||||
Json,
|
||||
response::{IntoResponse, Response},
|
||||
http::StatusCode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use crate::state::AppState;
|
||||
use chrono::Utc;
|
||||
use sqlx::Row;
|
||||
use cid::Cid;
|
||||
use std::str::FromStr;
|
||||
use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore};
|
||||
use jacquard::types::{string::{Nsid, Tid}, did::Did, integer::LimitedU32};
|
||||
use tracing::error;
|
||||
use std::sync::Arc;
|
||||
use sha2::{Sha256, Digest};
|
||||
use multihash::Multihash;
|
||||
use axum::body::Bytes;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CreateRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: Option<String>,
|
||||
pub validate: Option<bool>,
|
||||
pub record: serde_json::Value,
|
||||
#[serde(rename = "swapCommit")]
|
||||
pub swap_commit: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateRecordOutput {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
}
|
||||
|
||||
pub async fn create_record(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(input): Json<CreateRecordInput>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
|
||||
}
|
||||
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
|
||||
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
|
||||
}
|
||||
|
||||
if input.repo != did {
|
||||
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(),
|
||||
};
|
||||
|
||||
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let current_root_cid = match repo_root_query {
|
||||
Ok(Some(row)) => {
|
||||
let cid_str: String = row.get("repo_root_cid");
|
||||
Cid::from_str(&cid_str).ok()
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if current_root_cid.is_none() {
|
||||
error!("Repo root not found for user {}", did);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response();
|
||||
}
|
||||
let current_root_cid = current_root_cid.unwrap();
|
||||
|
||||
let commit_bytes = match state.block_store.get(¤t_root_cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
Ok(None) => {
|
||||
error!("Commit block not found: {}", current_root_cid);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to load commit block: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let commit = match Commit::from_cbor(&commit_bytes) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to parse commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mst_root = commit.data;
|
||||
let store = Arc::new(state.block_store.clone());
|
||||
let mst = Mst::load(store.clone(), mst_root, None);
|
||||
|
||||
let collection_nsid = match input.collection.parse::<Nsid>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
|
||||
};
|
||||
|
||||
let rkey = input.rkey.unwrap_or_else(|| {
|
||||
Utc::now().format("%Y%m%d%H%M%S%f").to_string()
|
||||
});
|
||||
|
||||
let mut record_bytes = Vec::new();
|
||||
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
|
||||
error!("Error serializing record: {:?}", e);
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
|
||||
}
|
||||
|
||||
let record_cid = match state.block_store.put(&record_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save record block: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let key = format!("{}/{}", collection_nsid, rkey);
|
||||
if let Err(e) = mst.update(&key, record_cid).await {
|
||||
error!("Failed to update MST: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
|
||||
let new_mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get new MST root: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let new_commit = Commit::new_unsigned(
|
||||
did_obj,
|
||||
new_mst_root,
|
||||
rev,
|
||||
Some(current_root_cid)
|
||||
);
|
||||
|
||||
let new_commit_bytes = match new_commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize new commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save new commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
|
||||
.bind(new_root_cid.to_string())
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_repo {
|
||||
error!("Failed to update repo root in DB: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
|
||||
let record_insert = sqlx::query(
|
||||
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&rkey)
|
||||
.bind(record_cid.to_string())
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = record_insert {
|
||||
error!("Error inserting record index: {:?}", e);
|
||||
}
|
||||
|
||||
let output = CreateRecordOutput {
|
||||
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
|
||||
cid: record_cid.to_string(),
|
||||
};
|
||||
(StatusCode::OK, Json(output)).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct PutRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: String,
|
||||
pub validate: Option<bool>,
|
||||
pub record: serde_json::Value,
|
||||
#[serde(rename = "swapCommit")]
|
||||
pub swap_commit: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PutRecordOutput {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
}
|
||||
|
||||
pub async fn put_record(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(input): Json<PutRecordInput>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
|
||||
}
|
||||
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
|
||||
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
|
||||
}
|
||||
|
||||
if input.repo != did {
|
||||
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(),
|
||||
};
|
||||
|
||||
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let current_root_cid = match repo_root_query {
|
||||
Ok(Some(row)) => {
|
||||
let cid_str: String = row.get("repo_root_cid");
|
||||
Cid::from_str(&cid_str).ok()
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if current_root_cid.is_none() {
|
||||
error!("Repo root not found for user {}", did);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response();
|
||||
}
|
||||
let current_root_cid = current_root_cid.unwrap();
|
||||
|
||||
let commit_bytes = match state.block_store.get(¤t_root_cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
Ok(None) => {
|
||||
error!("Commit block not found: {}", current_root_cid);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response();
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to load commit block: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to load commit block"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let commit = match Commit::from_cbor(&commit_bytes) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to parse commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mst_root = commit.data;
|
||||
let store = Arc::new(state.block_store.clone());
|
||||
let mst = Mst::load(store.clone(), mst_root, None);
|
||||
|
||||
let collection_nsid = match input.collection.parse::<Nsid>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
|
||||
};
|
||||
|
||||
let rkey = input.rkey.clone();
|
||||
|
||||
let mut record_bytes = Vec::new();
|
||||
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
|
||||
error!("Error serializing record: {:?}", e);
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
|
||||
}
|
||||
|
||||
let record_cid = match state.block_store.put(&record_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save record block: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save record block"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let key = format!("{}/{}", collection_nsid, rkey);
|
||||
if let Err(e) = mst.update(&key, record_cid).await {
|
||||
error!("Failed to update MST: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response();
|
||||
}
|
||||
|
||||
let new_mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get new MST root: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST root"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let new_commit = Commit::new_unsigned(
|
||||
did_obj,
|
||||
new_mst_root,
|
||||
rev,
|
||||
Some(current_root_cid)
|
||||
);
|
||||
|
||||
let new_commit_bytes = match new_commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize new commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to serialize new commit"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save new commit: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save new commit"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
|
||||
.bind(new_root_cid.to_string())
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_repo {
|
||||
error!("Failed to update repo root in DB: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"}))).into_response();
|
||||
}
|
||||
|
||||
let record_insert = sqlx::query(
|
||||
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&rkey)
|
||||
.bind(record_cid.to_string())
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = record_insert {
|
||||
error!("Error inserting record index: {:?}", e);
|
||||
}
|
||||
|
||||
let output = PutRecordOutput {
|
||||
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
|
||||
cid: record_cid.to_string(),
|
||||
};
|
||||
(StatusCode::OK, Json(output)).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: String,
|
||||
pub cid: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_record(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<GetRecordInput>,
|
||||
) -> Response {
|
||||
let user_row = if input.repo.starts_with("did:") {
|
||||
sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query("SELECT id FROM users WHERE handle = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
let user_id: uuid::Uuid = match user_row {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(),
|
||||
};
|
||||
|
||||
let record_row = sqlx::query("SELECT record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3")
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&input.rkey)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let record_cid_str: String = match record_row {
|
||||
Ok(Some(row)) => row.get("record_cid"),
|
||||
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Record not found"}))).into_response(),
|
||||
};
|
||||
|
||||
if let Some(expected_cid) = &input.cid {
|
||||
if &record_cid_str != expected_cid {
|
||||
return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Record CID mismatch"}))).into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let cid = match Cid::from_str(&record_cid_str) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid CID in DB"}))).into_response(),
|
||||
};
|
||||
|
||||
let block = match state.block_store.get(&cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Record block not found"}))).into_response(),
|
||||
};
|
||||
|
||||
let value: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("Failed to deserialize record: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
Json(json!({
|
||||
"uri": format!("at://{}/{}/{}", input.repo, input.collection, input.rkey),
|
||||
"cid": record_cid_str,
|
||||
"value": value
|
||||
})).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DeleteRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: String,
|
||||
#[serde(rename = "swapRecord")]
|
||||
pub swap_record: Option<String>,
|
||||
#[serde(rename = "swapCommit")]
|
||||
pub swap_commit: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn delete_record(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(input): Json<DeleteRecordInput>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
|
||||
}
|
||||
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
|
||||
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
|
||||
}
|
||||
|
||||
if input.repo != did {
|
||||
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "User not found"}))).into_response(),
|
||||
};
|
||||
|
||||
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let current_root_cid = match repo_root_query {
|
||||
Ok(Some(row)) => {
|
||||
let cid_str: String = row.get("repo_root_cid");
|
||||
Cid::from_str(&cid_str).ok()
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if current_root_cid.is_none() {
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Repo root not found"}))).into_response();
|
||||
}
|
||||
let current_root_cid = current_root_cid.unwrap();
|
||||
|
||||
let commit_bytes = match state.block_store.get(¤t_root_cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
Ok(None) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(),
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to load commit block: {:?}", e)}))).into_response(),
|
||||
};
|
||||
|
||||
let commit = match Commit::from_cbor(&commit_bytes) {
|
||||
Ok(c) => c,
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to parse commit: {:?}", e)}))).into_response(),
|
||||
};
|
||||
|
||||
let mst_root = commit.data;
|
||||
let store = Arc::new(state.block_store.clone());
|
||||
let mst = Mst::load(store.clone(), mst_root, None);
|
||||
|
||||
let collection_nsid = match input.collection.parse::<Nsid>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
|
||||
};
|
||||
|
||||
let key = format!("{}/{}", collection_nsid, input.rkey);
|
||||
|
||||
// TODO: Check swapRecord if provided? Skipping for brevity/robustness
|
||||
|
||||
if let Err(e) = mst.delete(&key).await {
|
||||
error!("Failed to delete from MST: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response();
|
||||
}
|
||||
|
||||
let new_mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST root"}))).into_response(),
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let new_commit = Commit::new_unsigned(
|
||||
did_obj,
|
||||
new_mst_root,
|
||||
rev,
|
||||
Some(current_root_cid)
|
||||
);
|
||||
|
||||
let new_commit_bytes = match new_commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to serialize new commit"}))).into_response(),
|
||||
};
|
||||
|
||||
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(_e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save new commit"}))).into_response(),
|
||||
};
|
||||
|
||||
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
|
||||
.bind(new_root_cid.to_string())
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_repo {
|
||||
error!("Failed to update repo root in DB: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"}))).into_response();
|
||||
}
|
||||
|
||||
let record_delete = sqlx::query("DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3")
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&input.rkey)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = record_delete {
|
||||
error!("Error deleting record index: {:?}", e);
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(json!({}))).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ListRecordsInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub limit: Option<i32>,
|
||||
pub cursor: Option<String>,
|
||||
#[serde(rename = "rkeyStart")]
|
||||
pub rkey_start: Option<String>,
|
||||
#[serde(rename = "rkeyEnd")]
|
||||
pub rkey_end: Option<String>,
|
||||
pub reverse: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ListRecordsOutput {
|
||||
pub cursor: Option<String>,
|
||||
pub records: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub async fn list_records(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<ListRecordsInput>,
|
||||
) -> Response {
|
||||
let user_row = if input.repo.starts_with("did:") {
|
||||
sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query("SELECT id FROM users WHERE handle = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
let user_id: uuid::Uuid = match user_row {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(),
|
||||
};
|
||||
|
||||
let limit = input.limit.unwrap_or(50).clamp(1, 100);
|
||||
let reverse = input.reverse.unwrap_or(false);
|
||||
|
||||
// Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination
|
||||
// TODO: Implement rkeyStart/End and correct cursor logic
|
||||
|
||||
let query_str = format!(
|
||||
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 {} ORDER BY rkey {} LIMIT {}",
|
||||
if let Some(_c) = &input.cursor {
|
||||
if reverse { "AND rkey < $3" } else { "AND rkey > $3" }
|
||||
} else {
|
||||
""
|
||||
},
|
||||
if reverse { "DESC" } else { "ASC" },
|
||||
limit
|
||||
);
|
||||
|
||||
let mut query = sqlx::query(&query_str)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection);
|
||||
|
||||
if let Some(c) = &input.cursor {
|
||||
query = query.bind(c);
|
||||
}
|
||||
|
||||
let rows = match query.fetch_all(&state.db).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Error listing records: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut records = Vec::new();
|
||||
let mut last_rkey = None;
|
||||
|
||||
for row in rows {
|
||||
let rkey: String = row.get("rkey");
|
||||
let cid_str: String = row.get("record_cid");
|
||||
last_rkey = Some(rkey.clone());
|
||||
|
||||
if let Ok(cid) = Cid::from_str(&cid_str) {
|
||||
if let Ok(Some(block)) = state.block_store.get(&cid).await {
|
||||
if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
|
||||
records.push(json!({
|
||||
"uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
|
||||
"cid": cid_str,
|
||||
"value": value
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json(ListRecordsOutput {
|
||||
cursor: last_rkey,
|
||||
records,
|
||||
}).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DescribeRepoInput {
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
pub async fn describe_repo(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<DescribeRepoInput>,
|
||||
) -> Response {
|
||||
let user_row = if input.repo.starts_with("did:") {
|
||||
sqlx::query("SELECT id, handle, did FROM users WHERE did = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query("SELECT id, handle, did FROM users WHERE handle = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
let (user_id, handle, did) = match user_row {
|
||||
Ok(Some(row)) => (row.get::<uuid::Uuid, _>("id"), row.get::<String, _>("handle"), row.get::<String, _>("did")),
|
||||
_ => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "Repo not found"}))).into_response(),
|
||||
};
|
||||
|
||||
let collections_query = sqlx::query("SELECT DISTINCT collection FROM records WHERE repo_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
let collections: Vec<String> = match collections_query {
|
||||
Ok(rows) => rows.iter().map(|r| r.get("collection")).collect(),
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let did_doc = json!({
|
||||
"id": did,
|
||||
"alsoKnownAs": [format!("at://{}", handle)]
|
||||
});
|
||||
|
||||
Json(json!({
|
||||
"handle": handle,
|
||||
"did": did,
|
||||
"didDoc": did_doc,
|
||||
"collections": collections,
|
||||
"handleIsCorrect": true
|
||||
})).into_response()
|
||||
}
|
||||
|
||||
pub async fn upload_blob(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
|
||||
}
|
||||
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (row.get::<String, _>("did"), row.get::<Vec<u8>, _>("key_bytes")),
|
||||
None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response(),
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
|
||||
}
|
||||
|
||||
let mime_type = headers.get("content-type")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("application/octet-stream")
|
||||
.to_string();
|
||||
|
||||
let size = body.len() as i64;
|
||||
let data = body.to_vec();
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&data);
|
||||
let hash = hasher.finalize();
|
||||
let multihash = Multihash::wrap(0x12, &hash).unwrap();
|
||||
let cid = Cid::new_v1(0x55, multihash);
|
||||
let cid_str = cid.to_string();
|
||||
|
||||
let storage_key = format!("blobs/{}", cid_str);
|
||||
|
||||
if let Err(e) = state.blob_store.put(&storage_key, &data).await {
|
||||
error!("Failed to upload blob to storage: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store blob"}))).into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(),
|
||||
};
|
||||
|
||||
let insert = sqlx::query(
|
||||
"INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING"
|
||||
)
|
||||
.bind(&cid_str)
|
||||
.bind(&mime_type)
|
||||
.bind(size)
|
||||
.bind(user_id)
|
||||
.bind(&storage_key)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = insert {
|
||||
error!("Failed to insert blob record: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
}
|
||||
|
||||
Json(json!({
|
||||
"blob": {
|
||||
"ref": {
|
||||
"$link": cid_str
|
||||
},
|
||||
"mimeType": mime_type,
|
||||
"size": size
|
||||
}
|
||||
})).into_response()
|
||||
}
|
||||
138
src/api/repo/blob.rs
Normal file
138
src/api/repo/blob.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use crate::state::AppState;
|
||||
use axum::body::Bytes;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use cid::Cid;
|
||||
use multihash::Multihash;
|
||||
use serde_json::json;
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::Row;
|
||||
use tracing::error;
|
||||
|
||||
pub async fn upload_blob(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let token = auth_header
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap_or("")
|
||||
.replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (
|
||||
row.get::<String, _>("did"),
|
||||
row.get::<Vec<u8>, _>("key_bytes"),
|
||||
),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let mime_type = headers
|
||||
.get("content-type")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("application/octet-stream")
|
||||
.to_string();
|
||||
|
||||
let size = body.len() as i64;
|
||||
let data = body.to_vec();
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&data);
|
||||
let hash = hasher.finalize();
|
||||
let multihash = Multihash::wrap(0x12, &hash).unwrap();
|
||||
let cid = Cid::new_v1(0x55, multihash);
|
||||
let cid_str = cid.to_string();
|
||||
|
||||
let storage_key = format!("blobs/{}", cid_str);
|
||||
|
||||
if let Err(e) = state.blob_store.put(&storage_key, &data).await {
|
||||
error!("Failed to upload blob to storage: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to store blob"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let insert = sqlx::query(
|
||||
"INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING"
|
||||
)
|
||||
.bind(&cid_str)
|
||||
.bind(&mime_type)
|
||||
.bind(size)
|
||||
.bind(user_id)
|
||||
.bind(&storage_key)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = insert {
|
||||
error!("Failed to insert blob record: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
Json(json!({
|
||||
"blob": {
|
||||
"ref": {
|
||||
"$link": cid_str
|
||||
},
|
||||
"mimeType": mime_type,
|
||||
"size": size
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
72
src/api/repo/meta.rs
Normal file
72
src/api/repo/meta.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DescribeRepoInput {
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
pub async fn describe_repo(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<DescribeRepoInput>,
|
||||
) -> Response {
|
||||
let user_row = if input.repo.starts_with("did:") {
|
||||
sqlx::query("SELECT id, handle, did FROM users WHERE did = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query("SELECT id, handle, did FROM users WHERE handle = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
let (user_id, handle, did) = match user_row {
|
||||
Ok(Some(row)) => (
|
||||
row.get::<uuid::Uuid, _>("id"),
|
||||
row.get::<String, _>("handle"),
|
||||
row.get::<String, _>("did"),
|
||||
),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let collections_query =
|
||||
sqlx::query("SELECT DISTINCT collection FROM records WHERE repo_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
let collections: Vec<String> = match collections_query {
|
||||
Ok(rows) => rows.iter().map(|r| r.get("collection")).collect(),
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let did_doc = json!({
|
||||
"id": did,
|
||||
"alsoKnownAs": [format!("at://{}", handle)]
|
||||
});
|
||||
|
||||
Json(json!({
|
||||
"handle": handle,
|
||||
"did": did,
|
||||
"didDoc": did_doc,
|
||||
"collections": collections,
|
||||
"handleIsCorrect": true
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
7
src/api/repo/mod.rs
Normal file
7
src/api/repo/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod blob;
|
||||
pub mod meta;
|
||||
pub mod record;
|
||||
|
||||
pub use blob::upload_blob;
|
||||
pub use meta::describe_repo;
|
||||
pub use record::{create_record, delete_record, get_record, list_records, put_record};
|
||||
236
src/api/repo/record/delete.rs
Normal file
236
src/api/repo/record/delete.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use cid::Cid;
|
||||
use jacquard::types::{
|
||||
did::Did,
|
||||
integer::LimitedU32,
|
||||
string::{Nsid, Tid},
|
||||
};
|
||||
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DeleteRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: String,
|
||||
#[serde(rename = "swapRecord")]
|
||||
pub swap_record: Option<String>,
|
||||
#[serde(rename = "swapCommit")]
|
||||
pub swap_commit: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn delete_record(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(input): Json<DeleteRecordInput>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let token = auth_header
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap_or("")
|
||||
.replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (
|
||||
row.get::<String, _>("did"),
|
||||
row.get::<Vec<u8>, _>("key_bytes"),
|
||||
),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if input.repo != did {
|
||||
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "User not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let current_root_cid = match repo_root_query {
|
||||
Ok(Some(row)) => {
|
||||
let cid_str: String = row.get("repo_root_cid");
|
||||
Cid::from_str(&cid_str).ok()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if current_root_cid.is_none() {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let current_root_cid = current_root_cid.unwrap();
|
||||
|
||||
let commit_bytes = match state.block_store.get(¤t_root_cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
Ok(None) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(),
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to load commit block: {:?}", e)}))).into_response(),
|
||||
};
|
||||
|
||||
let commit = match Commit::from_cbor(&commit_bytes) {
|
||||
Ok(c) => c,
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to parse commit: {:?}", e)}))).into_response(),
|
||||
};
|
||||
|
||||
let mst_root = commit.data;
|
||||
let store = Arc::new(state.block_store.clone());
|
||||
let mst = Mst::load(store.clone(), mst_root, None);
|
||||
|
||||
let collection_nsid = match input.collection.parse::<Nsid>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidCollection"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let key = format!("{}/{}", collection_nsid, input.rkey);
|
||||
|
||||
// TODO: Check swapRecord if provided? Skipping for brevity/robustness
|
||||
|
||||
if let Err(e) = mst.delete(&key).await {
|
||||
error!("Failed to delete from MST: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to delete from MST: {:?}", e)}))).into_response();
|
||||
}
|
||||
|
||||
let new_mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(_e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid));
|
||||
|
||||
let new_commit_bytes =
|
||||
match new_commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(_e) => return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
json!({"error": "InternalError", "message": "Failed to serialize new commit"}),
|
||||
),
|
||||
)
|
||||
.into_response(),
|
||||
};
|
||||
|
||||
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(_e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to save new commit"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
|
||||
.bind(new_root_cid.to_string())
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_repo {
|
||||
error!("Failed to update repo root in DB: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let record_delete =
|
||||
sqlx::query("DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3")
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&input.rkey)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = record_delete {
|
||||
error!("Error deleting record index: {:?}", e);
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(json!({}))).into_response()
|
||||
}
|
||||
10
src/api/repo/record/mod.rs
Normal file
10
src/api/repo/record/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
pub mod delete;
|
||||
pub mod read;
|
||||
pub mod write;
|
||||
|
||||
pub use delete::{DeleteRecordInput, delete_record};
|
||||
pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
|
||||
pub use write::{
|
||||
CreateRecordInput, CreateRecordOutput, PutRecordInput, PutRecordOutput, create_record,
|
||||
put_record,
|
||||
};
|
||||
236
src/api/repo/record/read.rs
Normal file
236
src/api/repo/record/read.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use cid::Cid;
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
use std::str::FromStr;
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: String,
|
||||
pub cid: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_record(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<GetRecordInput>,
|
||||
) -> Response {
|
||||
let user_row = if input.repo.starts_with("did:") {
|
||||
sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query("SELECT id FROM users WHERE handle = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
let user_id: uuid::Uuid = match user_row {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let record_row = sqlx::query(
|
||||
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&input.rkey)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let record_cid_str: String = match record_row {
|
||||
Ok(Some(row)) => row.get("record_cid"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Record not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(expected_cid) = &input.cid {
|
||||
if &record_cid_str != expected_cid {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Record CID mismatch"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let cid = match Cid::from_str(&record_cid_str) {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Invalid CID in DB"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let block = match state.block_store.get(&cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Record block not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let value: serde_json::Value = match serde_ipld_dagcbor::from_slice(&block) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("Failed to deserialize record: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
Json(json!({
|
||||
"uri": format!("at://{}/{}/{}", input.repo, input.collection, input.rkey),
|
||||
"cid": record_cid_str,
|
||||
"value": value
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ListRecordsInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub limit: Option<i32>,
|
||||
pub cursor: Option<String>,
|
||||
#[serde(rename = "rkeyStart")]
|
||||
pub rkey_start: Option<String>,
|
||||
#[serde(rename = "rkeyEnd")]
|
||||
pub rkey_end: Option<String>,
|
||||
pub reverse: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ListRecordsOutput {
|
||||
pub cursor: Option<String>,
|
||||
pub records: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub async fn list_records(
|
||||
State(state): State<AppState>,
|
||||
Query(input): Query<ListRecordsInput>,
|
||||
) -> Response {
|
||||
let user_row = if input.repo.starts_with("did:") {
|
||||
sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query("SELECT id FROM users WHERE handle = $1")
|
||||
.bind(&input.repo)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
let user_id: uuid::Uuid = match user_row {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "NotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let limit = input.limit.unwrap_or(50).clamp(1, 100);
|
||||
let reverse = input.reverse.unwrap_or(false);
|
||||
|
||||
// Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination
|
||||
// TODO: Implement rkeyStart/End and correct cursor logic
|
||||
|
||||
let query_str = format!(
|
||||
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 {} ORDER BY rkey {} LIMIT {}",
|
||||
if let Some(_c) = &input.cursor {
|
||||
if reverse {
|
||||
"AND rkey < $3"
|
||||
} else {
|
||||
"AND rkey > $3"
|
||||
}
|
||||
} else {
|
||||
""
|
||||
},
|
||||
if reverse { "DESC" } else { "ASC" },
|
||||
limit
|
||||
);
|
||||
|
||||
let mut query = sqlx::query(&query_str)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection);
|
||||
|
||||
if let Some(c) = &input.cursor {
|
||||
query = query.bind(c);
|
||||
}
|
||||
|
||||
let rows = match query.fetch_all(&state.db).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Error listing records: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut records = Vec::new();
|
||||
let mut last_rkey = None;
|
||||
|
||||
for row in rows {
|
||||
let rkey: String = row.get("rkey");
|
||||
let cid_str: String = row.get("record_cid");
|
||||
last_rkey = Some(rkey.clone());
|
||||
|
||||
if let Ok(cid) = Cid::from_str(&cid_str) {
|
||||
if let Ok(Some(block)) = state.block_store.get(&cid).await {
|
||||
if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
|
||||
records.push(json!({
|
||||
"uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
|
||||
"cid": cid_str,
|
||||
"value": value
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json(ListRecordsOutput {
|
||||
cursor: last_rkey,
|
||||
records,
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
591
src/api/repo/record/write.rs
Normal file
591
src/api/repo/record/write.rs
Normal file
@@ -0,0 +1,591 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use chrono::Utc;
|
||||
use cid::Cid;
|
||||
use jacquard::types::{
|
||||
did::Did,
|
||||
integer::LimitedU32,
|
||||
string::{Nsid, Tid},
|
||||
};
|
||||
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CreateRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: Option<String>,
|
||||
pub validate: Option<bool>,
|
||||
pub record: serde_json::Value,
|
||||
#[serde(rename = "swapCommit")]
|
||||
pub swap_commit: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateRecordOutput {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
}
|
||||
|
||||
pub async fn create_record(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(input): Json<CreateRecordInput>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let token = auth_header
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap_or("")
|
||||
.replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (
|
||||
row.get::<String, _>("did"),
|
||||
row.get::<Vec<u8>, _>("key_bytes"),
|
||||
),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if input.repo != did {
|
||||
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "User not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let current_root_cid = match repo_root_query {
|
||||
Ok(Some(row)) => {
|
||||
let cid_str: String = row.get("repo_root_cid");
|
||||
Cid::from_str(&cid_str).ok()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if current_root_cid.is_none() {
|
||||
error!("Repo root not found for user {}", did);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let current_root_cid = current_root_cid.unwrap();
|
||||
|
||||
let commit_bytes = match state.block_store.get(¤t_root_cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
Ok(None) => {
|
||||
error!("Commit block not found: {}", current_root_cid);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to load commit block: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let commit = match Commit::from_cbor(&commit_bytes) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to parse commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mst_root = commit.data;
|
||||
let store = Arc::new(state.block_store.clone());
|
||||
let mst = Mst::load(store.clone(), mst_root, None);
|
||||
|
||||
let collection_nsid = match input.collection.parse::<Nsid>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidCollection"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let rkey = input
|
||||
.rkey
|
||||
.unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
|
||||
|
||||
let mut record_bytes = Vec::new();
|
||||
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
|
||||
error!("Error serializing record: {:?}", e);
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let record_cid = match state.block_store.put(&record_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save record block: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let key = format!("{}/{}", collection_nsid, rkey);
|
||||
if let Err(e) = mst.update(&key, record_cid).await {
|
||||
error!("Failed to update MST: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let new_mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get new MST root: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid));
|
||||
|
||||
let new_commit_bytes = match new_commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize new commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save new commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
|
||||
.bind(new_root_cid.to_string())
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_repo {
|
||||
error!("Failed to update repo root in DB: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let record_insert = sqlx::query(
|
||||
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&rkey)
|
||||
.bind(record_cid.to_string())
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = record_insert {
|
||||
error!("Error inserting record index: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to index record"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let output = CreateRecordOutput {
|
||||
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
|
||||
cid: record_cid.to_string(),
|
||||
};
|
||||
(StatusCode::OK, Json(output)).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct PutRecordInput {
|
||||
pub repo: String,
|
||||
pub collection: String,
|
||||
pub rkey: String,
|
||||
pub validate: Option<bool>,
|
||||
pub record: serde_json::Value,
|
||||
#[serde(rename = "swapCommit")]
|
||||
pub swap_commit: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PutRecordOutput {
|
||||
pub uri: String,
|
||||
pub cid: String,
|
||||
}
|
||||
|
||||
pub async fn put_record(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(input): Json<PutRecordInput>,
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let token = auth_header
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap_or("")
|
||||
.replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1"
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (did, key_bytes) = match session {
|
||||
Some(row) => (
|
||||
row.get::<String, _>("did"),
|
||||
row.get::<Vec<u8>, _>("key_bytes"),
|
||||
),
|
||||
None => {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if input.repo != did {
|
||||
return (StatusCode::FORBIDDEN, Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"}))).into_response();
|
||||
}
|
||||
|
||||
let user_query = sqlx::query("SELECT id FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let user_id: uuid::Uuid = match user_query {
|
||||
Ok(Some(row)) => row.get("id"),
|
||||
_ => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "User not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let repo_root_query = sqlx::query("SELECT repo_root_cid FROM repos WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let current_root_cid = match repo_root_query {
|
||||
Ok(Some(row)) => {
|
||||
let cid_str: String = row.get("repo_root_cid");
|
||||
Cid::from_str(&cid_str).ok()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if current_root_cid.is_none() {
|
||||
error!("Repo root not found for user {}", did);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let current_root_cid = current_root_cid.unwrap();
|
||||
|
||||
let commit_bytes = match state.block_store.get(¤t_root_cid).await {
|
||||
Ok(Some(b)) => b,
|
||||
Ok(None) => {
|
||||
error!("Commit block not found: {}", current_root_cid);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Commit block not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to load commit block: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to load commit block"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let commit = match Commit::from_cbor(&commit_bytes) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to parse commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mst_root = commit.data;
|
||||
let store = Arc::new(state.block_store.clone());
|
||||
let mst = Mst::load(store.clone(), mst_root, None);
|
||||
|
||||
let collection_nsid = match input.collection.parse::<Nsid>() {
|
||||
Ok(n) => n,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidCollection"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let rkey = input.rkey.clone();
|
||||
|
||||
let mut record_bytes = Vec::new();
|
||||
if let Err(e) = serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record) {
|
||||
error!("Error serializing record: {:?}", e);
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let record_cid = match state.block_store.put(&record_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save record block: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to save record block"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let key = format!("{}/{}", collection_nsid, rkey);
|
||||
if let Err(e) = mst.update(&key, record_cid).await {
|
||||
error!("Failed to update MST: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": format!("Failed to update MST: {:?}", e)}))).into_response();
|
||||
}
|
||||
|
||||
let new_mst_root = match mst.root().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get new MST root: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to get new MST root"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let did_obj = match Did::new(&did) {
|
||||
Ok(d) => d,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Invalid DID"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let rev = Tid::now(LimitedU32::MIN);
|
||||
|
||||
let new_commit = Commit::new_unsigned(did_obj, new_mst_root, rev, Some(current_root_cid));
|
||||
|
||||
let new_commit_bytes = match new_commit.to_cbor() {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize new commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
json!({"error": "InternalError", "message": "Failed to serialize new commit"}),
|
||||
),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let new_root_cid = match state.block_store.put(&new_commit_bytes).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to save new commit: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to save new commit"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let update_repo = sqlx::query("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2")
|
||||
.bind(new_root_cid.to_string())
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = update_repo {
|
||||
error!("Failed to update repo root in DB: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to update repo root in DB"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let record_insert = sqlx::query(
|
||||
"INSERT INTO records (repo_id, collection, rkey, record_cid) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, created_at = NOW()",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(&rkey)
|
||||
.bind(record_cid.to_string())
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = record_insert {
|
||||
error!("Error inserting record index: {:?}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Failed to index record"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let output = PutRecordOutput {
|
||||
uri: format!("at://{}/{}/{}", input.repo, input.collection, rkey),
|
||||
cid: record_cid.to_string(),
|
||||
};
|
||||
(StatusCode::OK, Json(output)).into_response()
|
||||
}
|
||||
25
src/api/server/meta.rs
Normal file
25
src/api/server/meta.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
|
||||
use serde_json::json;
|
||||
|
||||
use tracing::error;
|
||||
|
||||
pub async fn describe_server() -> impl IntoResponse {
|
||||
let domains_str =
|
||||
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
|
||||
let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect();
|
||||
|
||||
Json(json!({
|
||||
"availableUserDomains": domains
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn health(State(state): State<AppState>) -> impl IntoResponse {
|
||||
match sqlx::query("SELECT 1").execute(&state.db).await {
|
||||
Ok(_) => (StatusCode::OK, "OK"),
|
||||
Err(e) => {
|
||||
error!("Health check failed: {:?}", e);
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
|
||||
}
|
||||
}
|
||||
}
|
||||
5
src/api/server/mod.rs
Normal file
5
src/api/server/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod meta;
|
||||
pub mod session;
|
||||
|
||||
pub use meta::{describe_server, health};
|
||||
pub use session::{create_session, delete_session, get_session, refresh_session};
|
||||
@@ -1,34 +1,15 @@
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::State,
|
||||
Json,
|
||||
response::{IntoResponse, Response},
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bcrypt::verify;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use crate::state::AppState;
|
||||
use sqlx::Row;
|
||||
use bcrypt::verify;
|
||||
use tracing::{info, error, warn};
|
||||
|
||||
pub async fn describe_server() -> impl IntoResponse {
|
||||
let domains_str = std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
|
||||
let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect();
|
||||
|
||||
Json(json!({
|
||||
"availableUserDomains": domains
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn health(State(state): State<AppState>) -> impl IntoResponse {
|
||||
match sqlx::query("SELECT 1").execute(&state.db).await {
|
||||
Ok(_) => (StatusCode::OK, "OK"),
|
||||
Err(e) => {
|
||||
error!("Health check failed: {:?}", e);
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable")
|
||||
}
|
||||
}
|
||||
}
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CreateSessionInput {
|
||||
@@ -69,7 +50,11 @@ pub async fn create_session(
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
error!("Failed to create access token: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -77,45 +62,70 @@ pub async fn create_session(
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
error!("Failed to create refresh token: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
|
||||
.bind(&access_jwt)
|
||||
.bind(&refresh_jwt)
|
||||
.bind(&did)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
let session_insert = sqlx::query(
|
||||
"INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)",
|
||||
)
|
||||
.bind(&access_jwt)
|
||||
.bind(&refresh_jwt)
|
||||
.bind(&did)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match session_insert {
|
||||
Ok(_) => {
|
||||
return (StatusCode::OK, Json(CreateSessionOutput {
|
||||
access_jwt,
|
||||
refresh_jwt,
|
||||
handle,
|
||||
did,
|
||||
})).into_response();
|
||||
},
|
||||
return (
|
||||
StatusCode::OK,
|
||||
Json(CreateSessionOutput {
|
||||
access_jwt,
|
||||
refresh_jwt,
|
||||
handle,
|
||||
did,
|
||||
}),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to insert session: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("Password verification failed for identifier: {}", input.identifier);
|
||||
warn!(
|
||||
"Password verification failed for identifier: {}",
|
||||
input.identifier
|
||||
);
|
||||
}
|
||||
},
|
||||
}
|
||||
Ok(None) => {
|
||||
warn!("User not found for identifier: {}", input.identifier);
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error fetching user: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
(StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"}))).into_response()
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub async fn get_session(
|
||||
@@ -124,10 +134,18 @@ pub async fn get_session(
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
|
||||
let token = auth_header
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap_or("")
|
||||
.replace("Bearer ", "");
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
@@ -136,7 +154,7 @@ pub async fn get_session(
|
||||
JOIN users u ON s.did = u.did
|
||||
JOIN user_keys k ON u.id = k.user_id
|
||||
WHERE s.access_jwt = $1
|
||||
"#
|
||||
"#,
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -150,22 +168,34 @@ pub async fn get_session(
|
||||
let key_bytes: Vec<u8> = row.get("key_bytes");
|
||||
|
||||
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"}))).into_response();
|
||||
}
|
||||
|
||||
return (StatusCode::OK, Json(json!({
|
||||
"handle": handle,
|
||||
"did": did,
|
||||
"email": email,
|
||||
"didDoc": {}
|
||||
}))).into_response();
|
||||
},
|
||||
return (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"handle": handle,
|
||||
"did": did,
|
||||
"email": email,
|
||||
"didDoc": {}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Ok(None) => {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response();
|
||||
},
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error in get_session: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -176,10 +206,18 @@ pub async fn delete_session(
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
|
||||
let token = auth_header
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap_or("")
|
||||
.replace("Bearer ", "");
|
||||
|
||||
let result = sqlx::query("DELETE FROM sessions WHERE access_jwt = $1")
|
||||
.bind(token)
|
||||
@@ -191,13 +229,17 @@ pub async fn delete_session(
|
||||
if res.rows_affected() > 0 {
|
||||
return (StatusCode::OK, Json(json!({}))).into_response();
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error in delete_session: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
(StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed"}))).into_response()
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub async fn refresh_session(
|
||||
@@ -206,10 +248,18 @@ pub async fn refresh_session(
|
||||
) -> Response {
|
||||
let auth_header = headers.get("Authorization");
|
||||
if auth_header.is_none() {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationRequired"}))).into_response();
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationRequired"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let refresh_token = auth_header.unwrap().to_str().unwrap_or("").replace("Bearer ", "");
|
||||
let refresh_token = auth_header
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap_or("")
|
||||
.replace("Bearer ", "");
|
||||
|
||||
let session = sqlx::query(
|
||||
"SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.refresh_jwt = $1"
|
||||
@@ -231,27 +281,37 @@ pub async fn refresh_session(
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
error!("Failed to create access token: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
let new_refresh_jwt = match crate::auth::create_refresh_token(&did, &key_bytes) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
error!("Failed to create refresh token: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let update = sqlx::query("UPDATE sessions SET access_jwt = $1, refresh_jwt = $2 WHERE refresh_jwt = $3")
|
||||
.bind(&new_access_jwt)
|
||||
.bind(&new_refresh_jwt)
|
||||
.bind(&refresh_token)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
let update = sqlx::query(
|
||||
"UPDATE sessions SET access_jwt = $1, refresh_jwt = $2 WHERE refresh_jwt = $3",
|
||||
)
|
||||
.bind(&new_access_jwt)
|
||||
.bind(&new_refresh_jwt)
|
||||
.bind(&refresh_token)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match update {
|
||||
Ok(_) => {
|
||||
let user = sqlx::query("SELECT handle FROM users WHERE did = $1")
|
||||
let user = sqlx::query("SELECT handle FROM users WHERE did = $1")
|
||||
.bind(&did)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
@@ -259,36 +319,59 @@ pub async fn refresh_session(
|
||||
match user {
|
||||
Ok(Some(u)) => {
|
||||
let handle: String = u.get("handle");
|
||||
return (StatusCode::OK, Json(json!({
|
||||
"accessJwt": new_access_jwt,
|
||||
"refreshJwt": new_refresh_jwt,
|
||||
"handle": handle,
|
||||
"did": did
|
||||
}))).into_response();
|
||||
},
|
||||
return (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"accessJwt": new_access_jwt,
|
||||
"refreshJwt": new_refresh_jwt,
|
||||
"handle": handle,
|
||||
"did": did
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Ok(None) => {
|
||||
error!("User not found for existing session: {}", did);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
},
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error fetching user: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error updating session: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Ok(None) => {
|
||||
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"}))).into_response();
|
||||
},
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error fetching session: {:?}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
157
src/auth.rs
157
src/auth.rs
@@ -1,157 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use chrono::{Utc, Duration};
|
||||
use k256::ecdsa::{SigningKey, VerifyingKey, signature::Signer, signature::Verifier, Signature};
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub iss: String,
|
||||
pub sub: String,
|
||||
pub aud: String,
|
||||
pub exp: usize,
|
||||
pub iat: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scope: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lxm: Option<String>,
|
||||
pub jti: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Header {
|
||||
alg: String,
|
||||
typ: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UnsafeClaims {
|
||||
iss: String,
|
||||
sub: Option<String>,
|
||||
}
|
||||
|
||||
// fancy boy TokenData equivalent for compatibility/structure
|
||||
pub struct TokenData<T> {
|
||||
pub claims: T,
|
||||
}
|
||||
|
||||
pub fn get_did_from_token(token: &str) -> Result<String, String> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err("Invalid token format".to_string());
|
||||
}
|
||||
|
||||
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1])
|
||||
.map_err(|e| format!("Base64 decode failed: {}", e))?;
|
||||
|
||||
let claims: UnsafeClaims = serde_json::from_slice(&payload_bytes)
|
||||
.map_err(|e| format!("JSON decode failed: {}", e))?;
|
||||
|
||||
Ok(claims.sub.unwrap_or(claims.iss))
|
||||
}
|
||||
|
||||
pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String, anyhow::Error> {
|
||||
create_signed_token(did, "access", key_bytes, Duration::minutes(15))
|
||||
}
|
||||
|
||||
pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String, anyhow::Error> {
|
||||
create_signed_token(did, "refresh", key_bytes, Duration::days(7))
|
||||
}
|
||||
|
||||
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String, anyhow::Error> {
|
||||
let signing_key = SigningKey::from_slice(key_bytes)?;
|
||||
|
||||
let expiration = Utc::now()
|
||||
.checked_add_signed(Duration::seconds(60))
|
||||
.expect("valid timestamp")
|
||||
.timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
iss: did.to_owned(),
|
||||
sub: did.to_owned(),
|
||||
aud: aud.to_owned(),
|
||||
exp: expiration as usize,
|
||||
iat: Utc::now().timestamp() as usize,
|
||||
scope: None,
|
||||
lxm: Some(lxm.to_string()),
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
sign_claims(claims, &signing_key)
|
||||
}
|
||||
|
||||
fn create_signed_token(did: &str, scope: &str, key_bytes: &[u8], duration: Duration) -> Result<String, anyhow::Error> {
|
||||
let signing_key = SigningKey::from_slice(key_bytes)?;
|
||||
|
||||
let expiration = Utc::now()
|
||||
.checked_add_signed(duration)
|
||||
.expect("valid timestamp")
|
||||
.timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
iss: did.to_owned(),
|
||||
sub: did.to_owned(),
|
||||
aud: format!("did:web:{}", std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())),
|
||||
exp: expiration as usize,
|
||||
iat: Utc::now().timestamp() as usize,
|
||||
scope: Some(scope.to_string()),
|
||||
lxm: None,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
sign_claims(claims, &signing_key)
|
||||
}
|
||||
|
||||
fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String, anyhow::Error> {
|
||||
let header = Header {
|
||||
alg: "ES256K".to_string(),
|
||||
typ: "JWT".to_string(),
|
||||
};
|
||||
|
||||
let header_json = serde_json::to_string(&header)?;
|
||||
let claims_json = serde_json::to_string(&claims)?;
|
||||
|
||||
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
|
||||
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
|
||||
|
||||
let message = format!("{}.{}", header_b64, claims_b64);
|
||||
let signature: Signature = key.sign(message.as_bytes());
|
||||
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
|
||||
|
||||
Ok(format!("{}.{}", message, signature_b64))
|
||||
}
|
||||
|
||||
pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>, anyhow::Error> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(anyhow!("Invalid token format"));
|
||||
}
|
||||
|
||||
let header_b64 = parts[0];
|
||||
let claims_b64 = parts[1];
|
||||
let signature_b64 = parts[2];
|
||||
|
||||
let signature_bytes = URL_SAFE_NO_PAD.decode(signature_b64)
|
||||
.context("Base64 decode of signature failed")?;
|
||||
let signature = Signature::from_slice(&signature_bytes)
|
||||
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
|
||||
|
||||
let signing_key = SigningKey::from_slice(key_bytes)?;
|
||||
let verifying_key = VerifyingKey::from(&signing_key);
|
||||
|
||||
let message = format!("{}.{}", header_b64, claims_b64);
|
||||
verifying_key.verify(message.as_bytes(), &signature)
|
||||
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
|
||||
|
||||
let claims_bytes = URL_SAFE_NO_PAD.decode(claims_b64)
|
||||
.context("Base64 decode of claims failed")?;
|
||||
let claims: Claims = serde_json::from_slice(&claims_bytes)
|
||||
.context("JSON decode of claims failed")?;
|
||||
|
||||
let now = Utc::now().timestamp() as usize;
|
||||
if claims.exp < now {
|
||||
return Err(anyhow!("Token expired"));
|
||||
}
|
||||
|
||||
Ok(TokenData { claims })
|
||||
}
|
||||
38
src/auth/mod.rs
Normal file
38
src/auth/mod.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub mod token;
|
||||
pub mod verify;
|
||||
|
||||
pub use token::{create_access_token, create_refresh_token, create_service_token};
|
||||
pub use verify::{get_did_from_token, verify_token};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub iss: String,
|
||||
pub sub: String,
|
||||
pub aud: String,
|
||||
pub exp: usize,
|
||||
pub iat: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scope: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lxm: Option<String>,
|
||||
pub jti: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Header {
|
||||
pub alg: String,
|
||||
pub typ: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UnsafeClaims {
|
||||
pub iss: String,
|
||||
pub sub: Option<String>,
|
||||
}
|
||||
|
||||
// fancy boy TokenData equivalent for compatibility/structure
|
||||
pub struct TokenData<T> {
|
||||
pub claims: T,
|
||||
}
|
||||
86
src/auth/token.rs
Normal file
86
src/auth/token.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use super::{Claims, Header};
|
||||
use anyhow::Result;
|
||||
use base64::Engine as _;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use chrono::{Duration, Utc};
|
||||
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
|
||||
use uuid;
|
||||
|
||||
pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> {
|
||||
create_signed_token(did, "access", key_bytes, Duration::minutes(15))
|
||||
}
|
||||
|
||||
pub fn create_refresh_token(did: &str, key_bytes: &[u8]) -> Result<String> {
|
||||
create_signed_token(did, "refresh", key_bytes, Duration::days(7))
|
||||
}
|
||||
|
||||
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> {
|
||||
let signing_key = SigningKey::from_slice(key_bytes)?;
|
||||
|
||||
let expiration = Utc::now()
|
||||
.checked_add_signed(Duration::seconds(60))
|
||||
.expect("valid timestamp")
|
||||
.timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
iss: did.to_owned(),
|
||||
sub: did.to_owned(),
|
||||
aud: aud.to_owned(),
|
||||
exp: expiration as usize,
|
||||
iat: Utc::now().timestamp() as usize,
|
||||
scope: None,
|
||||
lxm: Some(lxm.to_string()),
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
sign_claims(claims, &signing_key)
|
||||
}
|
||||
|
||||
fn create_signed_token(
|
||||
did: &str,
|
||||
scope: &str,
|
||||
key_bytes: &[u8],
|
||||
duration: Duration,
|
||||
) -> Result<String> {
|
||||
let signing_key = SigningKey::from_slice(key_bytes)?;
|
||||
|
||||
let expiration = Utc::now()
|
||||
.checked_add_signed(duration)
|
||||
.expect("valid timestamp")
|
||||
.timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
iss: did.to_owned(),
|
||||
sub: did.to_owned(),
|
||||
aud: format!(
|
||||
"did:web:{}",
|
||||
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
|
||||
),
|
||||
exp: expiration as usize,
|
||||
iat: Utc::now().timestamp() as usize,
|
||||
scope: Some(scope.to_string()),
|
||||
lxm: None,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
sign_claims(claims, &signing_key)
|
||||
}
|
||||
|
||||
fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> {
|
||||
let header = Header {
|
||||
alg: "ES256K".to_string(),
|
||||
typ: "JWT".to_string(),
|
||||
};
|
||||
|
||||
let header_json = serde_json::to_string(&header)?;
|
||||
let claims_json = serde_json::to_string(&claims)?;
|
||||
|
||||
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
|
||||
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
|
||||
|
||||
let message = format!("{}.{}", header_b64, claims_b64);
|
||||
let signature: Signature = key.sign(message.as_bytes());
|
||||
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
|
||||
|
||||
Ok(format!("{}.{}", message, signature_b64))
|
||||
}
|
||||
60
src/auth/verify.rs
Normal file
60
src/auth/verify.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use super::{Claims, TokenData, UnsafeClaims};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use base64::Engine as _;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use chrono::Utc;
|
||||
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
|
||||
|
||||
pub fn get_did_from_token(token: &str) -> Result<String, String> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err("Invalid token format".to_string());
|
||||
}
|
||||
|
||||
let payload_bytes = URL_SAFE_NO_PAD
|
||||
.decode(parts[1])
|
||||
.map_err(|e| format!("Base64 decode failed: {}", e))?;
|
||||
|
||||
let claims: UnsafeClaims =
|
||||
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
|
||||
|
||||
Ok(claims.sub.unwrap_or(claims.iss))
|
||||
}
|
||||
|
||||
pub fn verify_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<Claims>> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(anyhow!("Invalid token format"));
|
||||
}
|
||||
|
||||
let header_b64 = parts[0];
|
||||
let claims_b64 = parts[1];
|
||||
let signature_b64 = parts[2];
|
||||
|
||||
let signature_bytes = URL_SAFE_NO_PAD
|
||||
.decode(signature_b64)
|
||||
.context("Base64 decode of signature failed")?;
|
||||
let signature = Signature::from_slice(&signature_bytes)
|
||||
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
|
||||
|
||||
let signing_key = SigningKey::from_slice(key_bytes)?;
|
||||
let verifying_key = VerifyingKey::from(&signing_key);
|
||||
|
||||
let message = format!("{}.{}", header_b64, claims_b64);
|
||||
verifying_key
|
||||
.verify(message.as_bytes(), &signature)
|
||||
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
|
||||
|
||||
let claims_bytes = URL_SAFE_NO_PAD
|
||||
.decode(claims_b64)
|
||||
.context("Base64 decode of claims failed")?;
|
||||
let claims: Claims =
|
||||
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
|
||||
|
||||
let now = Utc::now().timestamp() as usize;
|
||||
if claims.exp < now {
|
||||
return Err(anyhow!("Token expired"));
|
||||
}
|
||||
|
||||
Ok(TokenData { claims })
|
||||
}
|
||||
69
src/lib.rs
69
src/lib.rs
@@ -1,31 +1,70 @@
|
||||
pub mod api;
|
||||
pub mod state;
|
||||
pub mod auth;
|
||||
pub mod repo;
|
||||
pub mod state;
|
||||
pub mod storage;
|
||||
|
||||
use axum::{
|
||||
routing::{get, post, any},
|
||||
Router,
|
||||
routing::{any, get, post},
|
||||
};
|
||||
use state::AppState;
|
||||
|
||||
pub fn app(state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route("/health", get(api::server::health))
|
||||
.route("/xrpc/com.atproto.server.describeServer", get(api::server::describe_server))
|
||||
.route("/xrpc/com.atproto.server.createAccount", post(api::identity::create_account))
|
||||
.route("/xrpc/com.atproto.server.createSession", post(api::server::create_session))
|
||||
.route("/xrpc/com.atproto.server.getSession", get(api::server::get_session))
|
||||
.route("/xrpc/com.atproto.server.deleteSession", post(api::server::delete_session))
|
||||
.route("/xrpc/com.atproto.server.refreshSession", post(api::server::refresh_session))
|
||||
.route("/xrpc/com.atproto.repo.createRecord", post(api::repo::create_record))
|
||||
.route("/xrpc/com.atproto.repo.putRecord", post(api::repo::put_record))
|
||||
.route("/xrpc/com.atproto.repo.getRecord", get(api::repo::get_record))
|
||||
.route("/xrpc/com.atproto.repo.deleteRecord", post(api::repo::delete_record))
|
||||
.route("/xrpc/com.atproto.repo.listRecords", get(api::repo::list_records))
|
||||
.route("/xrpc/com.atproto.repo.describeRepo", get(api::repo::describe_repo))
|
||||
.route("/xrpc/com.atproto.repo.uploadBlob", post(api::repo::upload_blob))
|
||||
.route(
|
||||
"/xrpc/com.atproto.server.describeServer",
|
||||
get(api::server::describe_server),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.server.createAccount",
|
||||
post(api::identity::create_account),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.server.createSession",
|
||||
post(api::server::create_session),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.server.getSession",
|
||||
get(api::server::get_session),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.server.deleteSession",
|
||||
post(api::server::delete_session),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.server.refreshSession",
|
||||
post(api::server::refresh_session),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.repo.createRecord",
|
||||
post(api::repo::create_record),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.repo.putRecord",
|
||||
post(api::repo::put_record),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.repo.getRecord",
|
||||
get(api::repo::get_record),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.repo.deleteRecord",
|
||||
post(api::repo::delete_record),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.repo.listRecords",
|
||||
get(api::repo::list_records),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.repo.describeRepo",
|
||||
get(api::repo::describe_repo),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.repo.uploadBlob",
|
||||
post(api::repo::upload_blob),
|
||||
)
|
||||
.route("/.well-known/did.json", get(api::identity::well_known_did))
|
||||
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
|
||||
.route("/xrpc/{*method}", any(api::proxy::proxy_handler))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::net::SocketAddr;
|
||||
use bspds::state::AppState;
|
||||
use std::net::SocketAddr;
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use bytes::Bytes;
|
||||
use cid::Cid;
|
||||
use jacquard_repo::error::RepoError;
|
||||
use jacquard_repo::repo::CommitData;
|
||||
use cid::Cid;
|
||||
use sqlx::{PgPool, Row};
|
||||
use bytes::Bytes;
|
||||
use sha2::{Sha256, Digest};
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use multihash::Multihash;
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::{PgPool, Row};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PostgresBlockStore {
|
||||
@@ -31,7 +31,7 @@ impl BlockStore for PostgresBlockStore {
|
||||
Some(row) => {
|
||||
let data: Vec<u8> = row.get("data");
|
||||
Ok(Some(Bytes::from(data)))
|
||||
},
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
@@ -65,16 +65,21 @@ impl BlockStore for PostgresBlockStore {
|
||||
Ok(row.is_some())
|
||||
}
|
||||
|
||||
async fn put_many(&self, blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send) -> Result<(), RepoError> {
|
||||
async fn put_many(
|
||||
&self,
|
||||
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
|
||||
) -> Result<(), RepoError> {
|
||||
let blocks: Vec<_> = blocks.into_iter().collect();
|
||||
for (cid, data) in blocks {
|
||||
let cid_bytes = cid.to_bytes();
|
||||
sqlx::query("INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING")
|
||||
.bind(cid_bytes)
|
||||
.bind(data.as_ref())
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| RepoError::storage(e))?;
|
||||
sqlx::query(
|
||||
"INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING",
|
||||
)
|
||||
.bind(cid_bytes)
|
||||
.bind(data.as_ref())
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| RepoError::storage(e))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use sqlx::PgPool;
|
||||
use crate::repo::PostgresBlockStore;
|
||||
use crate::storage::{BlobStorage, S3BlobStorage};
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -14,6 +14,10 @@ impl AppState {
|
||||
pub async fn new(db: PgPool) -> Self {
|
||||
let block_store = PostgresBlockStore::new(db.clone());
|
||||
let blob_store = S3BlobStorage::new().await;
|
||||
Self { db, block_store, blob_store: Arc::new(blob_store) }
|
||||
Self {
|
||||
db,
|
||||
block_store,
|
||||
blob_store: Arc::new(blob_store),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use aws_config::BehaviorVersion;
|
||||
use aws_config::meta::region::RegionProviderChain;
|
||||
use aws_sdk_s3::Client;
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
use aws_config::meta::region::RegionProviderChain;
|
||||
use aws_config::BehaviorVersion;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum StorageError {
|
||||
@@ -55,7 +55,8 @@ impl S3BlobStorage {
|
||||
#[async_trait]
|
||||
impl BlobStorage for S3BlobStorage {
|
||||
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> {
|
||||
self.client.put_object()
|
||||
self.client
|
||||
.put_object()
|
||||
.bucket(&self.bucket)
|
||||
.key(key)
|
||||
.body(ByteStream::from(data.to_vec()))
|
||||
@@ -66,14 +67,19 @@ impl BlobStorage for S3BlobStorage {
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> {
|
||||
let resp = self.client.get_object()
|
||||
let resp = self
|
||||
.client
|
||||
.get_object()
|
||||
.bucket(&self.bucket)
|
||||
.key(key)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| StorageError::S3(e.to_string()))?;
|
||||
|
||||
let data = resp.body.collect().await
|
||||
let data = resp
|
||||
.body
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|e| StorageError::S3(e.to_string()))?
|
||||
.into_bytes();
|
||||
|
||||
@@ -81,7 +87,8 @@ impl BlobStorage for S3BlobStorage {
|
||||
}
|
||||
|
||||
async fn delete(&self, key: &str) -> Result<(), StorageError> {
|
||||
self.client.delete_object()
|
||||
self.client
|
||||
.delete_object()
|
||||
.bucket(&self.bucket)
|
||||
.key(key)
|
||||
.send()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use bspds::auth;
|
||||
use k256::SecretKey;
|
||||
use rand::rngs::OsRng;
|
||||
use chrono::{Utc, Duration};
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use serde_json::json;
|
||||
use bspds::auth;
|
||||
use chrono::{Duration, Utc};
|
||||
use k256::SecretKey;
|
||||
use k256::ecdsa::{SigningKey, signature::Signer};
|
||||
use rand::rngs::OsRng;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_jwt_flow() {
|
||||
@@ -24,7 +24,8 @@ fn test_jwt_flow() {
|
||||
|
||||
let aud = "did:web:service";
|
||||
let lxm = "com.example.test";
|
||||
let s_token = auth::create_service_token(did, aud, lxm, &key_bytes).expect("create service token");
|
||||
let s_token =
|
||||
auth::create_service_token(did, aud, lxm, &key_bytes).expect("create service token");
|
||||
let s_data = auth::verify_token(&s_token, &key_bytes).expect("verify service token");
|
||||
assert_eq!(s_data.claims.aud, aud);
|
||||
assert_eq!(s_data.claims.lxm, Some(lxm.to_string()));
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
use reqwest::{header, Client, StatusCode};
|
||||
use serde_json::{json, Value};
|
||||
use aws_config::BehaviorVersion;
|
||||
use aws_sdk_s3::Client as S3Client;
|
||||
use aws_sdk_s3::config::Credentials;
|
||||
use bspds::state::AppState;
|
||||
use chrono::Utc;
|
||||
use reqwest::{Client, StatusCode, header};
|
||||
use serde_json::{Value, json};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
#[allow(unused_imports)]
|
||||
use std::collections::HashMap;
|
||||
use std::sync::OnceLock;
|
||||
#[allow(unused_imports)]
|
||||
use std::time::Duration;
|
||||
use std::sync::OnceLock;
|
||||
use bspds::state::AppState;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use tokio::net::TcpListener;
|
||||
use testcontainers::{runners::AsyncRunner, ContainerAsync, ImageExt, GenericImage};
|
||||
use testcontainers::core::ContainerPort;
|
||||
use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner};
|
||||
use testcontainers_modules::postgres::Postgres;
|
||||
use aws_sdk_s3::Client as S3Client;
|
||||
use aws_config::BehaviorVersion;
|
||||
use aws_sdk_s3::config::Credentials;
|
||||
use wiremock::{MockServer, Mock, ResponseTemplate};
|
||||
use tokio::net::TcpListener;
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
static SERVER_URL: OnceLock<String> = OnceLock::new();
|
||||
static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
|
||||
@@ -46,7 +46,12 @@ pub async fn base_url() -> &'static str {
|
||||
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
|
||||
let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
|
||||
if podman_sock.exists() {
|
||||
unsafe { std::env::set_var("DOCKER_HOST", format!("unix://{}", podman_sock.display())); }
|
||||
unsafe {
|
||||
std::env::set_var(
|
||||
"DOCKER_HOST",
|
||||
format!("unix://{}", podman_sock.display()),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,7 +67,10 @@ pub async fn base_url() -> &'static str {
|
||||
.await
|
||||
.expect("Failed to start MinIO");
|
||||
|
||||
let s3_port = s3_container.get_host_port_ipv4(9000).await.expect("Failed to get S3 port");
|
||||
let s3_port = s3_container
|
||||
.get_host_port_ipv4(9000)
|
||||
.await
|
||||
.expect("Failed to get S3 port");
|
||||
let s3_endpoint = format!("http://127.0.0.1:{}", s3_port);
|
||||
|
||||
unsafe {
|
||||
@@ -76,7 +84,13 @@ pub async fn base_url() -> &'static str {
|
||||
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
|
||||
.region("us-east-1")
|
||||
.endpoint_url(&s3_endpoint)
|
||||
.credentials_provider(Credentials::new("minioadmin", "minioadmin", None, None, "test"))
|
||||
.credentials_provider(Credentials::new(
|
||||
"minioadmin",
|
||||
"minioadmin",
|
||||
None,
|
||||
None,
|
||||
"test",
|
||||
))
|
||||
.load()
|
||||
.await;
|
||||
|
||||
@@ -108,15 +122,24 @@ pub async fn base_url() -> &'static str {
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
unsafe { std::env::set_var("APPVIEW_URL", mock_server.uri()); }
|
||||
unsafe {
|
||||
std::env::set_var("APPVIEW_URL", mock_server.uri());
|
||||
}
|
||||
MOCK_APPVIEW.set(mock_server).ok();
|
||||
|
||||
S3_CONTAINER.set(s3_container).ok();
|
||||
|
||||
let container = Postgres::default().with_tag("18-alpine").start().await.expect("Failed to start Postgres");
|
||||
let container = Postgres::default()
|
||||
.with_tag("18-alpine")
|
||||
.start()
|
||||
.await
|
||||
.expect("Failed to start Postgres");
|
||||
let connection_string = format!(
|
||||
"postgres://postgres:postgres@127.0.0.1:{}/postgres",
|
||||
container.get_host_port_ipv4(5432).await.expect("Failed to get port")
|
||||
container
|
||||
.get_host_port_ipv4(5432)
|
||||
.await
|
||||
.expect("Failed to get port")
|
||||
);
|
||||
|
||||
DB_CONTAINER.set(container).ok();
|
||||
@@ -157,7 +180,11 @@ async fn spawn_app(database_url: String) -> String {
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.uploadBlob",
|
||||
base_url().await
|
||||
))
|
||||
.header(header::CONTENT_TYPE, mime)
|
||||
.bearer_auth(AUTH_TOKEN)
|
||||
.body(data)
|
||||
@@ -170,12 +197,11 @@ pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'stati
|
||||
body["blob"].clone()
|
||||
}
|
||||
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn create_test_post(
|
||||
client: &Client,
|
||||
text: &str,
|
||||
reply_to: Option<Value>
|
||||
reply_to: Option<Value>,
|
||||
) -> (String, String, String) {
|
||||
let collection = "app.bsky.feed.post";
|
||||
let mut record = json!({
|
||||
@@ -194,7 +220,11 @@ pub async fn create_test_post(
|
||||
"record": record
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.createRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(AUTH_TOKEN)
|
||||
.json(&payload)
|
||||
.send()
|
||||
@@ -202,11 +232,24 @@ pub async fn create_test_post(
|
||||
.expect("Failed to send createRecord");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK, "Failed to create post record");
|
||||
let body: Value = res.json().await.expect("createRecord response was not JSON");
|
||||
let body: Value = res
|
||||
.json()
|
||||
.await
|
||||
.expect("createRecord response was not JSON");
|
||||
|
||||
let uri = body["uri"].as_str().expect("Response had no URI").to_string();
|
||||
let cid = body["cid"].as_str().expect("Response had no CID").to_string();
|
||||
let rkey = uri.split('/').last().expect("URI was malformed").to_string();
|
||||
let uri = body["uri"]
|
||||
.as_str()
|
||||
.expect("Response had no URI")
|
||||
.to_string();
|
||||
let cid = body["cid"]
|
||||
.as_str()
|
||||
.expect("Response had no CID")
|
||||
.to_string();
|
||||
let rkey = uri
|
||||
.split('/')
|
||||
.last()
|
||||
.expect("URI was malformed")
|
||||
.to_string();
|
||||
|
||||
(uri, cid, rkey)
|
||||
}
|
||||
@@ -220,7 +263,11 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) {
|
||||
"password": "password"
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
@@ -231,7 +278,10 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) {
|
||||
}
|
||||
|
||||
let body: Value = res.json().await.expect("Invalid JSON");
|
||||
let access_jwt = body["accessJwt"].as_str().expect("No accessJwt").to_string();
|
||||
let access_jwt = body["accessJwt"]
|
||||
.as_str()
|
||||
.expect("No accessJwt")
|
||||
.to_string();
|
||||
let did = body["did"].as_str().expect("No did").to_string();
|
||||
(access_jwt, did)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
mod common;
|
||||
use common::*;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::{json, Value};
|
||||
use wiremock::{MockServer, Mock, ResponseTemplate};
|
||||
use serde_json::{Value, json};
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn test_resolve_handle() {
|
||||
@@ -23,7 +23,8 @@ use wiremock::matchers::{method, path};
|
||||
#[tokio::test]
|
||||
async fn test_well_known_did() {
|
||||
let client = client();
|
||||
let res = client.get(format!("{}/.well-known/did.json", base_url().await))
|
||||
let res = client
|
||||
.get(format!("{}/.well-known/did.json", base_url().await))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
@@ -71,7 +72,11 @@ async fn test_create_did_web_account_and_resolve() {
|
||||
"did": did
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
@@ -79,13 +84,20 @@ async fn test_create_did_web_account_and_resolve() {
|
||||
|
||||
if res.status() != StatusCode::OK {
|
||||
let status = res.status();
|
||||
let body: Value = res.json().await.unwrap_or(json!({"error": "could not parse body"}));
|
||||
let body: Value = res
|
||||
.json()
|
||||
.await
|
||||
.unwrap_or(json!({"error": "could not parse body"}));
|
||||
panic!("createAccount failed with status {}: {:?}", status, body);
|
||||
}
|
||||
let body: Value = res.json().await.expect("createAccount response was not JSON");
|
||||
let body: Value = res
|
||||
.json()
|
||||
.await
|
||||
.expect("createAccount response was not JSON");
|
||||
assert_eq!(body["did"], did);
|
||||
|
||||
let res = client.get(format!("{}/u/{}/did.json", base_url().await, handle))
|
||||
let res = client
|
||||
.get(format!("{}/u/{}/did.json", base_url().await, handle))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to fetch DID doc");
|
||||
@@ -111,14 +123,22 @@ async fn test_create_account_duplicate_handle() {
|
||||
"password": "password"
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
@@ -143,7 +163,11 @@ async fn test_did_web_lifecycle() {
|
||||
"did": did
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&create_payload)
|
||||
.send()
|
||||
.await
|
||||
@@ -162,7 +186,11 @@ async fn test_did_web_lifecycle() {
|
||||
"identifier": handle,
|
||||
"password": "password"
|
||||
});
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createSession",
|
||||
base_url().await
|
||||
))
|
||||
.json(&login_payload)
|
||||
.send()
|
||||
.await
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
mod common;
|
||||
use common::*;
|
||||
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde_json::{json, Value};
|
||||
use chrono::Utc;
|
||||
#[allow(unused_imports)]
|
||||
use reqwest;
|
||||
use serde_json::{Value, json};
|
||||
use std::time::Duration;
|
||||
|
||||
async fn setup_new_user(handle_prefix: &str) -> (String, String) {
|
||||
@@ -19,20 +18,36 @@ async fn setup_new_user(handle_prefix: &str) -> (String, String) {
|
||||
"email": email,
|
||||
"password": password
|
||||
});
|
||||
let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let create_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&create_account_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("setup_new_user: Failed to send createAccount");
|
||||
|
||||
if create_res.status() != StatusCode::OK {
|
||||
panic!("setup_new_user: Failed to create account: {:?}", create_res.text().await);
|
||||
if create_res.status() != reqwest::StatusCode::OK {
|
||||
panic!(
|
||||
"setup_new_user: Failed to create account: {:?}",
|
||||
create_res.text().await
|
||||
);
|
||||
}
|
||||
|
||||
let create_body: Value = create_res.json().await.expect("setup_new_user: createAccount response was not JSON");
|
||||
let create_body: Value = create_res
|
||||
.json()
|
||||
.await
|
||||
.expect("setup_new_user: createAccount response was not JSON");
|
||||
|
||||
let new_did = create_body["did"].as_str().expect("setup_new_user: Response had no DID").to_string();
|
||||
let new_jwt = create_body["accessJwt"].as_str().expect("setup_new_user: Response had no accessJwt").to_string();
|
||||
let new_did = create_body["did"]
|
||||
.as_str()
|
||||
.expect("setup_new_user: Response had no DID")
|
||||
.to_string();
|
||||
let new_jwt = create_body["accessJwt"]
|
||||
.as_str()
|
||||
.expect("setup_new_user: Response had no accessJwt")
|
||||
.to_string();
|
||||
|
||||
(new_did, new_jwt)
|
||||
}
|
||||
@@ -59,35 +74,59 @@ async fn test_post_crud_lifecycle() {
|
||||
}
|
||||
});
|
||||
|
||||
let create_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let create_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&jwt)
|
||||
.json(&create_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send create request");
|
||||
|
||||
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record");
|
||||
let create_body: Value = create_res.json().await.expect("create response was not JSON");
|
||||
let uri = create_body["uri"].as_str().unwrap();
|
||||
if create_res.status() != reqwest::StatusCode::OK {
|
||||
let status = create_res.status();
|
||||
let body = create_res
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Could not get body".to_string());
|
||||
panic!(
|
||||
"Failed to create record. Status: {}, Body: {}",
|
||||
status, body
|
||||
);
|
||||
}
|
||||
|
||||
let create_body: Value = create_res
|
||||
.json()
|
||||
.await
|
||||
.expect("create response was not JSON");
|
||||
let uri = create_body["uri"].as_str().unwrap();
|
||||
|
||||
let params = [
|
||||
("repo", did.as_str()),
|
||||
("collection", collection),
|
||||
("rkey", &rkey),
|
||||
];
|
||||
let get_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
|
||||
let get_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send get request");
|
||||
|
||||
assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create");
|
||||
assert_eq!(
|
||||
get_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get record after create"
|
||||
);
|
||||
let get_body: Value = get_res.json().await.expect("get response was not JSON");
|
||||
assert_eq!(get_body["uri"], uri);
|
||||
assert_eq!(get_body["value"]["text"], original_text);
|
||||
|
||||
|
||||
let updated_text = "This post has been updated.";
|
||||
let update_payload = json!({
|
||||
"repo": did,
|
||||
@@ -100,26 +139,46 @@ async fn test_post_crud_lifecycle() {
|
||||
}
|
||||
});
|
||||
|
||||
let update_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let update_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&jwt)
|
||||
.json(&update_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send update request");
|
||||
|
||||
assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record");
|
||||
assert_eq!(
|
||||
update_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to update record"
|
||||
);
|
||||
|
||||
|
||||
let get_updated_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
|
||||
let get_updated_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send get-after-update request");
|
||||
|
||||
assert_eq!(get_updated_res.status(), StatusCode::OK, "Failed to get record after update");
|
||||
let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON");
|
||||
assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated");
|
||||
|
||||
assert_eq!(
|
||||
get_updated_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get record after update"
|
||||
);
|
||||
let get_updated_body: Value = get_updated_res
|
||||
.json()
|
||||
.await
|
||||
.expect("get-updated response was not JSON");
|
||||
assert_eq!(
|
||||
get_updated_body["value"]["text"], updated_text,
|
||||
"Text was not updated"
|
||||
);
|
||||
|
||||
let delete_payload = json!({
|
||||
"repo": did,
|
||||
@@ -127,23 +186,38 @@ async fn test_post_crud_lifecycle() {
|
||||
"rkey": rkey
|
||||
});
|
||||
|
||||
let delete_res = client.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
|
||||
let delete_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.deleteRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&jwt)
|
||||
.json(&delete_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send delete request");
|
||||
|
||||
assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record");
|
||||
assert_eq!(
|
||||
delete_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to delete record"
|
||||
);
|
||||
|
||||
|
||||
let get_deleted_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
|
||||
let get_deleted_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send get-after-delete request");
|
||||
|
||||
assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record was found, but it should be deleted");
|
||||
assert_eq!(
|
||||
get_deleted_res.status(),
|
||||
reqwest::StatusCode::NOT_FOUND,
|
||||
"Record was found, but it should be deleted"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -161,24 +235,39 @@ async fn test_record_update_conflict_lifecycle() {
|
||||
"displayName": "Original Name"
|
||||
}
|
||||
});
|
||||
let create_res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let create_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&user_jwt)
|
||||
.json(&profile_payload)
|
||||
.send().await.expect("create profile failed");
|
||||
.send()
|
||||
.await
|
||||
.expect("create profile failed");
|
||||
|
||||
if create_res.status() != StatusCode::OK {
|
||||
if create_res.status() != reqwest::StatusCode::OK {
|
||||
return;
|
||||
}
|
||||
|
||||
let get_res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
|
||||
let get_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", &user_did),
|
||||
("collection", &"app.bsky.actor.profile".to_string()),
|
||||
("rkey", &"self".to_string()),
|
||||
])
|
||||
.send().await.expect("getRecord failed");
|
||||
.send()
|
||||
.await
|
||||
.expect("getRecord failed");
|
||||
let get_body: Value = get_res.json().await.expect("getRecord not json");
|
||||
let cid_v1 = get_body["cid"].as_str().expect("Profile v1 had no CID").to_string();
|
||||
let cid_v1 = get_body["cid"]
|
||||
.as_str()
|
||||
.expect("Profile v1 had no CID")
|
||||
.to_string();
|
||||
|
||||
let update_payload_v2 = json!({
|
||||
"repo": user_did,
|
||||
@@ -190,13 +279,26 @@ async fn test_record_update_conflict_lifecycle() {
|
||||
},
|
||||
"swapCommit": cid_v1 // <-- Correctly point to v1
|
||||
});
|
||||
let update_res_v2 = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let update_res_v2 = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&user_jwt)
|
||||
.json(&update_payload_v2)
|
||||
.send().await.expect("putRecord v2 failed");
|
||||
assert_eq!(update_res_v2.status(), StatusCode::OK, "v2 update failed");
|
||||
.send()
|
||||
.await
|
||||
.expect("putRecord v2 failed");
|
||||
assert_eq!(
|
||||
update_res_v2.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"v2 update failed"
|
||||
);
|
||||
let update_body_v2: Value = update_res_v2.json().await.expect("v2 body not json");
|
||||
let cid_v2 = update_body_v2["cid"].as_str().expect("v2 response had no CID").to_string();
|
||||
let cid_v2 = update_body_v2["cid"]
|
||||
.as_str()
|
||||
.expect("v2 response had no CID")
|
||||
.to_string();
|
||||
|
||||
let update_payload_v3_stale = json!({
|
||||
"repo": user_did,
|
||||
@@ -208,14 +310,20 @@ async fn test_record_update_conflict_lifecycle() {
|
||||
},
|
||||
"swapCommit": cid_v1
|
||||
});
|
||||
let update_res_v3_stale = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let update_res_v3_stale = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&user_jwt)
|
||||
.json(&update_payload_v3_stale)
|
||||
.send().await.expect("putRecord v3 (stale) failed");
|
||||
.send()
|
||||
.await
|
||||
.expect("putRecord v3 (stale) failed");
|
||||
|
||||
assert_eq!(
|
||||
update_res_v3_stale.status(),
|
||||
StatusCode::CONFLICT,
|
||||
reqwest::StatusCode::CONFLICT,
|
||||
"Stale update did not cause a 409 Conflict"
|
||||
);
|
||||
|
||||
@@ -229,10 +337,233 @@ async fn test_record_update_conflict_lifecycle() {
|
||||
},
|
||||
"swapCommit": cid_v2 // <-- Correct
|
||||
});
|
||||
let update_res_v3_good = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let update_res_v3_good = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&user_jwt)
|
||||
.json(&update_payload_v3_good)
|
||||
.send().await.expect("putRecord v3 (good) failed");
|
||||
.send()
|
||||
.await
|
||||
.expect("putRecord v3 (good) failed");
|
||||
|
||||
assert_eq!(update_res_v3_good.status(), StatusCode::OK, "v3 (good) update failed");
|
||||
assert_eq!(
|
||||
update_res_v3_good.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"v3 (good) update failed"
|
||||
);
|
||||
}
|
||||
|
||||
async fn create_post(
|
||||
client: &reqwest::Client,
|
||||
did: &str,
|
||||
jwt: &str,
|
||||
text: &str,
|
||||
) -> (String, String) {
|
||||
let collection = "app.bsky.feed.post";
|
||||
let rkey = format!("e2e_social_{}", Utc::now().timestamp_millis());
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
let create_payload = json!({
|
||||
"repo": did,
|
||||
"collection": collection,
|
||||
"rkey": rkey,
|
||||
"record": {
|
||||
"$type": collection,
|
||||
"text": text,
|
||||
"createdAt": now
|
||||
}
|
||||
});
|
||||
|
||||
let create_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(jwt)
|
||||
.json(&create_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send create post request");
|
||||
|
||||
assert_eq!(
|
||||
create_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to create post record"
|
||||
);
|
||||
let create_body: Value = create_res
|
||||
.json()
|
||||
.await
|
||||
.expect("create post response was not JSON");
|
||||
let uri = create_body["uri"].as_str().unwrap().to_string();
|
||||
let cid = create_body["cid"].as_str().unwrap().to_string();
|
||||
(uri, cid)
|
||||
}
|
||||
|
||||
async fn create_follow(
|
||||
client: &reqwest::Client,
|
||||
follower_did: &str,
|
||||
follower_jwt: &str,
|
||||
followee_did: &str,
|
||||
) -> (String, String) {
|
||||
let collection = "app.bsky.graph.follow";
|
||||
let rkey = format!("e2e_follow_{}", Utc::now().timestamp_millis());
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
let create_payload = json!({
|
||||
"repo": follower_did,
|
||||
"collection": collection,
|
||||
"rkey": rkey,
|
||||
"record": {
|
||||
"$type": collection,
|
||||
"subject": followee_did,
|
||||
"createdAt": now
|
||||
}
|
||||
});
|
||||
|
||||
let create_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(follower_jwt)
|
||||
.json(&create_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send create follow request");
|
||||
|
||||
assert_eq!(
|
||||
create_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to create follow record"
|
||||
);
|
||||
let create_body: Value = create_res
|
||||
.json()
|
||||
.await
|
||||
.expect("create follow response was not JSON");
|
||||
let uri = create_body["uri"].as_str().unwrap().to_string();
|
||||
let cid = create_body["cid"].as_str().unwrap().to_string();
|
||||
(uri, cid)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_social_flow_lifecycle() {
|
||||
let client = client();
|
||||
|
||||
let (alice_did, alice_jwt) = setup_new_user("alice-social").await;
|
||||
let (bob_did, bob_jwt) = setup_new_user("bob-social").await;
|
||||
|
||||
let (post1_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's first post!").await;
|
||||
|
||||
create_follow(&client, &bob_did, &bob_jwt, &alice_did).await;
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
let timeline_res_1 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline (1)");
|
||||
|
||||
assert_eq!(
|
||||
timeline_res_1.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get timeline (1)"
|
||||
);
|
||||
let timeline_body_1: Value = timeline_res_1.json().await.expect("Timeline (1) not JSON");
|
||||
let feed_1 = timeline_body_1["feed"].as_array().unwrap();
|
||||
assert_eq!(feed_1.len(), 1, "Timeline should have 1 post");
|
||||
assert_eq!(
|
||||
feed_1[0]["post"]["uri"], post1_uri,
|
||||
"Post URI mismatch in timeline (1)"
|
||||
);
|
||||
|
||||
let (post2_uri, _) = create_post(
|
||||
&client,
|
||||
&alice_did,
|
||||
&alice_jwt,
|
||||
"Alice's second post, so exciting!",
|
||||
)
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
let timeline_res_2 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline (2)");
|
||||
|
||||
assert_eq!(
|
||||
timeline_res_2.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get timeline (2)"
|
||||
);
|
||||
let timeline_body_2: Value = timeline_res_2.json().await.expect("Timeline (2) not JSON");
|
||||
let feed_2 = timeline_body_2["feed"].as_array().unwrap();
|
||||
assert_eq!(feed_2.len(), 2, "Timeline should have 2 posts");
|
||||
assert_eq!(
|
||||
feed_2[0]["post"]["uri"], post2_uri,
|
||||
"Post 2 should be first"
|
||||
);
|
||||
assert_eq!(
|
||||
feed_2[1]["post"]["uri"], post1_uri,
|
||||
"Post 1 should be second"
|
||||
);
|
||||
|
||||
let delete_payload = json!({
|
||||
"repo": alice_did,
|
||||
"collection": "app.bsky.feed.post",
|
||||
"rkey": post1_uri.split('/').last().unwrap()
|
||||
});
|
||||
let delete_res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.deleteRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&alice_jwt)
|
||||
.json(&delete_payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send delete request");
|
||||
assert_eq!(
|
||||
delete_res.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to delete record"
|
||||
);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
let timeline_res_3 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/app.bsky.feed.getTimeline",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&bob_jwt)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get timeline (3)");
|
||||
|
||||
assert_eq!(
|
||||
timeline_res_3.status(),
|
||||
reqwest::StatusCode::OK,
|
||||
"Failed to get timeline (3)"
|
||||
);
|
||||
let timeline_body_3: Value = timeline_res_3.json().await.expect("Timeline (3) not JSON");
|
||||
let feed_3 = timeline_body_3["feed"].as_array().unwrap();
|
||||
assert_eq!(feed_3.len(), 1, "Timeline should have 1 post after delete");
|
||||
assert_eq!(
|
||||
feed_3[0]["post"]["uri"], post2_uri,
|
||||
"Only post 2 should remain"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
mod common;
|
||||
|
||||
use axum::{
|
||||
routing::any,
|
||||
Router,
|
||||
extract::Request,
|
||||
http::StatusCode,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
use axum::{Router, extract::Request, http::StatusCode, routing::any};
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use reqwest::Client;
|
||||
use std::sync::Arc;
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
async fn spawn_mock_upstream() -> (String, tokio::sync::mpsc::Receiver<(String, String, Option<String>)>) {
|
||||
async fn spawn_mock_upstream() -> (
|
||||
String,
|
||||
tokio::sync::mpsc::Receiver<(String, String, Option<String>)>,
|
||||
) {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
@@ -20,7 +18,9 @@ async fn spawn_mock_upstream() -> (String, tokio::sync::mpsc::Receiver<(String,
|
||||
async move {
|
||||
let method = req.method().to_string();
|
||||
let uri = req.uri().to_string();
|
||||
let auth = req.headers().get("Authorization")
|
||||
let auth = req
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
@@ -45,7 +45,8 @@ async fn test_proxy_via_header() {
|
||||
let (upstream_url, mut rx) = spawn_mock_upstream().await;
|
||||
let client = Client::new();
|
||||
|
||||
let res = client.get(format!("{}/xrpc/com.example.test", app_url))
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/com.example.test", app_url))
|
||||
.header("atproto-proxy", &upstream_url)
|
||||
.header("Authorization", "Bearer test-token")
|
||||
.send()
|
||||
@@ -65,12 +66,15 @@ async fn test_proxy_via_header() {
|
||||
async fn test_proxy_via_env_var() {
|
||||
let (upstream_url, mut rx) = spawn_mock_upstream().await;
|
||||
|
||||
unsafe { std::env::set_var("APPVIEW_URL", &upstream_url); }
|
||||
unsafe {
|
||||
std::env::set_var("APPVIEW_URL", &upstream_url);
|
||||
}
|
||||
|
||||
let app_url = common::base_url().await;
|
||||
let client = Client::new();
|
||||
|
||||
let res = client.get(format!("{}/xrpc/com.example.envtest", app_url))
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/com.example.envtest", app_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -85,12 +89,15 @@ async fn test_proxy_via_env_var() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_proxy_missing_config() {
|
||||
unsafe { std::env::remove_var("APPVIEW_URL"); }
|
||||
unsafe {
|
||||
std::env::remove_var("APPVIEW_URL");
|
||||
}
|
||||
|
||||
let app_url = common::base_url().await;
|
||||
let client = Client::new();
|
||||
|
||||
let res = client.get(format!("{}/xrpc/com.example.fail", app_url))
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/com.example.fail", app_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -106,7 +113,8 @@ async fn test_proxy_auth_signing() {
|
||||
|
||||
let (access_jwt, did) = common::create_account_and_login(&client).await;
|
||||
|
||||
let res = client.get(format!("{}/xrpc/com.example.signed", app_url))
|
||||
let res = client
|
||||
.get(format!("{}/xrpc/com.example.signed", app_url))
|
||||
.header("atproto-proxy", &upstream_url)
|
||||
.header("Authorization", format!("Bearer {}", access_jwt))
|
||||
.send()
|
||||
|
||||
131
tests/repo.rs
131
tests/repo.rs
@@ -1,9 +1,9 @@
|
||||
mod common;
|
||||
use common::*;
|
||||
|
||||
use reqwest::{header, StatusCode};
|
||||
use serde_json::{json, Value};
|
||||
use chrono::Utc;
|
||||
use reqwest::{StatusCode, header};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
@@ -15,7 +15,11 @@ async fn test_get_record() {
|
||||
("rkey", "self"),
|
||||
];
|
||||
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
@@ -36,7 +40,11 @@ async fn test_get_record_not_found() {
|
||||
("rkey", "nonexistent"),
|
||||
];
|
||||
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
@@ -50,7 +58,11 @@ async fn test_get_record_not_found() {
|
||||
#[tokio::test]
|
||||
async fn test_upload_blob_no_auth() {
|
||||
let client = client();
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.uploadBlob",
|
||||
base_url().await
|
||||
))
|
||||
.header(header::CONTENT_TYPE, "text/plain")
|
||||
.body("no auth")
|
||||
.send()
|
||||
@@ -66,7 +78,11 @@ async fn test_upload_blob_no_auth() {
|
||||
async fn test_upload_blob_success() {
|
||||
let client = client();
|
||||
let (token, _) = create_account_and_login(&client).await;
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.uploadBlob",
|
||||
base_url().await
|
||||
))
|
||||
.header(header::CONTENT_TYPE, "text/plain")
|
||||
.bearer_auth(token)
|
||||
.body("This is our blob data")
|
||||
@@ -90,7 +106,11 @@ async fn test_put_record_no_auth() {
|
||||
"record": {}
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
@@ -118,7 +138,11 @@ async fn test_put_record_success() {
|
||||
}
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
@@ -135,23 +159,33 @@ async fn test_put_record_success() {
|
||||
#[ignore]
|
||||
async fn test_get_record_missing_params() {
|
||||
let client = client();
|
||||
let params = [
|
||||
("repo", "did:plc:12345"),
|
||||
];
|
||||
let params = [("repo", "did:plc:12345")];
|
||||
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for missing params");
|
||||
assert_eq!(
|
||||
res.status(),
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Expected 400 for missing params"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upload_blob_bad_token() {
|
||||
let client = client();
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.uploadBlob",
|
||||
base_url().await
|
||||
))
|
||||
.header(header::CONTENT_TYPE, "text/plain")
|
||||
.bearer_auth(BAD_AUTH_TOKEN)
|
||||
.body("This is our blob data")
|
||||
@@ -181,14 +215,22 @@ async fn test_put_record_mismatched_repo() {
|
||||
}
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::FORBIDDEN, "Expected 403 for mismatched repo and auth");
|
||||
assert_eq!(
|
||||
res.status(),
|
||||
StatusCode::FORBIDDEN,
|
||||
"Expected 403 for mismatched repo and auth"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -207,21 +249,33 @@ async fn test_put_record_invalid_schema() {
|
||||
}
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for invalid record schema");
|
||||
assert_eq!(
|
||||
res.status(),
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Expected 400 for invalid record schema"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upload_blob_unsupported_mime_type() {
|
||||
let client = client();
|
||||
let (token, _) = create_account_and_login(&client).await;
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.uploadBlob",
|
||||
base_url().await
|
||||
))
|
||||
.header(header::CONTENT_TYPE, "application/xml")
|
||||
.bearer_auth(token)
|
||||
.body("<xml>not an image</xml>")
|
||||
@@ -242,7 +296,11 @@ async fn test_list_records() {
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "10"),
|
||||
];
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
@@ -255,10 +313,12 @@ async fn test_list_records() {
|
||||
async fn test_describe_repo() {
|
||||
let client = client();
|
||||
let (_, did) = create_account_and_login(&client).await;
|
||||
let params = [
|
||||
("repo", did.as_str()),
|
||||
];
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.repo.describeRepo", base_url().await))
|
||||
let params = [("repo", did.as_str())];
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.describeRepo",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
@@ -282,7 +342,11 @@ async fn test_create_record_success_with_generated_rkey() {
|
||||
}
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.createRecord",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
@@ -313,7 +377,11 @@ async fn test_create_record_success_with_provided_rkey() {
|
||||
}
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.createRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.createRecord",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
@@ -322,7 +390,10 @@ async fn test_create_record_success_with_provided_rkey() {
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["uri"], format!("at://{}/app.bsky.feed.post/{}", did, rkey));
|
||||
assert_eq!(
|
||||
body["uri"],
|
||||
format!("at://{}/app.bsky.feed.post/{}", did, rkey)
|
||||
);
|
||||
// assert_eq!(body["cid"], "bafyreihy");
|
||||
}
|
||||
|
||||
@@ -336,7 +407,11 @@ async fn test_delete_record() {
|
||||
"collection": "app.bsky.feed.post",
|
||||
"rkey": "some_post_to_delete"
|
||||
});
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.deleteRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(token)
|
||||
.json(&payload)
|
||||
.send()
|
||||
|
||||
@@ -2,12 +2,13 @@ mod common;
|
||||
use common::*;
|
||||
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health() {
|
||||
let client = client();
|
||||
let res = client.get(format!("{}/health", base_url().await))
|
||||
let res = client
|
||||
.get(format!("{}/health", base_url().await))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
@@ -19,7 +20,11 @@ async fn test_health() {
|
||||
#[tokio::test]
|
||||
async fn test_describe_server() {
|
||||
let client = client();
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base_url().await))
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.describeServer",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
@@ -39,7 +44,11 @@ async fn test_create_session() {
|
||||
"email": format!("{}@example.com", handle),
|
||||
"password": "password"
|
||||
});
|
||||
let _ = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let _ = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await;
|
||||
@@ -49,7 +58,11 @@ async fn test_create_session() {
|
||||
"password": "password"
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createSession",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
@@ -67,14 +80,21 @@ async fn test_create_session_missing_identifier() {
|
||||
"password": "password"
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createSession",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert!(res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"Expected 400 or 422 for missing identifier, got {}", res.status());
|
||||
assert!(
|
||||
res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"Expected 400 or 422 for missing identifier, got {}",
|
||||
res.status()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -86,19 +106,31 @@ async fn test_create_account_invalid_handle() {
|
||||
"password": "password"
|
||||
});
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Expected 400 for invalid handle chars");
|
||||
assert_eq!(
|
||||
res.status(),
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Expected 400 for invalid handle chars"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_session() {
|
||||
let client = client();
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.server.getSession", base_url().await))
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.server.getSession",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(AUTH_TOKEN)
|
||||
.send()
|
||||
.await
|
||||
@@ -117,7 +149,11 @@ async fn test_refresh_session() {
|
||||
"email": format!("{}@example.com", handle),
|
||||
"password": "password"
|
||||
});
|
||||
let _ = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url().await))
|
||||
let _ = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createAccount",
|
||||
base_url().await
|
||||
))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await;
|
||||
@@ -126,7 +162,11 @@ async fn test_refresh_session() {
|
||||
"identifier": handle,
|
||||
"password": "password"
|
||||
});
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.createSession", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.createSession",
|
||||
base_url().await
|
||||
))
|
||||
.json(&login_payload)
|
||||
.send()
|
||||
.await
|
||||
@@ -134,10 +174,20 @@ async fn test_refresh_session() {
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Invalid JSON");
|
||||
let refresh_jwt = body["refreshJwt"].as_str().expect("No refreshJwt").to_string();
|
||||
let access_jwt = body["accessJwt"].as_str().expect("No accessJwt").to_string();
|
||||
let refresh_jwt = body["refreshJwt"]
|
||||
.as_str()
|
||||
.expect("No refreshJwt")
|
||||
.to_string();
|
||||
let access_jwt = body["accessJwt"]
|
||||
.as_str()
|
||||
.expect("No accessJwt")
|
||||
.to_string();
|
||||
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.refreshSession",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(&refresh_jwt)
|
||||
.send()
|
||||
.await
|
||||
@@ -154,7 +204,11 @@ async fn test_refresh_session() {
|
||||
#[tokio::test]
|
||||
async fn test_delete_session() {
|
||||
let client = client();
|
||||
let res = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base_url().await))
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.server.deleteSession",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(AUTH_TOKEN)
|
||||
.send()
|
||||
.await
|
||||
|
||||
@@ -6,10 +6,12 @@ use reqwest::StatusCode;
|
||||
#[ignore]
|
||||
async fn test_get_repo() {
|
||||
let client = client();
|
||||
let params = [
|
||||
("did", AUTH_DID),
|
||||
];
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.sync.getRepo", base_url().await))
|
||||
let params = [("did", AUTH_DID)];
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getRepo",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
@@ -26,7 +28,11 @@ async fn test_get_blocks() {
|
||||
("did", AUTH_DID),
|
||||
// "cids" would be a list of CIDs
|
||||
];
|
||||
let res = client.get(format!("{}/xrpc/com.atproto.sync.getBlocks", base_url().await))
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getBlocks",
|
||||
base_url().await
|
||||
))
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
|
||||
Reference in New Issue
Block a user