Inbound migrations work

This commit is contained in:
lewis
2025-12-18 21:20:41 +02:00
parent d695135a4d
commit 95958bb119
17 changed files with 951 additions and 184 deletions

View File

@@ -0,0 +1,82 @@
{
"db_name": "PostgreSQL",
"query": "SELECT\n handle, email, email_verified, is_admin, deactivated_at,\n preferred_comms_channel as \"preferred_channel: crate::comms::CommsChannel\",\n discord_verified, telegram_verified, signal_verified\n FROM users WHERE did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "handle",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "email",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "email_verified",
"type_info": "Bool"
},
{
"ordinal": 3,
"name": "is_admin",
"type_info": "Bool"
},
{
"ordinal": 4,
"name": "deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "preferred_channel: crate::comms::CommsChannel",
"type_info": {
"Custom": {
"name": "comms_channel",
"kind": {
"Enum": [
"email",
"discord",
"telegram",
"signal"
]
}
}
}
},
{
"ordinal": 6,
"name": "discord_verified",
"type_info": "Bool"
},
{
"ordinal": 7,
"name": "telegram_verified",
"type_info": "Bool"
},
{
"ordinal": 8,
"name": "signal_verified",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true,
false,
false,
true,
false,
false,
false,
false
]
},
"hash": "17da8b6f6b46eae067bd8842a369a406699888f689122d2bae8bef13b532bcd2"
}

View File

@@ -0,0 +1,28 @@
{
"db_name": "PostgreSQL",
"query": "SELECT id, deactivated_at FROM users WHERE did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "deactivated_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true
]
},
"hash": "933f6585efdafedc82a8b6ac3c1513f25459bd9ab08e385ebc929469666d7747"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "SELECT\n u.id, u.did, u.handle, u.password_hash,\n u.email_verified, u.discord_verified, u.telegram_verified, u.signal_verified,\n k.key_bytes, k.encryption_version\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.handle = $1 OR u.email = $1",
"query": "SELECT\n u.id, u.did, u.handle, u.password_hash,\n u.email_verified, u.discord_verified, u.telegram_verified, u.signal_verified,\n k.key_bytes, k.encryption_version\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.handle = $1 OR u.email = $1 OR u.did = $1",
"describe": {
"columns": [
{
@@ -72,5 +72,5 @@
true
]
},
"hash": "d61c982dac3a508393b31a30bad50c0088ce6e117fe63c5a1062a97000dedf89"
"hash": "c60e77678da0c42399179015971f55f4f811a0d666237a93035cfece07445590"
}

View File

@@ -0,0 +1,34 @@
{
"db_name": "PostgreSQL",
"query": "SELECT id, handle, deactivated_at FROM users WHERE did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "handle",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "deactivated_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
true
]
},
"hash": "e60550cc972a5b0dd7cbdbc20d6ae6439eae3811d488166dca1b41bcc11f81f7"
}

View File

@@ -32,7 +32,7 @@ pub async fn get_preferences(
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
Ok(user) => user,
Err(_) => {
return (
@@ -109,7 +109,7 @@ pub async fn put_preferences(
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
Ok(user) => user,
Err(_) => {
return (
@@ -119,12 +119,12 @@ pub async fn put_preferences(
.into_response();
}
};
let user_id: uuid::Uuid =
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_user.did)
let (user_id, is_migration): (uuid::Uuid, bool) =
match sqlx::query!("SELECT id, deactivated_at FROM users WHERE did = $1", auth_user.did)
.fetch_optional(&state.db)
.await
{
Ok(Some(id)) => id,
Ok(Some(row)) => (row.id, row.deactivated_at.is_some()),
_ => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
@@ -166,7 +166,7 @@ pub async fn put_preferences(
)
.into_response();
}
if pref_type == "app.bsky.actor.defs#declaredAgePref" {
if pref_type == "app.bsky.actor.defs#declaredAgePref" && !is_migration {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidRequest", "message": "declaredAgePref is read-only"})),

View File

@@ -1,4 +1,5 @@
use super::did::verify_did_web;
use crate::auth::{ServiceTokenVerifier, extract_bearer_token_from_header, is_service_token};
use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key};
use crate::state::{AppState, RateLimitKind};
use axum::{
@@ -15,7 +16,7 @@ use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc;
use tracing::{error, info, warn};
use tracing::{debug, error, info, warn};
fn extract_client_ip(headers: &HeaderMap) -> String {
if let Some(forwarded) = headers.get("x-forwarded-for")
@@ -50,6 +51,10 @@ pub struct CreateAccountInput {
pub struct CreateAccountOutput {
pub handle: String,
pub did: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub access_jwt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_jwt: Option<String>,
pub verification_required: bool,
pub verification_channel: String,
}
@@ -75,6 +80,58 @@ pub async fn create_account(
)
.into_response();
}
let migration_auth = if let Some(token) =
extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok()))
{
if is_service_token(&token) {
let verifier = ServiceTokenVerifier::new();
match verifier
.verify_service_token(&token, Some("com.atproto.server.createAccount"))
.await
{
Ok(claims) => {
debug!("Service token verified for migration: iss={}", claims.iss);
Some(claims.iss)
}
Err(e) => {
error!("Service token verification failed: {:?}", e);
return (
StatusCode::UNAUTHORIZED,
Json(json!({
"error": "AuthenticationFailed",
"message": format!("Service token verification failed: {}", e)
})),
)
.into_response();
}
}
} else {
None
}
} else {
None
};
let is_migration = migration_auth.is_some()
&& input.did.as_ref().map(|d| d.starts_with("did:plc:")).unwrap_or(false);
if is_migration {
let migration_did = input.did.as_ref().unwrap();
let auth_did = migration_auth.as_ref().unwrap();
if migration_did != auth_did {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "AuthorizationError",
"message": format!("Service token issuer {} does not match DID {}", auth_did, migration_did)
})),
)
.into_response();
}
info!(did = %migration_did, "Processing account migration");
}
if input.handle.contains('!') || input.handle.contains('@') {
return (
StatusCode::BAD_REQUEST,
@@ -99,46 +156,50 @@ pub async fn create_account(
}
let verification_channel = input.verification_channel.as_deref().unwrap_or("email");
let valid_channels = ["email", "discord", "telegram", "signal"];
if !valid_channels.contains(&verification_channel) {
if !valid_channels.contains(&verification_channel) && !is_migration {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidVerificationChannel", "message": "Invalid verification channel. Must be one of: email, discord, telegram, signal"})),
)
.into_response();
}
let verification_recipient = match verification_channel {
"email" => match &input.email {
Some(email) if !email.trim().is_empty() => email.trim().to_string(),
let verification_recipient = if is_migration {
None
} else {
Some(match verification_channel {
"email" => match &input.email {
Some(email) if !email.trim().is_empty() => email.trim().to_string(),
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingEmail", "message": "Email is required when using email verification"})),
).into_response(),
},
"discord" => match &input.discord_id {
Some(id) if !id.trim().is_empty() => id.trim().to_string(),
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingDiscordId", "message": "Discord ID is required when using Discord verification"})),
).into_response(),
},
"telegram" => match &input.telegram_username {
Some(username) if !username.trim().is_empty() => username.trim().to_string(),
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingTelegramUsername", "message": "Telegram username is required when using Telegram verification"})),
).into_response(),
},
"signal" => match &input.signal_number {
Some(number) if !number.trim().is_empty() => number.trim().to_string(),
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingSignalNumber", "message": "Signal phone number is required when using Signal verification"})),
).into_response(),
},
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingEmail", "message": "Email is required when using email verification"})),
Json(json!({"error": "InvalidVerificationChannel", "message": "Invalid verification channel"})),
).into_response(),
},
"discord" => match &input.discord_id {
Some(id) if !id.trim().is_empty() => id.trim().to_string(),
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingDiscordId", "message": "Discord ID is required when using Discord verification"})),
).into_response(),
},
"telegram" => match &input.telegram_username {
Some(username) if !username.trim().is_empty() => username.trim().to_string(),
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingTelegramUsername", "message": "Telegram username is required when using Telegram verification"})),
).into_response(),
},
"signal" => match &input.signal_number {
Some(number) if !number.trim().is_empty() => number.trim().to_string(),
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "MissingSignalNumber", "message": "Signal phone number is required when using Signal verification"})),
).into_response(),
},
_ => return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidVerificationChannel", "message": "Invalid verification channel"})),
).into_response(),
})
};
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let pds_endpoint = format!("https://{}", hostname);
@@ -246,10 +307,12 @@ pub async fn create_account(
.into_response();
}
d.clone()
} else if d.starts_with("did:plc:") && is_migration {
d.clone()
} else {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidDid", "message": "Only did:web DIDs can be provided; leave empty for did:plc"})),
Json(json!({"error": "InvalidDid", "message": "Only did:web DIDs can be provided; leave empty for did:plc. For migration with existing did:plc, provide service auth."})),
)
.into_response();
}
@@ -396,13 +459,18 @@ pub async fn create_account(
.await
.map(|c| c.unwrap_or(0) == 0)
.unwrap_or(false);
let deactivated_at: Option<chrono::DateTime<chrono::Utc>> = if is_migration {
Some(chrono::Utc::now())
} else {
None
};
let user_insert: Result<(uuid::Uuid,), _> = sqlx::query_as(
r#"INSERT INTO users (
handle, email, did, password_hash,
preferred_comms_channel,
discord_id, telegram_username, signal_number,
is_admin
) VALUES ($1, $2, $3, $4, $5::comms_channel, $6, $7, $8, $9) RETURNING id"#,
is_admin, deactivated_at, email_verified
) VALUES ($1, $2, $3, $4, $5::comms_channel, $6, $7, $8, $9, $10, $11) RETURNING id"#,
)
.bind(short_handle)
.bind(&email)
@@ -431,6 +499,8 @@ pub async fn create_account(
.filter(|s| !s.is_empty()),
)
.bind(is_first_user)
.bind(deactivated_at)
.bind(is_migration)
.fetch_one(&mut *tx)
.await;
let user_id = match user_insert {
@@ -477,21 +547,23 @@ pub async fn create_account(
}
};
if let Err(e) = sqlx::query!(
"INSERT INTO channel_verifications (user_id, channel, code, pending_identifier, expires_at) VALUES ($1, 'email', $2, $3, $4)",
user_id,
verification_code,
email,
code_expires_at
)
.execute(&mut *tx)
.await {
error!("Error inserting verification code: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
if !is_migration {
if let Err(e) = sqlx::query!(
"INSERT INTO channel_verifications (user_id, channel, code, pending_identifier, expires_at) VALUES ($1, 'email', $2, $3, $4)",
user_id,
verification_code,
email,
code_expires_at
)
.into_response();
.execute(&mut *tx)
.await {
error!("Error inserting verification code: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) {
Ok(enc) => enc,
@@ -636,50 +708,105 @@ pub async fn create_account(
)
.into_response();
}
if let Err(e) =
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await
{
warn!("Failed to sequence identity event for {}: {}", did, e);
}
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await
{
warn!("Failed to sequence account event for {}: {}", did, e);
}
let profile_record = json!({
"$type": "app.bsky.actor.profile",
"displayName": input.handle
});
if let Err(e) = crate::api::repo::record::create_record_internal(
&state,
&did,
"app.bsky.actor.profile",
"self",
&profile_record,
)
.await
{
warn!("Failed to create default profile for {}: {}", did, e);
}
if let Err(e) = crate::comms::enqueue_signup_verification(
&state.db,
user_id,
verification_channel,
&verification_recipient,
&verification_code,
)
.await
{
warn!(
"Failed to enqueue signup verification notification: {:?}",
e
);
if !is_migration {
if let Err(e) =
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await
{
warn!("Failed to sequence identity event for {}: {}", did, e);
}
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await
{
warn!("Failed to sequence account event for {}: {}", did, e);
}
let profile_record = json!({
"$type": "app.bsky.actor.profile",
"displayName": input.handle
});
if let Err(e) = crate::api::repo::record::create_record_internal(
&state,
&did,
"app.bsky.actor.profile",
"self",
&profile_record,
)
.await
{
warn!("Failed to create default profile for {}: {}", did, e);
}
if let Some(ref recipient) = verification_recipient {
if let Err(e) = crate::comms::enqueue_signup_verification(
&state.db,
user_id,
verification_channel,
recipient,
&verification_code,
)
.await
{
warn!(
"Failed to enqueue signup verification notification: {:?}",
e
);
}
}
}
let (access_jwt, refresh_jwt) = if is_migration {
let access_meta =
match crate::auth::create_access_token_with_metadata(&did, &secret_key_bytes) {
Ok(m) => m,
Err(e) => {
error!("Error creating access token for migration: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let refresh_meta =
match crate::auth::create_refresh_token_with_metadata(&did, &secret_key_bytes) {
Ok(m) => m,
Err(e) => {
error!("Error creating refresh token for migration: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
if let Err(e) = sqlx::query!(
"INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)",
did,
access_meta.jti,
refresh_meta.jti,
access_meta.expires_at,
refresh_meta.expires_at
)
.execute(&state.db)
.await
{
error!("Error creating session for migration: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
(Some(access_meta.token), Some(refresh_meta.token))
} else {
(None, None)
};
(
StatusCode::OK,
Json(CreateAccountOutput {
handle: short_handle.to_string(),
handle: full_handle.clone(),
did,
verification_required: true,
access_jwt,
refresh_jwt,
verification_required: !is_migration,
verification_channel: verification_channel.to_string(),
}),
)

View File

@@ -1,4 +1,5 @@
use crate::api::ApiError;
use crate::plc::signing_key_to_did_key;
use crate::state::AppState;
use axum::{
Json,
@@ -309,7 +310,7 @@ pub async fn get_recommended_did_credentials(
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};
@@ -334,24 +335,21 @@ pub async fn get_recommended_did_credentials(
};
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let pds_endpoint = format!("https://{}", hostname);
let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
let full_handle = if user.handle.contains('.') {
user.handle.clone()
} else {
format!("{}.{}", user.handle, hostname)
};
let signing_key = match k256::ecdsa::SigningKey::from_slice(&key_bytes) {
Ok(k) => k,
Err(_) => return ApiError::InternalError.into_response(),
};
let public_key = secret_key.public_key();
let encoded = public_key.to_encoded_point(true);
let did_key = format!(
"did:key:zQ3sh{}",
multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
.chars()
.skip(1)
.collect::<String>()
);
let did_key = signing_key_to_did_key(&signing_key);
(
StatusCode::OK,
Json(GetRecommendedDidCredentialsOutput {
rotation_keys: vec![did_key.clone()],
also_known_as: vec![format!("at://{}", user.handle)],
also_known_as: vec![format!("at://{}", full_handle)],
verification_methods: VerificationMethods { atproto: did_key },
services: Services {
atproto_pds: AtprotoPds {
@@ -380,7 +378,7 @@ pub async fn update_handle(
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
Ok(user) => user.did,
Err(e) => return ApiError::from(e).into_response(),
};

View File

@@ -24,7 +24,7 @@ pub async fn request_plc_operation_signature(
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};

View File

@@ -50,7 +50,7 @@ pub async fn sign_plc_operation(
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};

View File

@@ -29,7 +29,7 @@ pub async fn submit_plc_operation(
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};
@@ -40,7 +40,7 @@ pub async fn submit_plc_operation(
let op = &input.operation;
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let public_url = format!("https://{}", hostname);
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
let user = match sqlx::query!("SELECT id, handle, deactivated_at FROM users WHERE did = $1", did)
.fetch_optional(&state.db)
.await
{
@@ -53,6 +53,7 @@ pub async fn submit_plc_operation(
.into_response();
}
};
let is_migration = user.deactivated_at.is_some();
let key_row = match sqlx::query!(
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
user.id
@@ -93,21 +94,23 @@ pub async fn submit_plc_operation(
}
};
let user_did_key = signing_key_to_did_key(&signing_key);
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
let server_rotation_key =
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
let has_server_key = rotation_keys
.iter()
.any(|k| k.as_str() == Some(&server_rotation_key));
if !has_server_key {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Rotation keys do not include server's rotation key"
})),
)
.into_response();
if !is_migration {
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
let server_rotation_key =
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
let has_server_key = rotation_keys
.iter()
.any(|k| k.as_str() == Some(&server_rotation_key));
if !has_server_key {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Rotation keys do not include server's rotation key"
})),
)
.into_response();
}
}
}
if let Some(services) = op.get("services").and_then(|v| v.as_object())
@@ -135,30 +138,32 @@ pub async fn submit_plc_operation(
.into_response();
}
}
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object())
&& let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
&& atproto_key != user_did_key {
if !is_migration {
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object())
&& let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
&& atproto_key != user_did_key {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Incorrect signing key in verificationMethods"
})),
)
.into_response();
}
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
let expected_handle = format!("at://{}", user.handle);
let first_aka = also_known_as.first().and_then(|v| v.as_str());
if first_aka != Some(&expected_handle) {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Incorrect signing key in verificationMethods"
"message": "Incorrect handle in alsoKnownAs"
})),
)
.into_response();
}
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
let expected_handle = format!("at://{}", user.handle);
let first_aka = also_known_as.first().and_then(|v| v.as_str());
if first_aka != Some(&expected_handle) {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Incorrect handle in alsoKnownAs"
})),
)
.into_response();
}
}
let plc_client = PlcClient::new(None);

View File

@@ -1,3 +1,4 @@
use crate::auth::{ServiceTokenVerifier, is_service_token};
use crate::state::AppState;
use axum::body::Bytes;
use axum::{
@@ -13,22 +14,16 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use sha2::{Digest, Sha256};
use std::str::FromStr;
use tracing::error;
use tracing::{debug, error};
const MAX_BLOB_SIZE: usize = 1_000_000;
const MAX_VIDEO_BLOB_SIZE: usize = 100_000_000;
pub async fn upload_blob(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
body: Bytes,
) -> Response {
if body.len() > MAX_BLOB_SIZE {
return (
StatusCode::PAYLOAD_TOO_LARGE,
Json(json!({"error": "BlobTooLarge", "message": format!("Blob size {} exceeds maximum of {} bytes", body.len(), MAX_BLOB_SIZE)})),
)
.into_response();
}
let token = match crate::auth::extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok()),
) {
@@ -41,17 +36,66 @@ pub async fn upload_blob(
.into_response();
}
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
Ok(user) => user,
Err(_) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
let is_service_auth = is_service_token(&token);
let (did, is_migration) = if is_service_auth {
debug!("Verifying service token for blob upload");
let verifier = ServiceTokenVerifier::new();
match verifier
.verify_service_token(&token, Some("com.atproto.repo.uploadBlob"))
.await
{
Ok(claims) => {
debug!("Service token verified for DID: {}", claims.iss);
(claims.iss, false)
}
Err(e) => {
error!("Service token verification failed: {:?}", e);
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": format!("Service token verification failed: {}", e)})),
)
.into_response();
}
}
} else {
match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
Ok(user) => {
let deactivated = sqlx::query_scalar!(
"SELECT deactivated_at FROM users WHERE did = $1",
user.did
)
.fetch_optional(&state.db)
.await
.ok()
.flatten()
.flatten();
(user.did, deactivated.is_some())
}
Err(_) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
}
};
let did = auth_user.did;
let max_size = if is_service_auth || is_migration {
MAX_VIDEO_BLOB_SIZE
} else {
MAX_BLOB_SIZE
};
if body.len() > max_size {
return (
StatusCode::PAYLOAD_TOO_LARGE,
Json(json!({"error": "BlobTooLarge", "message": format!("Blob size {} exceeds maximum of {} bytes", body.len(), max_size)})),
)
.into_response();
}
let mime_type = headers
.get("content-type")
.and_then(|h| h.to_str().ok())

View File

@@ -53,7 +53,7 @@ pub async fn import_repo(
Some(t) => t,
None => return ApiError::AuthenticationRequired.into_response(),
};
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
Ok(user) => user,
Err(e) => return ApiError::from(e).into_response(),
};
@@ -82,16 +82,6 @@ pub async fn import_repo(
.into_response();
}
};
if user.deactivated_at.is_some() {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "AccountDeactivated",
"message": "Account is deactivated"
})),
)
.into_response();
}
if user.takedown_ref.is_some() {
return (
StatusCode::FORBIDDEN,
@@ -185,7 +175,58 @@ pub async fn import_repo(
let skip_verification = std::env::var("SKIP_IMPORT_VERIFICATION")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !skip_verification {
let is_migration = user.deactivated_at.is_some();
if skip_verification {
warn!("Skipping all CAR verification for import (SKIP_IMPORT_VERIFICATION=true)");
} else if is_migration {
debug!("Verifying CAR file structure for migration (skipping signature verification)");
let verifier = CarVerifier::new();
match verifier.verify_car_structure_only(did, &root, &blocks) {
Ok(verified) => {
debug!(
"CAR structure verification successful: rev={}, data_cid={}",
verified.rev, verified.data_cid
);
}
Err(crate::sync::verify::VerifyError::DidMismatch {
commit_did,
expected_did,
}) => {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "InvalidRequest",
"message": format!(
"CAR file is for DID {} but you are authenticated as {}",
commit_did, expected_did
)
})),
)
.into_response();
}
Err(crate::sync::verify::VerifyError::MstValidationFailed(msg)) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("MST validation failed: {}", msg)
})),
)
.into_response();
}
Err(e) => {
error!("CAR structure verification error: {:?}", e);
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": format!("CAR verification failed: {}", e)
})),
)
.into_response();
}
}
} else {
debug!("Verifying CAR file signature and structure for DID {}", did);
let verifier = CarVerifier::new();
match verifier.verify_car(did, &root, &blocks).await {
@@ -264,8 +305,6 @@ pub async fn import_repo(
.into_response();
}
}
} else {
warn!("Skipping CAR signature verification for import (SKIP_IMPORT_VERIFICATION=true)");
}
let max_blocks: usize = std::env::var("MAX_IMPORT_BLOCKS")
.ok()

View File

@@ -1,5 +1,5 @@
use crate::api::ApiError;
use crate::auth::BearerAuth;
use crate::auth::{BearerAuth, BearerAuthAllowDeactivated};
use crate::state::{AppState, RateLimitKind};
use axum::{
Json,
@@ -88,7 +88,7 @@ pub async fn create_session(
k.key_bytes, k.encryption_version
FROM users u
JOIN user_keys k ON u.id = k.user_id
WHERE u.handle = $1 OR u.email = $1"#,
WHERE u.handle = $1 OR u.email = $1 OR u.did = $1"#,
normalized_identifier
)
.fetch_optional(&state.db)
@@ -189,11 +189,11 @@ pub async fn create_session(
pub async fn get_session(
State(state): State<AppState>,
BearerAuth(auth_user): BearerAuth,
BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated,
) -> Response {
match sqlx::query!(
r#"SELECT
handle, email, email_verified, is_admin,
handle, email, email_verified, is_admin, deactivated_at,
preferred_comms_channel as "preferred_channel: crate::comms::CommsChannel",
discord_verified, telegram_verified, signal_verified
FROM users WHERE did = $1"#,
@@ -211,6 +211,7 @@ pub async fn get_session(
};
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let handle = full_handle(&row.handle, &pds_hostname);
let is_active = row.deactivated_at.is_none();
Json(json!({
"handle": handle,
"did": auth_user.did,
@@ -219,7 +220,8 @@ pub async fn get_session(
"preferredChannel": preferred_channel,
"preferredChannelVerified": preferred_channel_verified,
"isAdmin": row.is_admin,
"active": true,
"active": is_active,
"status": if is_active { "active" } else { "deactivated" },
"didDoc": {}
})).into_response()
}

View File

@@ -7,6 +7,7 @@ use std::time::Duration;
use crate::cache::Cache;
pub mod extractor;
pub mod service;
pub mod token;
pub mod verify;
@@ -23,6 +24,7 @@ pub use token::{
pub use verify::{
get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
};
pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
const KEY_CACHE_TTL_SECS: u64 = 300;
const SESSION_CACHE_TTL_SECS: u64 = 60;

375
src/auth/service.rs Normal file
View File

@@ -0,0 +1,375 @@
use anyhow::{Result, anyhow};
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::Utc;
use k256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FullDidDocument {
pub id: String,
#[serde(default)]
pub also_known_as: Vec<String>,
#[serde(default)]
pub verification_method: Vec<VerificationMethod>,
#[serde(default)]
pub service: Vec<DidService>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct VerificationMethod {
pub id: String,
#[serde(rename = "type")]
pub method_type: String,
pub controller: String,
#[serde(default)]
pub public_key_multibase: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DidService {
pub id: String,
#[serde(rename = "type")]
pub service_type: String,
pub service_endpoint: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceTokenClaims {
pub iss: String,
#[serde(default)]
pub sub: Option<String>,
pub aud: String,
pub exp: usize,
#[serde(default)]
pub iat: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lxm: Option<String>,
#[serde(default)]
pub jti: Option<String>,
}
impl ServiceTokenClaims {
pub fn subject(&self) -> &str {
self.sub.as_deref().unwrap_or(&self.iss)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TokenHeader {
pub alg: String,
pub typ: String,
}
pub struct ServiceTokenVerifier {
client: Client,
plc_directory_url: String,
pds_did: String,
}
impl ServiceTokenVerifier {
pub fn new() -> Self {
let plc_directory_url = std::env::var("PLC_DIRECTORY_URL")
.unwrap_or_else(|_| "https://plc.directory".to_string());
let pds_hostname =
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let pds_did = format!("did:web:{}", pds_hostname);
let client = Client::builder()
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(5))
.build()
.unwrap_or_else(|_| Client::new());
Self {
client,
plc_directory_url,
pds_did,
}
}
pub async fn verify_service_token(
&self,
token: &str,
required_lxm: Option<&str>,
) -> Result<ServiceTokenClaims> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(anyhow!("Invalid token format"));
}
let header_bytes = URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?;
let header: TokenHeader = serde_json::from_slice(&header_bytes)
.map_err(|e| anyhow!("JSON decode of header failed: {}", e))?;
if header.alg != "ES256K" {
return Err(anyhow!("Unsupported algorithm: {}", header.alg));
}
let claims_bytes = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?;
let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes)
.map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?;
let now = Utc::now().timestamp() as usize;
if claims.exp < now {
return Err(anyhow!("Token expired"));
}
if claims.aud != self.pds_did {
return Err(anyhow!(
"Invalid audience: expected {}, got {}",
self.pds_did,
claims.aud
));
}
if let Some(required) = required_lxm {
match &claims.lxm {
Some(lxm) if lxm == "*" || lxm == required => {}
Some(lxm) => {
return Err(anyhow!(
"Token lxm '{}' does not permit '{}'",
lxm,
required
));
}
None => {
return Err(anyhow!("Token missing lxm claim"));
}
}
}
let did = &claims.iss;
let public_key = self.resolve_signing_key(did).await?;
let signature_bytes = URL_SAFE_NO_PAD
.decode(parts[2])
.map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?;
let signature = Signature::from_slice(&signature_bytes)
.map_err(|e| anyhow!("Invalid signature format: {}", e))?;
let message = format!("{}.{}", parts[0], parts[1]);
public_key
.verify(message.as_bytes(), &signature)
.map_err(|e| anyhow!("Signature verification failed: {}", e))?;
debug!("Service token verified for DID: {}", did);
Ok(claims)
}
async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> {
let did_doc = self.resolve_did_document(did).await?;
let atproto_key = did_doc
.verification_method
.iter()
.find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did))
.ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?;
let multibase = atproto_key
.public_key_multibase
.as_ref()
.ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?;
parse_did_key_multibase(multibase)
}
async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> {
if did.starts_with("did:plc:") {
self.resolve_did_plc(did).await
} else if did.starts_with("did:web:") {
self.resolve_did_web(did).await
} else {
Err(anyhow!("Unsupported DID method: {}", did))
}
}
async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> {
let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did));
debug!("Resolving did:plc {} via {}", did, url);
let resp = self
.client
.get(&url)
.send()
.await
.map_err(|e| anyhow!("HTTP request failed: {}", e))?;
if resp.status() == reqwest::StatusCode::NOT_FOUND {
return Err(anyhow!("DID not found: {}", did));
}
if !resp.status().is_success() {
return Err(anyhow!("HTTP {}", resp.status()));
}
resp.json::<FullDidDocument>()
.await
.map_err(|e| anyhow!("Failed to parse DID document: {}", e))
}
async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> {
let host = did
.strip_prefix("did:web:")
.ok_or_else(|| anyhow!("Invalid did:web format"))?;
let decoded_host = host.replace("%3A", ":");
let (host_part, path_part) = if let Some(idx) = decoded_host.find('/') {
(&decoded_host[..idx], &decoded_host[idx..])
} else {
(decoded_host.as_str(), "")
};
let scheme = if host_part.starts_with("localhost")
|| host_part.starts_with("127.0.0.1")
|| host_part.contains(':')
{
"http"
} else {
"https"
};
let url = if path_part.is_empty() {
format!("{}://{}/.well-known/did.json", scheme, host_part)
} else {
format!("{}://{}{}/did.json", scheme, host_part, path_part)
};
debug!("Resolving did:web {} via {}", did, url);
let resp = self
.client
.get(&url)
.send()
.await
.map_err(|e| anyhow!("HTTP request failed: {}", e))?;
if !resp.status().is_success() {
return Err(anyhow!("HTTP {}", resp.status()));
}
resp.json::<FullDidDocument>()
.await
.map_err(|e| anyhow!("Failed to parse DID document: {}", e))
}
}
impl Default for ServiceTokenVerifier {
fn default() -> Self {
Self::new()
}
}
fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> {
if !multibase.starts_with('z') {
return Err(anyhow!("Expected base58btc multibase encoding (starts with 'z')"));
}
let (_, decoded) = multibase::decode(multibase)
.map_err(|e| anyhow!("Failed to decode multibase: {}", e))?;
if decoded.len() < 2 {
return Err(anyhow!("Invalid multicodec data"));
}
let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
(0xe701u16, &decoded[2..])
} else {
return Err(anyhow!(
"Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}",
decoded[0],
decoded[1]
));
};
if codec != 0xe701 {
return Err(anyhow!("Only secp256k1 keys are supported"));
}
VerifyingKey::from_sec1_bytes(key_bytes)
.map_err(|e| anyhow!("Invalid public key: {}", e))
}
pub fn is_service_token(token: &str) -> bool {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return false;
}
let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else {
return false;
};
let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&claims_bytes) else {
return false;
};
claims.get("lxm").is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_service_token() {
let claims_with_lxm = serde_json::json!({
"iss": "did:plc:test",
"sub": "did:plc:test",
"aud": "did:web:test.com",
"exp": 9999999999i64,
"iat": 1000000000i64,
"lxm": "com.atproto.repo.uploadBlob",
"jti": "test-jti"
});
let claims_without_lxm = serde_json::json!({
"iss": "did:plc:test",
"sub": "did:plc:test",
"aud": "did:web:test.com",
"exp": 9999999999i64,
"iat": 1000000000i64,
"jti": "test-jti"
});
let token_with_lxm = format!(
"{}.{}.{}",
URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"jwt"}"#),
URL_SAFE_NO_PAD.encode(claims_with_lxm.to_string()),
URL_SAFE_NO_PAD.encode("fake-sig")
);
let token_without_lxm = format!(
"{}.{}.{}",
URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#),
URL_SAFE_NO_PAD.encode(claims_without_lxm.to_string()),
URL_SAFE_NO_PAD.encode("fake-sig")
);
assert!(is_service_token(&token_with_lxm));
assert!(!is_service_token(&token_without_lxm));
}
#[test]
fn test_parse_did_key_multibase() {
let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB";
let result = parse_did_key_multibase(test_key);
assert!(result.is_ok(), "Failed to parse valid multibase key");
}
}

View File

@@ -86,6 +86,38 @@ impl CarVerifier {
})
}
pub fn verify_car_structure_only(
&self,
expected_did: &str,
root_cid: &Cid,
blocks: &HashMap<Cid, Bytes>,
) -> Result<VerifiedCar, VerifyError> {
let root_block = blocks
.get(root_cid)
.ok_or_else(|| VerifyError::BlockNotFound(root_cid.to_string()))?;
let commit =
Commit::from_cbor(root_block).map_err(|e| VerifyError::InvalidCommit(e.to_string()))?;
let commit_did = commit.did().as_str();
if commit_did != expected_did {
return Err(VerifyError::DidMismatch {
commit_did: commit_did.to_string(),
expected_did: expected_did.to_string(),
});
}
let data_cid = commit.data();
self.verify_mst_structure(data_cid, blocks)?;
debug!(
"MST structure verified for DID {} (signature verification skipped for migration)",
commit_did
);
Ok(VerifiedCar {
did: commit_did.to_string(),
rev: commit.rev().to_string(),
data_cid: *data_cid,
prev: commit.prev().cloned(),
})
}
async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> {
let did_doc = self.resolve_did_document(did).await?;
did_doc

View File

@@ -192,7 +192,7 @@ async fn test_import_repo_size_limit() {
}
#[tokio::test]
async fn test_import_deactivated_account_rejected() {
async fn test_import_deactivated_account_allowed_for_migration() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
let export_res = client
@@ -229,9 +229,8 @@ async fn test_import_deactivated_account_rejected() {
.await
.expect("Import failed");
assert!(
import_res.status() == StatusCode::FORBIDDEN
|| import_res.status() == StatusCode::UNAUTHORIZED,
"Expected FORBIDDEN (403) or UNAUTHORIZED (401), got {}",
import_res.status().is_success(),
"Deactivated accounts should allow import for migration, got {}",
import_res.status()
);
}