From d1902506a592df31497ba6f79afedf57e8452da3 Mon Sep 17 00:00:00 2001 From: lewis Date: Sun, 11 Jan 2026 18:59:49 +0200 Subject: [PATCH] oauth jti fix, more code quality --- crates/tranquil-oauth/src/client.rs | 41 ++++--- .../tranquil-pds/src/api/identity/account.rs | 13 ++- crates/tranquil-pds/src/api/identity/did.rs | 21 ++-- .../src/api/server/account_status.rs | 4 +- .../tranquil-pds/src/api/server/migration.rs | 10 +- crates/tranquil-pds/src/api/server/session.rs | 8 +- crates/tranquil-pds/src/api/server/totp.rs | 6 +- .../src/auth/verification_token.rs | 17 ++- .../src/oauth/db/scope_preference.rs | 10 +- crates/tranquil-pds/src/oauth/db/token.rs | 40 +++---- .../src/oauth/endpoints/authorize.rs | 12 +-- .../tranquil-pds/src/oauth/endpoints/par.rs | 70 ++++++------ .../src/oauth/endpoints/token/grants.rs | 8 +- .../src/oauth/endpoints/token/helpers.rs | 18 +++- .../src/oauth/endpoints/token/introspect.rs | 2 +- crates/tranquil-pds/src/oauth/verify.rs | 4 +- crates/tranquil-pds/src/sync/deprecated.rs | 8 +- crates/tranquil-pds/src/sync/import.rs | 8 +- crates/tranquil-pds/src/sync/repo.rs | 42 ++++---- crates/tranquil-pds/src/sync/util.rs | 100 ++++++++---------- crates/tranquil-pds/tests/common/mod.rs | 30 +++--- .../tranquil-pds/tests/import_verification.rs | 5 +- 22 files changed, 226 insertions(+), 251 deletions(-) diff --git a/crates/tranquil-oauth/src/client.rs b/crates/tranquil-oauth/src/client.rs index cb7f5a9..43787be 100644 --- a/crates/tranquil-oauth/src/client.rs +++ b/crates/tranquil-oauth/src/client.rs @@ -529,28 +529,27 @@ async fn verify_private_key_jwt_async( let signature_bytes = URL_SAFE_NO_PAD .decode(parts[2]) .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?; - for key in matching_keys { - let key_alg = key.get("alg").and_then(|a| a.as_str()); - if key_alg.is_some() && key_alg != Some(alg) { - continue; - } - let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or(""); - let verified = match (alg, kty) { - ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes), - ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes), - ("RS256" | "RS384" | "RS512", "RSA") => { - verify_rsa(alg, key, &signing_input, &signature_bytes) + matching_keys + .into_iter() + .filter(|key| { + let key_alg = key.get("alg").and_then(|a| a.as_str()); + key_alg.is_none() || key_alg == Some(alg) + }) + .find_map(|key| { + let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or(""); + match (alg, kty) { + ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes).ok(), + ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes).ok(), + ("RS256" | "RS384" | "RS512", "RSA") => { + verify_rsa(alg, key, &signing_input, &signature_bytes).ok() + } + ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes).ok(), + _ => None, } - ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes), - _ => continue, - }; - if verified.is_ok() { - return Ok(()); - } - } - Err(OAuthError::InvalidClient( - "client_assertion signature verification failed".to_string(), - )) + }) + .ok_or_else(|| { + OAuthError::InvalidClient("client_assertion signature verification failed".to_string()) + }) } fn verify_es256( diff --git a/crates/tranquil-pds/src/api/identity/account.rs b/crates/tranquil-pds/src/api/identity/account.rs index c8ed282..4e9cb38 100644 --- a/crates/tranquil-pds/src/api/identity/account.rs +++ b/crates/tranquil-pds/src/api/identity/account.rs @@ -189,14 +189,13 @@ pub async fn create_account( if input.handle.contains(' ') || input.handle.contains('\t') { return ApiError::InvalidRequest("Handle cannot contain spaces".into()).into_response(); } - for c in input.handle.chars() { - if !c.is_ascii_alphanumeric() && c != '.' && c != '-' { - return ApiError::InvalidRequest(format!( - "Handle contains invalid character: {}", - c - )) + if let Some(c) = input + .handle + .chars() + .find(|c| !c.is_ascii_alphanumeric() && *c != '.' && *c != '-') + { + return ApiError::InvalidRequest(format!("Handle contains invalid character: {}", c)) .into_response(); - } } let handle_lower = input.handle.to_lowercase(); if crate::moderation::has_explicit_slur(&handle_lower) { diff --git a/crates/tranquil-pds/src/api/identity/did.rs b/crates/tranquil-pds/src/api/identity/did.rs index 391e791..a0cb1f3 100644 --- a/crates/tranquil-pds/src/api/identity/did.rs +++ b/crates/tranquil-pds/src/api/identity/did.rs @@ -639,17 +639,18 @@ pub async fn update_handle( return ApiError::InvalidHandle(Some("Handle contains invalid characters".into())) .into_response(); } - for segment in new_handle.split('.') { - if segment.is_empty() { - return ApiError::InvalidHandle(Some("Handle contains empty segment".into())) - .into_response(); - } - if segment.starts_with('-') || segment.ends_with('-') { - return ApiError::InvalidHandle(Some( - "Handle segment cannot start or end with hyphen".into(), - )) + if new_handle.split('.').any(|segment| segment.is_empty()) { + return ApiError::InvalidHandle(Some("Handle contains empty segment".into())) .into_response(); - } + } + if new_handle + .split('.') + .any(|segment| segment.starts_with('-') || segment.ends_with('-')) + { + return ApiError::InvalidHandle(Some( + "Handle segment cannot start or end with hyphen".into(), + )) + .into_response(); } if crate::moderation::has_explicit_slur(&new_handle) { return ApiError::InvalidHandle(Some("Inappropriate language in handle".into())) diff --git a/crates/tranquil-pds/src/api/server/account_status.rs b/crates/tranquil-pds/src/api/server/account_status.rs index d8aa725..9775d50 100644 --- a/crates/tranquil-pds/src/api/server/account_status.rs +++ b/crates/tranquil-pds/src/api/server/account_status.rs @@ -203,7 +203,9 @@ async fn assert_valid_did_document_for_service( Ok(data) => { let pds_endpoint = data .get("services") - .and_then(|s: &serde_json::Value| s.get("atproto_pds").or_else(|| s.get("atprotoPds"))) + .and_then(|s: &serde_json::Value| { + s.get("atproto_pds").or_else(|| s.get("atprotoPds")) + }) .and_then(|p: &serde_json::Value| p.get("endpoint")) .and_then(|e: &serde_json::Value| e.as_str()); diff --git a/crates/tranquil-pds/src/api/server/migration.rs b/crates/tranquil-pds/src/api/server/migration.rs index 292c1db..72e72bd 100644 --- a/crates/tranquil-pds/src/api/server/migration.rs +++ b/crates/tranquil-pds/src/api/server/migration.rs @@ -115,11 +115,11 @@ pub async fn update_did_document( } } - if let Some(ref handles) = input.also_known_as { - if handles.iter().any(|h| !h.starts_with("at://")) { - return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into()) - .into_response(); - } + if let Some(ref handles) = input.also_known_as + && handles.iter().any(|h| !h.starts_with("at://")) + { + return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into()) + .into_response(); } if let Some(ref endpoint) = input.service_endpoint { diff --git a/crates/tranquil-pds/src/api/server/session.rs b/crates/tranquil-pds/src/api/server/session.rs index 47615f7..6301306 100644 --- a/crates/tranquil-pds/src/api/server/session.rs +++ b/crates/tranquil-pds/src/api/server/session.rs @@ -949,16 +949,16 @@ pub async fn list_sessions( } }; - let jwt_sessions = jwt_rows.into_iter().map(|(id, access_jti, created_at, expires_at)| { - SessionInfo { + let jwt_sessions = jwt_rows + .into_iter() + .map(|(id, access_jti, created_at, expires_at)| SessionInfo { id: format!("jwt:{}", id), session_type: "legacy".to_string(), client_name: None, created_at: created_at.to_rfc3339(), expires_at: expires_at.to_rfc3339(), is_current: current_jti.as_ref() == Some(&access_jti), - } - }); + }); let is_oauth = auth.0.is_oauth; let oauth_sessions = diff --git a/crates/tranquil-pds/src/api/server/totp.rs b/crates/tranquil-pds/src/api/server/totp.rs index a0c87c3..74e6b4c 100644 --- a/crates/tranquil-pds/src/api/server/totp.rs +++ b/crates/tranquil-pds/src/api/server/totp.rs @@ -195,7 +195,8 @@ pub async fn enable_totp( return ApiError::InternalError(None).into_response(); } - let backup_hashes: Result, _> = backup_codes.iter().map(|c| hash_backup_code(c)).collect(); + let backup_hashes: Result, _> = + backup_codes.iter().map(|c| hash_backup_code(c)).collect(); let backup_hashes = match backup_hashes { Ok(hashes) => hashes, Err(e) => { @@ -484,7 +485,8 @@ pub async fn regenerate_backup_codes( return ApiError::InternalError(None).into_response(); } - let backup_hashes: Result, _> = backup_codes.iter().map(|c| hash_backup_code(c)).collect(); + let backup_hashes: Result, _> = + backup_codes.iter().map(|c| hash_backup_code(c)).collect(); let backup_hashes = match backup_hashes { Ok(hashes) => hashes, Err(e) => { diff --git a/crates/tranquil-pds/src/auth/verification_token.rs b/crates/tranquil-pds/src/auth/verification_token.rs index 033801c..3d62aef 100644 --- a/crates/tranquil-pds/src/auth/verification_token.rs +++ b/crates/tranquil-pds/src/auth/verification_token.rs @@ -296,15 +296,14 @@ pub fn verify_token_signature(token: &str) -> Result String { - let clean = token.replace(['-', ' '], ""); - let mut result = String::new(); - for (i, c) in clean.chars().enumerate() { - if i > 0 && i % 4 == 0 { - result.push('-'); - } - result.push(c); - } - result + token + .replace(['-', ' '], "") + .chars() + .collect::>() + .chunks(4) + .map(|chunk| chunk.iter().collect::()) + .collect::>() + .join("-") } pub fn normalize_token_input(input: &str) -> String { diff --git a/crates/tranquil-pds/src/oauth/db/scope_preference.rs b/crates/tranquil-pds/src/oauth/db/scope_preference.rs index b5459f6..f0b9d2f 100644 --- a/crates/tranquil-pds/src/oauth/db/scope_preference.rs +++ b/crates/tranquil-pds/src/oauth/db/scope_preference.rs @@ -75,13 +75,9 @@ pub async fn should_show_consent( let stored_scopes: std::collections::HashSet<&str> = stored_prefs.iter().map(|p| p.scope.as_str()).collect(); - for scope in requested_scopes { - if !stored_scopes.contains(scope.as_str()) { - return Ok(true); - } - } - - Ok(false) + Ok(requested_scopes + .iter() + .any(|scope| !stored_scopes.contains(scope.as_str()))) } pub async fn delete_scope_preferences( diff --git a/crates/tranquil-pds/src/oauth/db/token.rs b/crates/tranquil-pds/src/oauth/db/token.rs index 6b72a80..b626644 100644 --- a/crates/tranquil-pds/src/oauth/db/token.rs +++ b/crates/tranquil-pds/src/oauth/db/token.rs @@ -315,26 +315,26 @@ pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result Result { diff --git a/crates/tranquil-pds/src/oauth/endpoints/authorize.rs b/crates/tranquil-pds/src/oauth/endpoints/authorize.rs index c9bee75..e1858b7 100644 --- a/crates/tranquil-pds/src/oauth/endpoints/authorize.rs +++ b/crates/tranquil-pds/src/oauth/endpoints/authorize.rs @@ -102,13 +102,11 @@ fn extract_device_cookie(headers: &HeaderMap) -> Option { .get("cookie") .and_then(|v| v.to_str().ok()) .and_then(|cookie_str| { - for cookie in cookie_str.split(';') { - let cookie = cookie.trim(); - if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) { - return crate::config::AuthConfig::get().verify_device_cookie(value); - } - } - None + cookie_str.split(';').map(|c| c.trim()).find_map(|cookie| { + cookie + .strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) + .and_then(|value| crate::config::AuthConfig::get().verify_device_cookie(value)) + }) }) } diff --git a/crates/tranquil-pds/src/oauth/endpoints/par.rs b/crates/tranquil-pds/src/oauth/endpoints/par.rs index f317daa..f3472f2 100644 --- a/crates/tranquil-pds/src/oauth/endpoints/par.rs +++ b/crates/tranquil-pds/src/oauth/endpoints/par.rs @@ -182,35 +182,36 @@ fn validate_scope( if requested_scopes.is_empty() { return Ok(Some("atproto".to_string())); } - let mut has_transition = false; - let mut has_granular = false; - - for scope in &requested_scopes { - let parsed = parse_scope(scope); - match &parsed { - ParsedScope::Unknown(_) => { - return Err(OAuthError::InvalidScope(format!( - "Unsupported scope: {}", - scope - ))); - } - ParsedScope::TransitionGeneric - | ParsedScope::TransitionChat - | ParsedScope::TransitionEmail => { - has_transition = true; - } - ParsedScope::Repo(_) - | ParsedScope::Blob(_) - | ParsedScope::Rpc(_) - | ParsedScope::Account(_) - | ParsedScope::Identity(_) - | ParsedScope::Include(_) => { - has_granular = true; - } - ParsedScope::Atproto => {} - } + if let Some(unknown) = requested_scopes + .iter() + .find(|s| matches!(parse_scope(s), ParsedScope::Unknown(_))) + { + return Err(OAuthError::InvalidScope(format!( + "Unsupported scope: {}", + unknown + ))); } + let has_transition = requested_scopes.iter().any(|s| { + matches!( + parse_scope(s), + ParsedScope::TransitionGeneric + | ParsedScope::TransitionChat + | ParsedScope::TransitionEmail + ) + }); + let has_granular = requested_scopes.iter().any(|s| { + matches!( + parse_scope(s), + ParsedScope::Repo(_) + | ParsedScope::Blob(_) + | ParsedScope::Rpc(_) + | ParsedScope::Account(_) + | ParsedScope::Identity(_) + | ParsedScope::Include(_) + ) + }); + if has_transition && has_granular { return Err(OAuthError::InvalidScope( "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string() @@ -219,13 +220,14 @@ fn validate_scope( if let Some(client_scope) = &client_metadata.scope { let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); - for scope in &requested_scopes { - if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) { - return Err(OAuthError::InvalidScope(format!( - "Scope '{}' not registered for this client", - scope - ))); - } + if let Some(unregistered) = requested_scopes + .iter() + .find(|scope| !client_scopes.iter().any(|cs| scope_matches(cs, scope))) + { + return Err(OAuthError::InvalidScope(format!( + "Scope '{}' not registered for this client", + unregistered + ))); } } Ok(Some(requested_scopes.join(" "))) diff --git a/crates/tranquil-pds/src/oauth/endpoints/token/grants.rs b/crates/tranquil-pds/src/oauth/endpoints/token/grants.rs index 62a36c9..28ab02a 100644 --- a/crates/tranquil-pds/src/oauth/endpoints/token/grants.rs +++ b/crates/tranquil-pds/src/oauth/endpoints/token/grants.rs @@ -334,13 +334,7 @@ pub async fn handle_refresh_token_grant( REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL }; let new_expires_at = Utc::now() + Duration::days(refresh_expiry_days); - db::rotate_token( - &state.db, - db_id, - &new_refresh_token.0, - new_expires_at, - ) - .await?; + db::rotate_token(&state.db, db_id, &new_refresh_token.0, new_expires_at).await?; tracing::info!( did = %token_data.did, new_expires_at = %new_expires_at, diff --git a/crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs b/crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs index 1754ef3..f145987 100644 --- a/crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs +++ b/crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs @@ -11,6 +11,7 @@ const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 300; pub struct TokenClaims { pub jti: String, + pub sid: String, pub exp: i64, pub iat: i64, } @@ -33,22 +34,23 @@ pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAut } pub fn create_access_token( - token_id: &str, + session_id: &str, sub: &str, dpop_jkt: Option<&str>, scope: Option<&str>, ) -> Result { - create_access_token_with_delegation(token_id, sub, dpop_jkt, scope, None) + create_access_token_with_delegation(session_id, sub, dpop_jkt, scope, None) } pub fn create_access_token_with_delegation( - token_id: &str, + session_id: &str, sub: &str, dpop_jkt: Option<&str>, scope: Option<&str>, controller_did: Option<&str>, ) -> Result { use serde_json::json; + let jti = uuid::Uuid::new_v4().to_string(); let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); let issuer = format!("https://{}", pds_hostname); let now = Utc::now().timestamp(); @@ -60,7 +62,8 @@ pub fn create_access_token_with_delegation( "aud": issuer, "iat": now, "exp": exp, - "jti": token_id, + "jti": jti, + "sid": session_id, "scope": actual_scope }); if let Some(jkt) = dpop_jkt { @@ -132,6 +135,11 @@ pub fn extract_token_claims(token: &str) -> Result { .and_then(|j| j.as_str()) .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))? .to_string(); + let sid = payload + .get("sid") + .and_then(|s| s.as_str()) + .ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))? + .to_string(); let exp = payload .get("exp") .and_then(|e| e.as_i64()) @@ -140,5 +148,5 @@ pub fn extract_token_claims(token: &str) -> Result { .get("iat") .and_then(|i| i.as_i64()) .ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?; - Ok(TokenClaims { jti, exp, iat }) + Ok(TokenClaims { jti, sid, exp, iat }) } diff --git a/crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs b/crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs index df5cf42..624d0dc 100644 --- a/crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs +++ b/crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs @@ -102,7 +102,7 @@ pub async fn introspect_token( Ok(info) => info, Err(_) => return Ok(Json(inactive_response)), }; - let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await { + let token_data = match db::get_token_by_id(&state.db, &token_info.sid).await { Ok(Some(data)) => data, _ => return Ok(Json(inactive_response)), }; diff --git a/crates/tranquil-pds/src/oauth/verify.rs b/crates/tranquil-pds/src/oauth/verify.rs index 7f0aa00..8232d08 100644 --- a/crates/tranquil-pds/src/oauth/verify.rs +++ b/crates/tranquil-pds/src/oauth/verify.rs @@ -142,9 +142,9 @@ pub fn extract_oauth_token_info(token: &str) -> Result) { stack.push(*cid); } Ipld::Map(map) => { - for v in map.values() { - extract_links_ipld(v, stack); - } + map.values().for_each(|v| extract_links_ipld(v, stack)); } Ipld::List(arr) => { - for v in arr { - extract_links_ipld(v, stack); - } + arr.iter().for_each(|v| extract_links_ipld(v, stack)); } _ => {} } diff --git a/crates/tranquil-pds/src/sync/import.rs b/crates/tranquil-pds/src/sync/import.rs index f9d2d58..c4b122f 100644 --- a/crates/tranquil-pds/src/sync/import.rs +++ b/crates/tranquil-pds/src/sync/import.rs @@ -148,14 +148,10 @@ pub fn extract_links(value: &Ipld, links: &mut Vec) { links.push(*cid); } Ipld::Map(map) => { - for v in map.values() { - extract_links(v, links); - } + map.values().for_each(|v| extract_links(v, links)); } Ipld::List(arr) => { - for v in arr { - extract_links(v, links); - } + arr.iter().for_each(|v| extract_links(v, links)); } _ => {} } diff --git a/crates/tranquil-pds/src/sync/repo.rs b/crates/tranquil-pds/src/sync/repo.rs index 9023d76..560de6a 100644 --- a/crates/tranquil-pds/src/sync/repo.rs +++ b/crates/tranquil-pds/src/sync/repo.rs @@ -181,24 +181,26 @@ async fn get_repo_since(state: &AppState, did: &str, head_cid: &Cid, since: &str } }; - let mut block_cids: Vec = Vec::new(); - for event in &events { - if let Some(cids) = &event.blocks_cids { - for cid_str in cids { - if let Ok(cid) = Cid::from_str(cid_str) - && !block_cids.contains(&cid) - { - block_cids.push(cid); - } + let block_cids: Vec = events + .iter() + .flat_map(|event| { + let block_cids = event + .blocks_cids + .as_ref() + .map(|cids| cids.iter().filter_map(|s| Cid::from_str(s).ok()).collect()) + .unwrap_or_else(Vec::new); + let commit_cid = event + .commit_cid + .as_ref() + .and_then(|s| Cid::from_str(s).ok()); + block_cids.into_iter().chain(commit_cid) + }) + .fold(Vec::new(), |mut acc, cid| { + if !acc.contains(&cid) { + acc.push(cid); } - } - if let Some(commit_cid_str) = &event.commit_cid - && let Ok(cid) = Cid::from_str(commit_cid_str) - && !block_cids.contains(&cid) - { - block_cids.push(cid); - } - } + acc + }); let mut car_bytes = match encode_car_header(head_cid) { Ok(h) => h, @@ -334,9 +336,9 @@ pub async fn get_record( car.extend_from_slice(&writer); }; write_block(&mut car_bytes, &commit_cid, &commit_bytes); - for (cid, data) in &proof_blocks { - write_block(&mut car_bytes, cid, data); - } + proof_blocks + .iter() + .for_each(|(cid, data)| write_block(&mut car_bytes, cid, data)); write_block(&mut car_bytes, &record_cid, &record_block); ( StatusCode::OK, diff --git a/crates/tranquil-pds/src/sync/util.rs b/crates/tranquil-pds/src/sync/util.rs index 8452c00..fb931aa 100644 --- a/crates/tranquil-pds/src/sync/util.rs +++ b/crates/tranquil-pds/src/sync/util.rs @@ -210,13 +210,11 @@ async fn write_car_blocks( let mut buffer = Cursor::new(Vec::new()); let header = CarHeader::new_v1(vec![commit_cid]); let mut writer = CarWriter::new(header, &mut buffer); - for (cid, data) in other_blocks { - if cid != commit_cid { - writer - .write(cid, data.as_ref()) - .await - .map_err(|e| anyhow::anyhow!("writing block {}: {}", cid, e))?; - } + for (cid, data) in other_blocks.iter().filter(|(c, _)| **c != commit_cid) { + writer + .write(*cid, data.as_ref()) + .await + .map_err(|e| anyhow::anyhow!("writing block {}: {}", cid, e))?; } if let Some(data) = commit_bytes { writer @@ -360,20 +358,18 @@ pub async fn format_event_for_sending( } let car_bytes = if !all_cids.is_empty() { let fetched = state.block_store.get_many(&all_cids).await?; - let mut blocks = std::collections::BTreeMap::new(); - let mut commit_bytes: Option = None; - for (cid, data_opt) in all_cids.iter().zip(fetched.iter()) { - if let Some(data) = data_opt { - if *cid == commit_cid { - commit_bytes = Some(data.clone()); - if let Some(rev) = extract_rev_from_commit_bytes(data) { - frame.rev = rev; - } - } else { - blocks.insert(*cid, data.clone()); - } - } + let (commit_data, other_blocks): (Vec<_>, Vec<_>) = all_cids + .iter() + .zip(fetched.iter()) + .filter_map(|(cid, data_opt)| data_opt.as_ref().map(|data| (*cid, data.clone()))) + .partition(|(cid, _)| *cid == commit_cid); + let commit_bytes = commit_data.into_iter().next().map(|(_, data)| data); + if let Some(ref cb) = commit_bytes + && let Some(rev) = extract_rev_from_commit_bytes(cb) + { + frame.rev = rev; } + let blocks: std::collections::BTreeMap = other_blocks.into_iter().collect(); write_car_blocks(commit_cid, commit_bytes, blocks).await? } else { Vec::new() @@ -393,38 +389,33 @@ pub async fn prefetch_blocks_for_events( state: &AppState, events: &[SequencedEvent], ) -> Result, anyhow::Error> { - let mut all_cids: Vec = Vec::new(); - for event in events { - if let Some(ref commit_cid_str) = event.commit_cid - && let Ok(cid) = Cid::from_str(commit_cid_str) - { - all_cids.push(cid); - } - if let Some(ref prev_cid_str) = event.prev_cid - && let Ok(cid) = Cid::from_str(prev_cid_str) - { - all_cids.push(cid); - } - if let Some(ref block_cids_str) = event.blocks_cids { - for s in block_cids_str { - if let Ok(cid) = Cid::from_str(s) { - all_cids.push(cid); - } - } - } - } + let mut all_cids: Vec = events + .iter() + .flat_map(|event| { + let commit_cid = event + .commit_cid + .as_ref() + .and_then(|s| Cid::from_str(s).ok()); + let prev_cid = event.prev_cid.as_ref().and_then(|s| Cid::from_str(s).ok()); + let block_cids = event + .blocks_cids + .as_ref() + .map(|cids| cids.iter().filter_map(|s| Cid::from_str(s).ok()).collect()) + .unwrap_or_else(Vec::new); + commit_cid.into_iter().chain(prev_cid).chain(block_cids) + }) + .collect(); all_cids.sort(); all_cids.dedup(); if all_cids.is_empty() { return Ok(HashMap::new()); } let fetched = state.block_store.get_many(&all_cids).await?; - let mut blocks_map = HashMap::with_capacity(all_cids.len()); - for (cid, data_opt) in all_cids.into_iter().zip(fetched.into_iter()) { - if let Some(data) = data_opt { - blocks_map.insert(cid, data); - } - } + let blocks_map: HashMap = all_cids + .into_iter() + .zip(fetched) + .filter_map(|(cid, data_opt)| data_opt.map(|data| (cid, data))) + .collect(); Ok(blocks_map) } @@ -511,17 +502,12 @@ pub async fn format_event_with_prefetched_blocks( frame.since = Some(rev); } let car_bytes = if !all_cids.is_empty() { - let mut blocks = BTreeMap::new(); - let mut commit_bytes_for_car: Option = None; - for cid in all_cids { - if let Some(data) = prefetched.get(&cid) { - if cid == commit_cid { - commit_bytes_for_car = Some(data.clone()); - } else { - blocks.insert(cid, data.clone()); - } - } - } + let (commit_data, other_blocks): (Vec<_>, Vec<_>) = all_cids + .into_iter() + .filter_map(|cid| prefetched.get(&cid).map(|data| (cid, data.clone()))) + .partition(|(cid, _)| *cid == commit_cid); + let commit_bytes_for_car = commit_data.into_iter().next().map(|(_, data)| data); + let blocks: BTreeMap = other_blocks.into_iter().collect(); write_car_blocks(commit_cid, commit_bytes_for_car, blocks).await? } else { Vec::new() diff --git a/crates/tranquil-pds/tests/common/mod.rs b/crates/tranquil-pds/tests/common/mod.rs index 65dabab..c5b18c6 100644 --- a/crates/tranquil-pds/tests/common/mod.rs +++ b/crates/tranquil-pds/tests/common/mod.rs @@ -256,10 +256,10 @@ impl Respond for PlcPostResponder { .unwrap_or_default() .to_string(); - if let Ok(body) = serde_json::from_slice::(request.body.as_slice()) { - if let Ok(mut store) = self.store.write() { - store.insert(did, body); - } + if let Ok(body) = serde_json::from_slice::(request.body.as_slice()) + && let Ok(mut store) = self.store.write() + { + store.insert(did, body); } ResponseTemplate::new(200) } @@ -298,18 +298,16 @@ impl Respond for PlcGetResponder { match endpoint { "/log/last" => { - let response = operation - .cloned() - .unwrap_or_else(|| { - json!({ - "type": "plc_operation", - "rotationKeys": [], - "verificationMethods": {}, - "alsoKnownAs": [], - "services": {}, - "prev": null - }) - }); + let response = operation.cloned().unwrap_or_else(|| { + json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "prev": null + }) + }); ResponseTemplate::new(200).set_body_json(response) } "/log/audit" => ResponseTemplate::new(200).set_body_json(json!([])), diff --git a/crates/tranquil-pds/tests/import_verification.rs b/crates/tranquil-pds/tests/import_verification.rs index 2726d42..b260452 100644 --- a/crates/tranquil-pds/tests/import_verification.rs +++ b/crates/tranquil-pds/tests/import_verification.rs @@ -159,10 +159,7 @@ async fn test_import_accepts_own_exported_repo() { let status = import_res.status(); if status != StatusCode::OK { let body = import_res.text().await.unwrap_or_default(); - panic!( - "Import failed with status {}: {}", - status, body - ); + panic!("Import failed with status {}: {}", status, body); } }