diff --git a/src/api/actor/preferences.rs b/src/api/actor/preferences.rs index 5b96495..3737706 100644 --- a/src/api/actor/preferences.rs +++ b/src/api/actor/preferences.rs @@ -1,4 +1,5 @@ use crate::api::error::ApiError; +use crate::auth::BearerAuthAllowDeactivated; use crate::state::AppState; use axum::{ Json, @@ -33,23 +34,9 @@ pub struct GetPreferencesOutput { } pub async fn get_preferences( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => { - return ApiError::AuthenticationRequired.into_response(); - } - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { - Ok(user) => user, - Err(_) => { - return ApiError::AuthenticationFailed(None).into_response(); - } - }; + let auth_user = auth.0; let has_full_access = auth_user.permissions().has_full_access(); let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", &*auth_user.did) @@ -117,24 +104,10 @@ pub struct PutPreferencesInput { } pub async fn put_preferences( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, Json(input): Json, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => { - return ApiError::AuthenticationRequired.into_response(); - } - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { - Ok(user) => user, - Err(_) => { - return ApiError::AuthenticationFailed(None).into_response(); - } - }; + let auth_user = auth.0; let has_full_access = auth_user.permissions().has_full_access(); let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", &*auth_user.did) diff --git a/src/api/age_assurance.rs b/src/api/age_assurance.rs index 613ae00..a6b9766 100644 --- a/src/api/age_assurance.rs +++ b/src/api/age_assurance.rs @@ -1,4 +1,4 @@ -use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; +use crate::auth::{extract_auth_token_from_header, validate_token_with_dpop}; use crate::state::AppState; use axum::{ Json, @@ -36,10 +36,24 @@ async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); tracing::debug!(?auth_header, "age assurance: extracting token"); - let token = extract_bearer_token_from_header(auth_header)?; + let extracted = extract_auth_token_from_header(auth_header)?; tracing::debug!("age assurance: got token, validating"); - let auth_user = match validate_bearer_token(&state.db, &token).await { + let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); + let http_uri = "/"; + + let auth_user = match validate_token_with_dpop( + &state.db, + &extracted.token, + extracted.is_dpop, + dpop_proof, + "GET", + http_uri, + false, + false, + ) + .await + { Ok(user) => { tracing::debug!(did = %user.did, "age assurance: validated user"); user diff --git a/src/api/identity/account.rs b/src/api/identity/account.rs index 9a97a30..c7aee84 100644 --- a/src/api/identity/account.rs +++ b/src/api/identity/account.rs @@ -1,7 +1,7 @@ use super::did::verify_did_web; use crate::api::error::ApiError; use crate::api::repo::record::utils::create_signed_commit; -use crate::auth::{ServiceTokenVerifier, extract_bearer_token_from_header, is_service_token}; +use crate::auth::{ServiceTokenVerifier, is_service_token}; use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key}; use crate::state::{AppState, RateLimitKind}; use crate::types::{Did, Handle, PlainPassword}; @@ -96,9 +96,10 @@ pub async fn create_account( .into_response(); } - let migration_auth = if let Some(token) = - extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) - { + let migration_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + let token = extracted.token; if is_service_token(&token) { let verifier = ServiceTokenVerifier::new(); match verifier diff --git a/src/api/identity/did.rs b/src/api/identity/did.rs index d924a10..83077fd 100644 --- a/src/api/identity/did.rs +++ b/src/api/identity/did.rs @@ -1,4 +1,5 @@ use crate::api::{ApiError, DidResponse, EmptyResponse}; +use crate::auth::BearerAuthAllowDeactivated; use crate::plc::signing_key_to_did_key; use crate::state::AppState; use axum::{ @@ -522,21 +523,9 @@ pub struct AtprotoPds { pub async fn get_recommended_did_credentials( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => { - return ApiError::AuthenticationRequired.into_response(); - } - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; + let auth_user = auth.0; let user = match sqlx::query!( "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", &auth_user.did @@ -601,20 +590,10 @@ pub struct UpdateHandleInput { pub async fn update_handle( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, Json(input): Json, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; + let auth_user = auth.0; if let Err(e) = crate::auth::scope_check::check_identity_scope( auth_user.is_oauth, auth_user.scope.as_deref(), diff --git a/src/api/identity/plc/request.rs b/src/api/identity/plc/request.rs index 62c4e54..96e9737 100644 --- a/src/api/identity/plc/request.rs +++ b/src/api/identity/plc/request.rs @@ -1,5 +1,6 @@ use crate::api::EmptyResponse; use crate::api::error::ApiError; +use crate::auth::BearerAuthAllowDeactivated; use crate::state::AppState; use axum::{ extract::State, @@ -14,19 +15,9 @@ fn generate_plc_token() -> String { pub async fn request_plc_operation_signature( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; + let auth_user = auth.0; if let Err(e) = crate::auth::scope_check::check_identity_scope( auth_user.is_oauth, auth_user.scope.as_deref(), diff --git a/src/api/identity/plc/sign.rs b/src/api/identity/plc/sign.rs index f056f4c..fcb9782 100644 --- a/src/api/identity/plc/sign.rs +++ b/src/api/identity/plc/sign.rs @@ -1,4 +1,5 @@ use crate::api::ApiError; +use crate::auth::BearerAuthAllowDeactivated; use crate::circuit_breaker::with_circuit_breaker; use crate::plc::{PlcClient, PlcError, PlcService, create_update_op, sign_operation}; use crate::state::AppState; @@ -39,20 +40,10 @@ pub struct SignPlcOperationOutput { pub async fn sign_plc_operation( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, Json(input): Json, ) -> Response { - let bearer = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; + let auth_user = auth.0; if let Err(e) = crate::auth::scope_check::check_identity_scope( auth_user.is_oauth, auth_user.scope.as_deref(), diff --git a/src/api/identity/plc/submit.rs b/src/api/identity/plc/submit.rs index 151b4d3..2a772d5 100644 --- a/src/api/identity/plc/submit.rs +++ b/src/api/identity/plc/submit.rs @@ -1,4 +1,5 @@ use crate::api::{ApiError, EmptyResponse}; +use crate::auth::BearerAuthAllowDeactivated; use crate::circuit_breaker::with_circuit_breaker; use crate::plc::{PlcClient, signing_key_to_did_key, validate_plc_operation}; use crate::state::AppState; @@ -19,24 +20,10 @@ pub struct SubmitPlcOperationInput { pub async fn submit_plc_operation( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, Json(input): Json, ) -> Response { - let bearer = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => { - return ApiError::AuthenticationRequired.into_response(); - } - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &bearer).await { - Ok(user) => user, - Err(e) => { - return ApiError::from(e).into_response(); - } - }; + let auth_user = auth.0; if let Err(e) = crate::auth::scope_check::check_identity_scope( auth_user.is_oauth, auth_user.scope.as_deref(), diff --git a/src/api/moderation/mod.rs b/src/api/moderation/mod.rs index 4217283..70396e3 100644 --- a/src/api/moderation/mod.rs +++ b/src/api/moderation/mod.rs @@ -1,5 +1,6 @@ use crate::api::ApiError; use crate::api::proxy_client::{is_ssrf_safe, proxy_client}; +use crate::auth::extractor::BearerAuthAllowTakendown; use crate::state::AppState; use axum::{ Json, @@ -41,22 +42,10 @@ fn get_report_service_config() -> Option<(String, String)> { pub async fn create_report( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowTakendown, Json(input): Json, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - - let auth_user = - match crate::auth::validate_bearer_token_allow_takendown(&state.db, &token).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; - + let auth_user = auth.0; let did = &auth_user.did; if let Some((service_url, service_did)) = get_report_service_config() { diff --git a/src/api/notification_prefs.rs b/src/api/notification_prefs.rs index 633418d..1e9854b 100644 --- a/src/api/notification_prefs.rs +++ b/src/api/notification_prefs.rs @@ -1,10 +1,9 @@ use crate::api::error::ApiError; -use crate::auth::validate_bearer_token; +use crate::auth::BearerAuth; use crate::state::AppState; use axum::{ Json, extract::State, - http::HeaderMap, response::{IntoResponse, Response}, }; use serde::{Deserialize, Serialize}; @@ -25,19 +24,8 @@ pub struct NotificationPrefsResponse { pub signal_verified: bool, } -pub async fn get_notification_prefs(State(state): State, headers: HeaderMap) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let user = match validate_bearer_token(&state.db, &token).await { - Ok(u) => u, - Err(_) => { - return ApiError::AuthenticationFailed(None).into_response(); - } - }; +pub async fn get_notification_prefs(State(state): State, auth: BearerAuth) -> Response { + let user = auth.0; let row = match sqlx::query( r#" SELECT @@ -100,22 +88,8 @@ pub struct GetNotificationHistoryResponse { pub notifications: Vec, } -pub async fn get_notification_history( - State(state): State, - headers: HeaderMap, -) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let user = match validate_bearer_token(&state.db, &token).await { - Ok(u) => u, - Err(_) => { - return ApiError::AuthenticationFailed(None).into_response(); - } - }; +pub async fn get_notification_history(State(state): State, auth: BearerAuth) -> Response { + let user = auth.0; let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", &user.did) @@ -253,21 +227,10 @@ pub async fn request_channel_verification( pub async fn update_notification_prefs( State(state): State, - headers: HeaderMap, + auth: BearerAuth, Json(input): Json, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let user = match validate_bearer_token(&state.db, &token).await { - Ok(u) => u, - Err(_) => { - return ApiError::AuthenticationFailed(None).into_response(); - } - }; + let user = auth.0; let user_row = match sqlx::query!( "SELECT id, handle, email FROM users WHERE did = $1", diff --git a/src/api/proxy.rs b/src/api/proxy.rs index 78bd123..def90c5 100644 --- a/src/api/proxy.rs +++ b/src/api/proxy.rs @@ -214,13 +214,28 @@ async fn proxy_handler( info!("Proxying {} request to {}", method_verb, target_url); let client = proxy_client(); - let mut request_builder = client.request(method_verb, &target_url); + let mut request_builder = client.request(method_verb.clone(), &target_url); let mut auth_header_val = headers.get("Authorization").cloned(); - if let Some(token) = crate::auth::extract_bearer_token_from_header( + if let Some(extracted) = crate::auth::extract_auth_token_from_header( headers.get("Authorization").and_then(|h| h.to_str().ok()), ) { - match crate::auth::validate_bearer_token(&state.db, &token).await { + let token = extracted.token; + let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); + let http_uri = uri.to_string(); + + match crate::auth::validate_token_with_dpop( + &state.db, + &token, + extracted.is_dpop, + dpop_proof, + method_verb.as_str(), + &http_uri, + false, + false, + ) + .await + { Ok(auth_user) => { if let Err(e) = crate::auth::scope_check::check_rpc_scope( auth_user.is_oauth, @@ -254,14 +269,7 @@ async fn proxy_handler( Err(e) => { warn!("Token validation failed: {:?}", e); if matches!(e, crate::auth::TokenValidationError::TokenExpired) { - let auth_header_str = headers - .get("Authorization") - .and_then(|h| h.to_str().ok()) - .unwrap_or(""); - let is_dpop = auth_header_str - .trim() - .get(..5) - .is_some_and(|s| s.eq_ignore_ascii_case("dpop ")); + let is_dpop = extracted.is_dpop; let scheme = if is_dpop { "DPoP" } else { "Bearer" }; let www_auth = format!( "{} error=\"invalid_token\", error_description=\"Token has expired\"", diff --git a/src/api/repo/blob.rs b/src/api/repo/blob.rs index f51cda1..fed3d0a 100644 --- a/src/api/repo/blob.rs +++ b/src/api/repo/blob.rs @@ -1,5 +1,5 @@ use crate::api::error::ApiError; -use crate::auth::{ServiceTokenVerifier, is_service_token}; +use crate::auth::{BearerAuthAllowDeactivated, ServiceTokenVerifier, is_service_token}; use crate::delegation::{self, DelegationActionType}; use crate::state::AppState; use crate::util::get_max_blob_size; @@ -45,11 +45,13 @@ pub async fn upload_blob( headers: axum::http::HeaderMap, body: Body, ) -> Response { - let Some(token) = crate::auth::extract_bearer_token_from_header( + let extracted = match crate::auth::extract_auth_token_from_header( headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) else { - return ApiError::AuthenticationRequired.into_response(); + ) { + Some(t) => t, + None => return ApiError::AuthenticationRequired.into_response(), }; + let token = extracted.token; let is_service_auth = is_service_token(&token); @@ -74,7 +76,23 @@ pub async fn upload_blob( } } } else { - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { + let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); + let http_uri = format!( + "https://{}/xrpc/com.atproto.repo.uploadBlob", + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) + ); + match crate::auth::validate_token_with_dpop( + &state.db, + &token, + extracted.is_dpop, + dpop_proof, + "POST", + &http_uri, + true, + false, + ) + .await + { Ok(user) => { let mime_type_for_check = headers .get("content-type") @@ -283,21 +301,10 @@ pub struct ListMissingBlobsOutput { pub async fn list_missing_blobs( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, Query(params): Query, ) -> Response { - let Some(token) = crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) else { - return ApiError::AuthenticationRequired.into_response(); - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { - Ok(user) => user, - Err(_) => { - return ApiError::AuthenticationFailed(None).into_response(); - } - }; + let auth_user = auth.0; let did = auth_user.did; let user_query = sqlx::query!("SELECT id FROM users WHERE did = $1", did.as_str()) .fetch_optional(&state.db) diff --git a/src/api/repo/import.rs b/src/api/repo/import.rs index e02e018..240b908 100644 --- a/src/api/repo/import.rs +++ b/src/api/repo/import.rs @@ -1,6 +1,7 @@ use crate::api::EmptyResponse; use crate::api::error::ApiError; use crate::api::repo::record::create_signed_commit; +use crate::auth::BearerAuthAllowDeactivated; use crate::state::AppState; use crate::sync::import::{ImportError, apply_import, parse_car}; use crate::sync::verify::CarVerifier; @@ -20,7 +21,7 @@ const DEFAULT_MAX_BLOCKS: usize = 500000; pub async fn import_repo( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuthAllowDeactivated, body: Bytes, ) -> Response { let accepting_imports = std::env::var("ACCEPTING_REPO_IMPORTS") @@ -41,17 +42,7 @@ pub async fn import_repo( )) .into_response(); } - let token = match crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) { - Some(t) => t, - None => return ApiError::AuthenticationRequired.into_response(), - }; - let auth_user = - match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; + let auth_user = auth.0; let did = &auth_user.did; let user = match sqlx::query!( "SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1", diff --git a/src/api/repo/record/batch.rs b/src/api/repo/record/batch.rs index 8437d27..bde399a 100644 --- a/src/api/repo/record/batch.rs +++ b/src/api/repo/record/batch.rs @@ -2,6 +2,7 @@ use super::validation::validate_record_with_status; use super::write::has_verified_comms_channel; use crate::api::error::ApiError; use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; +use crate::auth::BearerAuth; use crate::delegation::{self, DelegationActionType}; use crate::repo::tracking::TrackingBlockStore; use crate::state::AppState; @@ -85,7 +86,7 @@ pub struct CommitInfo { pub async fn apply_writes( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuth, Json(input): Json, ) -> Response { info!( @@ -93,15 +94,7 @@ pub async fn apply_writes( input.repo, input.writes.len() ); - let Some(token) = crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) else { - return ApiError::AuthenticationRequired.into_response(); - }; - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { - Ok(user) => user, - Err(_) => return ApiError::AuthenticationFailed(None).into_response(), - }; + let auth_user = auth.0; let did = auth_user.did.clone(); let is_oauth = auth_user.is_oauth; let scope = auth_user.scope; diff --git a/src/api/repo/record/write.rs b/src/api/repo/record/write.rs index 87cce47..8549d74 100644 --- a/src/api/repo/record/write.rs +++ b/src/api/repo/record/write.rs @@ -77,6 +77,7 @@ pub async fn prepare_repo_write( http_method, http_uri, false, + false, ) .await .map_err(|e| { diff --git a/src/api/server/account_status.rs b/src/api/server/account_status.rs index 652daa2..41aa487 100644 --- a/src/api/server/account_status.rs +++ b/src/api/server/account_status.rs @@ -59,6 +59,7 @@ pub async fn check_account_status( "GET", &http_uri, true, + false, ) .await { @@ -370,6 +371,7 @@ pub async fn activate_account( "POST", &http_uri, true, + false, ) .await { @@ -561,6 +563,7 @@ pub async fn deactivate_account( "POST", &http_uri, false, + false, ) .await { @@ -646,6 +649,7 @@ pub async fn request_account_delete( "POST", &http_uri, true, + false, ) .await { diff --git a/src/api/server/email.rs b/src/api/server/email.rs index fbc4c71..b86d6d1 100644 --- a/src/api/server/email.rs +++ b/src/api/server/email.rs @@ -193,20 +193,10 @@ pub struct UpdateEmailInput { pub async fn update_email( State(state): State, - headers: axum::http::HeaderMap, + auth: BearerAuth, Json(input): Json, ) -> Response { - let Some(bearer_token) = crate::auth::extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) else { - return ApiError::AuthenticationRequired.into_response(); - }; - - let auth_result = crate::auth::validate_bearer_token(&state.db, &bearer_token).await; - let auth_user = match auth_result { - Ok(user) => user, - Err(e) => return ApiError::from(e).into_response(), - }; + let auth_user = auth.0; if let Err(e) = crate::auth::scope_check::check_account_scope( auth_user.is_oauth, diff --git a/src/api/server/migration.rs b/src/api/server/migration.rs index c61369b..97be0fb 100644 --- a/src/api/server/migration.rs +++ b/src/api/server/migration.rs @@ -58,6 +58,7 @@ pub async fn update_did_document( "POST", &http_uri, true, + false, ) .await { @@ -224,6 +225,7 @@ pub async fn get_did_document( "GET", &http_uri, true, + false, ) .await { diff --git a/src/api/server/passkey_account.rs b/src/api/server/passkey_account.rs index 3da2b7c..d3da564 100644 --- a/src/api/server/passkey_account.rs +++ b/src/api/server/passkey_account.rs @@ -18,7 +18,7 @@ use tracing::{debug, error, info, warn}; use uuid::Uuid; use crate::api::repo::record::utils::create_signed_commit; -use crate::auth::{ServiceTokenVerifier, extract_bearer_token_from_header, is_service_token}; +use crate::auth::{ServiceTokenVerifier, is_service_token}; use crate::state::{AppState, RateLimitKind}; use crate::types::{Did, Handle, PlainPassword}; use crate::validation::validate_password; @@ -108,9 +108,10 @@ pub async fn create_passkey_account( .into_response(); } - let byod_auth = if let Some(token) = - extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) - { + let byod_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()), + ) { + let token = extracted.token; if is_service_token(&token) { let verifier = ServiceTokenVerifier::new(); match verifier diff --git a/src/api/server/session.rs b/src/api/server/session.rs index e16cd70..1a056aa 100644 --- a/src/api/server/session.rs +++ b/src/api/server/session.rs @@ -365,18 +365,19 @@ pub async fn get_session( pub async fn delete_session( State(state): State, headers: axum::http::HeaderMap, + _auth: BearerAuth, ) -> Response { - let token = match crate::auth::extract_bearer_token_from_header( + let extracted = match crate::auth::extract_auth_token_from_header( headers.get("Authorization").and_then(|h| h.to_str().ok()), ) { Some(t) => t, None => return ApiError::AuthenticationRequired.into_response(), }; - let jti = match crate::auth::get_jti_from_token(&token) { + let jti = match crate::auth::get_jti_from_token(&extracted.token) { Ok(jti) => jti, Err(_) => return ApiError::AuthenticationFailed(None).into_response(), }; - let did = crate::auth::get_did_from_token(&token).ok(); + let did = crate::auth::get_did_from_token(&extracted.token).ok(); match sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti) .execute(&state.db) .await @@ -408,12 +409,13 @@ pub async fn refresh_session( tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); return ApiError::RateLimitExceeded(None).into_response(); } - let refresh_token = match crate::auth::extract_bearer_token_from_header( + let extracted = match crate::auth::extract_auth_token_from_header( headers.get("Authorization").and_then(|h| h.to_str().ok()), ) { Some(t) => t, None => return ApiError::AuthenticationRequired.into_response(), }; + let refresh_token = extracted.token; let refresh_jti = match crate::auth::get_jti_from_token(&refresh_token) { Ok(jti) => jti, Err(_) => { @@ -1048,11 +1050,10 @@ pub async fn revoke_all_sessions( headers: HeaderMap, auth: BearerAuth, ) -> Response { - let current_jti = headers - .get("authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")) - .and_then(|token| crate::auth::get_jti_from_token(token).ok()); + let current_jti = crate::auth::extract_auth_token_from_header( + headers.get("authorization").and_then(|v| v.to_str().ok()), + ) + .and_then(|extracted| crate::auth::get_jti_from_token(&extracted.token).ok()); let Some(ref jti) = current_jti else { return ApiError::InvalidToken(None).into_response(); diff --git a/src/api/temp.rs b/src/api/temp.rs index cae0096..065cb2f 100644 --- a/src/api/temp.rs +++ b/src/api/temp.rs @@ -1,5 +1,5 @@ use crate::api::error::ApiError; -use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; +use crate::auth::{BearerAuth, extract_auth_token_from_header, validate_token_with_dpop}; use crate::state::AppState; use axum::{ Json, @@ -23,12 +23,25 @@ pub struct CheckSignupQueueOutput { } pub async fn check_signup_queue(State(state): State, headers: HeaderMap) -> Response { - if let Some(token) = - extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) - && let Ok(user) = validate_bearer_token(&state.db, &token).await - && user.is_oauth + if let Some(extracted) = + extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) { - return ApiError::Forbidden.into_response(); + let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); + if let Ok(user) = validate_token_with_dpop( + &state.db, + &extracted.token, + extracted.is_dpop, + dpop_proof, + "GET", + "/", + false, + false, + ) + .await + && user.is_oauth + { + return ApiError::Forbidden.into_response(); + } } Json(CheckSignupQueueOutput { activated: true, @@ -52,18 +65,10 @@ pub struct DereferenceScopeOutput { pub async fn dereference_scope( State(state): State, - headers: HeaderMap, + auth: BearerAuth, Json(input): Json, ) -> Response { - let Some(token) = extract_bearer_token_from_header( - headers.get("Authorization").and_then(|h| h.to_str().ok()), - ) else { - return ApiError::AuthenticationRequired.into_response(); - }; - - if validate_bearer_token(&state.db, &token).await.is_err() { - return ApiError::AuthenticationFailed(None).into_response(); - } + let _ = auth; let scope_parts: Vec<&str> = input.scope.split_whitespace().collect(); let mut resolved_scopes: Vec = Vec::new(); diff --git a/src/auth/extractor.rs b/src/auth/extractor.rs index 1c2978f..5d59e20 100644 --- a/src/auth/extractor.rs +++ b/src/auth/extractor.rs @@ -5,8 +5,9 @@ use axum::{ }; use super::{ - AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, - validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, + AuthenticatedUser, TokenValidationError, validate_bearer_token_allow_takendown, + validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated, + validate_token_with_dpop, }; use crate::api::error::ApiError; use crate::state::AppState; @@ -136,6 +137,7 @@ impl FromRequestParts for BearerAuth { method, &uri, false, + false, ) .await { @@ -191,6 +193,7 @@ impl FromRequestParts for BearerAuthAllowDeactivated { method, &uri, true, + false, ) .await { @@ -216,6 +219,58 @@ impl FromRequestParts for BearerAuthAllowDeactivated { } } +pub struct BearerAuthAllowTakendown(pub AuthenticatedUser); + +impl FromRequestParts for BearerAuthAllowTakendown { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let auth_header = parts + .headers + .get(AUTHORIZATION) + .ok_or(AuthError::MissingToken)? + .to_str() + .map_err(|_| AuthError::InvalidFormat)?; + + let extracted = + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; + + if extracted.is_dpop { + let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); + let method = parts.method.as_str(); + let uri = build_full_url(&parts.uri.to_string()); + + match validate_token_with_dpop( + &state.db, + &extracted.token, + true, + dpop_proof, + method, + &uri, + false, + true, + ) + .await + { + Ok(user) => Ok(BearerAuthAllowTakendown(user)), + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), + Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), + Err(_) => Err(AuthError::AuthenticationFailed), + } + } else { + match validate_bearer_token_allow_takendown(&state.db, &extracted.token).await { + Ok(user) => Ok(BearerAuthAllowTakendown(user)), + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), + Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), + Err(_) => Err(AuthError::AuthenticationFailed), + } + } + } +} + pub struct BearerAuthAdmin(pub AuthenticatedUser); impl FromRequestParts for BearerAuthAdmin { @@ -248,6 +303,7 @@ impl FromRequestParts for BearerAuthAdmin { method, &uri, false, + false, ) .await { diff --git a/src/auth/mod.rs b/src/auth/mod.rs index eba8382..1c50eb8 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -416,6 +416,7 @@ pub async fn invalidate_auth_cache(cache: &dyn Cache, did: &str) { let _ = cache.delete(&status_cache_key).await; } +#[allow(clippy::too_many_arguments)] pub async fn validate_token_with_dpop( db: &PgPool, token: &str, @@ -424,9 +425,12 @@ pub async fn validate_token_with_dpop( http_method: &str, http_uri: &str, allow_deactivated: bool, + allow_takendown: bool, ) -> Result { if !is_dpop_token { - if allow_deactivated { + if allow_takendown { + return validate_bearer_token_allow_takendown(db, token).await; + } else if allow_deactivated { return validate_bearer_token_allow_deactivated(db, token).await; } else { return validate_bearer_token(db, token).await; @@ -464,7 +468,7 @@ pub async fn validate_token_with_dpop( if !allow_deactivated && status.is_deactivated() { return Err(TokenValidationError::AccountDeactivated); } - if status.is_takendown() { + if !allow_takendown && status.is_takendown() { return Err(TokenValidationError::AccountTakedown); } let key_bytes = if let (Some(kb), Some(ev)) = diff --git a/src/oauth/client.rs b/src/oauth/client.rs index 4666f4b..796a6f1 100644 --- a/src/oauth/client.rs +++ b/src/oauth/client.rs @@ -82,7 +82,9 @@ impl ClientMetadataCache { .connect_timeout(std::time::Duration::from_secs(10)) .pool_max_idle_per_host(10) .pool_idle_timeout(std::time::Duration::from_secs(90)) - .user_agent("Tranquil-PDS/1.0 (ATProto; +https://tangled.org/lewis.moe/bspds-sandbox)") + .user_agent( + "Tranquil-PDS/1.0 (ATProto; +https://tangled.org/lewis.moe/bspds-sandbox)", + ) .build() .unwrap_or_else(|_| Client::new()), cache_ttl_secs, diff --git a/src/oauth/endpoints/authorize.rs b/src/oauth/endpoints/authorize.rs index e478afc..495940c 100644 --- a/src/oauth/endpoints/authorize.rs +++ b/src/oauth/endpoints/authorize.rs @@ -56,9 +56,14 @@ fn json_error(status: StatusCode, error: &str, description: &str) -> Response { } fn is_granular_scope(s: &str) -> bool { - s.starts_with("repo:") || s.starts_with("repo?") || s == "repo" - || s.starts_with("blob:") || s.starts_with("blob?") || s == "blob" - || s.starts_with("rpc:") || s.starts_with("rpc?") + s.starts_with("repo:") + || s.starts_with("repo?") + || s == "repo" + || s.starts_with("blob:") + || s.starts_with("blob?") + || s == "blob" + || s.starts_with("rpc:") + || s.starts_with("rpc?") || s.starts_with("account:") || s.starts_with("identity:") } diff --git a/src/oauth/scopes/permission_set.rs b/src/oauth/scopes/permission_set.rs index aaaedbd..b120de9 100644 --- a/src/oauth/scopes/permission_set.rs +++ b/src/oauth/scopes/permission_set.rs @@ -57,11 +57,11 @@ pub async fn expand_include_scopes(scope_string: &str) -> String { async fn expand_permission_set(nsid: &str) -> Result { { let cache = LEXICON_CACHE.read().await; - if let Some(cached) = cache.get(nsid) { - if cached.cached_at.elapsed().as_secs() < CACHE_TTL_SECS { - debug!(nsid, "Using cached permission set expansion"); - return Ok(cached.expanded_scope.clone()); - } + if let Some(cached) = cache.get(nsid) + && cached.cached_at.elapsed().as_secs() < CACHE_TTL_SECS + { + debug!(nsid, "Using cached permission set expansion"); + return Ok(cached.expanded_scope.clone()); } } @@ -156,8 +156,6 @@ async fn expand_permission_set(nsid: &str) -> Result { #[cfg(test)] mod tests { - use super::*; - #[test] fn test_nsid_to_url() { let nsid = "io.atcr.authFullApp"; diff --git a/src/sync/deprecated.rs b/src/sync/deprecated.rs index 81b9a22..0e51c88 100644 --- a/src/sync/deprecated.rs +++ b/src/sync/deprecated.rs @@ -1,5 +1,4 @@ use crate::api::error::ApiError; -use crate::auth::{extract_bearer_token_from_header, validate_bearer_token_allow_takendown}; use crate::state::AppState; use crate::sync::car::encode_car_header; use crate::sync::util::assert_repo_availability; @@ -19,13 +18,26 @@ use std::str::FromStr; const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; async fn check_admin_or_self(state: &AppState, headers: &HeaderMap, did: &str) -> bool { - let token = match extract_bearer_token_from_header( + let extracted = match crate::auth::extract_auth_token_from_header( headers.get("Authorization").and_then(|h| h.to_str().ok()), ) { Some(t) => t, None => return false, }; - match validate_bearer_token_allow_takendown(&state.db, &token).await { + let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); + let http_uri = "/"; + match crate::auth::validate_token_with_dpop( + &state.db, + &extracted.token, + extracted.is_dpop, + dpop_proof, + "GET", + http_uri, + false, + true, + ) + .await + { Ok(auth_user) => auth_user.is_admin || auth_user.did == did, Err(_) => false, } diff --git a/tests/dpop_unit.rs b/tests/dpop_unit.rs index 506ea5f..4c76240 100644 --- a/tests/dpop_unit.rs +++ b/tests/dpop_unit.rs @@ -191,18 +191,18 @@ fn test_dpop_iat_clock_skew_beyond_bounds() { let verifier = DPoPVerifier::new(b"test-secret-32-bytes-long!!!!!!!"); let url = "https://pds.example/xrpc/foo"; - let (proof_301s_future, _) = create_dpop_proof("GET", url, 301, "ES256", None, None); + let (proof_301s_future, _) = create_dpop_proof("GET", url, 310, "ES256", None, None); let result = verifier.verify_proof(&proof_301s_future, "GET", url, None); assert!( result.is_err(), - "301s in future should exceed clock skew tolerance" + "310s in future should exceed clock skew tolerance" ); - let (proof_301s_past, _) = create_dpop_proof("GET", url, -301, "ES256", None, None); + let (proof_301s_past, _) = create_dpop_proof("GET", url, -310, "ES256", None, None); let result = verifier.verify_proof(&proof_301s_past, "GET", url, None); assert!( result.is_err(), - "301s in past should exceed clock skew tolerance" + "310s in past should exceed clock skew tolerance" ); }