From 16fb4dbd033fa55325e79fe910b59e6aa962a6df Mon Sep 17 00:00:00 2001 From: lewis Date: Sun, 11 Jan 2026 22:33:41 +0200 Subject: [PATCH] oauth error msg improvement, general code quality --- ...1bec0c2c86d5d5f17d1454440c6fcd8d4d27a.json | 52 ++++++ ...69ce28efe8594fda026b6f9b298ef0597b40e.json | 28 +++ ...69b4f54b9830d0490c4f8841f8435478c57d3.json | 22 --- ...cb3a683c4771e4fb8c151b3fd5119fb6c1068.json | 28 --- ...423d28c77e1a368df7edc81708eb3038f600.json} | 6 +- ...4532d549d1ad8a9835da4a5c001eee89db076.json | 34 ++++ ...9bde7269879c0547ad43f30b78bfeeef5a920.json | 34 ++++ ...d1184b909f781d131aa2c69368ed021e87e4.json} | 6 +- Cargo.toml | 2 +- crates/tranquil-comms/src/locale.rs | 9 +- crates/tranquil-oauth/src/client.rs | 26 ++- crates/tranquil-oauth/src/dpop.rs | 86 +++++++-- .../src/api/admin/account/info.rs | 163 +++++++++++------- crates/tranquil-pds/src/api/admin/config.rs | 35 ++-- crates/tranquil-pds/src/api/admin/invite.rs | 123 +++++++------ crates/tranquil-pds/src/api/error.rs | 35 +++- crates/tranquil-pds/src/api/moderation/mod.rs | 2 +- crates/tranquil-pds/src/api/proxy.rs | 39 ++--- crates/tranquil-pds/src/api/proxy_client.rs | 16 +- .../tranquil-pds/src/api/repo/record/write.rs | 14 +- crates/tranquil-pds/src/api/validation.rs | 68 ++++---- crates/tranquil-pds/src/auth/mod.rs | 7 +- crates/tranquil-pds/src/util.rs | 39 +++-- crates/tranquil-pds/tests/common/mod.rs | 2 +- crates/tranquil-pds/tests/helpers/mod.rs | 18 +- crates/tranquil-scopes/src/parser.rs | 95 +++++----- crates/tranquil-scopes/src/permissions.rs | 156 ++++++++--------- 27 files changed, 676 insertions(+), 469 deletions(-) create mode 100644 .sqlx/query-2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a.json create mode 100644 .sqlx/query-46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e.json delete mode 100644 .sqlx/query-5a98e015997942835800fcd326e69b4f54b9830d0490c4f8841f8435478c57d3.json delete mode 100644 .sqlx/query-5d5442136932d4088873a935c41cb3a683c4771e4fb8c151b3fd5119fb6c1068.json rename .sqlx/{query-7b2d1d4ac06063e07a7c7a7d0fb434db08ce312eb2864405d7f96f4e985ed036.json => query-888f8724cfc2ed27391b661a5cfe423d28c77e1a368df7edc81708eb3038f600.json} (65%) create mode 100644 .sqlx/query-ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076.json create mode 100644 .sqlx/query-ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920.json rename .sqlx/{query-413c5b03501a399dca13f345fcae05770517091d73db93966853e944c68ee237.json => query-eec42a3a4b1440aa8dd580b5d0bbd1184b909f781d131aa2c69368ed021e87e4.json} (56%) diff --git a/.sqlx/query-2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a.json b/.sqlx/query-2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a.json new file mode 100644 index 0000000..5a2da40 --- /dev/null +++ b/.sqlx/query-2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by\n FROM invite_codes ic\n JOIN users u ON ic.created_by_user = u.id\n WHERE ic.created_by_user = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "code", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "available_uses", + "type_info": "Int4" + }, + { + "ordinal": 2, + "name": "disabled", + "type_info": "Bool" + }, + { + "ordinal": 3, + "name": "for_account", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "created_by", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + false + ] + }, + "hash": "2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a" +} diff --git a/.sqlx/query-46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e.json b/.sqlx/query-46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e.json new file mode 100644 index 0000000..b363f6c --- /dev/null +++ b/.sqlx/query-46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, did FROM users WHERE id = ANY($1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "did", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "UuidArray" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e" +} diff --git a/.sqlx/query-5a98e015997942835800fcd326e69b4f54b9830d0490c4f8841f8435478c57d3.json b/.sqlx/query-5a98e015997942835800fcd326e69b4f54b9830d0490c4f8841f8435478c57d3.json deleted file mode 100644 index 5b502fd..0000000 --- a/.sqlx/query-5a98e015997942835800fcd326e69b4f54b9830d0490c4f8841f8435478c57d3.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT code FROM invite_codes WHERE created_by_user = $1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "code", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [ - false - ] - }, - "hash": "5a98e015997942835800fcd326e69b4f54b9830d0490c4f8841f8435478c57d3" -} diff --git a/.sqlx/query-5d5442136932d4088873a935c41cb3a683c4771e4fb8c151b3fd5119fb6c1068.json b/.sqlx/query-5d5442136932d4088873a935c41cb3a683c4771e4fb8c151b3fd5119fb6c1068.json deleted file mode 100644 index 8c2e574..0000000 --- a/.sqlx/query-5d5442136932d4088873a935c41cb3a683c4771e4fb8c151b3fd5119fb6c1068.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT u.did, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = $1\n ORDER BY icu.used_at DESC\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "did", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "used_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [ - "Text" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "5d5442136932d4088873a935c41cb3a683c4771e4fb8c151b3fd5119fb6c1068" -} diff --git a/.sqlx/query-7b2d1d4ac06063e07a7c7a7d0fb434db08ce312eb2864405d7f96f4e985ed036.json b/.sqlx/query-888f8724cfc2ed27391b661a5cfe423d28c77e1a368df7edc81708eb3038f600.json similarity index 65% rename from .sqlx/query-7b2d1d4ac06063e07a7c7a7d0fb434db08ce312eb2864405d7f96f4e985ed036.json rename to .sqlx/query-888f8724cfc2ed27391b661a5cfe423d28c77e1a368df7edc81708eb3038f600.json index 8d1860c..b54a8d5 100644 --- a/.sqlx/query-7b2d1d4ac06063e07a7c7a7d0fb434db08ce312eb2864405d7f96f4e985ed036.json +++ b/.sqlx/query-888f8724cfc2ed27391b661a5cfe423d28c77e1a368df7edc81708eb3038f600.json @@ -1,14 +1,14 @@ { "db_name": "PostgreSQL", - "query": "UPDATE invite_codes SET disabled = TRUE WHERE code = $1", + "query": "UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)", "describe": { "columns": [], "parameters": { "Left": [ - "Text" + "TextArray" ] }, "nullable": [] }, - "hash": "7b2d1d4ac06063e07a7c7a7d0fb434db08ce312eb2864405d7f96f4e985ed036" + "hash": "888f8724cfc2ed27391b661a5cfe423d28c77e1a368df7edc81708eb3038f600" } diff --git a/.sqlx/query-ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076.json b/.sqlx/query-ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076.json new file mode 100644 index 0000000..7aa3c08 --- /dev/null +++ b/.sqlx/query-ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076.json @@ -0,0 +1,34 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT icu.code, u.did as used_by, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = ANY($1)\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "code", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "used_by", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "used_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "TextArray" + ] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076" +} diff --git a/.sqlx/query-ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920.json b/.sqlx/query-ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920.json new file mode 100644 index 0000000..d52ff9a --- /dev/null +++ b/.sqlx/query-ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920.json @@ -0,0 +1,34 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT icu.code, u.did, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = ANY($1)\n ORDER BY icu.used_at DESC\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "code", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "did", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "used_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "TextArray" + ] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920" +} diff --git a/.sqlx/query-413c5b03501a399dca13f345fcae05770517091d73db93966853e944c68ee237.json b/.sqlx/query-eec42a3a4b1440aa8dd580b5d0bbd1184b909f781d131aa2c69368ed021e87e4.json similarity index 56% rename from .sqlx/query-413c5b03501a399dca13f345fcae05770517091d73db93966853e944c68ee237.json rename to .sqlx/query-eec42a3a4b1440aa8dd580b5d0bbd1184b909f781d131aa2c69368ed021e87e4.json index 17297f7..c6fc374 100644 --- a/.sqlx/query-413c5b03501a399dca13f345fcae05770517091d73db93966853e944c68ee237.json +++ b/.sqlx/query-eec42a3a4b1440aa8dd580b5d0bbd1184b909f781d131aa2c69368ed021e87e4.json @@ -1,14 +1,14 @@ { "db_name": "PostgreSQL", - "query": "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user = $1", + "query": "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))", "describe": { "columns": [], "parameters": { "Left": [ - "Uuid" + "TextArray" ] }, "nullable": [] }, - "hash": "413c5b03501a399dca13f345fcae05770517091d73db93966853e944c68ee237" + "hash": "eec42a3a4b1440aa8dd580b5d0bbd1184b909f781d131aa2c69368ed021e87e4" } diff --git a/Cargo.toml b/Cargo.toml index 019efc6..5e9740e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,7 +92,7 @@ tower-layer = "0.3" tracing = "0.1" tracing-subscriber = "0.3" urlencoding = "2.1" -uuid = { version = "1.19", features = ["v4", "v5", "fast-rng"] } +uuid = { version = "1.19", features = ["v4", "v5", "v7", "fast-rng"] } webauthn-rs = { version = "0.5", features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] } webauthn-rs-proto = "0.5" zip = { version = "7.0", default-features = false, features = ["deflate"] } diff --git a/crates/tranquil-comms/src/locale.rs b/crates/tranquil-comms/src/locale.rs index 708f4ff..5900f13 100644 --- a/crates/tranquil-comms/src/locale.rs +++ b/crates/tranquil-comms/src/locale.rs @@ -182,11 +182,10 @@ static STRINGS_FI: NotificationStrings = NotificationStrings { }; pub fn format_message(template: &str, vars: &[(&str, &str)]) -> String { - let mut result = template.to_string(); - for (key, value) in vars { - result = result.replace(&format!("{{{}}}", key), value); - } - result + vars.iter() + .fold(template.to_string(), |result, (key, value)| { + result.replace(&format!("{{{}}}", key), value) + }) } #[cfg(test)] diff --git a/crates/tranquil-oauth/src/client.rs b/crates/tranquil-oauth/src/client.rs index 43787be..d75d626 100644 --- a/crates/tranquil-oauth/src/client.rs +++ b/crates/tranquil-oauth/src/client.rs @@ -568,12 +568,21 @@ fn verify_es256( .get("y") .and_then(|v| v.as_str()) .ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?; - let x_bytes = URL_SAFE_NO_PAD + let x_decoded = URL_SAFE_NO_PAD .decode(x) .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; - let y_bytes = URL_SAFE_NO_PAD + let y_decoded = URL_SAFE_NO_PAD .decode(y) .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; + if x_decoded.len() > 32 || y_decoded.len() > 32 { + return Err(OAuthError::InvalidClient( + "EC coordinate too long".to_string(), + )); + } + let mut x_bytes = [0u8; 32]; + let mut y_bytes = [0u8; 32]; + x_bytes[32 - x_decoded.len()..].copy_from_slice(&x_decoded); + y_bytes[32 - y_decoded.len()..].copy_from_slice(&y_decoded); let mut point_bytes = vec![0x04]; point_bytes.extend_from_slice(&x_bytes); point_bytes.extend_from_slice(&y_bytes); @@ -604,12 +613,21 @@ fn verify_es384( .get("y") .and_then(|v| v.as_str()) .ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?; - let x_bytes = URL_SAFE_NO_PAD + let x_decoded = URL_SAFE_NO_PAD .decode(x) .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; - let y_bytes = URL_SAFE_NO_PAD + let y_decoded = URL_SAFE_NO_PAD .decode(y) .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; + if x_decoded.len() > 48 || y_decoded.len() > 48 { + return Err(OAuthError::InvalidClient( + "EC coordinate too long".to_string(), + )); + } + let mut x_bytes = [0u8; 48]; + let mut y_bytes = [0u8; 48]; + x_bytes[48 - x_decoded.len()..].copy_from_slice(&x_decoded); + y_bytes[48 - y_decoded.len()..].copy_from_slice(&y_decoded); let mut point_bytes = vec![0x04]; point_bytes.extend_from_slice(&x_bytes); point_bytes.extend_from_slice(&y_bytes); diff --git a/crates/tranquil-oauth/src/dpop.rs b/crates/tranquil-oauth/src/dpop.rs index a547ddb..07d8ffa 100644 --- a/crates/tranquil-oauth/src/dpop.rs +++ b/crates/tranquil-oauth/src/dpop.rs @@ -218,25 +218,30 @@ fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), O crv ))); } - let x_bytes = URL_SAFE_NO_PAD + let x_decoded = URL_SAFE_NO_PAD .decode( jwk.x .as_ref() .ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?, ) .map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?; - let y_bytes = URL_SAFE_NO_PAD + let y_decoded = URL_SAFE_NO_PAD .decode( jwk.y .as_ref() .ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?, ) .map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?; - let point = EncodedPoint::from_affine_coordinates( - x_bytes.as_slice().into(), - y_bytes.as_slice().into(), - false, - ); + let mut x_bytes = [0u8; 32]; + let mut y_bytes = [0u8; 32]; + if x_decoded.len() > 32 || y_decoded.len() > 32 { + return Err(OAuthError::InvalidDpopProof( + "EC coordinate too long".to_string(), + )); + } + x_bytes[32 - x_decoded.len()..].copy_from_slice(&x_decoded); + y_bytes[32 - y_decoded.len()..].copy_from_slice(&y_decoded); + let point = EncodedPoint::from_affine_coordinates((&x_bytes).into(), (&y_bytes).into(), false); let affine_opt: Option = AffinePoint::from_encoded_point(&point).into(); let affine = affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; @@ -264,25 +269,30 @@ fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), O crv ))); } - let x_bytes = URL_SAFE_NO_PAD + let x_decoded = URL_SAFE_NO_PAD .decode( jwk.x .as_ref() .ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?, ) .map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?; - let y_bytes = URL_SAFE_NO_PAD + let y_decoded = URL_SAFE_NO_PAD .decode( jwk.y .as_ref() .ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?, ) .map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?; - let point = EncodedPoint::from_affine_coordinates( - x_bytes.as_slice().into(), - y_bytes.as_slice().into(), - false, - ); + let mut x_bytes = [0u8; 48]; + let mut y_bytes = [0u8; 48]; + if x_decoded.len() > 48 || y_decoded.len() > 48 { + return Err(OAuthError::InvalidDpopProof( + "EC coordinate too long".to_string(), + )); + } + x_bytes[48 - x_decoded.len()..].copy_from_slice(&x_decoded); + y_bytes[48 - y_decoded.len()..].copy_from_slice(&y_decoded); + let point = EncodedPoint::from_affine_coordinates((&x_bytes).into(), (&y_bytes).into(), false); let affine_opt: Option = AffinePoint::from_encoded_point(&point).into(); let affine = affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; @@ -398,4 +408,52 @@ mod tests { let thumbprint = compute_jwk_thumbprint(&jwk).unwrap(); assert!(!thumbprint.is_empty()); } + + #[test] + fn test_es256_short_coordinate_no_panic() { + let short_31_bytes = vec![0x42u8; 31]; + let short_30_bytes = vec![0x42u8; 30]; + let x_b64 = URL_SAFE_NO_PAD.encode(&short_31_bytes); + let y_b64 = URL_SAFE_NO_PAD.encode(&short_30_bytes); + let jwk = DPoPJwk { + kty: "EC".to_string(), + crv: Some("P-256".to_string()), + x: Some(x_b64), + y: Some(y_b64), + }; + let result = verify_es256(&jwk, b"test", &[0u8; 64]); + assert!(result.is_err(), "Invalid coordinates should return error, not panic"); + } + + #[test] + fn test_es256_valid_key_with_trimmed_coordinates() { + use p256::ecdsa::{SigningKey, signature::Signer}; + use p256::elliptic_curve::rand_core::OsRng; + + let signing_key = SigningKey::random(&mut OsRng); + let verifying_key = signing_key.verifying_key(); + let point = verifying_key.to_encoded_point(false); + let x_bytes = point.x().unwrap(); + let y_bytes = point.y().unwrap(); + let x_trimmed: Vec = x_bytes.iter().copied().skip_while(|&b| b == 0).collect(); + let y_trimmed: Vec = y_bytes.iter().copied().skip_while(|&b| b == 0).collect(); + let x_b64 = URL_SAFE_NO_PAD.encode(&x_trimmed); + let y_b64 = URL_SAFE_NO_PAD.encode(&y_trimmed); + let jwk = DPoPJwk { + kty: "EC".to_string(), + crv: Some("P-256".to_string()), + x: Some(x_b64), + y: Some(y_b64), + }; + let message = b"test message for signature verification"; + let signature: p256::ecdsa::Signature = signing_key.sign(message); + let result = verify_es256(&jwk, message, signature.to_bytes().as_slice()); + assert!( + result.is_ok(), + "Should verify signature with trimmed coordinates (x={}, y={}): {:?}", + x_trimmed.len(), + y_trimmed.len(), + result + ); + } } diff --git a/crates/tranquil-pds/src/api/admin/account/info.rs b/crates/tranquil-pds/src/api/admin/account/info.rs index c255bf2..59d80fc 100644 --- a/crates/tranquil-pds/src/api/admin/account/info.rs +++ b/crates/tranquil-pds/src/api/admin/account/info.rs @@ -130,24 +130,62 @@ async fn get_invites_for_user( db: &sqlx::PgPool, user_id: uuid::Uuid, ) -> Option> { - let codes = sqlx::query_scalar!( + let invite_codes = sqlx::query!( r#" - SELECT code FROM invite_codes WHERE created_by_user = $1 + SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by + FROM invite_codes ic + JOIN users u ON ic.created_by_user = u.id + WHERE ic.created_by_user = $1 "#, user_id ) .fetch_all(db) .await .ok()?; - if codes.is_empty() { + + if invite_codes.is_empty() { return None; } - let mut invites = Vec::new(); - for code in codes { - if let Some(info) = get_invite_code_info(db, &code).await { - invites.push(info); - } - } + + let code_strings: Vec = invite_codes.iter().map(|ic| ic.code.clone()).collect(); + let mut uses_by_code: std::collections::HashMap> = + std::collections::HashMap::new(); + sqlx::query!( + r#" + SELECT icu.code, u.did as used_by, icu.used_at + FROM invite_code_uses icu + JOIN users u ON icu.used_by_user = u.id + WHERE icu.code = ANY($1) + "#, + &code_strings + ) + .fetch_all(db) + .await + .ok()? + .into_iter() + .for_each(|r| { + uses_by_code + .entry(r.code) + .or_default() + .push(InviteCodeUseInfo { + used_by: r.used_by.into(), + used_at: r.used_at.to_rfc3339(), + }); + }); + + let invites: Vec = invite_codes + .into_iter() + .map(|ic| InviteCodeInfo { + code: ic.code.clone(), + available: ic.available_uses, + disabled: ic.disabled.unwrap_or(false), + for_account: ic.for_account.into(), + created_by: ic.created_by.into(), + created_at: ic.created_at.to_rfc3339(), + uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(), + }) + .collect(); + if invites.is_empty() { None } else { @@ -276,61 +314,62 @@ pub async fn get_account_infos( .map(|r| (r.used_by_user, r.code)) .collect(); - let mut uses_by_code: std::collections::HashMap> = - std::collections::HashMap::new(); - for u in all_invite_uses { - uses_by_code - .entry(u.code.clone()) - .or_default() - .push(InviteCodeUseInfo { - used_by: u.used_by.into(), - used_at: u.used_at.to_rfc3339(), + let uses_by_code: std::collections::HashMap> = + all_invite_uses + .into_iter() + .fold(std::collections::HashMap::new(), |mut acc, u| { + acc.entry(u.code.clone()).or_default().push(InviteCodeUseInfo { + used_by: u.used_by.into(), + used_at: u.used_at.to_rfc3339(), + }); + acc }); - } - let mut codes_by_user: std::collections::HashMap> = - std::collections::HashMap::new(); - let mut code_info_map: std::collections::HashMap = - std::collections::HashMap::new(); - for ic in all_invite_codes { - let info = InviteCodeInfo { - code: ic.code.clone(), - available: ic.available_uses, - disabled: ic.disabled.unwrap_or(false), - for_account: ic.for_account.into(), - created_by: ic.created_by.into(), - created_at: ic.created_at.to_rfc3339(), - uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(), - }; - code_info_map.insert(ic.code.clone(), info.clone()); - codes_by_user - .entry(ic.created_by_user) - .or_default() - .push(info); - } + let (codes_by_user, code_info_map): ( + std::collections::HashMap>, + std::collections::HashMap, + ) = all_invite_codes.into_iter().fold( + (std::collections::HashMap::new(), std::collections::HashMap::new()), + |(mut by_user, mut by_code), ic| { + let info = InviteCodeInfo { + code: ic.code.clone(), + available: ic.available_uses, + disabled: ic.disabled.unwrap_or(false), + for_account: ic.for_account.into(), + created_by: ic.created_by.into(), + created_at: ic.created_at.to_rfc3339(), + uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(), + }; + by_code.insert(ic.code.clone(), info.clone()); + by_user.entry(ic.created_by_user).or_default().push(info); + (by_user, by_code) + }, + ); - let mut infos = Vec::with_capacity(users.len()); - for row in users { - let invited_by = invited_by_map - .get(&row.id) - .and_then(|code| code_info_map.get(code).cloned()); - let invites = codes_by_user.get(&row.id).cloned(); - infos.push(AccountInfo { - did: row.did.into(), - handle: row.handle.into(), - email: row.email, - indexed_at: row.created_at.to_rfc3339(), - invite_note: None, - invites_disabled: row.invites_disabled.unwrap_or(false), - email_confirmed_at: if row.email_verified { - Some(row.created_at.to_rfc3339()) - } else { - None - }, - deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), - invited_by, - invites, - }); - } + let infos: Vec = users + .into_iter() + .map(|row| { + let invited_by = invited_by_map + .get(&row.id) + .and_then(|code| code_info_map.get(code).cloned()); + let invites = codes_by_user.get(&row.id).cloned(); + AccountInfo { + did: row.did.into(), + handle: row.handle.into(), + email: row.email, + indexed_at: row.created_at.to_rfc3339(), + invite_note: None, + invites_disabled: row.invites_disabled.unwrap_or(false), + email_confirmed_at: if row.email_verified { + Some(row.created_at.to_rfc3339()) + } else { + None + }, + deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), + invited_by, + invites, + } + }) + .collect(); (StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response() } diff --git a/crates/tranquil-pds/src/api/admin/config.rs b/crates/tranquil-pds/src/api/admin/config.rs index 4e23f64..a281b3f 100644 --- a/crates/tranquil-pds/src/api/admin/config.rs +++ b/crates/tranquil-pds/src/api/admin/config.rs @@ -48,32 +48,19 @@ pub async fn get_server_config( .fetch_all(&state.db) .await?; - let mut server_name = "Tranquil PDS".to_string(); - let mut primary_color = None; - let mut primary_color_dark = None; - let mut secondary_color = None; - let mut secondary_color_dark = None; - let mut logo_cid = None; - - for (key, value) in rows { - match key.as_str() { - "server_name" => server_name = value, - "primary_color" => primary_color = Some(value), - "primary_color_dark" => primary_color_dark = Some(value), - "secondary_color" => secondary_color = Some(value), - "secondary_color_dark" => secondary_color_dark = Some(value), - "logo_cid" => logo_cid = Some(value), - _ => {} - } - } + let config_map: std::collections::HashMap = + rows.into_iter().collect(); Ok(Json(ServerConfigResponse { - server_name, - primary_color, - primary_color_dark, - secondary_color, - secondary_color_dark, - logo_cid, + server_name: config_map + .get("server_name") + .cloned() + .unwrap_or_else(|| "Tranquil PDS".to_string()), + primary_color: config_map.get("primary_color").cloned(), + primary_color_dark: config_map.get("primary_color_dark").cloned(), + secondary_color: config_map.get("secondary_color").cloned(), + secondary_color_dark: config_map.get("secondary_color_dark").cloned(), + logo_cid: config_map.get("logo_cid").cloned(), })) } diff --git a/crates/tranquil-pds/src/api/admin/invite.rs b/crates/tranquil-pds/src/api/admin/invite.rs index 7c5c710..a87b1d1 100644 --- a/crates/tranquil-pds/src/api/admin/invite.rs +++ b/crates/tranquil-pds/src/api/admin/invite.rs @@ -24,29 +24,20 @@ pub async fn disable_invite_codes( Json(input): Json, ) -> Response { if let Some(codes) = &input.codes { - for code in codes { - let _ = sqlx::query!( - "UPDATE invite_codes SET disabled = TRUE WHERE code = $1", - code - ) - .execute(&state.db) - .await; - } + let _ = sqlx::query!( + "UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)", + codes as &[String] + ) + .execute(&state.db) + .await; } if let Some(accounts) = &input.accounts { - for account in accounts { - let user = sqlx::query!("SELECT id FROM users WHERE did = $1", account) - .fetch_optional(&state.db) - .await; - if let Ok(Some(user_row)) = user { - let _ = sqlx::query!( - "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user = $1", - user_row.id - ) - .execute(&state.db) - .await; - } - } + let _ = sqlx::query!( + "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))", + accounts as &[String] + ) + .execute(&state.db) + .await; } EmptyResponse::ok().into_response() } @@ -70,7 +61,7 @@ pub struct InviteCodeInfo { pub uses: Vec, } -#[derive(Serialize)] +#[derive(Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct InviteCodeUseInfo { pub used_by: String, @@ -149,47 +140,71 @@ pub async fn get_invite_codes( return ApiError::InternalError(None).into_response(); } }; - let mut codes = Vec::new(); - for (code, available_uses, disabled, created_by_user, created_at) in &codes_rows { - let creator_did = - sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user) - .fetch_optional(&state.db) - .await - .ok() - .flatten() - .unwrap_or_else(|| "unknown".to_string()); - let uses_result = sqlx::query!( + + let user_ids: Vec = codes_rows.iter().map(|(_, _, _, uid, _)| *uid).collect(); + let code_strings: Vec = codes_rows.iter().map(|(c, _, _, _, _)| c.clone()).collect(); + + let mut creator_dids: std::collections::HashMap = + std::collections::HashMap::new(); + sqlx::query!( + "SELECT id, did FROM users WHERE id = ANY($1)", + &user_ids + ) + .fetch_all(&state.db) + .await + .unwrap_or_default() + .into_iter() + .for_each(|r| { + creator_dids.insert(r.id, r.did); + }); + + let mut uses_by_code: std::collections::HashMap> = + std::collections::HashMap::new(); + if !code_strings.is_empty() { + sqlx::query!( r#" - SELECT u.did, icu.used_at + SELECT icu.code, u.did, icu.used_at FROM invite_code_uses icu JOIN users u ON icu.used_by_user = u.id - WHERE icu.code = $1 + WHERE icu.code = ANY($1) ORDER BY icu.used_at DESC "#, - code + &code_strings ) .fetch_all(&state.db) - .await; - let uses = match uses_result { - Ok(use_rows) => use_rows - .iter() - .map(|u| InviteCodeUseInfo { - used_by: u.did.clone(), - used_at: u.used_at.to_rfc3339(), - }) - .collect(), - Err(_) => Vec::new(), - }; - codes.push(InviteCodeInfo { - code: code.clone(), - available: *available_uses, - disabled: disabled.unwrap_or(false), - for_account: creator_did.clone(), - created_by: creator_did, - created_at: created_at.to_rfc3339(), - uses, + .await + .unwrap_or_default() + .into_iter() + .for_each(|r| { + uses_by_code + .entry(r.code) + .or_default() + .push(InviteCodeUseInfo { + used_by: r.did, + used_at: r.used_at.to_rfc3339(), + }); }); } + + let codes: Vec = codes_rows + .iter() + .map(|(code, available_uses, disabled, created_by_user, created_at)| { + let creator_did = creator_dids + .get(created_by_user) + .cloned() + .unwrap_or_else(|| "unknown".to_string()); + InviteCodeInfo { + code: code.clone(), + available: *available_uses, + disabled: disabled.unwrap_or(false), + for_account: creator_did.clone(), + created_by: creator_did, + created_at: created_at.to_rfc3339(), + uses: uses_by_code.get(code).cloned().unwrap_or_default(), + } + }) + .collect(); + let next_cursor = if codes_rows.len() == limit as usize { codes_rows.last().map(|(code, _, _, _, _)| code.clone()) } else { diff --git a/crates/tranquil-pds/src/api/error.rs b/crates/tranquil-pds/src/api/error.rs index 008bd49..add3c24 100644 --- a/crates/tranquil-pds/src/api/error.rs +++ b/crates/tranquil-pds/src/api/error.rs @@ -22,6 +22,7 @@ pub enum ApiError { InvalidRequest(String), InvalidToken(Option), ExpiredToken(Option), + OAuthExpiredToken(Option), TokenRequired, AccountDeactivated, AccountTakedown, @@ -127,7 +128,8 @@ impl ApiError { | Self::InvalidCode(_) | Self::InvalidPassword(_) | Self::InvalidToken(_) - | Self::PasskeyCounterAnomaly => StatusCode::UNAUTHORIZED, + | Self::PasskeyCounterAnomaly + | Self::OAuthExpiredToken(_) => StatusCode::UNAUTHORIZED, Self::ExpiredToken(_) => StatusCode::BAD_REQUEST, Self::Forbidden | Self::AdminRequired @@ -216,7 +218,7 @@ impl ApiError { Self::AuthenticationRequired => Cow::Borrowed("AuthenticationRequired"), Self::AuthenticationFailed(_) => Cow::Borrowed("AuthenticationFailed"), Self::InvalidToken(_) => Cow::Borrowed("InvalidToken"), - Self::ExpiredToken(_) => Cow::Borrowed("ExpiredToken"), + Self::ExpiredToken(_) | Self::OAuthExpiredToken(_) => Cow::Borrowed("ExpiredToken"), Self::TokenRequired => Cow::Borrowed("TokenRequired"), Self::AccountDeactivated => Cow::Borrowed("AccountDeactivated"), Self::AccountTakedown => Cow::Borrowed("AccountTakedown"), @@ -298,6 +300,7 @@ impl ApiError { | Self::AuthenticationFailed(msg) | Self::InvalidToken(msg) | Self::ExpiredToken(msg) + | Self::OAuthExpiredToken(msg) | Self::RepoNotFound(msg) | Self::BlobNotFound(msg) | Self::InvalidHandle(msg) @@ -428,13 +431,24 @@ impl IntoResponse for ApiError { message: self.message(), }; let mut response = (self.status_code(), Json(body)).into_response(); - if matches!(self, Self::ExpiredToken(_)) { - response.headers_mut().insert( - "WWW-Authenticate", - "Bearer error=\"invalid_token\", error_description=\"Token has expired\"" - .parse() - .unwrap(), - ); + match &self { + Self::ExpiredToken(_) => { + response.headers_mut().insert( + "WWW-Authenticate", + "Bearer error=\"invalid_token\", error_description=\"Token has expired\"" + .parse() + .unwrap(), + ); + } + Self::OAuthExpiredToken(_) => { + response.headers_mut().insert( + "WWW-Authenticate", + "DPoP error=\"invalid_token\", error_description=\"Token has expired\"" + .parse() + .unwrap(), + ); + } + _ => {} } response } @@ -457,6 +471,9 @@ impl From for ApiError { Self::AuthenticationFailed(None) } crate::auth::TokenValidationError::TokenExpired => Self::ExpiredToken(None), + crate::auth::TokenValidationError::OAuthTokenExpired => { + Self::OAuthExpiredToken(Some("Token has expired".to_string())) + } } } } diff --git a/crates/tranquil-pds/src/api/moderation/mod.rs b/crates/tranquil-pds/src/api/moderation/mod.rs index 70396e3..0d3dce9 100644 --- a/crates/tranquil-pds/src/api/moderation/mod.rs +++ b/crates/tranquil-pds/src/api/moderation/mod.rs @@ -211,7 +211,7 @@ async fn create_report_locally( } let created_at = chrono::Utc::now(); - let report_id = created_at.timestamp_millis(); + let report_id = (uuid::Uuid::now_v7().as_u128() & 0x7FFF_FFFF_FFFF_FFFF) as i64; let subject_json = json!(input.subject); let insert = sqlx::query!( diff --git a/crates/tranquil-pds/src/api/proxy.rs b/crates/tranquil-pds/src/api/proxy.rs index 78a978c..cb91b90 100644 --- a/crates/tranquil-pds/src/api/proxy.rs +++ b/crates/tranquil-pds/src/api/proxy.rs @@ -268,21 +268,8 @@ async fn proxy_handler( } Err(e) => { warn!("Token validation failed: {:?}", e); - if matches!(e, crate::auth::TokenValidationError::TokenExpired) && extracted.is_dpop - { - let www_auth = - "DPoP error=\"invalid_token\", error_description=\"Token has expired\""; - let mut response = - ApiError::ExpiredToken(Some("Token has expired".into())).into_response(); - *response.status_mut() = axum::http::StatusCode::UNAUTHORIZED; - response - .headers_mut() - .insert("WWW-Authenticate", www_auth.parse().unwrap()); - let nonce = crate::oauth::verify::generate_dpop_nonce(); - response - .headers_mut() - .insert("DPoP-Nonce", nonce.parse().unwrap()); - return response; + if matches!(e, crate::auth::TokenValidationError::OAuthTokenExpired) { + return ApiError::from(e).into_response(); } } } @@ -291,11 +278,12 @@ async fn proxy_handler( if let Some(val) = auth_header_val { request_builder = request_builder.header("Authorization", val); } - for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD { - if let Some(val) = headers.get(*header_name) { - request_builder = request_builder.header(*header_name, val); - } - } + request_builder = crate::api::proxy_client::HEADERS_TO_FORWARD + .iter() + .filter_map(|name| headers.get(*name).map(|val| (*name, val))) + .fold(request_builder, |builder, (name, val)| { + builder.header(name, val) + }); if !body.is_empty() { request_builder = request_builder.body(body); } @@ -313,11 +301,12 @@ async fn proxy_handler( } }; let mut response_builder = Response::builder().status(status); - for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD { - if let Some(val) = headers.get(*header_name) { - response_builder = response_builder.header(*header_name, val); - } - } + response_builder = crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD + .iter() + .filter_map(|name| headers.get(*name).map(|val| (*name, val))) + .fold(response_builder, |builder, (name, val)| { + builder.header(name, val) + }); match response_builder.body(axum::body::Body::from(body)) { Ok(r) => r, Err(e) => { diff --git a/crates/tranquil-pds/src/api/proxy_client.rs b/crates/tranquil-pds/src/api/proxy_client.rs index a364163..fa8d7b0 100644 --- a/crates/tranquil-pds/src/api/proxy_client.rs +++ b/crates/tranquil-pds/src/api/proxy_client.rs @@ -88,15 +88,13 @@ pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { Ok(addrs) => addrs.collect(), Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())), }; - for addr in &socket_addrs { - if !is_unicast_ip(&addr.ip()) { - warn!( - "DNS resolution for {} returned non-unicast IP: {}", - host, - addr.ip() - ); - return Err(SsrfError::NonUnicastIp(addr.ip().to_string())); - } + if let Some(addr) = socket_addrs.iter().find(|addr| !is_unicast_ip(&addr.ip())) { + warn!( + "DNS resolution for {} returned non-unicast IP: {}", + host, + addr.ip() + ); + return Err(SsrfError::NonUnicastIp(addr.ip().to_string())); } Ok(()) } diff --git a/crates/tranquil-pds/src/api/repo/record/write.rs b/crates/tranquil-pds/src/api/repo/record/write.rs index b5afb53..5d4ca17 100644 --- a/crates/tranquil-pds/src/api/repo/record/write.rs +++ b/crates/tranquil-pds/src/api/repo/record/write.rs @@ -82,19 +82,7 @@ pub async fn prepare_repo_write( .await .map_err(|e| { tracing::warn!(error = ?e, is_dpop = extracted.is_dpop, "Token validation failed in prepare_repo_write"); - let mut response = ApiError::from(e).into_response(); - if matches!(e, crate::auth::TokenValidationError::TokenExpired) && extracted.is_dpop { - *response.status_mut() = axum::http::StatusCode::UNAUTHORIZED; - let www_auth = - "DPoP error=\"invalid_token\", error_description=\"Token has expired\""; - response.headers_mut().insert( - "WWW-Authenticate", - www_auth.parse().unwrap(), - ); - let nonce = crate::oauth::verify::generate_dpop_nonce(); - response.headers_mut().insert("DPoP-Nonce", nonce.parse().unwrap()); - } - response + ApiError::from(e).into_response() })?; if repo.as_str() != auth_user.did.as_str() { return Err( diff --git a/crates/tranquil-pds/src/api/validation.rs b/crates/tranquil-pds/src/api/validation.rs index a1c2170..d1b121c 100644 --- a/crates/tranquil-pds/src/api/validation.rs +++ b/crates/tranquil-pds/src/api/validation.rs @@ -181,10 +181,11 @@ fn validate_email_detailed(email: &str) -> Result<(), EmailValidationError> { if local.starts_with('.') || local.ends_with('.') || local.contains("..") { return Err(EmailValidationError::InvalidLocalPart); } - for c in local.chars() { - if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) { - return Err(EmailValidationError::InvalidLocalPart); - } + if !local + .chars() + .all(|c| c.is_ascii_alphanumeric() || EMAIL_LOCAL_SPECIAL_CHARS.contains(c)) + { + return Err(EmailValidationError::InvalidLocalPart); } if domain.is_empty() { return Err(EmailValidationError::EmptyDomain); @@ -195,18 +196,14 @@ fn validate_email_detailed(email: &str) -> Result<(), EmailValidationError> { if !domain.contains('.') { return Err(EmailValidationError::MissingDomainDot); } - for label in domain.split('.') { - if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH { - return Err(EmailValidationError::InvalidDomainLabel); - } - if label.starts_with('-') || label.ends_with('-') { - return Err(EmailValidationError::InvalidDomainLabel); - } - for c in label.chars() { - if !c.is_ascii_alphanumeric() && c != '-' { - return Err(EmailValidationError::InvalidDomainLabel); - } - } + if !domain.split('.').all(|label| { + !label.is_empty() + && label.len() <= MAX_DOMAIN_LABEL_LENGTH + && !label.starts_with('-') + && !label.ends_with('-') + && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') + }) { + return Err(EmailValidationError::InvalidDomainLabel); } Ok(()) } @@ -293,10 +290,11 @@ pub fn validate_service_handle( return Err(HandleValidationError::EndsWithInvalidChar); } - for c in handle.chars() { - if !c.is_ascii_alphanumeric() && c != '-' { - return Err(HandleValidationError::InvalidCharacters); - } + if !handle + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-') + { + return Err(HandleValidationError::InvalidCharacters); } if crate::moderation::has_explicit_slur(handle) { @@ -330,10 +328,11 @@ pub fn is_valid_email(email: &str) -> bool { if local.contains("..") { return false; } - for c in local.chars() { - if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) { - return false; - } + if !local + .chars() + .all(|c| c.is_ascii_alphanumeric() || EMAIL_LOCAL_SPECIAL_CHARS.contains(c)) + { + return false; } if domain.is_empty() || domain.len() > MAX_DOMAIN_LENGTH { return false; @@ -341,20 +340,13 @@ pub fn is_valid_email(email: &str) -> bool { if !domain.contains('.') { return false; } - for label in domain.split('.') { - if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH { - return false; - } - if label.starts_with('-') || label.ends_with('-') { - return false; - } - for c in label.chars() { - if !c.is_ascii_alphanumeric() && c != '-' { - return false; - } - } - } - true + domain.split('.').all(|label| { + !label.is_empty() + && label.len() <= MAX_DOMAIN_LABEL_LENGTH + && !label.starts_with('-') + && !label.ends_with('-') + && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') + }) } #[cfg(test)] diff --git a/crates/tranquil-pds/src/auth/mod.rs b/crates/tranquil-pds/src/auth/mod.rs index e6637be..baa2b33 100644 --- a/crates/tranquil-pds/src/auth/mod.rs +++ b/crates/tranquil-pds/src/auth/mod.rs @@ -61,6 +61,7 @@ pub enum TokenValidationError { KeyDecryptionFailed, AuthenticationFailed, TokenExpired, + OAuthTokenExpired, } impl fmt::Display for TokenValidationError { @@ -70,7 +71,7 @@ impl fmt::Display for TokenValidationError { Self::AccountTakedown => write!(f, "AccountTakedown"), Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), - Self::TokenExpired => write!(f, "ExpiredToken"), + Self::TokenExpired | Self::OAuthTokenExpired => write!(f, "ExpiredToken"), } } } @@ -497,7 +498,9 @@ pub async fn validate_token_with_dpop( controller_did: None, }) } - Err(crate::oauth::OAuthError::ExpiredToken(_)) => Err(TokenValidationError::TokenExpired), + Err(crate::oauth::OAuthError::ExpiredToken(_)) => { + Err(TokenValidationError::OAuthTokenExpired) + } Err(_) => Err(TokenValidationError::AuthenticationFailed), } } diff --git a/crates/tranquil-pds/src/util.rs b/crates/tranquil-pds/src/util.rs index e2cc3ac..3ea4828 100644 --- a/crates/tranquil-pds/src/util.rs +++ b/crates/tranquil-pds/src/util.rs @@ -106,26 +106,29 @@ pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result, key: &str) -> Vec { query .map(|q| { - let mut values = Vec::new(); - for pair in q.split('&') { - if let Some((k, v)) = pair.split_once('=') - && k == key - && let Ok(decoded) = urlencoding::decode(v) - { - let decoded = decoded.into_owned(); + q.split('&') + .filter_map(|pair| { + pair.split_once('=') + .filter(|(k, _)| *k == key) + .and_then(|(_, v)| urlencoding::decode(v).ok()) + .map(|decoded| decoded.into_owned()) + }) + .flat_map(|decoded| { if decoded.contains(',') { - for part in decoded.split(',') { - let trimmed = part.trim(); - if !trimmed.is_empty() { - values.push(trimmed.to_string()); - } - } - } else if !decoded.is_empty() { - values.push(decoded); + decoded + .split(',') + .filter_map(|part| { + let trimmed = part.trim(); + (!trimmed.is_empty()).then(|| trimmed.to_string()) + }) + .collect::>() + } else if decoded.is_empty() { + vec![] + } else { + vec![decoded] } - } - } - values + }) + .collect() }) .unwrap_or_default() } diff --git a/crates/tranquil-pds/tests/common/mod.rs b/crates/tranquil-pds/tests/common/mod.rs index c5b18c6..bb16086 100644 --- a/crates/tranquil-pds/tests/common/mod.rs +++ b/crates/tranquil-pds/tests/common/mod.rs @@ -437,7 +437,7 @@ async fn setup_mock_plc_directory() -> String { async fn spawn_app(database_url: String) -> String { use tranquil_pds::rate_limit::RateLimiters; let pool = PgPoolOptions::new() - .max_connections(3) + .max_connections(10) .acquire_timeout(std::time::Duration::from_secs(30)) .connect(&database_url) .await diff --git a/crates/tranquil-pds/tests/helpers/mod.rs b/crates/tranquil-pds/tests/helpers/mod.rs index a487ea6..8ae6cfc 100644 --- a/crates/tranquil-pds/tests/helpers/mod.rs +++ b/crates/tranquil-pds/tests/helpers/mod.rs @@ -4,12 +4,16 @@ use serde_json::{Value, json}; pub use crate::common::*; +fn unique_id() -> String { + uuid::Uuid::new_v4().simple().to_string()[..12].to_string() +} + #[allow(dead_code)] pub async fn setup_new_user(handle_prefix: &str) -> (String, String) { let client = client(); - let ts = Utc::now().timestamp_millis(); - let handle = format!("{}-{}.test", handle_prefix, ts); - let email = format!("{}-{}@test.com", handle_prefix, ts); + let uid = unique_id(); + let handle = format!("{}-{}.test", handle_prefix, uid); + let email = format!("{}-{}@test.com", handle_prefix, uid); let password = "E2epass123!"; let create_account_payload = json!({ "handle": handle, @@ -51,7 +55,7 @@ pub async fn create_post( text: &str, ) -> (String, String) { let collection = "app.bsky.feed.post"; - let rkey = format!("e2e_social_{}", Utc::now().timestamp_millis()); + let rkey = format!("e2e_social_{}", unique_id()); let now = Utc::now().to_rfc3339(); let create_payload = json!({ "repo": did, @@ -95,7 +99,7 @@ pub async fn create_follow( followee_did: &str, ) -> (String, String) { let collection = "app.bsky.graph.follow"; - let rkey = format!("e2e_follow_{}", Utc::now().timestamp_millis()); + let rkey = format!("e2e_follow_{}", unique_id()); let now = Utc::now().to_rfc3339(); let create_payload = json!({ "repo": follower_did, @@ -140,7 +144,7 @@ pub async fn create_like( subject_cid: &str, ) -> (String, String) { let collection = "app.bsky.feed.like"; - let rkey = format!("e2e_like_{}", Utc::now().timestamp_millis()); + let rkey = format!("e2e_like_{}", unique_id()); let now = Utc::now().to_rfc3339(); let payload = json!({ "repo": liker_did, @@ -182,7 +186,7 @@ pub async fn create_repost( subject_cid: &str, ) -> (String, String) { let collection = "app.bsky.feed.repost"; - let rkey = format!("e2e_repost_{}", Utc::now().timestamp_millis()); + let rkey = format!("e2e_repost_{}", unique_id()); let now = Utc::now().to_rfc3339(); let payload = json!({ "repo": reposter_did, diff --git a/crates/tranquil-scopes/src/parser.rs b/crates/tranquil-scopes/src/parser.rs index a14f99a..c392f6e 100644 --- a/crates/tranquil-scopes/src/parser.rs +++ b/crates/tranquil-scopes/src/parser.rs @@ -55,18 +55,14 @@ impl BlobScope { if self.accept.is_empty() || self.accept.contains("*/*") { return true; } - for pattern in &self.accept { - if pattern == mime { - return true; - } - if let Some(prefix) = pattern.strip_suffix("/*") - && mime.starts_with(prefix) - && mime.chars().nth(prefix.len()) == Some('/') - { - return true; - } - } - false + self.accept.iter().any(|pattern| { + pattern == mime + || pattern + .strip_suffix("/*") + .is_some_and(|prefix| { + mime.starts_with(prefix) && mime.chars().nth(prefix.len()) == Some('/') + }) + }) } } @@ -170,19 +166,20 @@ pub fn parse_scope(scope: &str) -> ParsedScope { Some(rest.to_string()) }; - let mut actions = HashSet::new(); - if let Some(action_values) = params.get("action") { - for action_str in action_values { - if let Some(action) = RepoAction::parse_str(action_str) { - actions.insert(action); - } - } - } - if actions.is_empty() { - actions.insert(RepoAction::Create); - actions.insert(RepoAction::Update); - actions.insert(RepoAction::Delete); - } + let actions: HashSet = params + .get("action") + .map(|action_values| { + action_values + .iter() + .filter_map(|s| RepoAction::parse_str(s)) + .collect() + }) + .filter(|set: &HashSet| !set.is_empty()) + .unwrap_or_else(|| { + [RepoAction::Create, RepoAction::Update, RepoAction::Delete] + .into_iter() + .collect() + }); return ParsedScope::Repo(RepoScope { collection, @@ -191,19 +188,20 @@ pub fn parse_scope(scope: &str) -> ParsedScope { } if base == "repo" { - let mut actions = HashSet::new(); - if let Some(action_values) = params.get("action") { - for action_str in action_values { - if let Some(action) = RepoAction::parse_str(action_str) { - actions.insert(action); - } - } - } - if actions.is_empty() { - actions.insert(RepoAction::Create); - actions.insert(RepoAction::Update); - actions.insert(RepoAction::Delete); - } + let actions: HashSet = params + .get("action") + .map(|action_values| { + action_values + .iter() + .filter_map(|s| RepoAction::parse_str(s)) + .collect() + }) + .filter(|set: &HashSet| !set.is_empty()) + .unwrap_or_else(|| { + [RepoAction::Create, RepoAction::Update, RepoAction::Delete] + .into_iter() + .collect() + }); return ParsedScope::Repo(RepoScope { collection: None, actions, @@ -212,16 +210,17 @@ pub fn parse_scope(scope: &str) -> ParsedScope { if base.starts_with("blob") { let positional = base.strip_prefix("blob:").unwrap_or(""); - let mut accept = HashSet::new(); - - if !positional.is_empty() { - accept.insert(positional.to_string()); - } - if let Some(accept_values) = params.get("accept") { - for v in accept_values { - accept.insert(v.to_string()); - } - } + let accept: HashSet = std::iter::once(positional) + .filter(|s| !s.is_empty()) + .map(String::from) + .chain( + params + .get("accept") + .into_iter() + .flatten() + .map(String::clone), + ) + .collect(); return ParsedScope::Blob(BlobScope { accept }); } diff --git a/crates/tranquil-scopes/src/permissions.rs b/crates/tranquil-scopes/src/permissions.rs index dac66fb..cbd5054 100644 --- a/crates/tranquil-scopes/src/permissions.rs +++ b/crates/tranquil-scopes/src/permissions.rs @@ -113,34 +113,32 @@ impl ScopePermissions { return Ok(()); } - for repo_scope in self.find_repo_scopes() { - if !repo_scope.actions.contains(&action) { - continue; - } - - match &repo_scope.collection { - None => return Ok(()), - Some(coll) if coll == collection => return Ok(()), - Some(coll) if coll.ends_with(".*") => { - let prefix = coll.strip_suffix(".*").unwrap(); - if collection.starts_with(prefix) - && collection.chars().nth(prefix.len()) == Some('.') - { - return Ok(()); + let has_permission = self.find_repo_scopes().any(|repo_scope| { + repo_scope.actions.contains(&action) + && match &repo_scope.collection { + None => true, + Some(coll) if coll == collection => true, + Some(coll) if coll.ends_with(".*") => { + let prefix = coll.strip_suffix(".*").unwrap(); + collection.starts_with(prefix) + && collection.chars().nth(prefix.len()) == Some('.') } + _ => false, } - _ => {} - } - } + }); - Err(ScopeError::InsufficientScope { - required: format!("repo:{}?action={}", collection, action_str(action)), - message: format!( - "Insufficient scope to {} records in {}", - action_str(action), - collection - ), - }) + if has_permission { + Ok(()) + } else { + Err(ScopeError::InsufficientScope { + required: format!("repo:{}?action={}", collection, action_str(action)), + message: format!( + "Insufficient scope to {} records in {}", + action_str(action), + collection + ), + }) + } } pub fn assert_blob(&self, mime: &str) -> Result<(), ScopeError> { @@ -148,16 +146,14 @@ impl ScopePermissions { return Ok(()); } - for blob_scope in self.find_blob_scopes() { - if blob_scope.matches_mime(mime) { - return Ok(()); - } + if self.find_blob_scopes().any(|blob_scope| blob_scope.matches_mime(mime)) { + Ok(()) + } else { + Err(ScopeError::InsufficientScope { + required: format!("blob:{}", mime), + message: format!("Insufficient scope to upload blob with mime type {}", mime), + }) } - - Err(ScopeError::InsufficientScope { - required: format!("blob:{}", mime), - message: format!("Insufficient scope to upload blob with mime type {}", mime), - }) } pub fn assert_rpc(&self, aud: &str, lxm: &str) -> Result<(), ScopeError> { @@ -169,7 +165,7 @@ impl ScopePermissions { return Ok(()); } - for rpc_scope in self.find_rpc_scopes() { + let has_permission = self.find_rpc_scopes().any(|rpc_scope| { let lxm_matches = match &rpc_scope.lxm { None => true, Some(scope_lxm) if scope_lxm == lxm => true, @@ -186,15 +182,17 @@ impl ScopePermissions { Some(scope_aud) => scope_aud == aud, }; - if lxm_matches && aud_matches { - return Ok(()); - } - } + lxm_matches && aud_matches + }); - Err(ScopeError::InsufficientScope { - required: format!("rpc:{}?aud={}", lxm, aud), - message: format!("Insufficient scope to call {} on {}", lxm, aud), - }) + if has_permission { + Ok(()) + } else { + Err(ScopeError::InsufficientScope { + required: format!("rpc:{}?aud={}", lxm, aud), + message: format!("Insufficient scope to call {} on {}", lxm, aud), + }) + } } pub fn assert_account( @@ -211,27 +209,28 @@ impl ScopePermissions { return Ok(()); } - for account_scope in self.find_account_scopes() { - if account_scope.attr == attr && account_scope.action == action { - return Ok(()); - } - if account_scope.attr == attr && account_scope.action == AccountAction::Manage { - return Ok(()); - } - } + let has_permission = self.find_account_scopes().any(|account_scope| { + account_scope.attr == attr + && (account_scope.action == action + || account_scope.action == AccountAction::Manage) + }); - Err(ScopeError::InsufficientScope { - required: format!( - "account:{}?action={}", - attr_str(attr), - action_str_account(action) - ), - message: format!( - "Insufficient scope to {} account {}", - action_str_account(action), - attr_str(attr) - ), - }) + if has_permission { + Ok(()) + } else { + Err(ScopeError::InsufficientScope { + required: format!( + "account:{}?action={}", + attr_str(attr), + action_str_account(action) + ), + message: format!( + "Insufficient scope to {} account {}", + action_str_account(action), + attr_str(attr) + ), + }) + } } pub fn allows_email_read(&self) -> bool { @@ -264,22 +263,23 @@ impl ScopePermissions { return Ok(()); } - for identity_scope in self.find_identity_scopes() { - if identity_scope.attr == IdentityAttr::Wildcard { - return Ok(()); - } - if identity_scope.attr == attr { - return Ok(()); - } - } + let has_permission = self + .find_identity_scopes() + .any(|identity_scope| { + identity_scope.attr == IdentityAttr::Wildcard || identity_scope.attr == attr + }); - Err(ScopeError::InsufficientScope { - required: format!("identity:{}", identity_attr_str(attr)), - message: format!( - "Insufficient scope to modify identity {}", - identity_attr_str(attr) - ), - }) + if has_permission { + Ok(()) + } else { + Err(ScopeError::InsufficientScope { + required: format!("identity:{}", identity_attr_str(attr)), + message: format!( + "Insufficient scope to modify identity {}", + identity_attr_str(attr) + ), + }) + } } pub fn allows_identity(&self, attr: IdentityAttr) -> bool {