mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-08 21:30:08 +00:00
oauth error msg improvement, general code quality
This commit is contained in:
52
.sqlx/query-2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a.json
generated
Normal file
52
.sqlx/query-2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a.json
generated
Normal file
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by\n FROM invite_codes ic\n JOIN users u ON ic.created_by_user = u.id\n WHERE ic.created_by_user = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "available_uses",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "disabled",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "for_account",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "created_by",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "2c8868a59ae63dc65d996cf21fc1bec0c2c86d5d5f17d1454440c6fcd8d4d27a"
|
||||
}
|
||||
28
.sqlx/query-46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e.json
generated
Normal file
28
.sqlx/query-46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e.json
generated
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT id, did FROM users WHERE id = ANY($1)",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"UuidArray"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "46ea5ceff2a8f3f2b72c9c6a1bb69ce28efe8594fda026b6f9b298ef0597b40e"
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT code FROM invite_codes WHERE created_by_user = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "5a98e015997942835800fcd326e69b4f54b9830d0490c4f8841f8435478c57d3"
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT u.did, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = $1\n ORDER BY icu.used_at DESC\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "used_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "5d5442136932d4088873a935c41cb3a683c4771e4fb8c151b3fd5119fb6c1068"
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "UPDATE invite_codes SET disabled = TRUE WHERE code = $1",
|
||||
"query": "UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
"TextArray"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "7b2d1d4ac06063e07a7c7a7d0fb434db08ce312eb2864405d7f96f4e985ed036"
|
||||
"hash": "888f8724cfc2ed27391b661a5cfe423d28c77e1a368df7edc81708eb3038f600"
|
||||
}
|
||||
34
.sqlx/query-ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076.json
generated
Normal file
34
.sqlx/query-ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076.json
generated
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT icu.code, u.did as used_by, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = ANY($1)\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "used_by",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "used_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"TextArray"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "ae6695ae53fc5e5293f17ddf8cc4532d549d1ad8a9835da4a5c001eee89db076"
|
||||
}
|
||||
34
.sqlx/query-ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920.json
generated
Normal file
34
.sqlx/query-ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920.json
generated
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT icu.code, u.did, icu.used_at\n FROM invite_code_uses icu\n JOIN users u ON icu.used_by_user = u.id\n WHERE icu.code = ANY($1)\n ORDER BY icu.used_at DESC\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "used_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"TextArray"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "ed1ccbaaed6e3f34982dc118ddd9bde7269879c0547ad43f30b78bfeeef5a920"
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user = $1",
|
||||
"query": "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
"TextArray"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "413c5b03501a399dca13f345fcae05770517091d73db93966853e944c68ee237"
|
||||
"hash": "eec42a3a4b1440aa8dd580b5d0bbd1184b909f781d131aa2c69368ed021e87e4"
|
||||
}
|
||||
@@ -92,7 +92,7 @@ tower-layer = "0.3"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
urlencoding = "2.1"
|
||||
uuid = { version = "1.19", features = ["v4", "v5", "fast-rng"] }
|
||||
uuid = { version = "1.19", features = ["v4", "v5", "v7", "fast-rng"] }
|
||||
webauthn-rs = { version = "0.5", features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] }
|
||||
webauthn-rs-proto = "0.5"
|
||||
zip = { version = "7.0", default-features = false, features = ["deflate"] }
|
||||
|
||||
@@ -182,11 +182,10 @@ static STRINGS_FI: NotificationStrings = NotificationStrings {
|
||||
};
|
||||
|
||||
pub fn format_message(template: &str, vars: &[(&str, &str)]) -> String {
|
||||
let mut result = template.to_string();
|
||||
for (key, value) in vars {
|
||||
result = result.replace(&format!("{{{}}}", key), value);
|
||||
}
|
||||
result
|
||||
vars.iter()
|
||||
.fold(template.to_string(), |result, (key, value)| {
|
||||
result.replace(&format!("{{{}}}", key), value)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -568,12 +568,21 @@ fn verify_es256(
|
||||
.get("y")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?;
|
||||
let x_bytes = URL_SAFE_NO_PAD
|
||||
let x_decoded = URL_SAFE_NO_PAD
|
||||
.decode(x)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
|
||||
let y_bytes = URL_SAFE_NO_PAD
|
||||
let y_decoded = URL_SAFE_NO_PAD
|
||||
.decode(y)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
|
||||
if x_decoded.len() > 32 || y_decoded.len() > 32 {
|
||||
return Err(OAuthError::InvalidClient(
|
||||
"EC coordinate too long".to_string(),
|
||||
));
|
||||
}
|
||||
let mut x_bytes = [0u8; 32];
|
||||
let mut y_bytes = [0u8; 32];
|
||||
x_bytes[32 - x_decoded.len()..].copy_from_slice(&x_decoded);
|
||||
y_bytes[32 - y_decoded.len()..].copy_from_slice(&y_decoded);
|
||||
let mut point_bytes = vec![0x04];
|
||||
point_bytes.extend_from_slice(&x_bytes);
|
||||
point_bytes.extend_from_slice(&y_bytes);
|
||||
@@ -604,12 +613,21 @@ fn verify_es384(
|
||||
.get("y")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?;
|
||||
let x_bytes = URL_SAFE_NO_PAD
|
||||
let x_decoded = URL_SAFE_NO_PAD
|
||||
.decode(x)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
|
||||
let y_bytes = URL_SAFE_NO_PAD
|
||||
let y_decoded = URL_SAFE_NO_PAD
|
||||
.decode(y)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
|
||||
if x_decoded.len() > 48 || y_decoded.len() > 48 {
|
||||
return Err(OAuthError::InvalidClient(
|
||||
"EC coordinate too long".to_string(),
|
||||
));
|
||||
}
|
||||
let mut x_bytes = [0u8; 48];
|
||||
let mut y_bytes = [0u8; 48];
|
||||
x_bytes[48 - x_decoded.len()..].copy_from_slice(&x_decoded);
|
||||
y_bytes[48 - y_decoded.len()..].copy_from_slice(&y_decoded);
|
||||
let mut point_bytes = vec![0x04];
|
||||
point_bytes.extend_from_slice(&x_bytes);
|
||||
point_bytes.extend_from_slice(&y_bytes);
|
||||
|
||||
@@ -218,25 +218,30 @@ fn verify_es256(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), O
|
||||
crv
|
||||
)));
|
||||
}
|
||||
let x_bytes = URL_SAFE_NO_PAD
|
||||
let x_decoded = URL_SAFE_NO_PAD
|
||||
.decode(
|
||||
jwk.x
|
||||
.as_ref()
|
||||
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?,
|
||||
)
|
||||
.map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?;
|
||||
let y_bytes = URL_SAFE_NO_PAD
|
||||
let y_decoded = URL_SAFE_NO_PAD
|
||||
.decode(
|
||||
jwk.y
|
||||
.as_ref()
|
||||
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?,
|
||||
)
|
||||
.map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?;
|
||||
let point = EncodedPoint::from_affine_coordinates(
|
||||
x_bytes.as_slice().into(),
|
||||
y_bytes.as_slice().into(),
|
||||
false,
|
||||
);
|
||||
let mut x_bytes = [0u8; 32];
|
||||
let mut y_bytes = [0u8; 32];
|
||||
if x_decoded.len() > 32 || y_decoded.len() > 32 {
|
||||
return Err(OAuthError::InvalidDpopProof(
|
||||
"EC coordinate too long".to_string(),
|
||||
));
|
||||
}
|
||||
x_bytes[32 - x_decoded.len()..].copy_from_slice(&x_decoded);
|
||||
y_bytes[32 - y_decoded.len()..].copy_from_slice(&y_decoded);
|
||||
let point = EncodedPoint::from_affine_coordinates((&x_bytes).into(), (&y_bytes).into(), false);
|
||||
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
|
||||
let affine =
|
||||
affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
|
||||
@@ -264,25 +269,30 @@ fn verify_es384(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), O
|
||||
crv
|
||||
)));
|
||||
}
|
||||
let x_bytes = URL_SAFE_NO_PAD
|
||||
let x_decoded = URL_SAFE_NO_PAD
|
||||
.decode(
|
||||
jwk.x
|
||||
.as_ref()
|
||||
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?,
|
||||
)
|
||||
.map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?;
|
||||
let y_bytes = URL_SAFE_NO_PAD
|
||||
let y_decoded = URL_SAFE_NO_PAD
|
||||
.decode(
|
||||
jwk.y
|
||||
.as_ref()
|
||||
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?,
|
||||
)
|
||||
.map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?;
|
||||
let point = EncodedPoint::from_affine_coordinates(
|
||||
x_bytes.as_slice().into(),
|
||||
y_bytes.as_slice().into(),
|
||||
false,
|
||||
);
|
||||
let mut x_bytes = [0u8; 48];
|
||||
let mut y_bytes = [0u8; 48];
|
||||
if x_decoded.len() > 48 || y_decoded.len() > 48 {
|
||||
return Err(OAuthError::InvalidDpopProof(
|
||||
"EC coordinate too long".to_string(),
|
||||
));
|
||||
}
|
||||
x_bytes[48 - x_decoded.len()..].copy_from_slice(&x_decoded);
|
||||
y_bytes[48 - y_decoded.len()..].copy_from_slice(&y_decoded);
|
||||
let point = EncodedPoint::from_affine_coordinates((&x_bytes).into(), (&y_bytes).into(), false);
|
||||
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
|
||||
let affine =
|
||||
affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
|
||||
@@ -398,4 +408,52 @@ mod tests {
|
||||
let thumbprint = compute_jwk_thumbprint(&jwk).unwrap();
|
||||
assert!(!thumbprint.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_es256_short_coordinate_no_panic() {
|
||||
let short_31_bytes = vec![0x42u8; 31];
|
||||
let short_30_bytes = vec![0x42u8; 30];
|
||||
let x_b64 = URL_SAFE_NO_PAD.encode(&short_31_bytes);
|
||||
let y_b64 = URL_SAFE_NO_PAD.encode(&short_30_bytes);
|
||||
let jwk = DPoPJwk {
|
||||
kty: "EC".to_string(),
|
||||
crv: Some("P-256".to_string()),
|
||||
x: Some(x_b64),
|
||||
y: Some(y_b64),
|
||||
};
|
||||
let result = verify_es256(&jwk, b"test", &[0u8; 64]);
|
||||
assert!(result.is_err(), "Invalid coordinates should return error, not panic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_es256_valid_key_with_trimmed_coordinates() {
|
||||
use p256::ecdsa::{SigningKey, signature::Signer};
|
||||
use p256::elliptic_curve::rand_core::OsRng;
|
||||
|
||||
let signing_key = SigningKey::random(&mut OsRng);
|
||||
let verifying_key = signing_key.verifying_key();
|
||||
let point = verifying_key.to_encoded_point(false);
|
||||
let x_bytes = point.x().unwrap();
|
||||
let y_bytes = point.y().unwrap();
|
||||
let x_trimmed: Vec<u8> = x_bytes.iter().copied().skip_while(|&b| b == 0).collect();
|
||||
let y_trimmed: Vec<u8> = y_bytes.iter().copied().skip_while(|&b| b == 0).collect();
|
||||
let x_b64 = URL_SAFE_NO_PAD.encode(&x_trimmed);
|
||||
let y_b64 = URL_SAFE_NO_PAD.encode(&y_trimmed);
|
||||
let jwk = DPoPJwk {
|
||||
kty: "EC".to_string(),
|
||||
crv: Some("P-256".to_string()),
|
||||
x: Some(x_b64),
|
||||
y: Some(y_b64),
|
||||
};
|
||||
let message = b"test message for signature verification";
|
||||
let signature: p256::ecdsa::Signature = signing_key.sign(message);
|
||||
let result = verify_es256(&jwk, message, signature.to_bytes().as_slice());
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should verify signature with trimmed coordinates (x={}, y={}): {:?}",
|
||||
x_trimmed.len(),
|
||||
y_trimmed.len(),
|
||||
result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,24 +130,62 @@ async fn get_invites_for_user(
|
||||
db: &sqlx::PgPool,
|
||||
user_id: uuid::Uuid,
|
||||
) -> Option<Vec<InviteCodeInfo>> {
|
||||
let codes = sqlx::query_scalar!(
|
||||
let invite_codes = sqlx::query!(
|
||||
r#"
|
||||
SELECT code FROM invite_codes WHERE created_by_user = $1
|
||||
SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by
|
||||
FROM invite_codes ic
|
||||
JOIN users u ON ic.created_by_user = u.id
|
||||
WHERE ic.created_by_user = $1
|
||||
"#,
|
||||
user_id
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
.ok()?;
|
||||
if codes.is_empty() {
|
||||
|
||||
if invite_codes.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut invites = Vec::new();
|
||||
for code in codes {
|
||||
if let Some(info) = get_invite_code_info(db, &code).await {
|
||||
invites.push(info);
|
||||
}
|
||||
}
|
||||
|
||||
let code_strings: Vec<String> = invite_codes.iter().map(|ic| ic.code.clone()).collect();
|
||||
let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
|
||||
std::collections::HashMap::new();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT icu.code, u.did as used_by, icu.used_at
|
||||
FROM invite_code_uses icu
|
||||
JOIN users u ON icu.used_by_user = u.id
|
||||
WHERE icu.code = ANY($1)
|
||||
"#,
|
||||
&code_strings
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
.ok()?
|
||||
.into_iter()
|
||||
.for_each(|r| {
|
||||
uses_by_code
|
||||
.entry(r.code)
|
||||
.or_default()
|
||||
.push(InviteCodeUseInfo {
|
||||
used_by: r.used_by.into(),
|
||||
used_at: r.used_at.to_rfc3339(),
|
||||
});
|
||||
});
|
||||
|
||||
let invites: Vec<InviteCodeInfo> = invite_codes
|
||||
.into_iter()
|
||||
.map(|ic| InviteCodeInfo {
|
||||
code: ic.code.clone(),
|
||||
available: ic.available_uses,
|
||||
disabled: ic.disabled.unwrap_or(false),
|
||||
for_account: ic.for_account.into(),
|
||||
created_by: ic.created_by.into(),
|
||||
created_at: ic.created_at.to_rfc3339(),
|
||||
uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
if invites.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -276,61 +314,62 @@ pub async fn get_account_infos(
|
||||
.map(|r| (r.used_by_user, r.code))
|
||||
.collect();
|
||||
|
||||
let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
|
||||
std::collections::HashMap::new();
|
||||
for u in all_invite_uses {
|
||||
uses_by_code
|
||||
.entry(u.code.clone())
|
||||
.or_default()
|
||||
.push(InviteCodeUseInfo {
|
||||
used_by: u.used_by.into(),
|
||||
used_at: u.used_at.to_rfc3339(),
|
||||
let uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
|
||||
all_invite_uses
|
||||
.into_iter()
|
||||
.fold(std::collections::HashMap::new(), |mut acc, u| {
|
||||
acc.entry(u.code.clone()).or_default().push(InviteCodeUseInfo {
|
||||
used_by: u.used_by.into(),
|
||||
used_at: u.used_at.to_rfc3339(),
|
||||
});
|
||||
acc
|
||||
});
|
||||
}
|
||||
|
||||
let mut codes_by_user: std::collections::HashMap<uuid::Uuid, Vec<InviteCodeInfo>> =
|
||||
std::collections::HashMap::new();
|
||||
let mut code_info_map: std::collections::HashMap<String, InviteCodeInfo> =
|
||||
std::collections::HashMap::new();
|
||||
for ic in all_invite_codes {
|
||||
let info = InviteCodeInfo {
|
||||
code: ic.code.clone(),
|
||||
available: ic.available_uses,
|
||||
disabled: ic.disabled.unwrap_or(false),
|
||||
for_account: ic.for_account.into(),
|
||||
created_by: ic.created_by.into(),
|
||||
created_at: ic.created_at.to_rfc3339(),
|
||||
uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(),
|
||||
};
|
||||
code_info_map.insert(ic.code.clone(), info.clone());
|
||||
codes_by_user
|
||||
.entry(ic.created_by_user)
|
||||
.or_default()
|
||||
.push(info);
|
||||
}
|
||||
let (codes_by_user, code_info_map): (
|
||||
std::collections::HashMap<uuid::Uuid, Vec<InviteCodeInfo>>,
|
||||
std::collections::HashMap<String, InviteCodeInfo>,
|
||||
) = all_invite_codes.into_iter().fold(
|
||||
(std::collections::HashMap::new(), std::collections::HashMap::new()),
|
||||
|(mut by_user, mut by_code), ic| {
|
||||
let info = InviteCodeInfo {
|
||||
code: ic.code.clone(),
|
||||
available: ic.available_uses,
|
||||
disabled: ic.disabled.unwrap_or(false),
|
||||
for_account: ic.for_account.into(),
|
||||
created_by: ic.created_by.into(),
|
||||
created_at: ic.created_at.to_rfc3339(),
|
||||
uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(),
|
||||
};
|
||||
by_code.insert(ic.code.clone(), info.clone());
|
||||
by_user.entry(ic.created_by_user).or_default().push(info);
|
||||
(by_user, by_code)
|
||||
},
|
||||
);
|
||||
|
||||
let mut infos = Vec::with_capacity(users.len());
|
||||
for row in users {
|
||||
let invited_by = invited_by_map
|
||||
.get(&row.id)
|
||||
.and_then(|code| code_info_map.get(code).cloned());
|
||||
let invites = codes_by_user.get(&row.id).cloned();
|
||||
infos.push(AccountInfo {
|
||||
did: row.did.into(),
|
||||
handle: row.handle.into(),
|
||||
email: row.email,
|
||||
indexed_at: row.created_at.to_rfc3339(),
|
||||
invite_note: None,
|
||||
invites_disabled: row.invites_disabled.unwrap_or(false),
|
||||
email_confirmed_at: if row.email_verified {
|
||||
Some(row.created_at.to_rfc3339())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()),
|
||||
invited_by,
|
||||
invites,
|
||||
});
|
||||
}
|
||||
let infos: Vec<AccountInfo> = users
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let invited_by = invited_by_map
|
||||
.get(&row.id)
|
||||
.and_then(|code| code_info_map.get(code).cloned());
|
||||
let invites = codes_by_user.get(&row.id).cloned();
|
||||
AccountInfo {
|
||||
did: row.did.into(),
|
||||
handle: row.handle.into(),
|
||||
email: row.email,
|
||||
indexed_at: row.created_at.to_rfc3339(),
|
||||
invite_note: None,
|
||||
invites_disabled: row.invites_disabled.unwrap_or(false),
|
||||
email_confirmed_at: if row.email_verified {
|
||||
Some(row.created_at.to_rfc3339())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()),
|
||||
invited_by,
|
||||
invites,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
(StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()
|
||||
}
|
||||
|
||||
@@ -48,32 +48,19 @@ pub async fn get_server_config(
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let mut server_name = "Tranquil PDS".to_string();
|
||||
let mut primary_color = None;
|
||||
let mut primary_color_dark = None;
|
||||
let mut secondary_color = None;
|
||||
let mut secondary_color_dark = None;
|
||||
let mut logo_cid = None;
|
||||
|
||||
for (key, value) in rows {
|
||||
match key.as_str() {
|
||||
"server_name" => server_name = value,
|
||||
"primary_color" => primary_color = Some(value),
|
||||
"primary_color_dark" => primary_color_dark = Some(value),
|
||||
"secondary_color" => secondary_color = Some(value),
|
||||
"secondary_color_dark" => secondary_color_dark = Some(value),
|
||||
"logo_cid" => logo_cid = Some(value),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let config_map: std::collections::HashMap<String, String> =
|
||||
rows.into_iter().collect();
|
||||
|
||||
Ok(Json(ServerConfigResponse {
|
||||
server_name,
|
||||
primary_color,
|
||||
primary_color_dark,
|
||||
secondary_color,
|
||||
secondary_color_dark,
|
||||
logo_cid,
|
||||
server_name: config_map
|
||||
.get("server_name")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "Tranquil PDS".to_string()),
|
||||
primary_color: config_map.get("primary_color").cloned(),
|
||||
primary_color_dark: config_map.get("primary_color_dark").cloned(),
|
||||
secondary_color: config_map.get("secondary_color").cloned(),
|
||||
secondary_color_dark: config_map.get("secondary_color_dark").cloned(),
|
||||
logo_cid: config_map.get("logo_cid").cloned(),
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -24,29 +24,20 @@ pub async fn disable_invite_codes(
|
||||
Json(input): Json<DisableInviteCodesInput>,
|
||||
) -> Response {
|
||||
if let Some(codes) = &input.codes {
|
||||
for code in codes {
|
||||
let _ = sqlx::query!(
|
||||
"UPDATE invite_codes SET disabled = TRUE WHERE code = $1",
|
||||
code
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
let _ = sqlx::query!(
|
||||
"UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)",
|
||||
codes as &[String]
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
if let Some(accounts) = &input.accounts {
|
||||
for account in accounts {
|
||||
let user = sqlx::query!("SELECT id FROM users WHERE did = $1", account)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
if let Ok(Some(user_row)) = user {
|
||||
let _ = sqlx::query!(
|
||||
"UPDATE invite_codes SET disabled = TRUE WHERE created_by_user = $1",
|
||||
user_row.id
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
let _ = sqlx::query!(
|
||||
"UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))",
|
||||
accounts as &[String]
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
EmptyResponse::ok().into_response()
|
||||
}
|
||||
@@ -70,7 +61,7 @@ pub struct InviteCodeInfo {
|
||||
pub uses: Vec<InviteCodeUseInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InviteCodeUseInfo {
|
||||
pub used_by: String,
|
||||
@@ -149,47 +140,71 @@ pub async fn get_invite_codes(
|
||||
return ApiError::InternalError(None).into_response();
|
||||
}
|
||||
};
|
||||
let mut codes = Vec::new();
|
||||
for (code, available_uses, disabled, created_by_user, created_at) in &codes_rows {
|
||||
let creator_did =
|
||||
sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let uses_result = sqlx::query!(
|
||||
|
||||
let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|(_, _, _, uid, _)| *uid).collect();
|
||||
let code_strings: Vec<String> = codes_rows.iter().map(|(c, _, _, _, _)| c.clone()).collect();
|
||||
|
||||
let mut creator_dids: std::collections::HashMap<uuid::Uuid, String> =
|
||||
std::collections::HashMap::new();
|
||||
sqlx::query!(
|
||||
"SELECT id, did FROM users WHERE id = ANY($1)",
|
||||
&user_ids
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.for_each(|r| {
|
||||
creator_dids.insert(r.id, r.did);
|
||||
});
|
||||
|
||||
let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
|
||||
std::collections::HashMap::new();
|
||||
if !code_strings.is_empty() {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
SELECT u.did, icu.used_at
|
||||
SELECT icu.code, u.did, icu.used_at
|
||||
FROM invite_code_uses icu
|
||||
JOIN users u ON icu.used_by_user = u.id
|
||||
WHERE icu.code = $1
|
||||
WHERE icu.code = ANY($1)
|
||||
ORDER BY icu.used_at DESC
|
||||
"#,
|
||||
code
|
||||
&code_strings
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
let uses = match uses_result {
|
||||
Ok(use_rows) => use_rows
|
||||
.iter()
|
||||
.map(|u| InviteCodeUseInfo {
|
||||
used_by: u.did.clone(),
|
||||
used_at: u.used_at.to_rfc3339(),
|
||||
})
|
||||
.collect(),
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
codes.push(InviteCodeInfo {
|
||||
code: code.clone(),
|
||||
available: *available_uses,
|
||||
disabled: disabled.unwrap_or(false),
|
||||
for_account: creator_did.clone(),
|
||||
created_by: creator_did,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
uses,
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.for_each(|r| {
|
||||
uses_by_code
|
||||
.entry(r.code)
|
||||
.or_default()
|
||||
.push(InviteCodeUseInfo {
|
||||
used_by: r.did,
|
||||
used_at: r.used_at.to_rfc3339(),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
let codes: Vec<InviteCodeInfo> = codes_rows
|
||||
.iter()
|
||||
.map(|(code, available_uses, disabled, created_by_user, created_at)| {
|
||||
let creator_did = creator_dids
|
||||
.get(created_by_user)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
InviteCodeInfo {
|
||||
code: code.clone(),
|
||||
available: *available_uses,
|
||||
disabled: disabled.unwrap_or(false),
|
||||
for_account: creator_did.clone(),
|
||||
created_by: creator_did,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
uses: uses_by_code.get(code).cloned().unwrap_or_default(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let next_cursor = if codes_rows.len() == limit as usize {
|
||||
codes_rows.last().map(|(code, _, _, _, _)| code.clone())
|
||||
} else {
|
||||
|
||||
@@ -22,6 +22,7 @@ pub enum ApiError {
|
||||
InvalidRequest(String),
|
||||
InvalidToken(Option<String>),
|
||||
ExpiredToken(Option<String>),
|
||||
OAuthExpiredToken(Option<String>),
|
||||
TokenRequired,
|
||||
AccountDeactivated,
|
||||
AccountTakedown,
|
||||
@@ -127,7 +128,8 @@ impl ApiError {
|
||||
| Self::InvalidCode(_)
|
||||
| Self::InvalidPassword(_)
|
||||
| Self::InvalidToken(_)
|
||||
| Self::PasskeyCounterAnomaly => StatusCode::UNAUTHORIZED,
|
||||
| Self::PasskeyCounterAnomaly
|
||||
| Self::OAuthExpiredToken(_) => StatusCode::UNAUTHORIZED,
|
||||
Self::ExpiredToken(_) => StatusCode::BAD_REQUEST,
|
||||
Self::Forbidden
|
||||
| Self::AdminRequired
|
||||
@@ -216,7 +218,7 @@ impl ApiError {
|
||||
Self::AuthenticationRequired => Cow::Borrowed("AuthenticationRequired"),
|
||||
Self::AuthenticationFailed(_) => Cow::Borrowed("AuthenticationFailed"),
|
||||
Self::InvalidToken(_) => Cow::Borrowed("InvalidToken"),
|
||||
Self::ExpiredToken(_) => Cow::Borrowed("ExpiredToken"),
|
||||
Self::ExpiredToken(_) | Self::OAuthExpiredToken(_) => Cow::Borrowed("ExpiredToken"),
|
||||
Self::TokenRequired => Cow::Borrowed("TokenRequired"),
|
||||
Self::AccountDeactivated => Cow::Borrowed("AccountDeactivated"),
|
||||
Self::AccountTakedown => Cow::Borrowed("AccountTakedown"),
|
||||
@@ -298,6 +300,7 @@ impl ApiError {
|
||||
| Self::AuthenticationFailed(msg)
|
||||
| Self::InvalidToken(msg)
|
||||
| Self::ExpiredToken(msg)
|
||||
| Self::OAuthExpiredToken(msg)
|
||||
| Self::RepoNotFound(msg)
|
||||
| Self::BlobNotFound(msg)
|
||||
| Self::InvalidHandle(msg)
|
||||
@@ -428,13 +431,24 @@ impl IntoResponse for ApiError {
|
||||
message: self.message(),
|
||||
};
|
||||
let mut response = (self.status_code(), Json(body)).into_response();
|
||||
if matches!(self, Self::ExpiredToken(_)) {
|
||||
response.headers_mut().insert(
|
||||
"WWW-Authenticate",
|
||||
"Bearer error=\"invalid_token\", error_description=\"Token has expired\""
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
match &self {
|
||||
Self::ExpiredToken(_) => {
|
||||
response.headers_mut().insert(
|
||||
"WWW-Authenticate",
|
||||
"Bearer error=\"invalid_token\", error_description=\"Token has expired\""
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
Self::OAuthExpiredToken(_) => {
|
||||
response.headers_mut().insert(
|
||||
"WWW-Authenticate",
|
||||
"DPoP error=\"invalid_token\", error_description=\"Token has expired\""
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
response
|
||||
}
|
||||
@@ -457,6 +471,9 @@ impl From<crate::auth::TokenValidationError> for ApiError {
|
||||
Self::AuthenticationFailed(None)
|
||||
}
|
||||
crate::auth::TokenValidationError::TokenExpired => Self::ExpiredToken(None),
|
||||
crate::auth::TokenValidationError::OAuthTokenExpired => {
|
||||
Self::OAuthExpiredToken(Some("Token has expired".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,7 +211,7 @@ async fn create_report_locally(
|
||||
}
|
||||
|
||||
let created_at = chrono::Utc::now();
|
||||
let report_id = created_at.timestamp_millis();
|
||||
let report_id = (uuid::Uuid::now_v7().as_u128() & 0x7FFF_FFFF_FFFF_FFFF) as i64;
|
||||
let subject_json = json!(input.subject);
|
||||
|
||||
let insert = sqlx::query!(
|
||||
|
||||
@@ -268,21 +268,8 @@ async fn proxy_handler(
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Token validation failed: {:?}", e);
|
||||
if matches!(e, crate::auth::TokenValidationError::TokenExpired) && extracted.is_dpop
|
||||
{
|
||||
let www_auth =
|
||||
"DPoP error=\"invalid_token\", error_description=\"Token has expired\"";
|
||||
let mut response =
|
||||
ApiError::ExpiredToken(Some("Token has expired".into())).into_response();
|
||||
*response.status_mut() = axum::http::StatusCode::UNAUTHORIZED;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("WWW-Authenticate", www_auth.parse().unwrap());
|
||||
let nonce = crate::oauth::verify::generate_dpop_nonce();
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("DPoP-Nonce", nonce.parse().unwrap());
|
||||
return response;
|
||||
if matches!(e, crate::auth::TokenValidationError::OAuthTokenExpired) {
|
||||
return ApiError::from(e).into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -291,11 +278,12 @@ async fn proxy_handler(
|
||||
if let Some(val) = auth_header_val {
|
||||
request_builder = request_builder.header("Authorization", val);
|
||||
}
|
||||
for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD {
|
||||
if let Some(val) = headers.get(*header_name) {
|
||||
request_builder = request_builder.header(*header_name, val);
|
||||
}
|
||||
}
|
||||
request_builder = crate::api::proxy_client::HEADERS_TO_FORWARD
|
||||
.iter()
|
||||
.filter_map(|name| headers.get(*name).map(|val| (*name, val)))
|
||||
.fold(request_builder, |builder, (name, val)| {
|
||||
builder.header(name, val)
|
||||
});
|
||||
if !body.is_empty() {
|
||||
request_builder = request_builder.body(body);
|
||||
}
|
||||
@@ -313,11 +301,12 @@ async fn proxy_handler(
|
||||
}
|
||||
};
|
||||
let mut response_builder = Response::builder().status(status);
|
||||
for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD {
|
||||
if let Some(val) = headers.get(*header_name) {
|
||||
response_builder = response_builder.header(*header_name, val);
|
||||
}
|
||||
}
|
||||
response_builder = crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD
|
||||
.iter()
|
||||
.filter_map(|name| headers.get(*name).map(|val| (*name, val)))
|
||||
.fold(response_builder, |builder, (name, val)| {
|
||||
builder.header(name, val)
|
||||
});
|
||||
match response_builder.body(axum::body::Body::from(body)) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
|
||||
@@ -88,15 +88,13 @@ pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
|
||||
Ok(addrs) => addrs.collect(),
|
||||
Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
|
||||
};
|
||||
for addr in &socket_addrs {
|
||||
if !is_unicast_ip(&addr.ip()) {
|
||||
warn!(
|
||||
"DNS resolution for {} returned non-unicast IP: {}",
|
||||
host,
|
||||
addr.ip()
|
||||
);
|
||||
return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
|
||||
}
|
||||
if let Some(addr) = socket_addrs.iter().find(|addr| !is_unicast_ip(&addr.ip())) {
|
||||
warn!(
|
||||
"DNS resolution for {} returned non-unicast IP: {}",
|
||||
host,
|
||||
addr.ip()
|
||||
);
|
||||
return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -82,19 +82,7 @@ pub async fn prepare_repo_write(
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!(error = ?e, is_dpop = extracted.is_dpop, "Token validation failed in prepare_repo_write");
|
||||
let mut response = ApiError::from(e).into_response();
|
||||
if matches!(e, crate::auth::TokenValidationError::TokenExpired) && extracted.is_dpop {
|
||||
*response.status_mut() = axum::http::StatusCode::UNAUTHORIZED;
|
||||
let www_auth =
|
||||
"DPoP error=\"invalid_token\", error_description=\"Token has expired\"";
|
||||
response.headers_mut().insert(
|
||||
"WWW-Authenticate",
|
||||
www_auth.parse().unwrap(),
|
||||
);
|
||||
let nonce = crate::oauth::verify::generate_dpop_nonce();
|
||||
response.headers_mut().insert("DPoP-Nonce", nonce.parse().unwrap());
|
||||
}
|
||||
response
|
||||
ApiError::from(e).into_response()
|
||||
})?;
|
||||
if repo.as_str() != auth_user.did.as_str() {
|
||||
return Err(
|
||||
|
||||
@@ -181,10 +181,11 @@ fn validate_email_detailed(email: &str) -> Result<(), EmailValidationError> {
|
||||
if local.starts_with('.') || local.ends_with('.') || local.contains("..") {
|
||||
return Err(EmailValidationError::InvalidLocalPart);
|
||||
}
|
||||
for c in local.chars() {
|
||||
if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) {
|
||||
return Err(EmailValidationError::InvalidLocalPart);
|
||||
}
|
||||
if !local
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || EMAIL_LOCAL_SPECIAL_CHARS.contains(c))
|
||||
{
|
||||
return Err(EmailValidationError::InvalidLocalPart);
|
||||
}
|
||||
if domain.is_empty() {
|
||||
return Err(EmailValidationError::EmptyDomain);
|
||||
@@ -195,18 +196,14 @@ fn validate_email_detailed(email: &str) -> Result<(), EmailValidationError> {
|
||||
if !domain.contains('.') {
|
||||
return Err(EmailValidationError::MissingDomainDot);
|
||||
}
|
||||
for label in domain.split('.') {
|
||||
if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH {
|
||||
return Err(EmailValidationError::InvalidDomainLabel);
|
||||
}
|
||||
if label.starts_with('-') || label.ends_with('-') {
|
||||
return Err(EmailValidationError::InvalidDomainLabel);
|
||||
}
|
||||
for c in label.chars() {
|
||||
if !c.is_ascii_alphanumeric() && c != '-' {
|
||||
return Err(EmailValidationError::InvalidDomainLabel);
|
||||
}
|
||||
}
|
||||
if !domain.split('.').all(|label| {
|
||||
!label.is_empty()
|
||||
&& label.len() <= MAX_DOMAIN_LABEL_LENGTH
|
||||
&& !label.starts_with('-')
|
||||
&& !label.ends_with('-')
|
||||
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
|
||||
}) {
|
||||
return Err(EmailValidationError::InvalidDomainLabel);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -293,10 +290,11 @@ pub fn validate_service_handle(
|
||||
return Err(HandleValidationError::EndsWithInvalidChar);
|
||||
}
|
||||
|
||||
for c in handle.chars() {
|
||||
if !c.is_ascii_alphanumeric() && c != '-' {
|
||||
return Err(HandleValidationError::InvalidCharacters);
|
||||
}
|
||||
if !handle
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '-')
|
||||
{
|
||||
return Err(HandleValidationError::InvalidCharacters);
|
||||
}
|
||||
|
||||
if crate::moderation::has_explicit_slur(handle) {
|
||||
@@ -330,10 +328,11 @@ pub fn is_valid_email(email: &str) -> bool {
|
||||
if local.contains("..") {
|
||||
return false;
|
||||
}
|
||||
for c in local.chars() {
|
||||
if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) {
|
||||
return false;
|
||||
}
|
||||
if !local
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || EMAIL_LOCAL_SPECIAL_CHARS.contains(c))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if domain.is_empty() || domain.len() > MAX_DOMAIN_LENGTH {
|
||||
return false;
|
||||
@@ -341,20 +340,13 @@ pub fn is_valid_email(email: &str) -> bool {
|
||||
if !domain.contains('.') {
|
||||
return false;
|
||||
}
|
||||
for label in domain.split('.') {
|
||||
if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH {
|
||||
return false;
|
||||
}
|
||||
if label.starts_with('-') || label.ends_with('-') {
|
||||
return false;
|
||||
}
|
||||
for c in label.chars() {
|
||||
if !c.is_ascii_alphanumeric() && c != '-' {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
domain.split('.').all(|label| {
|
||||
!label.is_empty()
|
||||
&& label.len() <= MAX_DOMAIN_LABEL_LENGTH
|
||||
&& !label.starts_with('-')
|
||||
&& !label.ends_with('-')
|
||||
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -61,6 +61,7 @@ pub enum TokenValidationError {
|
||||
KeyDecryptionFailed,
|
||||
AuthenticationFailed,
|
||||
TokenExpired,
|
||||
OAuthTokenExpired,
|
||||
}
|
||||
|
||||
impl fmt::Display for TokenValidationError {
|
||||
@@ -70,7 +71,7 @@ impl fmt::Display for TokenValidationError {
|
||||
Self::AccountTakedown => write!(f, "AccountTakedown"),
|
||||
Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
|
||||
Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
|
||||
Self::TokenExpired => write!(f, "ExpiredToken"),
|
||||
Self::TokenExpired | Self::OAuthTokenExpired => write!(f, "ExpiredToken"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -497,7 +498,9 @@ pub async fn validate_token_with_dpop(
|
||||
controller_did: None,
|
||||
})
|
||||
}
|
||||
Err(crate::oauth::OAuthError::ExpiredToken(_)) => Err(TokenValidationError::TokenExpired),
|
||||
Err(crate::oauth::OAuthError::ExpiredToken(_)) => {
|
||||
Err(TokenValidationError::OAuthTokenExpired)
|
||||
}
|
||||
Err(_) => Err(TokenValidationError::AuthenticationFailed),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,26 +106,29 @@ pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::E
|
||||
pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
|
||||
query
|
||||
.map(|q| {
|
||||
let mut values = Vec::new();
|
||||
for pair in q.split('&') {
|
||||
if let Some((k, v)) = pair.split_once('=')
|
||||
&& k == key
|
||||
&& let Ok(decoded) = urlencoding::decode(v)
|
||||
{
|
||||
let decoded = decoded.into_owned();
|
||||
q.split('&')
|
||||
.filter_map(|pair| {
|
||||
pair.split_once('=')
|
||||
.filter(|(k, _)| *k == key)
|
||||
.and_then(|(_, v)| urlencoding::decode(v).ok())
|
||||
.map(|decoded| decoded.into_owned())
|
||||
})
|
||||
.flat_map(|decoded| {
|
||||
if decoded.contains(',') {
|
||||
for part in decoded.split(',') {
|
||||
let trimmed = part.trim();
|
||||
if !trimmed.is_empty() {
|
||||
values.push(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
} else if !decoded.is_empty() {
|
||||
values.push(decoded);
|
||||
decoded
|
||||
.split(',')
|
||||
.filter_map(|part| {
|
||||
let trimmed = part.trim();
|
||||
(!trimmed.is_empty()).then(|| trimmed.to_string())
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else if decoded.is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
vec![decoded]
|
||||
}
|
||||
}
|
||||
}
|
||||
values
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ async fn setup_mock_plc_directory() -> String {
|
||||
async fn spawn_app(database_url: String) -> String {
|
||||
use tranquil_pds::rate_limit::RateLimiters;
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(3)
|
||||
.max_connections(10)
|
||||
.acquire_timeout(std::time::Duration::from_secs(30))
|
||||
.connect(&database_url)
|
||||
.await
|
||||
|
||||
@@ -4,12 +4,16 @@ use serde_json::{Value, json};
|
||||
|
||||
pub use crate::common::*;
|
||||
|
||||
fn unique_id() -> String {
|
||||
uuid::Uuid::new_v4().simple().to_string()[..12].to_string()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn setup_new_user(handle_prefix: &str) -> (String, String) {
|
||||
let client = client();
|
||||
let ts = Utc::now().timestamp_millis();
|
||||
let handle = format!("{}-{}.test", handle_prefix, ts);
|
||||
let email = format!("{}-{}@test.com", handle_prefix, ts);
|
||||
let uid = unique_id();
|
||||
let handle = format!("{}-{}.test", handle_prefix, uid);
|
||||
let email = format!("{}-{}@test.com", handle_prefix, uid);
|
||||
let password = "E2epass123!";
|
||||
let create_account_payload = json!({
|
||||
"handle": handle,
|
||||
@@ -51,7 +55,7 @@ pub async fn create_post(
|
||||
text: &str,
|
||||
) -> (String, String) {
|
||||
let collection = "app.bsky.feed.post";
|
||||
let rkey = format!("e2e_social_{}", Utc::now().timestamp_millis());
|
||||
let rkey = format!("e2e_social_{}", unique_id());
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let create_payload = json!({
|
||||
"repo": did,
|
||||
@@ -95,7 +99,7 @@ pub async fn create_follow(
|
||||
followee_did: &str,
|
||||
) -> (String, String) {
|
||||
let collection = "app.bsky.graph.follow";
|
||||
let rkey = format!("e2e_follow_{}", Utc::now().timestamp_millis());
|
||||
let rkey = format!("e2e_follow_{}", unique_id());
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let create_payload = json!({
|
||||
"repo": follower_did,
|
||||
@@ -140,7 +144,7 @@ pub async fn create_like(
|
||||
subject_cid: &str,
|
||||
) -> (String, String) {
|
||||
let collection = "app.bsky.feed.like";
|
||||
let rkey = format!("e2e_like_{}", Utc::now().timestamp_millis());
|
||||
let rkey = format!("e2e_like_{}", unique_id());
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let payload = json!({
|
||||
"repo": liker_did,
|
||||
@@ -182,7 +186,7 @@ pub async fn create_repost(
|
||||
subject_cid: &str,
|
||||
) -> (String, String) {
|
||||
let collection = "app.bsky.feed.repost";
|
||||
let rkey = format!("e2e_repost_{}", Utc::now().timestamp_millis());
|
||||
let rkey = format!("e2e_repost_{}", unique_id());
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let payload = json!({
|
||||
"repo": reposter_did,
|
||||
|
||||
@@ -55,18 +55,14 @@ impl BlobScope {
|
||||
if self.accept.is_empty() || self.accept.contains("*/*") {
|
||||
return true;
|
||||
}
|
||||
for pattern in &self.accept {
|
||||
if pattern == mime {
|
||||
return true;
|
||||
}
|
||||
if let Some(prefix) = pattern.strip_suffix("/*")
|
||||
&& mime.starts_with(prefix)
|
||||
&& mime.chars().nth(prefix.len()) == Some('/')
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
self.accept.iter().any(|pattern| {
|
||||
pattern == mime
|
||||
|| pattern
|
||||
.strip_suffix("/*")
|
||||
.is_some_and(|prefix| {
|
||||
mime.starts_with(prefix) && mime.chars().nth(prefix.len()) == Some('/')
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,19 +166,20 @@ pub fn parse_scope(scope: &str) -> ParsedScope {
|
||||
Some(rest.to_string())
|
||||
};
|
||||
|
||||
let mut actions = HashSet::new();
|
||||
if let Some(action_values) = params.get("action") {
|
||||
for action_str in action_values {
|
||||
if let Some(action) = RepoAction::parse_str(action_str) {
|
||||
actions.insert(action);
|
||||
}
|
||||
}
|
||||
}
|
||||
if actions.is_empty() {
|
||||
actions.insert(RepoAction::Create);
|
||||
actions.insert(RepoAction::Update);
|
||||
actions.insert(RepoAction::Delete);
|
||||
}
|
||||
let actions: HashSet<RepoAction> = params
|
||||
.get("action")
|
||||
.map(|action_values| {
|
||||
action_values
|
||||
.iter()
|
||||
.filter_map(|s| RepoAction::parse_str(s))
|
||||
.collect()
|
||||
})
|
||||
.filter(|set: &HashSet<RepoAction>| !set.is_empty())
|
||||
.unwrap_or_else(|| {
|
||||
[RepoAction::Create, RepoAction::Update, RepoAction::Delete]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
|
||||
return ParsedScope::Repo(RepoScope {
|
||||
collection,
|
||||
@@ -191,19 +188,20 @@ pub fn parse_scope(scope: &str) -> ParsedScope {
|
||||
}
|
||||
|
||||
if base == "repo" {
|
||||
let mut actions = HashSet::new();
|
||||
if let Some(action_values) = params.get("action") {
|
||||
for action_str in action_values {
|
||||
if let Some(action) = RepoAction::parse_str(action_str) {
|
||||
actions.insert(action);
|
||||
}
|
||||
}
|
||||
}
|
||||
if actions.is_empty() {
|
||||
actions.insert(RepoAction::Create);
|
||||
actions.insert(RepoAction::Update);
|
||||
actions.insert(RepoAction::Delete);
|
||||
}
|
||||
let actions: HashSet<RepoAction> = params
|
||||
.get("action")
|
||||
.map(|action_values| {
|
||||
action_values
|
||||
.iter()
|
||||
.filter_map(|s| RepoAction::parse_str(s))
|
||||
.collect()
|
||||
})
|
||||
.filter(|set: &HashSet<RepoAction>| !set.is_empty())
|
||||
.unwrap_or_else(|| {
|
||||
[RepoAction::Create, RepoAction::Update, RepoAction::Delete]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
return ParsedScope::Repo(RepoScope {
|
||||
collection: None,
|
||||
actions,
|
||||
@@ -212,16 +210,17 @@ pub fn parse_scope(scope: &str) -> ParsedScope {
|
||||
|
||||
if base.starts_with("blob") {
|
||||
let positional = base.strip_prefix("blob:").unwrap_or("");
|
||||
let mut accept = HashSet::new();
|
||||
|
||||
if !positional.is_empty() {
|
||||
accept.insert(positional.to_string());
|
||||
}
|
||||
if let Some(accept_values) = params.get("accept") {
|
||||
for v in accept_values {
|
||||
accept.insert(v.to_string());
|
||||
}
|
||||
}
|
||||
let accept: HashSet<String> = std::iter::once(positional)
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(String::from)
|
||||
.chain(
|
||||
params
|
||||
.get("accept")
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(String::clone),
|
||||
)
|
||||
.collect();
|
||||
|
||||
return ParsedScope::Blob(BlobScope { accept });
|
||||
}
|
||||
|
||||
@@ -113,34 +113,32 @@ impl ScopePermissions {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for repo_scope in self.find_repo_scopes() {
|
||||
if !repo_scope.actions.contains(&action) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match &repo_scope.collection {
|
||||
None => return Ok(()),
|
||||
Some(coll) if coll == collection => return Ok(()),
|
||||
Some(coll) if coll.ends_with(".*") => {
|
||||
let prefix = coll.strip_suffix(".*").unwrap();
|
||||
if collection.starts_with(prefix)
|
||||
&& collection.chars().nth(prefix.len()) == Some('.')
|
||||
{
|
||||
return Ok(());
|
||||
let has_permission = self.find_repo_scopes().any(|repo_scope| {
|
||||
repo_scope.actions.contains(&action)
|
||||
&& match &repo_scope.collection {
|
||||
None => true,
|
||||
Some(coll) if coll == collection => true,
|
||||
Some(coll) if coll.ends_with(".*") => {
|
||||
let prefix = coll.strip_suffix(".*").unwrap();
|
||||
collection.starts_with(prefix)
|
||||
&& collection.chars().nth(prefix.len()) == Some('.')
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("repo:{}?action={}", collection, action_str(action)),
|
||||
message: format!(
|
||||
"Insufficient scope to {} records in {}",
|
||||
action_str(action),
|
||||
collection
|
||||
),
|
||||
})
|
||||
if has_permission {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("repo:{}?action={}", collection, action_str(action)),
|
||||
message: format!(
|
||||
"Insufficient scope to {} records in {}",
|
||||
action_str(action),
|
||||
collection
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assert_blob(&self, mime: &str) -> Result<(), ScopeError> {
|
||||
@@ -148,16 +146,14 @@ impl ScopePermissions {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for blob_scope in self.find_blob_scopes() {
|
||||
if blob_scope.matches_mime(mime) {
|
||||
return Ok(());
|
||||
}
|
||||
if self.find_blob_scopes().any(|blob_scope| blob_scope.matches_mime(mime)) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("blob:{}", mime),
|
||||
message: format!("Insufficient scope to upload blob with mime type {}", mime),
|
||||
})
|
||||
}
|
||||
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("blob:{}", mime),
|
||||
message: format!("Insufficient scope to upload blob with mime type {}", mime),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn assert_rpc(&self, aud: &str, lxm: &str) -> Result<(), ScopeError> {
|
||||
@@ -169,7 +165,7 @@ impl ScopePermissions {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for rpc_scope in self.find_rpc_scopes() {
|
||||
let has_permission = self.find_rpc_scopes().any(|rpc_scope| {
|
||||
let lxm_matches = match &rpc_scope.lxm {
|
||||
None => true,
|
||||
Some(scope_lxm) if scope_lxm == lxm => true,
|
||||
@@ -186,15 +182,17 @@ impl ScopePermissions {
|
||||
Some(scope_aud) => scope_aud == aud,
|
||||
};
|
||||
|
||||
if lxm_matches && aud_matches {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
lxm_matches && aud_matches
|
||||
});
|
||||
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("rpc:{}?aud={}", lxm, aud),
|
||||
message: format!("Insufficient scope to call {} on {}", lxm, aud),
|
||||
})
|
||||
if has_permission {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("rpc:{}?aud={}", lxm, aud),
|
||||
message: format!("Insufficient scope to call {} on {}", lxm, aud),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assert_account(
|
||||
@@ -211,27 +209,28 @@ impl ScopePermissions {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for account_scope in self.find_account_scopes() {
|
||||
if account_scope.attr == attr && account_scope.action == action {
|
||||
return Ok(());
|
||||
}
|
||||
if account_scope.attr == attr && account_scope.action == AccountAction::Manage {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let has_permission = self.find_account_scopes().any(|account_scope| {
|
||||
account_scope.attr == attr
|
||||
&& (account_scope.action == action
|
||||
|| account_scope.action == AccountAction::Manage)
|
||||
});
|
||||
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!(
|
||||
"account:{}?action={}",
|
||||
attr_str(attr),
|
||||
action_str_account(action)
|
||||
),
|
||||
message: format!(
|
||||
"Insufficient scope to {} account {}",
|
||||
action_str_account(action),
|
||||
attr_str(attr)
|
||||
),
|
||||
})
|
||||
if has_permission {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!(
|
||||
"account:{}?action={}",
|
||||
attr_str(attr),
|
||||
action_str_account(action)
|
||||
),
|
||||
message: format!(
|
||||
"Insufficient scope to {} account {}",
|
||||
action_str_account(action),
|
||||
attr_str(attr)
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allows_email_read(&self) -> bool {
|
||||
@@ -264,22 +263,23 @@ impl ScopePermissions {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for identity_scope in self.find_identity_scopes() {
|
||||
if identity_scope.attr == IdentityAttr::Wildcard {
|
||||
return Ok(());
|
||||
}
|
||||
if identity_scope.attr == attr {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let has_permission = self
|
||||
.find_identity_scopes()
|
||||
.any(|identity_scope| {
|
||||
identity_scope.attr == IdentityAttr::Wildcard || identity_scope.attr == attr
|
||||
});
|
||||
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("identity:{}", identity_attr_str(attr)),
|
||||
message: format!(
|
||||
"Insufficient scope to modify identity {}",
|
||||
identity_attr_str(attr)
|
||||
),
|
||||
})
|
||||
if has_permission {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ScopeError::InsufficientScope {
|
||||
required: format!("identity:{}", identity_attr_str(attr)),
|
||||
message: format!(
|
||||
"Insufficient scope to modify identity {}",
|
||||
identity_attr_str(attr)
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allows_identity(&self, attr: IdentityAttr) -> bool {
|
||||
|
||||
Reference in New Issue
Block a user