mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-04-29 20:46:55 +00:00
oauth jti fix, more code quality
This commit is contained in:
@@ -529,28 +529,27 @@ async fn verify_private_key_jwt_async(
|
||||
let signature_bytes = URL_SAFE_NO_PAD
|
||||
.decode(parts[2])
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?;
|
||||
for key in matching_keys {
|
||||
let key_alg = key.get("alg").and_then(|a| a.as_str());
|
||||
if key_alg.is_some() && key_alg != Some(alg) {
|
||||
continue;
|
||||
}
|
||||
let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
|
||||
let verified = match (alg, kty) {
|
||||
("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes),
|
||||
("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes),
|
||||
("RS256" | "RS384" | "RS512", "RSA") => {
|
||||
verify_rsa(alg, key, &signing_input, &signature_bytes)
|
||||
matching_keys
|
||||
.into_iter()
|
||||
.filter(|key| {
|
||||
let key_alg = key.get("alg").and_then(|a| a.as_str());
|
||||
key_alg.is_none() || key_alg == Some(alg)
|
||||
})
|
||||
.find_map(|key| {
|
||||
let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
|
||||
match (alg, kty) {
|
||||
("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes).ok(),
|
||||
("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes).ok(),
|
||||
("RS256" | "RS384" | "RS512", "RSA") => {
|
||||
verify_rsa(alg, key, &signing_input, &signature_bytes).ok()
|
||||
}
|
||||
("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes).ok(),
|
||||
_ => None,
|
||||
}
|
||||
("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes),
|
||||
_ => continue,
|
||||
};
|
||||
if verified.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(OAuthError::InvalidClient(
|
||||
"client_assertion signature verification failed".to_string(),
|
||||
))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
OAuthError::InvalidClient("client_assertion signature verification failed".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
fn verify_es256(
|
||||
|
||||
@@ -189,14 +189,13 @@ pub async fn create_account(
|
||||
if input.handle.contains(' ') || input.handle.contains('\t') {
|
||||
return ApiError::InvalidRequest("Handle cannot contain spaces".into()).into_response();
|
||||
}
|
||||
for c in input.handle.chars() {
|
||||
if !c.is_ascii_alphanumeric() && c != '.' && c != '-' {
|
||||
return ApiError::InvalidRequest(format!(
|
||||
"Handle contains invalid character: {}",
|
||||
c
|
||||
))
|
||||
if let Some(c) = input
|
||||
.handle
|
||||
.chars()
|
||||
.find(|c| !c.is_ascii_alphanumeric() && *c != '.' && *c != '-')
|
||||
{
|
||||
return ApiError::InvalidRequest(format!("Handle contains invalid character: {}", c))
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
let handle_lower = input.handle.to_lowercase();
|
||||
if crate::moderation::has_explicit_slur(&handle_lower) {
|
||||
|
||||
@@ -639,17 +639,18 @@ pub async fn update_handle(
|
||||
return ApiError::InvalidHandle(Some("Handle contains invalid characters".into()))
|
||||
.into_response();
|
||||
}
|
||||
for segment in new_handle.split('.') {
|
||||
if segment.is_empty() {
|
||||
return ApiError::InvalidHandle(Some("Handle contains empty segment".into()))
|
||||
.into_response();
|
||||
}
|
||||
if segment.starts_with('-') || segment.ends_with('-') {
|
||||
return ApiError::InvalidHandle(Some(
|
||||
"Handle segment cannot start or end with hyphen".into(),
|
||||
))
|
||||
if new_handle.split('.').any(|segment| segment.is_empty()) {
|
||||
return ApiError::InvalidHandle(Some("Handle contains empty segment".into()))
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
if new_handle
|
||||
.split('.')
|
||||
.any(|segment| segment.starts_with('-') || segment.ends_with('-'))
|
||||
{
|
||||
return ApiError::InvalidHandle(Some(
|
||||
"Handle segment cannot start or end with hyphen".into(),
|
||||
))
|
||||
.into_response();
|
||||
}
|
||||
if crate::moderation::has_explicit_slur(&new_handle) {
|
||||
return ApiError::InvalidHandle(Some("Inappropriate language in handle".into()))
|
||||
|
||||
@@ -203,7 +203,9 @@ async fn assert_valid_did_document_for_service(
|
||||
Ok(data) => {
|
||||
let pds_endpoint = data
|
||||
.get("services")
|
||||
.and_then(|s: &serde_json::Value| s.get("atproto_pds").or_else(|| s.get("atprotoPds")))
|
||||
.and_then(|s: &serde_json::Value| {
|
||||
s.get("atproto_pds").or_else(|| s.get("atprotoPds"))
|
||||
})
|
||||
.and_then(|p: &serde_json::Value| p.get("endpoint"))
|
||||
.and_then(|e: &serde_json::Value| e.as_str());
|
||||
|
||||
|
||||
@@ -115,11 +115,11 @@ pub async fn update_did_document(
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref handles) = input.also_known_as {
|
||||
if handles.iter().any(|h| !h.starts_with("at://")) {
|
||||
return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into())
|
||||
.into_response();
|
||||
}
|
||||
if let Some(ref handles) = input.also_known_as
|
||||
&& handles.iter().any(|h| !h.starts_with("at://"))
|
||||
{
|
||||
return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into())
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if let Some(ref endpoint) = input.service_endpoint {
|
||||
|
||||
@@ -949,16 +949,16 @@ pub async fn list_sessions(
|
||||
}
|
||||
};
|
||||
|
||||
let jwt_sessions = jwt_rows.into_iter().map(|(id, access_jti, created_at, expires_at)| {
|
||||
SessionInfo {
|
||||
let jwt_sessions = jwt_rows
|
||||
.into_iter()
|
||||
.map(|(id, access_jti, created_at, expires_at)| SessionInfo {
|
||||
id: format!("jwt:{}", id),
|
||||
session_type: "legacy".to_string(),
|
||||
client_name: None,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
expires_at: expires_at.to_rfc3339(),
|
||||
is_current: current_jti.as_ref() == Some(&access_jti),
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let is_oauth = auth.0.is_oauth;
|
||||
let oauth_sessions =
|
||||
|
||||
@@ -195,7 +195,8 @@ pub async fn enable_totp(
|
||||
return ApiError::InternalError(None).into_response();
|
||||
}
|
||||
|
||||
let backup_hashes: Result<Vec<_>, _> = backup_codes.iter().map(|c| hash_backup_code(c)).collect();
|
||||
let backup_hashes: Result<Vec<_>, _> =
|
||||
backup_codes.iter().map(|c| hash_backup_code(c)).collect();
|
||||
let backup_hashes = match backup_hashes {
|
||||
Ok(hashes) => hashes,
|
||||
Err(e) => {
|
||||
@@ -484,7 +485,8 @@ pub async fn regenerate_backup_codes(
|
||||
return ApiError::InternalError(None).into_response();
|
||||
}
|
||||
|
||||
let backup_hashes: Result<Vec<_>, _> = backup_codes.iter().map(|c| hash_backup_code(c)).collect();
|
||||
let backup_hashes: Result<Vec<_>, _> =
|
||||
backup_codes.iter().map(|c| hash_backup_code(c)).collect();
|
||||
let backup_hashes = match backup_hashes {
|
||||
Ok(hashes) => hashes,
|
||||
Err(e) => {
|
||||
|
||||
@@ -296,15 +296,14 @@ pub fn verify_token_signature(token: &str) -> Result<VerificationToken, VerifyEr
|
||||
}
|
||||
|
||||
pub fn format_token_for_display(token: &str) -> String {
|
||||
let clean = token.replace(['-', ' '], "");
|
||||
let mut result = String::new();
|
||||
for (i, c) in clean.chars().enumerate() {
|
||||
if i > 0 && i % 4 == 0 {
|
||||
result.push('-');
|
||||
}
|
||||
result.push(c);
|
||||
}
|
||||
result
|
||||
token
|
||||
.replace(['-', ' '], "")
|
||||
.chars()
|
||||
.collect::<Vec<_>>()
|
||||
.chunks(4)
|
||||
.map(|chunk| chunk.iter().collect::<String>())
|
||||
.collect::<Vec<_>>()
|
||||
.join("-")
|
||||
}
|
||||
|
||||
pub fn normalize_token_input(input: &str) -> String {
|
||||
|
||||
@@ -75,13 +75,9 @@ pub async fn should_show_consent(
|
||||
let stored_scopes: std::collections::HashSet<&str> =
|
||||
stored_prefs.iter().map(|p| p.scope.as_str()).collect();
|
||||
|
||||
for scope in requested_scopes {
|
||||
if !stored_scopes.contains(scope.as_str()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
Ok(requested_scopes
|
||||
.iter()
|
||||
.any(|scope| !stored_scopes.contains(scope.as_str())))
|
||||
}
|
||||
|
||||
pub async fn delete_scope_preferences(
|
||||
|
||||
@@ -315,26 +315,26 @@ pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenD
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
let mut tokens = Vec::with_capacity(rows.len());
|
||||
for r in rows {
|
||||
tokens.push(TokenData {
|
||||
did: r.did,
|
||||
token_id: r.token_id,
|
||||
created_at: r.created_at,
|
||||
updated_at: r.updated_at,
|
||||
expires_at: r.expires_at,
|
||||
client_id: r.client_id,
|
||||
client_auth: from_json(r.client_auth)?,
|
||||
device_id: r.device_id,
|
||||
parameters: from_json(r.parameters)?,
|
||||
details: r.details,
|
||||
code: r.code,
|
||||
current_refresh_token: r.current_refresh_token,
|
||||
scope: r.scope,
|
||||
controller_did: r.controller_did,
|
||||
});
|
||||
}
|
||||
Ok(tokens)
|
||||
rows.into_iter()
|
||||
.map(|r| {
|
||||
Ok(TokenData {
|
||||
did: r.did,
|
||||
token_id: r.token_id,
|
||||
created_at: r.created_at,
|
||||
updated_at: r.updated_at,
|
||||
expires_at: r.expires_at,
|
||||
client_id: r.client_id,
|
||||
client_auth: from_json(r.client_auth)?,
|
||||
device_id: r.device_id,
|
||||
parameters: from_json(r.parameters)?,
|
||||
details: r.details,
|
||||
code: r.code,
|
||||
current_refresh_token: r.current_refresh_token,
|
||||
scope: r.scope,
|
||||
controller_did: r.controller_did,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
|
||||
|
||||
@@ -102,13 +102,11 @@ fn extract_device_cookie(headers: &HeaderMap) -> Option<String> {
|
||||
.get("cookie")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|cookie_str| {
|
||||
for cookie in cookie_str.split(';') {
|
||||
let cookie = cookie.trim();
|
||||
if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) {
|
||||
return crate::config::AuthConfig::get().verify_device_cookie(value);
|
||||
}
|
||||
}
|
||||
None
|
||||
cookie_str.split(';').map(|c| c.trim()).find_map(|cookie| {
|
||||
cookie
|
||||
.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME))
|
||||
.and_then(|value| crate::config::AuthConfig::get().verify_device_cookie(value))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -182,35 +182,36 @@ fn validate_scope(
|
||||
if requested_scopes.is_empty() {
|
||||
return Ok(Some("atproto".to_string()));
|
||||
}
|
||||
let mut has_transition = false;
|
||||
let mut has_granular = false;
|
||||
|
||||
for scope in &requested_scopes {
|
||||
let parsed = parse_scope(scope);
|
||||
match &parsed {
|
||||
ParsedScope::Unknown(_) => {
|
||||
return Err(OAuthError::InvalidScope(format!(
|
||||
"Unsupported scope: {}",
|
||||
scope
|
||||
)));
|
||||
}
|
||||
ParsedScope::TransitionGeneric
|
||||
| ParsedScope::TransitionChat
|
||||
| ParsedScope::TransitionEmail => {
|
||||
has_transition = true;
|
||||
}
|
||||
ParsedScope::Repo(_)
|
||||
| ParsedScope::Blob(_)
|
||||
| ParsedScope::Rpc(_)
|
||||
| ParsedScope::Account(_)
|
||||
| ParsedScope::Identity(_)
|
||||
| ParsedScope::Include(_) => {
|
||||
has_granular = true;
|
||||
}
|
||||
ParsedScope::Atproto => {}
|
||||
}
|
||||
if let Some(unknown) = requested_scopes
|
||||
.iter()
|
||||
.find(|s| matches!(parse_scope(s), ParsedScope::Unknown(_)))
|
||||
{
|
||||
return Err(OAuthError::InvalidScope(format!(
|
||||
"Unsupported scope: {}",
|
||||
unknown
|
||||
)));
|
||||
}
|
||||
|
||||
let has_transition = requested_scopes.iter().any(|s| {
|
||||
matches!(
|
||||
parse_scope(s),
|
||||
ParsedScope::TransitionGeneric
|
||||
| ParsedScope::TransitionChat
|
||||
| ParsedScope::TransitionEmail
|
||||
)
|
||||
});
|
||||
let has_granular = requested_scopes.iter().any(|s| {
|
||||
matches!(
|
||||
parse_scope(s),
|
||||
ParsedScope::Repo(_)
|
||||
| ParsedScope::Blob(_)
|
||||
| ParsedScope::Rpc(_)
|
||||
| ParsedScope::Account(_)
|
||||
| ParsedScope::Identity(_)
|
||||
| ParsedScope::Include(_)
|
||||
)
|
||||
});
|
||||
|
||||
if has_transition && has_granular {
|
||||
return Err(OAuthError::InvalidScope(
|
||||
"Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string()
|
||||
@@ -219,13 +220,14 @@ fn validate_scope(
|
||||
|
||||
if let Some(client_scope) = &client_metadata.scope {
|
||||
let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
|
||||
for scope in &requested_scopes {
|
||||
if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) {
|
||||
return Err(OAuthError::InvalidScope(format!(
|
||||
"Scope '{}' not registered for this client",
|
||||
scope
|
||||
)));
|
||||
}
|
||||
if let Some(unregistered) = requested_scopes
|
||||
.iter()
|
||||
.find(|scope| !client_scopes.iter().any(|cs| scope_matches(cs, scope)))
|
||||
{
|
||||
return Err(OAuthError::InvalidScope(format!(
|
||||
"Scope '{}' not registered for this client",
|
||||
unregistered
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(Some(requested_scopes.join(" ")))
|
||||
|
||||
@@ -334,13 +334,7 @@ pub async fn handle_refresh_token_grant(
|
||||
REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL
|
||||
};
|
||||
let new_expires_at = Utc::now() + Duration::days(refresh_expiry_days);
|
||||
db::rotate_token(
|
||||
&state.db,
|
||||
db_id,
|
||||
&new_refresh_token.0,
|
||||
new_expires_at,
|
||||
)
|
||||
.await?;
|
||||
db::rotate_token(&state.db, db_id, &new_refresh_token.0, new_expires_at).await?;
|
||||
tracing::info!(
|
||||
did = %token_data.did,
|
||||
new_expires_at = %new_expires_at,
|
||||
|
||||
@@ -11,6 +11,7 @@ const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 300;
|
||||
|
||||
pub struct TokenClaims {
|
||||
pub jti: String,
|
||||
pub sid: String,
|
||||
pub exp: i64,
|
||||
pub iat: i64,
|
||||
}
|
||||
@@ -33,22 +34,23 @@ pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAut
|
||||
}
|
||||
|
||||
pub fn create_access_token(
|
||||
token_id: &str,
|
||||
session_id: &str,
|
||||
sub: &str,
|
||||
dpop_jkt: Option<&str>,
|
||||
scope: Option<&str>,
|
||||
) -> Result<String, OAuthError> {
|
||||
create_access_token_with_delegation(token_id, sub, dpop_jkt, scope, None)
|
||||
create_access_token_with_delegation(session_id, sub, dpop_jkt, scope, None)
|
||||
}
|
||||
|
||||
pub fn create_access_token_with_delegation(
|
||||
token_id: &str,
|
||||
session_id: &str,
|
||||
sub: &str,
|
||||
dpop_jkt: Option<&str>,
|
||||
scope: Option<&str>,
|
||||
controller_did: Option<&str>,
|
||||
) -> Result<String, OAuthError> {
|
||||
use serde_json::json;
|
||||
let jti = uuid::Uuid::new_v4().to_string();
|
||||
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
let issuer = format!("https://{}", pds_hostname);
|
||||
let now = Utc::now().timestamp();
|
||||
@@ -60,7 +62,8 @@ pub fn create_access_token_with_delegation(
|
||||
"aud": issuer,
|
||||
"iat": now,
|
||||
"exp": exp,
|
||||
"jti": token_id,
|
||||
"jti": jti,
|
||||
"sid": session_id,
|
||||
"scope": actual_scope
|
||||
});
|
||||
if let Some(jkt) = dpop_jkt {
|
||||
@@ -132,6 +135,11 @@ pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
|
||||
.and_then(|j| j.as_str())
|
||||
.ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
|
||||
.to_string();
|
||||
let sid = payload
|
||||
.get("sid")
|
||||
.and_then(|s| s.as_str())
|
||||
.ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?
|
||||
.to_string();
|
||||
let exp = payload
|
||||
.get("exp")
|
||||
.and_then(|e| e.as_i64())
|
||||
@@ -140,5 +148,5 @@ pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
|
||||
.get("iat")
|
||||
.and_then(|i| i.as_i64())
|
||||
.ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?;
|
||||
Ok(TokenClaims { jti, exp, iat })
|
||||
Ok(TokenClaims { jti, sid, exp, iat })
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ pub async fn introspect_token(
|
||||
Ok(info) => info,
|
||||
Err(_) => return Ok(Json(inactive_response)),
|
||||
};
|
||||
let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
|
||||
let token_data = match db::get_token_by_id(&state.db, &token_info.sid).await {
|
||||
Ok(Some(data)) => data,
|
||||
_ => return Ok(Json(inactive_response)),
|
||||
};
|
||||
|
||||
@@ -142,9 +142,9 @@ pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthErro
|
||||
return Err(OAuthError::ExpiredToken("Token has expired".to_string()));
|
||||
}
|
||||
let token_id = payload
|
||||
.get("jti")
|
||||
.get("sid")
|
||||
.and_then(|j| j.as_str())
|
||||
.ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
|
||||
.ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?
|
||||
.to_string();
|
||||
let did = payload
|
||||
.get("sub")
|
||||
|
||||
@@ -146,14 +146,10 @@ fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
|
||||
stack.push(*cid);
|
||||
}
|
||||
Ipld::Map(map) => {
|
||||
for v in map.values() {
|
||||
extract_links_ipld(v, stack);
|
||||
}
|
||||
map.values().for_each(|v| extract_links_ipld(v, stack));
|
||||
}
|
||||
Ipld::List(arr) => {
|
||||
for v in arr {
|
||||
extract_links_ipld(v, stack);
|
||||
}
|
||||
arr.iter().for_each(|v| extract_links_ipld(v, stack));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -148,14 +148,10 @@ pub fn extract_links(value: &Ipld, links: &mut Vec<Cid>) {
|
||||
links.push(*cid);
|
||||
}
|
||||
Ipld::Map(map) => {
|
||||
for v in map.values() {
|
||||
extract_links(v, links);
|
||||
}
|
||||
map.values().for_each(|v| extract_links(v, links));
|
||||
}
|
||||
Ipld::List(arr) => {
|
||||
for v in arr {
|
||||
extract_links(v, links);
|
||||
}
|
||||
arr.iter().for_each(|v| extract_links(v, links));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -181,24 +181,26 @@ async fn get_repo_since(state: &AppState, did: &str, head_cid: &Cid, since: &str
|
||||
}
|
||||
};
|
||||
|
||||
let mut block_cids: Vec<Cid> = Vec::new();
|
||||
for event in &events {
|
||||
if let Some(cids) = &event.blocks_cids {
|
||||
for cid_str in cids {
|
||||
if let Ok(cid) = Cid::from_str(cid_str)
|
||||
&& !block_cids.contains(&cid)
|
||||
{
|
||||
block_cids.push(cid);
|
||||
}
|
||||
let block_cids: Vec<Cid> = events
|
||||
.iter()
|
||||
.flat_map(|event| {
|
||||
let block_cids = event
|
||||
.blocks_cids
|
||||
.as_ref()
|
||||
.map(|cids| cids.iter().filter_map(|s| Cid::from_str(s).ok()).collect())
|
||||
.unwrap_or_else(Vec::new);
|
||||
let commit_cid = event
|
||||
.commit_cid
|
||||
.as_ref()
|
||||
.and_then(|s| Cid::from_str(s).ok());
|
||||
block_cids.into_iter().chain(commit_cid)
|
||||
})
|
||||
.fold(Vec::new(), |mut acc, cid| {
|
||||
if !acc.contains(&cid) {
|
||||
acc.push(cid);
|
||||
}
|
||||
}
|
||||
if let Some(commit_cid_str) = &event.commit_cid
|
||||
&& let Ok(cid) = Cid::from_str(commit_cid_str)
|
||||
&& !block_cids.contains(&cid)
|
||||
{
|
||||
block_cids.push(cid);
|
||||
}
|
||||
}
|
||||
acc
|
||||
});
|
||||
|
||||
let mut car_bytes = match encode_car_header(head_cid) {
|
||||
Ok(h) => h,
|
||||
@@ -334,9 +336,9 @@ pub async fn get_record(
|
||||
car.extend_from_slice(&writer);
|
||||
};
|
||||
write_block(&mut car_bytes, &commit_cid, &commit_bytes);
|
||||
for (cid, data) in &proof_blocks {
|
||||
write_block(&mut car_bytes, cid, data);
|
||||
}
|
||||
proof_blocks
|
||||
.iter()
|
||||
.for_each(|(cid, data)| write_block(&mut car_bytes, cid, data));
|
||||
write_block(&mut car_bytes, &record_cid, &record_block);
|
||||
(
|
||||
StatusCode::OK,
|
||||
|
||||
@@ -210,13 +210,11 @@ async fn write_car_blocks(
|
||||
let mut buffer = Cursor::new(Vec::new());
|
||||
let header = CarHeader::new_v1(vec![commit_cid]);
|
||||
let mut writer = CarWriter::new(header, &mut buffer);
|
||||
for (cid, data) in other_blocks {
|
||||
if cid != commit_cid {
|
||||
writer
|
||||
.write(cid, data.as_ref())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("writing block {}: {}", cid, e))?;
|
||||
}
|
||||
for (cid, data) in other_blocks.iter().filter(|(c, _)| **c != commit_cid) {
|
||||
writer
|
||||
.write(*cid, data.as_ref())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("writing block {}: {}", cid, e))?;
|
||||
}
|
||||
if let Some(data) = commit_bytes {
|
||||
writer
|
||||
@@ -360,20 +358,18 @@ pub async fn format_event_for_sending(
|
||||
}
|
||||
let car_bytes = if !all_cids.is_empty() {
|
||||
let fetched = state.block_store.get_many(&all_cids).await?;
|
||||
let mut blocks = std::collections::BTreeMap::new();
|
||||
let mut commit_bytes: Option<Bytes> = None;
|
||||
for (cid, data_opt) in all_cids.iter().zip(fetched.iter()) {
|
||||
if let Some(data) = data_opt {
|
||||
if *cid == commit_cid {
|
||||
commit_bytes = Some(data.clone());
|
||||
if let Some(rev) = extract_rev_from_commit_bytes(data) {
|
||||
frame.rev = rev;
|
||||
}
|
||||
} else {
|
||||
blocks.insert(*cid, data.clone());
|
||||
}
|
||||
}
|
||||
let (commit_data, other_blocks): (Vec<_>, Vec<_>) = all_cids
|
||||
.iter()
|
||||
.zip(fetched.iter())
|
||||
.filter_map(|(cid, data_opt)| data_opt.as_ref().map(|data| (*cid, data.clone())))
|
||||
.partition(|(cid, _)| *cid == commit_cid);
|
||||
let commit_bytes = commit_data.into_iter().next().map(|(_, data)| data);
|
||||
if let Some(ref cb) = commit_bytes
|
||||
&& let Some(rev) = extract_rev_from_commit_bytes(cb)
|
||||
{
|
||||
frame.rev = rev;
|
||||
}
|
||||
let blocks: std::collections::BTreeMap<Cid, Bytes> = other_blocks.into_iter().collect();
|
||||
write_car_blocks(commit_cid, commit_bytes, blocks).await?
|
||||
} else {
|
||||
Vec::new()
|
||||
@@ -393,38 +389,33 @@ pub async fn prefetch_blocks_for_events(
|
||||
state: &AppState,
|
||||
events: &[SequencedEvent],
|
||||
) -> Result<HashMap<Cid, Bytes>, anyhow::Error> {
|
||||
let mut all_cids: Vec<Cid> = Vec::new();
|
||||
for event in events {
|
||||
if let Some(ref commit_cid_str) = event.commit_cid
|
||||
&& let Ok(cid) = Cid::from_str(commit_cid_str)
|
||||
{
|
||||
all_cids.push(cid);
|
||||
}
|
||||
if let Some(ref prev_cid_str) = event.prev_cid
|
||||
&& let Ok(cid) = Cid::from_str(prev_cid_str)
|
||||
{
|
||||
all_cids.push(cid);
|
||||
}
|
||||
if let Some(ref block_cids_str) = event.blocks_cids {
|
||||
for s in block_cids_str {
|
||||
if let Ok(cid) = Cid::from_str(s) {
|
||||
all_cids.push(cid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut all_cids: Vec<Cid> = events
|
||||
.iter()
|
||||
.flat_map(|event| {
|
||||
let commit_cid = event
|
||||
.commit_cid
|
||||
.as_ref()
|
||||
.and_then(|s| Cid::from_str(s).ok());
|
||||
let prev_cid = event.prev_cid.as_ref().and_then(|s| Cid::from_str(s).ok());
|
||||
let block_cids = event
|
||||
.blocks_cids
|
||||
.as_ref()
|
||||
.map(|cids| cids.iter().filter_map(|s| Cid::from_str(s).ok()).collect())
|
||||
.unwrap_or_else(Vec::new);
|
||||
commit_cid.into_iter().chain(prev_cid).chain(block_cids)
|
||||
})
|
||||
.collect();
|
||||
all_cids.sort();
|
||||
all_cids.dedup();
|
||||
if all_cids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let fetched = state.block_store.get_many(&all_cids).await?;
|
||||
let mut blocks_map = HashMap::with_capacity(all_cids.len());
|
||||
for (cid, data_opt) in all_cids.into_iter().zip(fetched.into_iter()) {
|
||||
if let Some(data) = data_opt {
|
||||
blocks_map.insert(cid, data);
|
||||
}
|
||||
}
|
||||
let blocks_map: HashMap<Cid, Bytes> = all_cids
|
||||
.into_iter()
|
||||
.zip(fetched)
|
||||
.filter_map(|(cid, data_opt)| data_opt.map(|data| (cid, data)))
|
||||
.collect();
|
||||
Ok(blocks_map)
|
||||
}
|
||||
|
||||
@@ -511,17 +502,12 @@ pub async fn format_event_with_prefetched_blocks(
|
||||
frame.since = Some(rev);
|
||||
}
|
||||
let car_bytes = if !all_cids.is_empty() {
|
||||
let mut blocks = BTreeMap::new();
|
||||
let mut commit_bytes_for_car: Option<Bytes> = None;
|
||||
for cid in all_cids {
|
||||
if let Some(data) = prefetched.get(&cid) {
|
||||
if cid == commit_cid {
|
||||
commit_bytes_for_car = Some(data.clone());
|
||||
} else {
|
||||
blocks.insert(cid, data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
let (commit_data, other_blocks): (Vec<_>, Vec<_>) = all_cids
|
||||
.into_iter()
|
||||
.filter_map(|cid| prefetched.get(&cid).map(|data| (cid, data.clone())))
|
||||
.partition(|(cid, _)| *cid == commit_cid);
|
||||
let commit_bytes_for_car = commit_data.into_iter().next().map(|(_, data)| data);
|
||||
let blocks: BTreeMap<Cid, Bytes> = other_blocks.into_iter().collect();
|
||||
write_car_blocks(commit_cid, commit_bytes_for_car, blocks).await?
|
||||
} else {
|
||||
Vec::new()
|
||||
|
||||
@@ -256,10 +256,10 @@ impl Respond for PlcPostResponder {
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
if let Ok(body) = serde_json::from_slice::<Value>(request.body.as_slice()) {
|
||||
if let Ok(mut store) = self.store.write() {
|
||||
store.insert(did, body);
|
||||
}
|
||||
if let Ok(body) = serde_json::from_slice::<Value>(request.body.as_slice())
|
||||
&& let Ok(mut store) = self.store.write()
|
||||
{
|
||||
store.insert(did, body);
|
||||
}
|
||||
ResponseTemplate::new(200)
|
||||
}
|
||||
@@ -298,18 +298,16 @@ impl Respond for PlcGetResponder {
|
||||
|
||||
match endpoint {
|
||||
"/log/last" => {
|
||||
let response = operation
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
})
|
||||
});
|
||||
let response = operation.cloned().unwrap_or_else(|| {
|
||||
json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
})
|
||||
});
|
||||
ResponseTemplate::new(200).set_body_json(response)
|
||||
}
|
||||
"/log/audit" => ResponseTemplate::new(200).set_body_json(json!([])),
|
||||
|
||||
@@ -159,10 +159,7 @@ async fn test_import_accepts_own_exported_repo() {
|
||||
let status = import_res.status();
|
||||
if status != StatusCode::OK {
|
||||
let body = import_res.text().await.unwrap_or_default();
|
||||
panic!(
|
||||
"Import failed with status {}: {}",
|
||||
status, body
|
||||
);
|
||||
panic!("Import failed with status {}: {}", status, body);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user