Files
tranquil-pds/src/auth/extractor.rs
2025-12-25 18:23:19 +02:00

344 lines
11 KiB
Rust

use axum::{
Json,
extract::FromRequestParts,
http::{StatusCode, header::AUTHORIZATION, request::Parts},
response::{IntoResponse, Response},
};
use serde_json::json;
use super::{
AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop,
};
use crate::state::AppState;
pub struct BearerAuth(pub AuthenticatedUser);
#[derive(Debug)]
pub enum AuthError {
MissingToken,
InvalidFormat,
AuthenticationFailed,
TokenExpired,
AccountDeactivated,
AccountTakedown,
AdminRequired,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error, message) = match self {
AuthError::MissingToken => (
StatusCode::UNAUTHORIZED,
"AuthenticationRequired",
"Authorization header is required",
),
AuthError::InvalidFormat => (
StatusCode::UNAUTHORIZED,
"InvalidToken",
"Invalid authorization header format",
),
AuthError::AuthenticationFailed => (
StatusCode::UNAUTHORIZED,
"InvalidToken",
"Token could not be verified",
),
AuthError::TokenExpired => (
StatusCode::UNAUTHORIZED,
"ExpiredToken",
"Token has expired",
),
AuthError::AccountDeactivated => (
StatusCode::UNAUTHORIZED,
"AccountDeactivated",
"Account is deactivated",
),
AuthError::AccountTakedown => (
StatusCode::UNAUTHORIZED,
"AccountTakedown",
"Account has been taken down",
),
AuthError::AdminRequired => (
StatusCode::FORBIDDEN,
"AdminRequired",
"This action requires admin privileges",
),
};
(status, Json(json!({ "error": error, "message": message }))).into_response()
}
}
#[cfg(test)]
fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
let auth_header = auth_header.trim();
if auth_header.len() < 8 {
return Err(AuthError::InvalidFormat);
}
let prefix = &auth_header[..7];
if !prefix.eq_ignore_ascii_case("bearer ") {
return Err(AuthError::InvalidFormat);
}
let token = auth_header[7..].trim();
if token.is_empty() {
return Err(AuthError::InvalidFormat);
}
Ok(token)
}
pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
let header = auth_header?;
let header = header.trim();
if header.len() < 7 {
return None;
}
if !header[..7].eq_ignore_ascii_case("bearer ") {
return None;
}
let token = header[7..].trim();
if token.is_empty() {
return None;
}
Some(token.to_string())
}
pub struct ExtractedToken {
pub token: String,
pub is_dpop: bool,
}
pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
let header = auth_header?;
let header = header.trim();
if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
let token = header[7..].trim();
if token.is_empty() {
return None;
}
return Some(ExtractedToken {
token: token.to_string(),
is_dpop: false,
});
}
if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
let token = header[5..].trim();
if token.is_empty() {
return None;
}
return Some(ExtractedToken {
token: token.to_string(),
is_dpop: true,
});
}
None
}
impl FromRequestParts<AppState> for BearerAuth {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(AUTHORIZATION)
.ok_or(AuthError::MissingToken)?
.to_str()
.map_err(|_| AuthError::InvalidFormat)?;
let extracted =
extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
if extracted.is_dpop {
let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
let method = parts.method.as_str();
let uri = parts.uri.to_string();
match validate_token_with_dpop(
&state.db,
&extracted.token,
true,
dpop_proof,
method,
&uri,
false,
)
.await
{
Ok(user) => Ok(BearerAuth(user)),
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
} else {
match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
Ok(user) => Ok(BearerAuth(user)),
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
}
}
}
pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(AUTHORIZATION)
.ok_or(AuthError::MissingToken)?
.to_str()
.map_err(|_| AuthError::InvalidFormat)?;
let extracted =
extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
if extracted.is_dpop {
let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
let method = parts.method.as_str();
let uri = parts.uri.to_string();
match validate_token_with_dpop(
&state.db,
&extracted.token,
true,
dpop_proof,
method,
&uri,
true,
)
.await
{
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
} else {
match validate_bearer_token_cached_allow_deactivated(
&state.db,
&state.cache,
&extracted.token,
)
.await
{
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
}
}
}
pub struct BearerAuthAdmin(pub AuthenticatedUser);
impl FromRequestParts<AppState> for BearerAuthAdmin {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(AUTHORIZATION)
.ok_or(AuthError::MissingToken)?
.to_str()
.map_err(|_| AuthError::InvalidFormat)?;
let extracted =
extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
let user = if extracted.is_dpop {
let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
let method = parts.method.as_str();
let uri = parts.uri.to_string();
match validate_token_with_dpop(
&state.db,
&extracted.token,
true,
dpop_proof,
method,
&uri,
false,
)
.await
{
Ok(user) => user,
Err(TokenValidationError::AccountDeactivated) => {
return Err(AuthError::AccountDeactivated);
}
Err(TokenValidationError::AccountTakedown) => {
return Err(AuthError::AccountTakedown);
}
Err(TokenValidationError::TokenExpired) => {
return Err(AuthError::TokenExpired);
}
Err(_) => return Err(AuthError::AuthenticationFailed),
}
} else {
match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
Ok(user) => user,
Err(TokenValidationError::AccountDeactivated) => {
return Err(AuthError::AccountDeactivated);
}
Err(TokenValidationError::AccountTakedown) => {
return Err(AuthError::AccountTakedown);
}
Err(TokenValidationError::TokenExpired) => {
return Err(AuthError::TokenExpired);
}
Err(_) => return Err(AuthError::AuthenticationFailed),
}
};
if !user.is_admin {
return Err(AuthError::AdminRequired);
}
Ok(BearerAuthAdmin(user))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_bearer_token() {
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
assert!(extract_bearer_token("Basic abc123").is_err());
assert!(extract_bearer_token("Bearer").is_err());
assert!(extract_bearer_token("Bearer ").is_err());
assert!(extract_bearer_token("abc123").is_err());
assert!(extract_bearer_token("").is_err());
}
}