oauth error msg improvement, general code quality

This commit is contained in:
lewis
2026-01-11 22:33:41 +02:00
parent d1902506a5
commit 16fb4dbd03
27 changed files with 676 additions and 469 deletions

View File

@@ -0,0 +1,52 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by\n FROM invite_codes ic\n JOIN users u ON ic.created_by_user = u.id\n WHERE ic.created_by_user = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "available_uses",
"type_info": "Int4"
},
{
"ordinal": 2,
"name": "disabled",
"type_info": "Bool"
},
{
"ordinal": 3,
"name": "for_account",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "created_by",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false,
false,
true,
false,
false,
false
]
},
"hash": "2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a"
}

View File

@@ -0,0 +1,28 @@
{
"db_name": "PostgreSQL",
"query": "SELECT id, did FROM users WHERE id = ANY($1)",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "did",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"UuidArray"
]
},
"nullable": [
false,
false
]
},
"hash": "46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e"
}

View File

@@ -1,22 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT code FROM invite_codes WHERE created_by_user = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "code",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false
]
},
"hash": "5a98e015997942835800fcd326e69b4f54b9830d0490c4f8841f8435478c57d3"
}

View File

@@ -1,28 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT u.did, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = $1\n ORDER BY icu.used_at DESC\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "used_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false
]
},
"hash": "5d5442136932d4088873a935c41cb3a683c4771e4fb8c151b3fd5119fb6c1068"
}

View File

@@ -1,14 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "UPDATE invite_codes SET disabled = TRUE WHERE code = $1",
"query": "UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text"
"TextArray"
]
},
"nullable": []
},
"hash": "7b2d1d4ac06063e07a7c7a7d0fb434db08ce312eb2864405d7f96f4e985ed036"
"hash": "888f8724cfc2ed27391b661a5cfe423d28c77e1a368df7edc81708eb3038f600"
}

View File

@@ -0,0 +1,34 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT icu.code, u.did as used_by, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = ANY($1)\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "used_by",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "used_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"TextArray"
]
},
"nullable": [
false,
false,
false
]
},
"hash": "ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076"
}

View File

@@ -0,0 +1,34 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT icu.code, u.did, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = ANY($1)\n ORDER BY icu.used_at DESC\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "used_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"TextArray"
]
},
"nullable": [
false,
false,
false
]
},
"hash": "ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920"
}

View File

@@ -1,14 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user = $1",
"query": "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
"TextArray"
]
},
"nullable": []
},
"hash": "413c5b03501a399dca13f345fcae05770517091d73db93966853e944c68ee237"
"hash": "eec42a3a4b1440aa8dd580b5d0bbd1184b909f781d131aa2c69368ed021e87e4"
}

View File

@@ -92,7 +92,7 @@ tower-layer = "0.3"
tracing = "0.1"
tracing-subscriber = "0.3"
urlencoding = "2.1"
uuid = { version = "1.19", features = ["v4", "v5", "fast-rng"] }
uuid = { version = "1.19", features = ["v4", "v5", "v7", "fast-rng"] }
webauthn-rs = { version = "0.5", features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] }
webauthn-rs-proto = "0.5"
zip = { version = "7.0", default-features = false, features = ["deflate"] }

View File

@@ -182,11 +182,10 @@ static STRINGS_FI: NotificationStrings = NotificationStrings {
};
pub fn format_message(template: &str, vars: &[(&str, &str)]) -> String {
let mut result = template.to_string();
for (key, value) in vars {
result = result.replace(&format!("{{{}}}", key), value);
}
result
vars.iter()
.fold(template.to_string(), |result, (key, value)| {
result.replace(&format!("{{{}}}", key), value)
})
}
#[cfg(test)]

View File

@@ -568,12 +568,21 @@ fn verify_es256(
.get("y")
.and_then(|v| v.as_str())
.ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?;
let x_bytes = URL_SAFE_NO_PAD
let x_decoded = URL_SAFE_NO_PAD
.decode(x)
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
let y_bytes = URL_SAFE_NO_PAD
let y_decoded = URL_SAFE_NO_PAD
.decode(y)
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
if x_decoded.len() > 32 || y_decoded.len() > 32 {
return Err(OAuthError::InvalidClient(
"EC coordinate too long".to_string(),
));
}
let mut x_bytes = [0u8; 32];
let mut y_bytes = [0u8; 32];
x_bytes[32 - x_decoded.len()..].copy_from_slice(&x_decoded);
y_bytes[32 - y_decoded.len()..].copy_from_slice(&y_decoded);
let mut point_bytes = vec![0x04];
point_bytes.extend_from_slice(&x_bytes);
point_bytes.extend_from_slice(&y_bytes);
@@ -604,12 +613,21 @@ fn verify_es384(
.get("y")
.and_then(|v| v.as_str())
.ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?;
let x_bytes = URL_SAFE_NO_PAD
let x_decoded = URL_SAFE_NO_PAD
.decode(x)
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
let y_bytes = URL_SAFE_NO_PAD
let y_decoded = URL_SAFE_NO_PAD
.decode(y)
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
if x_decoded.len() > 48 || y_decoded.len() > 48 {
return Err(OAuthError::InvalidClient(
"EC coordinate too long".to_string(),
));
}
let mut x_bytes = [0u8; 48];
let mut y_bytes = [0u8; 48];
x_bytes[48 - x_decoded.len()..].copy_from_slice(&x_decoded);
y_bytes[48 - y_decoded.len()..].copy_from_slice(&y_decoded);
let mut point_bytes = vec![0x04];
point_bytes.extend_from_slice(&x_bytes);
point_bytes.extend_from_slice(&y_bytes);

View File

@@ -218,25 +218,30 @@ fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), O
crv
)));
}
let x_bytes = URL_SAFE_NO_PAD
let x_decoded = URL_SAFE_NO_PAD
.decode(
jwk.x
.as_ref()
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?,
)
.map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?;
let y_bytes = URL_SAFE_NO_PAD
let y_decoded = URL_SAFE_NO_PAD
.decode(
jwk.y
.as_ref()
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?,
)
.map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?;
let point = EncodedPoint::from_affine_coordinates(
x_bytes.as_slice().into(),
y_bytes.as_slice().into(),
false,
);
let mut x_bytes = [0u8; 32];
let mut y_bytes = [0u8; 32];
if x_decoded.len() > 32 || y_decoded.len() > 32 {
return Err(OAuthError::InvalidDpopProof(
"EC coordinate too long".to_string(),
));
}
x_bytes[32 - x_decoded.len()..].copy_from_slice(&x_decoded);
y_bytes[32 - y_decoded.len()..].copy_from_slice(&y_decoded);
let point = EncodedPoint::from_affine_coordinates((&x_bytes).into(), (&y_bytes).into(), false);
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
let affine =
affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
@@ -264,25 +269,30 @@ fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), O
crv
)));
}
let x_bytes = URL_SAFE_NO_PAD
let x_decoded = URL_SAFE_NO_PAD
.decode(
jwk.x
.as_ref()
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?,
)
.map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?;
let y_bytes = URL_SAFE_NO_PAD
let y_decoded = URL_SAFE_NO_PAD
.decode(
jwk.y
.as_ref()
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?,
)
.map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?;
let point = EncodedPoint::from_affine_coordinates(
x_bytes.as_slice().into(),
y_bytes.as_slice().into(),
false,
);
let mut x_bytes = [0u8; 48];
let mut y_bytes = [0u8; 48];
if x_decoded.len() > 48 || y_decoded.len() > 48 {
return Err(OAuthError::InvalidDpopProof(
"EC coordinate too long".to_string(),
));
}
x_bytes[48 - x_decoded.len()..].copy_from_slice(&x_decoded);
y_bytes[48 - y_decoded.len()..].copy_from_slice(&y_decoded);
let point = EncodedPoint::from_affine_coordinates((&x_bytes).into(), (&y_bytes).into(), false);
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
let affine =
affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
@@ -398,4 +408,52 @@ mod tests {
let thumbprint = compute_jwk_thumbprint(&jwk).unwrap();
assert!(!thumbprint.is_empty());
}
#[test]
fn test_es256_short_coordinate_no_panic() {
let short_31_bytes = vec![0x42u8; 31];
let short_30_bytes = vec![0x42u8; 30];
let x_b64 = URL_SAFE_NO_PAD.encode(&short_31_bytes);
let y_b64 = URL_SAFE_NO_PAD.encode(&short_30_bytes);
let jwk = DPoPJwk {
kty: "EC".to_string(),
crv: Some("P-256".to_string()),
x: Some(x_b64),
y: Some(y_b64),
};
let result = verify_es256(&jwk, b"test", &[0u8; 64]);
assert!(result.is_err(), "Invalid coordinates should return error, not panic");
}
#[test]
fn test_es256_valid_key_with_trimmed_coordinates() {
use p256::ecdsa::{SigningKey, signature::Signer};
use p256::elliptic_curve::rand_core::OsRng;
let signing_key = SigningKey::random(&mut OsRng);
let verifying_key = signing_key.verifying_key();
let point = verifying_key.to_encoded_point(false);
let x_bytes = point.x().unwrap();
let y_bytes = point.y().unwrap();
let x_trimmed: Vec<u8> = x_bytes.iter().copied().skip_while(|&b| b == 0).collect();
let y_trimmed: Vec<u8> = y_bytes.iter().copied().skip_while(|&b| b == 0).collect();
let x_b64 = URL_SAFE_NO_PAD.encode(&x_trimmed);
let y_b64 = URL_SAFE_NO_PAD.encode(&y_trimmed);
let jwk = DPoPJwk {
kty: "EC".to_string(),
crv: Some("P-256".to_string()),
x: Some(x_b64),
y: Some(y_b64),
};
let message = b"test message for signature verification";
let signature: p256::ecdsa::Signature = signing_key.sign(message);
let result = verify_es256(&jwk, message, signature.to_bytes().as_slice());
assert!(
result.is_ok(),
"Should verify signature with trimmed coordinates (x={}, y={}): {:?}",
x_trimmed.len(),
y_trimmed.len(),
result
);
}
}

View File

@@ -130,24 +130,62 @@ async fn get_invites_for_user(
db: &sqlx::PgPool,
user_id: uuid::Uuid,
) -> Option<Vec<InviteCodeInfo>> {
let codes = sqlx::query_scalar!(
let invite_codes = sqlx::query!(
r#"
SELECT code FROM invite_codes WHERE created_by_user = $1
SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by
FROM invite_codes ic
JOIN users u ON ic.created_by_user = u.id
WHERE ic.created_by_user = $1
"#,
user_id
)
.fetch_all(db)
.await
.ok()?;
if codes.is_empty() {
if invite_codes.is_empty() {
return None;
}
let mut invites = Vec::new();
for code in codes {
if let Some(info) = get_invite_code_info(db, &code).await {
invites.push(info);
}
}
let code_strings: Vec<String> = invite_codes.iter().map(|ic| ic.code.clone()).collect();
let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
std::collections::HashMap::new();
sqlx::query!(
r#"
SELECT icu.code, u.did as used_by, icu.used_at
FROM invite_code_uses icu
JOIN users u ON icu.used_by_user = u.id
WHERE icu.code = ANY($1)
"#,
&code_strings
)
.fetch_all(db)
.await
.ok()?
.into_iter()
.for_each(|r| {
uses_by_code
.entry(r.code)
.or_default()
.push(InviteCodeUseInfo {
used_by: r.used_by.into(),
used_at: r.used_at.to_rfc3339(),
});
});
let invites: Vec<InviteCodeInfo> = invite_codes
.into_iter()
.map(|ic| InviteCodeInfo {
code: ic.code.clone(),
available: ic.available_uses,
disabled: ic.disabled.unwrap_or(false),
for_account: ic.for_account.into(),
created_by: ic.created_by.into(),
created_at: ic.created_at.to_rfc3339(),
uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(),
})
.collect();
if invites.is_empty() {
None
} else {
@@ -276,61 +314,62 @@ pub async fn get_account_infos(
.map(|r| (r.used_by_user, r.code))
.collect();
let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
std::collections::HashMap::new();
for u in all_invite_uses {
uses_by_code
.entry(u.code.clone())
.or_default()
.push(InviteCodeUseInfo {
used_by: u.used_by.into(),
used_at: u.used_at.to_rfc3339(),
let uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
all_invite_uses
.into_iter()
.fold(std::collections::HashMap::new(), |mut acc, u| {
acc.entry(u.code.clone()).or_default().push(InviteCodeUseInfo {
used_by: u.used_by.into(),
used_at: u.used_at.to_rfc3339(),
});
acc
});
}
let mut codes_by_user: std::collections::HashMap<uuid::Uuid, Vec<InviteCodeInfo>> =
std::collections::HashMap::new();
let mut code_info_map: std::collections::HashMap<String, InviteCodeInfo> =
std::collections::HashMap::new();
for ic in all_invite_codes {
let info = InviteCodeInfo {
code: ic.code.clone(),
available: ic.available_uses,
disabled: ic.disabled.unwrap_or(false),
for_account: ic.for_account.into(),
created_by: ic.created_by.into(),
created_at: ic.created_at.to_rfc3339(),
uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(),
};
code_info_map.insert(ic.code.clone(), info.clone());
codes_by_user
.entry(ic.created_by_user)
.or_default()
.push(info);
}
let (codes_by_user, code_info_map): (
std::collections::HashMap<uuid::Uuid, Vec<InviteCodeInfo>>,
std::collections::HashMap<String, InviteCodeInfo>,
) = all_invite_codes.into_iter().fold(
(std::collections::HashMap::new(), std::collections::HashMap::new()),
|(mut by_user, mut by_code), ic| {
let info = InviteCodeInfo {
code: ic.code.clone(),
available: ic.available_uses,
disabled: ic.disabled.unwrap_or(false),
for_account: ic.for_account.into(),
created_by: ic.created_by.into(),
created_at: ic.created_at.to_rfc3339(),
uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(),
};
by_code.insert(ic.code.clone(), info.clone());
by_user.entry(ic.created_by_user).or_default().push(info);
(by_user, by_code)
},
);
let mut infos = Vec::with_capacity(users.len());
for row in users {
let invited_by = invited_by_map
.get(&row.id)
.and_then(|code| code_info_map.get(code).cloned());
let invites = codes_by_user.get(&row.id).cloned();
infos.push(AccountInfo {
did: row.did.into(),
handle: row.handle.into(),
email: row.email,
indexed_at: row.created_at.to_rfc3339(),
invite_note: None,
invites_disabled: row.invites_disabled.unwrap_or(false),
email_confirmed_at: if row.email_verified {
Some(row.created_at.to_rfc3339())
} else {
None
},
deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()),
invited_by,
invites,
});
}
let infos: Vec<AccountInfo> = users
.into_iter()
.map(|row| {
let invited_by = invited_by_map
.get(&row.id)
.and_then(|code| code_info_map.get(code).cloned());
let invites = codes_by_user.get(&row.id).cloned();
AccountInfo {
did: row.did.into(),
handle: row.handle.into(),
email: row.email,
indexed_at: row.created_at.to_rfc3339(),
invite_note: None,
invites_disabled: row.invites_disabled.unwrap_or(false),
email_confirmed_at: if row.email_verified {
Some(row.created_at.to_rfc3339())
} else {
None
},
deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()),
invited_by,
invites,
}
})
.collect();
(StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()
}

View File

@@ -48,32 +48,19 @@ pub async fn get_server_config(
.fetch_all(&state.db)
.await?;
let mut server_name = "Tranquil PDS".to_string();
let mut primary_color = None;
let mut primary_color_dark = None;
let mut secondary_color = None;
let mut secondary_color_dark = None;
let mut logo_cid = None;
for (key, value) in rows {
match key.as_str() {
"server_name" => server_name = value,
"primary_color" => primary_color = Some(value),
"primary_color_dark" => primary_color_dark = Some(value),
"secondary_color" => secondary_color = Some(value),
"secondary_color_dark" => secondary_color_dark = Some(value),
"logo_cid" => logo_cid = Some(value),
_ => {}
}
}
let config_map: std::collections::HashMap<String, String> =
rows.into_iter().collect();
Ok(Json(ServerConfigResponse {
server_name,
primary_color,
primary_color_dark,
secondary_color,
secondary_color_dark,
logo_cid,
server_name: config_map
.get("server_name")
.cloned()
.unwrap_or_else(|| "Tranquil PDS".to_string()),
primary_color: config_map.get("primary_color").cloned(),
primary_color_dark: config_map.get("primary_color_dark").cloned(),
secondary_color: config_map.get("secondary_color").cloned(),
secondary_color_dark: config_map.get("secondary_color_dark").cloned(),
logo_cid: config_map.get("logo_cid").cloned(),
}))
}

View File

@@ -24,29 +24,20 @@ pub async fn disable_invite_codes(
Json(input): Json<DisableInviteCodesInput>,
) -> Response {
if let Some(codes) = &input.codes {
for code in codes {
let _ = sqlx::query!(
"UPDATE invite_codes SET disabled = TRUE WHERE code = $1",
code
)
.execute(&state.db)
.await;
}
let _ = sqlx::query!(
"UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)",
codes as &[String]
)
.execute(&state.db)
.await;
}
if let Some(accounts) = &input.accounts {
for account in accounts {
let user = sqlx::query!("SELECT id FROM users WHERE did = $1", account)
.fetch_optional(&state.db)
.await;
if let Ok(Some(user_row)) = user {
let _ = sqlx::query!(
"UPDATE invite_codes SET disabled = TRUE WHERE created_by_user = $1",
user_row.id
)
.execute(&state.db)
.await;
}
}
let _ = sqlx::query!(
"UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))",
accounts as &[String]
)
.execute(&state.db)
.await;
}
EmptyResponse::ok().into_response()
}
@@ -70,7 +61,7 @@ pub struct InviteCodeInfo {
pub uses: Vec<InviteCodeUseInfo>,
}
#[derive(Serialize)]
#[derive(Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct InviteCodeUseInfo {
pub used_by: String,
@@ -149,47 +140,71 @@ pub async fn get_invite_codes(
return ApiError::InternalError(None).into_response();
}
};
let mut codes = Vec::new();
for (code, available_uses, disabled, created_by_user, created_at) in &codes_rows {
let creator_did =
sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user)
.fetch_optional(&state.db)
.await
.ok()
.flatten()
.unwrap_or_else(|| "unknown".to_string());
let uses_result = sqlx::query!(
let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|(_, _, _, uid, _)| *uid).collect();
let code_strings: Vec<String> = codes_rows.iter().map(|(c, _, _, _, _)| c.clone()).collect();
let mut creator_dids: std::collections::HashMap<uuid::Uuid, String> =
std::collections::HashMap::new();
sqlx::query!(
"SELECT id, did FROM users WHERE id = ANY($1)",
&user_ids
)
.fetch_all(&state.db)
.await
.unwrap_or_default()
.into_iter()
.for_each(|r| {
creator_dids.insert(r.id, r.did);
});
let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
std::collections::HashMap::new();
if !code_strings.is_empty() {
sqlx::query!(
r#"
SELECT u.did, icu.used_at
SELECT icu.code, u.did, icu.used_at
FROM invite_code_uses icu
JOIN users u ON icu.used_by_user = u.id
WHERE icu.code = $1
WHERE icu.code = ANY($1)
ORDER BY icu.used_at DESC
"#,
code
&code_strings
)
.fetch_all(&state.db)
.await;
let uses = match uses_result {
Ok(use_rows) => use_rows
.iter()
.map(|u| InviteCodeUseInfo {
used_by: u.did.clone(),
used_at: u.used_at.to_rfc3339(),
})
.collect(),
Err(_) => Vec::new(),
};
codes.push(InviteCodeInfo {
code: code.clone(),
available: *available_uses,
disabled: disabled.unwrap_or(false),
for_account: creator_did.clone(),
created_by: creator_did,
created_at: created_at.to_rfc3339(),
uses,
.await
.unwrap_or_default()
.into_iter()
.for_each(|r| {
uses_by_code
.entry(r.code)
.or_default()
.push(InviteCodeUseInfo {
used_by: r.did,
used_at: r.used_at.to_rfc3339(),
});
});
}
let codes: Vec<InviteCodeInfo> = codes_rows
.iter()
.map(|(code, available_uses, disabled, created_by_user, created_at)| {
let creator_did = creator_dids
.get(created_by_user)
.cloned()
.unwrap_or_else(|| "unknown".to_string());
InviteCodeInfo {
code: code.clone(),
available: *available_uses,
disabled: disabled.unwrap_or(false),
for_account: creator_did.clone(),
created_by: creator_did,
created_at: created_at.to_rfc3339(),
uses: uses_by_code.get(code).cloned().unwrap_or_default(),
}
})
.collect();
let next_cursor = if codes_rows.len() == limit as usize {
codes_rows.last().map(|(code, _, _, _, _)| code.clone())
} else {

View File

@@ -22,6 +22,7 @@ pub enum ApiError {
InvalidRequest(String),
InvalidToken(Option<String>),
ExpiredToken(Option<String>),
OAuthExpiredToken(Option<String>),
TokenRequired,
AccountDeactivated,
AccountTakedown,
@@ -127,7 +128,8 @@ impl ApiError {
| Self::InvalidCode(_)
| Self::InvalidPassword(_)
| Self::InvalidToken(_)
| Self::PasskeyCounterAnomaly => StatusCode::UNAUTHORIZED,
| Self::PasskeyCounterAnomaly
| Self::OAuthExpiredToken(_) => StatusCode::UNAUTHORIZED,
Self::ExpiredToken(_) => StatusCode::BAD_REQUEST,
Self::Forbidden
| Self::AdminRequired
@@ -216,7 +218,7 @@ impl ApiError {
Self::AuthenticationRequired => Cow::Borrowed("AuthenticationRequired"),
Self::AuthenticationFailed(_) => Cow::Borrowed("AuthenticationFailed"),
Self::InvalidToken(_) => Cow::Borrowed("InvalidToken"),
Self::ExpiredToken(_) => Cow::Borrowed("ExpiredToken"),
Self::ExpiredToken(_) | Self::OAuthExpiredToken(_) => Cow::Borrowed("ExpiredToken"),
Self::TokenRequired => Cow::Borrowed("TokenRequired"),
Self::AccountDeactivated => Cow::Borrowed("AccountDeactivated"),
Self::AccountTakedown => Cow::Borrowed("AccountTakedown"),
@@ -298,6 +300,7 @@ impl ApiError {
| Self::AuthenticationFailed(msg)
| Self::InvalidToken(msg)
| Self::ExpiredToken(msg)
| Self::OAuthExpiredToken(msg)
| Self::RepoNotFound(msg)
| Self::BlobNotFound(msg)
| Self::InvalidHandle(msg)
@@ -428,13 +431,24 @@ impl IntoResponse for ApiError {
message: self.message(),
};
let mut response = (self.status_code(), Json(body)).into_response();
if matches!(self, Self::ExpiredToken(_)) {
response.headers_mut().insert(
"WWW-Authenticate",
"Bearer error=\"invalid_token\", error_description=\"Token has expired\""
.parse()
.unwrap(),
);
match &self {
Self::ExpiredToken(_) => {
response.headers_mut().insert(
"WWW-Authenticate",
"Bearer error=\"invalid_token\", error_description=\"Token has expired\""
.parse()
.unwrap(),
);
}
Self::OAuthExpiredToken(_) => {
response.headers_mut().insert(
"WWW-Authenticate",
"DPoP error=\"invalid_token\", error_description=\"Token has expired\""
.parse()
.unwrap(),
);
}
_ => {}
}
response
}
@@ -457,6 +471,9 @@ impl From<crate::auth::TokenValidationError> for ApiError {
Self::AuthenticationFailed(None)
}
crate::auth::TokenValidationError::TokenExpired => Self::ExpiredToken(None),
crate::auth::TokenValidationError::OAuthTokenExpired => {
Self::OAuthExpiredToken(Some("Token has expired".to_string()))
}
}
}
}

View File

@@ -211,7 +211,7 @@ async fn create_report_locally(
}
let created_at = chrono::Utc::now();
let report_id = created_at.timestamp_millis();
let report_id = (uuid::Uuid::now_v7().as_u128() & 0x7FFF_FFFF_FFFF_FFFF) as i64;
let subject_json = json!(input.subject);
let insert = sqlx::query!(

View File

@@ -268,21 +268,8 @@ async fn proxy_handler(
}
Err(e) => {
warn!("Token validation failed: {:?}", e);
if matches!(e, crate::auth::TokenValidationError::TokenExpired) && extracted.is_dpop
{
let www_auth =
"DPoP error=\"invalid_token\", error_description=\"Token has expired\"";
let mut response =
ApiError::ExpiredToken(Some("Token has expired".into())).into_response();
*response.status_mut() = axum::http::StatusCode::UNAUTHORIZED;
response
.headers_mut()
.insert("WWW-Authenticate", www_auth.parse().unwrap());
let nonce = crate::oauth::verify::generate_dpop_nonce();
response
.headers_mut()
.insert("DPoP-Nonce", nonce.parse().unwrap());
return response;
if matches!(e, crate::auth::TokenValidationError::OAuthTokenExpired) {
return ApiError::from(e).into_response();
}
}
}
@@ -291,11 +278,12 @@ async fn proxy_handler(
if let Some(val) = auth_header_val {
request_builder = request_builder.header("Authorization", val);
}
for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD {
if let Some(val) = headers.get(*header_name) {
request_builder = request_builder.header(*header_name, val);
}
}
request_builder = crate::api::proxy_client::HEADERS_TO_FORWARD
.iter()
.filter_map(|name| headers.get(*name).map(|val| (*name, val)))
.fold(request_builder, |builder, (name, val)| {
builder.header(name, val)
});
if !body.is_empty() {
request_builder = request_builder.body(body);
}
@@ -313,11 +301,12 @@ async fn proxy_handler(
}
};
let mut response_builder = Response::builder().status(status);
for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD {
if let Some(val) = headers.get(*header_name) {
response_builder = response_builder.header(*header_name, val);
}
}
response_builder = crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD
.iter()
.filter_map(|name| headers.get(*name).map(|val| (*name, val)))
.fold(response_builder, |builder, (name, val)| {
builder.header(name, val)
});
match response_builder.body(axum::body::Body::from(body)) {
Ok(r) => r,
Err(e) => {

View File

@@ -88,15 +88,13 @@ pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
Ok(addrs) => addrs.collect(),
Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
};
for addr in &socket_addrs {
if !is_unicast_ip(&addr.ip()) {
warn!(
"DNS resolution for {} returned non-unicast IP: {}",
host,
addr.ip()
);
return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
}
if let Some(addr) = socket_addrs.iter().find(|addr| !is_unicast_ip(&addr.ip())) {
warn!(
"DNS resolution for {} returned non-unicast IP: {}",
host,
addr.ip()
);
return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
}
Ok(())
}

View File

@@ -82,19 +82,7 @@ pub async fn prepare_repo_write(
.await
.map_err(|e| {
tracing::warn!(error = ?e, is_dpop = extracted.is_dpop, "Token validation failed in prepare_repo_write");
let mut response = ApiError::from(e).into_response();
if matches!(e, crate::auth::TokenValidationError::TokenExpired) && extracted.is_dpop {
*response.status_mut() = axum::http::StatusCode::UNAUTHORIZED;
let www_auth =
"DPoP error=\"invalid_token\", error_description=\"Token has expired\"";
response.headers_mut().insert(
"WWW-Authenticate",
www_auth.parse().unwrap(),
);
let nonce = crate::oauth::verify::generate_dpop_nonce();
response.headers_mut().insert("DPoP-Nonce", nonce.parse().unwrap());
}
response
ApiError::from(e).into_response()
})?;
if repo.as_str() != auth_user.did.as_str() {
return Err(

View File

@@ -181,10 +181,11 @@ fn validate_email_detailed(email: &str) -> Result<(), EmailValidationError> {
if local.starts_with('.') || local.ends_with('.') || local.contains("..") {
return Err(EmailValidationError::InvalidLocalPart);
}
for c in local.chars() {
if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) {
return Err(EmailValidationError::InvalidLocalPart);
}
if !local
.chars()
.all(|c| c.is_ascii_alphanumeric() || EMAIL_LOCAL_SPECIAL_CHARS.contains(c))
{
return Err(EmailValidationError::InvalidLocalPart);
}
if domain.is_empty() {
return Err(EmailValidationError::EmptyDomain);
@@ -195,18 +196,14 @@ fn validate_email_detailed(email: &str) -> Result<(), EmailValidationError> {
if !domain.contains('.') {
return Err(EmailValidationError::MissingDomainDot);
}
for label in domain.split('.') {
if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH {
return Err(EmailValidationError::InvalidDomainLabel);
}
if label.starts_with('-') || label.ends_with('-') {
return Err(EmailValidationError::InvalidDomainLabel);
}
for c in label.chars() {
if !c.is_ascii_alphanumeric() && c != '-' {
return Err(EmailValidationError::InvalidDomainLabel);
}
}
if !domain.split('.').all(|label| {
!label.is_empty()
&& label.len() <= MAX_DOMAIN_LABEL_LENGTH
&& !label.starts_with('-')
&& !label.ends_with('-')
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
}) {
return Err(EmailValidationError::InvalidDomainLabel);
}
Ok(())
}
@@ -293,10 +290,11 @@ pub fn validate_service_handle(
return Err(HandleValidationError::EndsWithInvalidChar);
}
for c in handle.chars() {
if !c.is_ascii_alphanumeric() && c != '-' {
return Err(HandleValidationError::InvalidCharacters);
}
if !handle
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-')
{
return Err(HandleValidationError::InvalidCharacters);
}
if crate::moderation::has_explicit_slur(handle) {
@@ -330,10 +328,11 @@ pub fn is_valid_email(email: &str) -> bool {
if local.contains("..") {
return false;
}
for c in local.chars() {
if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) {
return false;
}
if !local
.chars()
.all(|c| c.is_ascii_alphanumeric() || EMAIL_LOCAL_SPECIAL_CHARS.contains(c))
{
return false;
}
if domain.is_empty() || domain.len() > MAX_DOMAIN_LENGTH {
return false;
@@ -341,20 +340,13 @@ pub fn is_valid_email(email: &str) -> bool {
if !domain.contains('.') {
return false;
}
for label in domain.split('.') {
if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH {
return false;
}
if label.starts_with('-') || label.ends_with('-') {
return false;
}
for c in label.chars() {
if !c.is_ascii_alphanumeric() && c != '-' {
return false;
}
}
}
true
domain.split('.').all(|label| {
!label.is_empty()
&& label.len() <= MAX_DOMAIN_LABEL_LENGTH
&& !label.starts_with('-')
&& !label.ends_with('-')
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
})
}
#[cfg(test)]

View File

@@ -61,6 +61,7 @@ pub enum TokenValidationError {
KeyDecryptionFailed,
AuthenticationFailed,
TokenExpired,
OAuthTokenExpired,
}
impl fmt::Display for TokenValidationError {
@@ -70,7 +71,7 @@ impl fmt::Display for TokenValidationError {
Self::AccountTakedown => write!(f, "AccountTakedown"),
Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
Self::TokenExpired => write!(f, "ExpiredToken"),
Self::TokenExpired | Self::OAuthTokenExpired => write!(f, "ExpiredToken"),
}
}
}
@@ -497,7 +498,9 @@ pub async fn validate_token_with_dpop(
controller_did: None,
})
}
Err(crate::oauth::OAuthError::ExpiredToken(_)) => Err(TokenValidationError::TokenExpired),
Err(crate::oauth::OAuthError::ExpiredToken(_)) => {
Err(TokenValidationError::OAuthTokenExpired)
}
Err(_) => Err(TokenValidationError::AuthenticationFailed),
}
}

View File

@@ -106,26 +106,29 @@ pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::E
pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
query
.map(|q| {
let mut values = Vec::new();
for pair in q.split('&') {
if let Some((k, v)) = pair.split_once('=')
&& k == key
&& let Ok(decoded) = urlencoding::decode(v)
{
let decoded = decoded.into_owned();
q.split('&')
.filter_map(|pair| {
pair.split_once('=')
.filter(|(k, _)| *k == key)
.and_then(|(_, v)| urlencoding::decode(v).ok())
.map(|decoded| decoded.into_owned())
})
.flat_map(|decoded| {
if decoded.contains(',') {
for part in decoded.split(',') {
let trimmed = part.trim();
if !trimmed.is_empty() {
values.push(trimmed.to_string());
}
}
} else if !decoded.is_empty() {
values.push(decoded);
decoded
.split(',')
.filter_map(|part| {
let trimmed = part.trim();
(!trimmed.is_empty()).then(|| trimmed.to_string())
})
.collect::<Vec<_>>()
} else if decoded.is_empty() {
vec![]
} else {
vec![decoded]
}
}
}
values
})
.collect()
})
.unwrap_or_default()
}

View File

@@ -437,7 +437,7 @@ async fn setup_mock_plc_directory() -> String {
async fn spawn_app(database_url: String) -> String {
use tranquil_pds::rate_limit::RateLimiters;
let pool = PgPoolOptions::new()
.max_connections(3)
.max_connections(10)
.acquire_timeout(std::time::Duration::from_secs(30))
.connect(&database_url)
.await

View File

@@ -4,12 +4,16 @@ use serde_json::{Value, json};
pub use crate::common::*;
fn unique_id() -> String {
uuid::Uuid::new_v4().simple().to_string()[..12].to_string()
}
#[allow(dead_code)]
pub async fn setup_new_user(handle_prefix: &str) -> (String, String) {
let client = client();
let ts = Utc::now().timestamp_millis();
let handle = format!("{}-{}.test", handle_prefix, ts);
let email = format!("{}-{}@test.com", handle_prefix, ts);
let uid = unique_id();
let handle = format!("{}-{}.test", handle_prefix, uid);
let email = format!("{}-{}@test.com", handle_prefix, uid);
let password = "E2epass123!";
let create_account_payload = json!({
"handle": handle,
@@ -51,7 +55,7 @@ pub async fn create_post(
text: &str,
) -> (String, String) {
let collection = "app.bsky.feed.post";
let rkey = format!("e2e_social_{}", Utc::now().timestamp_millis());
let rkey = format!("e2e_social_{}", unique_id());
let now = Utc::now().to_rfc3339();
let create_payload = json!({
"repo": did,
@@ -95,7 +99,7 @@ pub async fn create_follow(
followee_did: &str,
) -> (String, String) {
let collection = "app.bsky.graph.follow";
let rkey = format!("e2e_follow_{}", Utc::now().timestamp_millis());
let rkey = format!("e2e_follow_{}", unique_id());
let now = Utc::now().to_rfc3339();
let create_payload = json!({
"repo": follower_did,
@@ -140,7 +144,7 @@ pub async fn create_like(
subject_cid: &str,
) -> (String, String) {
let collection = "app.bsky.feed.like";
let rkey = format!("e2e_like_{}", Utc::now().timestamp_millis());
let rkey = format!("e2e_like_{}", unique_id());
let now = Utc::now().to_rfc3339();
let payload = json!({
"repo": liker_did,
@@ -182,7 +186,7 @@ pub async fn create_repost(
subject_cid: &str,
) -> (String, String) {
let collection = "app.bsky.feed.repost";
let rkey = format!("e2e_repost_{}", Utc::now().timestamp_millis());
let rkey = format!("e2e_repost_{}", unique_id());
let now = Utc::now().to_rfc3339();
let payload = json!({
"repo": reposter_did,

View File

@@ -55,18 +55,14 @@ impl BlobScope {
if self.accept.is_empty() || self.accept.contains("*/*") {
return true;
}
for pattern in &self.accept {
if pattern == mime {
return true;
}
if let Some(prefix) = pattern.strip_suffix("/*")
&& mime.starts_with(prefix)
&& mime.chars().nth(prefix.len()) == Some('/')
{
return true;
}
}
false
self.accept.iter().any(|pattern| {
pattern == mime
|| pattern
.strip_suffix("/*")
.is_some_and(|prefix| {
mime.starts_with(prefix) && mime.chars().nth(prefix.len()) == Some('/')
})
})
}
}
@@ -170,19 +166,20 @@ pub fn parse_scope(scope: &str) -> ParsedScope {
Some(rest.to_string())
};
let mut actions = HashSet::new();
if let Some(action_values) = params.get("action") {
for action_str in action_values {
if let Some(action) = RepoAction::parse_str(action_str) {
actions.insert(action);
}
}
}
if actions.is_empty() {
actions.insert(RepoAction::Create);
actions.insert(RepoAction::Update);
actions.insert(RepoAction::Delete);
}
let actions: HashSet<RepoAction> = params
.get("action")
.map(|action_values| {
action_values
.iter()
.filter_map(|s| RepoAction::parse_str(s))
.collect()
})
.filter(|set: &HashSet<RepoAction>| !set.is_empty())
.unwrap_or_else(|| {
[RepoAction::Create, RepoAction::Update, RepoAction::Delete]
.into_iter()
.collect()
});
return ParsedScope::Repo(RepoScope {
collection,
@@ -191,19 +188,20 @@ pub fn parse_scope(scope: &str) -> ParsedScope {
}
if base == "repo" {
let mut actions = HashSet::new();
if let Some(action_values) = params.get("action") {
for action_str in action_values {
if let Some(action) = RepoAction::parse_str(action_str) {
actions.insert(action);
}
}
}
if actions.is_empty() {
actions.insert(RepoAction::Create);
actions.insert(RepoAction::Update);
actions.insert(RepoAction::Delete);
}
let actions: HashSet<RepoAction> = params
.get("action")
.map(|action_values| {
action_values
.iter()
.filter_map(|s| RepoAction::parse_str(s))
.collect()
})
.filter(|set: &HashSet<RepoAction>| !set.is_empty())
.unwrap_or_else(|| {
[RepoAction::Create, RepoAction::Update, RepoAction::Delete]
.into_iter()
.collect()
});
return ParsedScope::Repo(RepoScope {
collection: None,
actions,
@@ -212,16 +210,17 @@ pub fn parse_scope(scope: &str) -> ParsedScope {
if base.starts_with("blob") {
let positional = base.strip_prefix("blob:").unwrap_or("");
let mut accept = HashSet::new();
if !positional.is_empty() {
accept.insert(positional.to_string());
}
if let Some(accept_values) = params.get("accept") {
for v in accept_values {
accept.insert(v.to_string());
}
}
let accept: HashSet<String> = std::iter::once(positional)
.filter(|s| !s.is_empty())
.map(String::from)
.chain(
params
.get("accept")
.into_iter()
.flatten()
.map(String::clone),
)
.collect();
return ParsedScope::Blob(BlobScope { accept });
}

View File

@@ -113,34 +113,32 @@ impl ScopePermissions {
return Ok(());
}
for repo_scope in self.find_repo_scopes() {
if !repo_scope.actions.contains(&action) {
continue;
}
match &repo_scope.collection {
None => return Ok(()),
Some(coll) if coll == collection => return Ok(()),
Some(coll) if coll.ends_with(".*") => {
let prefix = coll.strip_suffix(".*").unwrap();
if collection.starts_with(prefix)
&& collection.chars().nth(prefix.len()) == Some('.')
{
return Ok(());
let has_permission = self.find_repo_scopes().any(|repo_scope| {
repo_scope.actions.contains(&action)
&& match &repo_scope.collection {
None => true,
Some(coll) if coll == collection => true,
Some(coll) if coll.ends_with(".*") => {
let prefix = coll.strip_suffix(".*").unwrap();
collection.starts_with(prefix)
&& collection.chars().nth(prefix.len()) == Some('.')
}
_ => false,
}
_ => {}
}
}
});
Err(ScopeError::InsufficientScope {
required: format!("repo:{}?action={}", collection, action_str(action)),
message: format!(
"Insufficient scope to {} records in {}",
action_str(action),
collection
),
})
if has_permission {
Ok(())
} else {
Err(ScopeError::InsufficientScope {
required: format!("repo:{}?action={}", collection, action_str(action)),
message: format!(
"Insufficient scope to {} records in {}",
action_str(action),
collection
),
})
}
}
pub fn assert_blob(&self, mime: &str) -> Result<(), ScopeError> {
@@ -148,16 +146,14 @@ impl ScopePermissions {
return Ok(());
}
for blob_scope in self.find_blob_scopes() {
if blob_scope.matches_mime(mime) {
return Ok(());
}
if self.find_blob_scopes().any(|blob_scope| blob_scope.matches_mime(mime)) {
Ok(())
} else {
Err(ScopeError::InsufficientScope {
required: format!("blob:{}", mime),
message: format!("Insufficient scope to upload blob with mime type {}", mime),
})
}
Err(ScopeError::InsufficientScope {
required: format!("blob:{}", mime),
message: format!("Insufficient scope to upload blob with mime type {}", mime),
})
}
pub fn assert_rpc(&self, aud: &str, lxm: &str) -> Result<(), ScopeError> {
@@ -169,7 +165,7 @@ impl ScopePermissions {
return Ok(());
}
for rpc_scope in self.find_rpc_scopes() {
let has_permission = self.find_rpc_scopes().any(|rpc_scope| {
let lxm_matches = match &rpc_scope.lxm {
None => true,
Some(scope_lxm) if scope_lxm == lxm => true,
@@ -186,15 +182,17 @@ impl ScopePermissions {
Some(scope_aud) => scope_aud == aud,
};
if lxm_matches && aud_matches {
return Ok(());
}
}
lxm_matches && aud_matches
});
Err(ScopeError::InsufficientScope {
required: format!("rpc:{}?aud={}", lxm, aud),
message: format!("Insufficient scope to call {} on {}", lxm, aud),
})
if has_permission {
Ok(())
} else {
Err(ScopeError::InsufficientScope {
required: format!("rpc:{}?aud={}", lxm, aud),
message: format!("Insufficient scope to call {} on {}", lxm, aud),
})
}
}
pub fn assert_account(
@@ -211,27 +209,28 @@ impl ScopePermissions {
return Ok(());
}
for account_scope in self.find_account_scopes() {
if account_scope.attr == attr && account_scope.action == action {
return Ok(());
}
if account_scope.attr == attr && account_scope.action == AccountAction::Manage {
return Ok(());
}
}
let has_permission = self.find_account_scopes().any(|account_scope| {
account_scope.attr == attr
&& (account_scope.action == action
|| account_scope.action == AccountAction::Manage)
});
Err(ScopeError::InsufficientScope {
required: format!(
"account:{}?action={}",
attr_str(attr),
action_str_account(action)
),
message: format!(
"Insufficient scope to {} account {}",
action_str_account(action),
attr_str(attr)
),
})
if has_permission {
Ok(())
} else {
Err(ScopeError::InsufficientScope {
required: format!(
"account:{}?action={}",
attr_str(attr),
action_str_account(action)
),
message: format!(
"Insufficient scope to {} account {}",
action_str_account(action),
attr_str(attr)
),
})
}
}
pub fn allows_email_read(&self) -> bool {
@@ -264,22 +263,23 @@ impl ScopePermissions {
return Ok(());
}
for identity_scope in self.find_identity_scopes() {
if identity_scope.attr == IdentityAttr::Wildcard {
return Ok(());
}
if identity_scope.attr == attr {
return Ok(());
}
}
let has_permission = self
.find_identity_scopes()
.any(|identity_scope| {
identity_scope.attr == IdentityAttr::Wildcard || identity_scope.attr == attr
});
Err(ScopeError::InsufficientScope {
required: format!("identity:{}", identity_attr_str(attr)),
message: format!(
"Insufficient scope to modify identity {}",
identity_attr_str(attr)
),
})
if has_permission {
Ok(())
} else {
Err(ScopeError::InsufficientScope {
required: format!("identity:{}", identity_attr_str(attr)),
message: format!(
"Insufficient scope to modify identity {}",
identity_attr_str(attr)
),
})
}
}
pub fn allows_identity(&self, attr: IdentityAttr) -> bool {