diff --git a/src/api/error.rs b/src/api/error.rs index 5e41fda..7c70f2f 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -192,6 +192,7 @@ impl From for ApiError { crate::auth::TokenValidationError::AccountTakedown => Self::AccountTakedown, crate::auth::TokenValidationError::KeyDecryptionFailed => Self::InternalError, crate::auth::TokenValidationError::AuthenticationFailed => Self::AuthenticationFailed, + crate::auth::TokenValidationError::TokenExpired => Self::ExpiredToken, } } } diff --git a/src/auth/extractor.rs b/src/auth/extractor.rs index 43bdaff..f794f7f 100644 --- a/src/auth/extractor.rs +++ b/src/auth/extractor.rs @@ -19,6 +19,7 @@ pub enum AuthError { MissingToken, InvalidFormat, AuthenticationFailed, + TokenExpired, AccountDeactivated, AccountTakedown, AdminRequired, @@ -39,8 +40,13 @@ impl IntoResponse for AuthError { ), AuthError::AuthenticationFailed => ( StatusCode::UNAUTHORIZED, - "AuthenticationFailed", - "Invalid or expired token", + "InvalidToken", + "Token could not be verified", + ), + AuthError::TokenExpired => ( + StatusCode::UNAUTHORIZED, + "ExpiredToken", + "Token has expired", ), AuthError::AccountDeactivated => ( StatusCode::UNAUTHORIZED, @@ -174,6 +180,7 @@ impl FromRequestParts for BearerAuth { Ok(user) => Ok(BearerAuth(user)), Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), + Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), Err(_) => Err(AuthError::AuthenticationFailed), } } else { @@ -181,6 +188,7 @@ impl FromRequestParts for BearerAuth { Ok(user) => Ok(BearerAuth(user)), Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), + Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), Err(_) => Err(AuthError::AuthenticationFailed), } } @@ -224,6 +232,7 @@ impl FromRequestParts for BearerAuthAllowDeactivated { { Ok(user) => Ok(BearerAuthAllowDeactivated(user)), Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), + Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), Err(_) => Err(AuthError::AuthenticationFailed), } } else { @@ -236,6 +245,7 @@ impl FromRequestParts for BearerAuthAllowDeactivated { { Ok(user) => Ok(BearerAuthAllowDeactivated(user)), Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), + Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), Err(_) => Err(AuthError::AuthenticationFailed), } } @@ -284,6 +294,9 @@ impl FromRequestParts for BearerAuthAdmin { Err(TokenValidationError::AccountTakedown) => { return Err(AuthError::AccountTakedown); } + Err(TokenValidationError::TokenExpired) => { + return Err(AuthError::TokenExpired); + } Err(_) => return Err(AuthError::AuthenticationFailed), } } else { @@ -295,6 +308,9 @@ impl FromRequestParts for BearerAuthAdmin { Err(TokenValidationError::AccountTakedown) => { return Err(AuthError::AccountTakedown); } + Err(TokenValidationError::TokenExpired) => { + return Err(AuthError::TokenExpired); + } Err(_) => return Err(AuthError::AuthenticationFailed), } }; diff --git a/src/auth/mod.rs b/src/auth/mod.rs index cff08dc..8f195aa 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -28,7 +28,8 @@ pub use token::{ create_service_token, }; pub use verify::{ - get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, + TokenVerifyError, get_did_from_token, get_jti_from_token, verify_access_token, + verify_access_token_typed, verify_refresh_token, verify_token, }; const KEY_CACHE_TTL_SECS: u64 = 300; @@ -40,6 +41,7 @@ pub enum TokenValidationError { AccountTakedown, KeyDecryptionFailed, AuthenticationFailed, + TokenExpired, } impl fmt::Display for TokenValidationError { @@ -49,6 +51,7 @@ impl fmt::Display for TokenValidationError { Self::AccountTakedown => write!(f, "AccountTakedown"), Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), + Self::TokenExpired => write!(f, "ExpiredToken"), } } } @@ -193,53 +196,59 @@ async fn validate_bearer_token_with_options_internal( return Err(TokenValidationError::AccountTakedown); } - if let Ok(token_data) = verify_access_token(token, &decrypted_key) { - let jti = &token_data.claims.jti; - let session_cache_key = format!("auth:session:{}:{}", did, jti); - let mut session_valid = false; + match verify_access_token_typed(token, &decrypted_key) { + Ok(token_data) => { + let jti = &token_data.claims.jti; + let session_cache_key = format!("auth:session:{}:{}", did, jti); + let mut session_valid = false; - if let Some(c) = cache { - if let Some(cached_value) = c.get(&session_cache_key).await { - session_valid = cached_value == "1"; - crate::metrics::record_auth_cache_hit("session"); - } else { - crate::metrics::record_auth_cache_miss("session"); + if let Some(c) = cache { + if let Some(cached_value) = c.get(&session_cache_key).await { + session_valid = cached_value == "1"; + crate::metrics::record_auth_cache_hit("session"); + } else { + crate::metrics::record_auth_cache_miss("session"); + } + } + + if !session_valid { + let session_exists = sqlx::query_scalar!( + "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", + did, + jti + ) + .fetch_optional(db) + .await + .ok() + .flatten(); + + session_valid = session_exists.is_some(); + + if session_valid && let Some(c) = cache { + let _ = c + .set( + &session_cache_key, + "1", + Duration::from_secs(SESSION_CACHE_TTL_SECS), + ) + .await; + } + } + + if session_valid { + return Ok(AuthenticatedUser { + did: did.clone(), + key_bytes: Some(decrypted_key), + is_oauth: false, + is_admin, + scope: None, + }); } } - - if !session_valid { - let session_exists = sqlx::query_scalar!( - "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", - did, - jti - ) - .fetch_optional(db) - .await - .ok() - .flatten(); - - session_valid = session_exists.is_some(); - - if session_valid && let Some(c) = cache { - let _ = c - .set( - &session_cache_key, - "1", - Duration::from_secs(SESSION_CACHE_TTL_SECS), - ) - .await; - } - } - - if session_valid { - return Ok(AuthenticatedUser { - did: did.clone(), - key_bytes: Some(decrypted_key), - is_oauth: false, - is_admin, - scope: None, - }); + Err(verify::TokenVerifyError::Expired) => { + return Err(TokenValidationError::TokenExpired); } + Err(verify::TokenVerifyError::Invalid) => {} } } } @@ -283,6 +292,8 @@ async fn validate_bearer_token_with_options_internal( is_admin: oauth_token.is_admin, scope: oauth_info.scope, }); + } else { + return Err(TokenValidationError::TokenExpired); } } diff --git a/src/auth/verify.rs b/src/auth/verify.rs index a38f810..7535561 100644 --- a/src/auth/verify.rs +++ b/src/auth/verify.rs @@ -10,10 +10,28 @@ use chrono::Utc; use hmac::{Hmac, Mac}; use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; use sha2::Sha256; +use std::fmt; use subtle::ConstantTimeEq; type HmacSha256 = Hmac; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TokenVerifyError { + Expired, + Invalid, +} + +impl fmt::Display for TokenVerifyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Expired => write!(f, "Token expired"), + Self::Invalid => write!(f, "Token invalid"), + } + } +} + +impl std::error::Error for TokenVerifyError {} + pub fn get_did_from_token(token: &str) -> Result { let parts: Vec<&str> = token.split('.').collect(); if parts.len() != 3 { @@ -234,6 +252,88 @@ fn verify_token_hs256_internal( Ok(TokenData { claims }) } +pub fn verify_access_token_typed( + token: &str, + key_bytes: &[u8], +) -> Result, TokenVerifyError> { + verify_token_typed_internal( + token, + key_bytes, + Some(TOKEN_TYPE_ACCESS), + Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), + ) +} + +fn verify_token_typed_internal( + token: &str, + key_bytes: &[u8], + expected_typ: Option<&str>, + allowed_scopes: Option<&[&str]>, +) -> Result, TokenVerifyError> { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return Err(TokenVerifyError::Invalid); + } + + let header_b64 = parts[0]; + let claims_b64 = parts[1]; + let signature_b64 = parts[2]; + + let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else { + return Err(TokenVerifyError::Invalid); + }; + + let Ok(header) = serde_json::from_slice::
(&header_bytes) else { + return Err(TokenVerifyError::Invalid); + }; + + if let Some(expected) = expected_typ + && header.typ != expected + { + return Err(TokenVerifyError::Invalid); + } + + let Ok(signature_bytes) = URL_SAFE_NO_PAD.decode(signature_b64) else { + return Err(TokenVerifyError::Invalid); + }; + + let Ok(signature) = Signature::from_slice(&signature_bytes) else { + return Err(TokenVerifyError::Invalid); + }; + + let Ok(signing_key) = SigningKey::from_slice(key_bytes) else { + return Err(TokenVerifyError::Invalid); + }; + let verifying_key = VerifyingKey::from(&signing_key); + + let message = format!("{}.{}", header_b64, claims_b64); + if verifying_key.verify(message.as_bytes(), &signature).is_err() { + return Err(TokenVerifyError::Invalid); + } + + let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(claims_b64) else { + return Err(TokenVerifyError::Invalid); + }; + + let Ok(claims) = serde_json::from_slice::(&claims_bytes) else { + return Err(TokenVerifyError::Invalid); + }; + + let now = Utc::now().timestamp() as usize; + if claims.exp < now { + return Err(TokenVerifyError::Expired); + } + + if let Some(scopes) = allowed_scopes { + let token_scope = claims.scope.as_deref().unwrap_or(""); + if !scopes.contains(&token_scope) { + return Err(TokenVerifyError::Invalid); + } + } + + Ok(TokenData { claims }) +} + pub fn get_algorithm_from_token(token: &str) -> Result { let parts: Vec<&str> = token.split('.').collect(); if parts.len() != 3 {