JWT token refresh good error

This commit is contained in:
lewis
2025-12-25 18:23:19 +02:00
parent bbc9d14216
commit d1dcf02c00
4 changed files with 174 additions and 46 deletions

View File

@@ -192,6 +192,7 @@ impl From<crate::auth::TokenValidationError> for ApiError {
crate::auth::TokenValidationError::AccountTakedown => Self::AccountTakedown,
crate::auth::TokenValidationError::KeyDecryptionFailed => Self::InternalError,
crate::auth::TokenValidationError::AuthenticationFailed => Self::AuthenticationFailed,
crate::auth::TokenValidationError::TokenExpired => Self::ExpiredToken,
}
}
}

View File

@@ -19,6 +19,7 @@ pub enum AuthError {
MissingToken,
InvalidFormat,
AuthenticationFailed,
TokenExpired,
AccountDeactivated,
AccountTakedown,
AdminRequired,
@@ -39,8 +40,13 @@ impl IntoResponse for AuthError {
),
AuthError::AuthenticationFailed => (
StatusCode::UNAUTHORIZED,
"AuthenticationFailed",
"Invalid or expired token",
"InvalidToken",
"Token could not be verified",
),
AuthError::TokenExpired => (
StatusCode::UNAUTHORIZED,
"ExpiredToken",
"Token has expired",
),
AuthError::AccountDeactivated => (
StatusCode::UNAUTHORIZED,
@@ -174,6 +180,7 @@ impl FromRequestParts<AppState> for BearerAuth {
Ok(user) => Ok(BearerAuth(user)),
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
} else {
@@ -181,6 +188,7 @@ impl FromRequestParts<AppState> for BearerAuth {
Ok(user) => Ok(BearerAuth(user)),
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
}
@@ -224,6 +232,7 @@ impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
{
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
} else {
@@ -236,6 +245,7 @@ impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
{
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
Err(_) => Err(AuthError::AuthenticationFailed),
}
}
@@ -284,6 +294,9 @@ impl FromRequestParts<AppState> for BearerAuthAdmin {
Err(TokenValidationError::AccountTakedown) => {
return Err(AuthError::AccountTakedown);
}
Err(TokenValidationError::TokenExpired) => {
return Err(AuthError::TokenExpired);
}
Err(_) => return Err(AuthError::AuthenticationFailed),
}
} else {
@@ -295,6 +308,9 @@ impl FromRequestParts<AppState> for BearerAuthAdmin {
Err(TokenValidationError::AccountTakedown) => {
return Err(AuthError::AccountTakedown);
}
Err(TokenValidationError::TokenExpired) => {
return Err(AuthError::TokenExpired);
}
Err(_) => return Err(AuthError::AuthenticationFailed),
}
};

View File

@@ -28,7 +28,8 @@ pub use token::{
create_service_token,
};
pub use verify::{
get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
TokenVerifyError, get_did_from_token, get_jti_from_token, verify_access_token,
verify_access_token_typed, verify_refresh_token, verify_token,
};
const KEY_CACHE_TTL_SECS: u64 = 300;
@@ -40,6 +41,7 @@ pub enum TokenValidationError {
AccountTakedown,
KeyDecryptionFailed,
AuthenticationFailed,
TokenExpired,
}
impl fmt::Display for TokenValidationError {
@@ -49,6 +51,7 @@ impl fmt::Display for TokenValidationError {
Self::AccountTakedown => write!(f, "AccountTakedown"),
Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
Self::TokenExpired => write!(f, "ExpiredToken"),
}
}
}
@@ -193,53 +196,59 @@ async fn validate_bearer_token_with_options_internal(
return Err(TokenValidationError::AccountTakedown);
}
if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
let jti = &token_data.claims.jti;
let session_cache_key = format!("auth:session:{}:{}", did, jti);
let mut session_valid = false;
match verify_access_token_typed(token, &decrypted_key) {
Ok(token_data) => {
let jti = &token_data.claims.jti;
let session_cache_key = format!("auth:session:{}:{}", did, jti);
let mut session_valid = false;
if let Some(c) = cache {
if let Some(cached_value) = c.get(&session_cache_key).await {
session_valid = cached_value == "1";
crate::metrics::record_auth_cache_hit("session");
} else {
crate::metrics::record_auth_cache_miss("session");
if let Some(c) = cache {
if let Some(cached_value) = c.get(&session_cache_key).await {
session_valid = cached_value == "1";
crate::metrics::record_auth_cache_hit("session");
} else {
crate::metrics::record_auth_cache_miss("session");
}
}
if !session_valid {
let session_exists = sqlx::query_scalar!(
"SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
did,
jti
)
.fetch_optional(db)
.await
.ok()
.flatten();
session_valid = session_exists.is_some();
if session_valid && let Some(c) = cache {
let _ = c
.set(
&session_cache_key,
"1",
Duration::from_secs(SESSION_CACHE_TTL_SECS),
)
.await;
}
}
if session_valid {
return Ok(AuthenticatedUser {
did: did.clone(),
key_bytes: Some(decrypted_key),
is_oauth: false,
is_admin,
scope: None,
});
}
}
if !session_valid {
let session_exists = sqlx::query_scalar!(
"SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
did,
jti
)
.fetch_optional(db)
.await
.ok()
.flatten();
session_valid = session_exists.is_some();
if session_valid && let Some(c) = cache {
let _ = c
.set(
&session_cache_key,
"1",
Duration::from_secs(SESSION_CACHE_TTL_SECS),
)
.await;
}
}
if session_valid {
return Ok(AuthenticatedUser {
did: did.clone(),
key_bytes: Some(decrypted_key),
is_oauth: false,
is_admin,
scope: None,
});
Err(verify::TokenVerifyError::Expired) => {
return Err(TokenValidationError::TokenExpired);
}
Err(verify::TokenVerifyError::Invalid) => {}
}
}
}
@@ -283,6 +292,8 @@ async fn validate_bearer_token_with_options_internal(
is_admin: oauth_token.is_admin,
scope: oauth_info.scope,
});
} else {
return Err(TokenValidationError::TokenExpired);
}
}

View File

@@ -10,10 +10,28 @@ use chrono::Utc;
use hmac::{Hmac, Mac};
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
use sha2::Sha256;
use std::fmt;
use subtle::ConstantTimeEq;
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenVerifyError {
Expired,
Invalid,
}
impl fmt::Display for TokenVerifyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Expired => write!(f, "Token expired"),
Self::Invalid => write!(f, "Token invalid"),
}
}
}
impl std::error::Error for TokenVerifyError {}
pub fn get_did_from_token(token: &str) -> Result<String, String> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
@@ -234,6 +252,88 @@ fn verify_token_hs256_internal(
Ok(TokenData { claims })
}
pub fn verify_access_token_typed(
token: &str,
key_bytes: &[u8],
) -> Result<TokenData<Claims>, TokenVerifyError> {
verify_token_typed_internal(
token,
key_bytes,
Some(TOKEN_TYPE_ACCESS),
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
)
}
fn verify_token_typed_internal(
token: &str,
key_bytes: &[u8],
expected_typ: Option<&str>,
allowed_scopes: Option<&[&str]>,
) -> Result<TokenData<Claims>, TokenVerifyError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(TokenVerifyError::Invalid);
}
let header_b64 = parts[0];
let claims_b64 = parts[1];
let signature_b64 = parts[2];
let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
return Err(TokenVerifyError::Invalid);
};
let Ok(header) = serde_json::from_slice::<Header>(&header_bytes) else {
return Err(TokenVerifyError::Invalid);
};
if let Some(expected) = expected_typ
&& header.typ != expected
{
return Err(TokenVerifyError::Invalid);
}
let Ok(signature_bytes) = URL_SAFE_NO_PAD.decode(signature_b64) else {
return Err(TokenVerifyError::Invalid);
};
let Ok(signature) = Signature::from_slice(&signature_bytes) else {
return Err(TokenVerifyError::Invalid);
};
let Ok(signing_key) = SigningKey::from_slice(key_bytes) else {
return Err(TokenVerifyError::Invalid);
};
let verifying_key = VerifyingKey::from(&signing_key);
let message = format!("{}.{}", header_b64, claims_b64);
if verifying_key.verify(message.as_bytes(), &signature).is_err() {
return Err(TokenVerifyError::Invalid);
}
let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(claims_b64) else {
return Err(TokenVerifyError::Invalid);
};
let Ok(claims) = serde_json::from_slice::<Claims>(&claims_bytes) else {
return Err(TokenVerifyError::Invalid);
};
let now = Utc::now().timestamp() as usize;
if claims.exp < now {
return Err(TokenVerifyError::Expired);
}
if let Some(scopes) = allowed_scopes {
let token_scope = claims.scope.as_deref().unwrap_or("");
if !scopes.contains(&token_scope) {
return Err(TokenVerifyError::Invalid);
}
}
Ok(TokenData { claims })
}
pub fn get_algorithm_from_token(token: &str) -> Result<String, String> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {