diff --git a/Cargo.lock b/Cargo.lock index 8b33e22..a6204b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2744,22 +2744,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "hyper-tls" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" -dependencies = [ - "bytes", - "http-body-util", - "hyper 1.8.1", - "hyper-util", - "native-tls", - "tokio", - "tokio-native-tls", - "tower-service", -] - [[package]] name = "hyper-util" version = "0.1.19" @@ -3585,23 +3569,6 @@ dependencies = [ "web-time", ] -[[package]] -name = "native-tls" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework 2.11.1", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nom" version = "7.1.3" @@ -4498,12 +4465,10 @@ dependencies = [ "http-body-util", "hyper 1.8.1", "hyper-rustls 0.27.7", - "hyper-tls", "hyper-util", "js-sys", "log", "mime", - "native-tls", "percent-encoding", "pin-project-lite", "quinn", @@ -4514,7 +4479,6 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-native-tls", "tokio-rustls 0.26.4", "tower", "tower-http", @@ -4661,7 +4625,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework 3.5.1", + "security-framework", ] [[package]] @@ -4800,19 +4764,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags", - "core-foundation 0.9.4", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - [[package]] name = "security-framework" version = "3.5.1" @@ -5525,19 +5476,6 @@ dependencies = [ "libc", ] -[[package]] -name = "tempfile" -version = "3.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" -dependencies = [ - "fastrand", - "getrandom 0.3.4", - "once_cell", - "rustix", - "windows-sys 0.61.2", -] - [[package]] name = "testcontainers" version = "0.26.2" @@ -5709,16 +5647,6 @@ dependencies = [ "syn 2.0.111", ] -[[package]] -name = "tokio-native-tls" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" -dependencies = [ - "native-tls", - "tokio", -] - [[package]] name = "tokio-rustls" version = "0.24.1" @@ -5758,10 +5686,12 @@ checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" dependencies = [ "futures-util", "log", - "native-tls", + "rustls 0.23.35", + "rustls-pki-types", "tokio", - "tokio-native-tls", + "tokio-rustls 0.26.4", "tungstenite", + "webpki-roots 0.26.11", ] [[package]] @@ -6274,8 +6204,9 @@ dependencies = [ "http 1.4.0", "httparse", "log", - "native-tls", "rand 0.9.2", + "rustls 0.23.35", + "rustls-pki-types", "sha1", "thiserror 2.0.17", "utf-8", diff --git a/Cargo.toml b/Cargo.toml index 424ae39..b3ee013 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ p384 = { version = "0.13", features = ["ecdsa"] } rand = "0.8" redis = { version = "1.0", features = ["tokio-comp", "connection-manager"] } regex = "1" -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-webpki-roots", "http2", "charset", "macos-system-configuration"] } serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" serde_ipld_dagcbor = "0.6" @@ -91,9 +91,9 @@ sha2 = "0.10" sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "uuid", "chrono", "json"] } subtle = "2.5" thiserror = "2.0" -tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process"] } +tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process", "io-util", "fs"] } tokio-util = "0.7.18" -tokio-tungstenite = { version = "0.28", features = ["native-tls"] } +tokio-tungstenite = { version = "0.28", features = ["rustls-tls-webpki-roots"] } totp-rs = { version = "5", features = ["qr"] } tower = "0.5" tower-http = { version = "0.6", features = ["cors"] } @@ -111,3 +111,8 @@ ctor = "0.6" testcontainers = "0.26" testcontainers-modules = { version = "0.14", features = ["postgres"] } wiremock = "0.6" + +[profile.release] +lto = "thin" +strip = true +codegen-units = 1 diff --git a/Dockerfile b/Dockerfile index 649b8ec..69fcb5f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,18 @@ FROM rust:1.92-alpine AS builder -RUN apk add --no-cache ca-certificates openssl openssl-dev openssl-libs-static pkgconfig musl-dev +RUN apk add --no-cache ca-certificates musl-dev pkgconfig openssl-dev openssl-libs-static WORKDIR /app +ARG SLIM="false" COPY Cargo.toml Cargo.lock ./ COPY crates ./crates COPY .sqlx ./.sqlx COPY migrations ./crates/tranquil-pds/migrations RUN --mount=type=cache,target=/usr/local/cargo/registry \ --mount=type=cache,target=/app/target \ - SQLX_OFFLINE=true cargo build --release -p tranquil-pds && \ + if [ "$SLIM" = "true" ]; then \ + SQLX_OFFLINE=true cargo build --release -p tranquil-pds --no-default-features; \ + else \ + SQLX_OFFLINE=true cargo build --release -p tranquil-pds; \ + fi && \ cp target/release/tranquil-pds /tmp/tranquil-pds FROM alpine:3.23 diff --git a/crates/tranquil-cache/Cargo.toml b/crates/tranquil-cache/Cargo.toml index 0f9f986..d374b03 100644 --- a/crates/tranquil-cache/Cargo.toml +++ b/crates/tranquil-cache/Cargo.toml @@ -4,12 +4,16 @@ version.workspace = true edition.workspace = true license.workspace = true +[features] +default = [] +valkey = ["dep:redis"] + [dependencies] tranquil-infra = { workspace = true } tranquil-ripple = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } -redis = { workspace = true } +redis = { workspace = true, optional = true } tokio-util = { workspace = true } tracing = { workspace = true } diff --git a/crates/tranquil-cache/src/lib.rs b/crates/tranquil-cache/src/lib.rs index f7804c8..87ff83d 100644 --- a/crates/tranquil-cache/src/lib.rs +++ b/crates/tranquil-cache/src/lib.rs @@ -1,72 +1,135 @@ pub use tranquil_infra::{Cache, CacheError, DistributedRateLimiter}; use async_trait::async_trait; -use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; use std::sync::Arc; use std::time::Duration; -#[derive(Clone)] -pub struct ValkeyCache { - conn: redis::aio::ConnectionManager, -} +#[cfg(feature = "valkey")] +mod valkey { + use super::*; + use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; -impl ValkeyCache { - pub async fn new(url: &str) -> Result { - let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?; - let manager = client - .get_connection_manager() - .await - .map_err(|e| CacheError::Connection(e.to_string()))?; - Ok(Self { conn: manager }) + #[derive(Clone)] + pub struct ValkeyCache { + conn: redis::aio::ConnectionManager, } - pub fn connection(&self) -> redis::aio::ConnectionManager { - self.conn.clone() + impl ValkeyCache { + pub async fn new(url: &str) -> Result { + let client = + redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?; + let manager = client + .get_connection_manager() + .await + .map_err(|e| CacheError::Connection(e.to_string()))?; + Ok(Self { conn: manager }) + } + + pub fn connection(&self) -> redis::aio::ConnectionManager { + self.conn.clone() + } + } + + #[async_trait] + impl Cache for ValkeyCache { + async fn get(&self, key: &str) -> Option { + let mut conn = self.conn.clone(); + redis::cmd("GET") + .arg(key) + .query_async::>(&mut conn) + .await + .ok() + .flatten() + } + + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { + let mut conn = self.conn.clone(); + redis::cmd("SET") + .arg(key) + .arg(value) + .arg("PX") + .arg(ttl.as_millis().min(i64::MAX as u128) as i64) + .query_async::<()>(&mut conn) + .await + .map_err(|e| CacheError::Connection(e.to_string())) + } + + async fn delete(&self, key: &str) -> Result<(), CacheError> { + let mut conn = self.conn.clone(); + redis::cmd("DEL") + .arg(key) + .query_async::<()>(&mut conn) + .await + .map_err(|e| CacheError::Connection(e.to_string())) + } + + async fn get_bytes(&self, key: &str) -> Option> { + self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) + } + + async fn set_bytes( + &self, + key: &str, + value: &[u8], + ttl: Duration, + ) -> Result<(), CacheError> { + let encoded = BASE64.encode(value); + self.set(key, &encoded, ttl).await + } + } + + #[derive(Clone)] + pub struct RedisRateLimiter { + conn: redis::aio::ConnectionManager, + } + + impl RedisRateLimiter { + pub fn new(conn: redis::aio::ConnectionManager) -> Self { + Self { conn } + } + } + + #[async_trait] + impl DistributedRateLimiter for RedisRateLimiter { + async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { + let mut conn = self.conn.clone(); + let full_key = format!("rl:{}", key); + let window_secs = window_ms.div_ceil(1000).max(1) as i64; + let result: Result = redis::Script::new( + r"local c = redis.call('INCR', KEYS[1]) +if c == 1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end +if redis.call('TTL', KEYS[1]) == -1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end +return c", + ) + .key(&full_key) + .arg(window_secs) + .invoke_async(&mut conn) + .await; + match result { + Ok(count) => count <= limit as i64, + Err(e) => { + tracing::warn!(error = %e, "redis rate limit script failed, allowing request"); + true + } + } + } + + async fn peek_rate_limit_count(&self, key: &str, _window_ms: u64) -> u64 { + let mut conn = self.conn.clone(); + let full_key = format!("rl:{}", key); + redis::cmd("GET") + .arg(&full_key) + .query_async::>(&mut conn) + .await + .ok() + .flatten() + .unwrap_or(0) + } } } -#[async_trait] -impl Cache for ValkeyCache { - async fn get(&self, key: &str) -> Option { - let mut conn = self.conn.clone(); - redis::cmd("GET") - .arg(key) - .query_async::>(&mut conn) - .await - .ok() - .flatten() - } - - async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { - let mut conn = self.conn.clone(); - redis::cmd("SET") - .arg(key) - .arg(value) - .arg("PX") - .arg(ttl.as_millis().min(i64::MAX as u128) as i64) - .query_async::<()>(&mut conn) - .await - .map_err(|e| CacheError::Connection(e.to_string())) - } - - async fn delete(&self, key: &str) -> Result<(), CacheError> { - let mut conn = self.conn.clone(); - redis::cmd("DEL") - .arg(key) - .query_async::<()>(&mut conn) - .await - .map_err(|e| CacheError::Connection(e.to_string())) - } - - async fn get_bytes(&self, key: &str) -> Option> { - self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) - } - - async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> { - let encoded = BASE64.encode(value); - self.set(key, &encoded, ttl).await - } -} +#[cfg(feature = "valkey")] +pub use valkey::{RedisRateLimiter, ValkeyCache}; pub struct NoOpCache; @@ -97,55 +160,6 @@ impl Cache for NoOpCache { } } -#[derive(Clone)] -pub struct RedisRateLimiter { - conn: redis::aio::ConnectionManager, -} - -impl RedisRateLimiter { - pub fn new(conn: redis::aio::ConnectionManager) -> Self { - Self { conn } - } -} - -#[async_trait] -impl DistributedRateLimiter for RedisRateLimiter { - async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { - let mut conn = self.conn.clone(); - let full_key = format!("rl:{}", key); - let window_secs = window_ms.div_ceil(1000).max(1) as i64; - let result: Result = redis::Script::new( - r"local c = redis.call('INCR', KEYS[1]) -if c == 1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end -if redis.call('TTL', KEYS[1]) == -1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end -return c" - ) - .key(&full_key) - .arg(window_secs) - .invoke_async(&mut conn) - .await; - match result { - Ok(count) => count <= limit as i64, - Err(e) => { - tracing::warn!(error = %e, "redis rate limit script failed, allowing request"); - true - } - } - } - - async fn peek_rate_limit_count(&self, key: &str, _window_ms: u64) -> u64 { - let mut conn = self.conn.clone(); - let full_key = format!("rl:{}", key); - redis::cmd("GET") - .arg(&full_key) - .query_async::>(&mut conn) - .await - .ok() - .flatten() - .unwrap_or(0) - } -} - pub struct NoOpRateLimiter; #[async_trait] @@ -158,6 +172,7 @@ impl DistributedRateLimiter for NoOpRateLimiter { pub async fn create_cache( shutdown: tokio_util::sync::CancellationToken, ) -> (Arc, Arc) { + #[cfg(feature = "valkey")] if let Ok(url) = std::env::var("VALKEY_URL") { match ValkeyCache::new(&url).await { Ok(cache) => { @@ -171,6 +186,13 @@ pub async fn create_cache( } } + #[cfg(not(feature = "valkey"))] + if std::env::var("VALKEY_URL").is_ok() { + tracing::warn!( + "VALKEY_URL is set but binary was compiled without valkey feature. using ripple." + ); + } + match tranquil_ripple::RippleConfig::from_env() { Ok(config) => { let peer_count = config.seed_peers.len(); diff --git a/crates/tranquil-pds/Cargo.toml b/crates/tranquil-pds/Cargo.toml index 70a64fc..873b289 100644 --- a/crates/tranquil-pds/Cargo.toml +++ b/crates/tranquil-pds/Cargo.toml @@ -21,8 +21,6 @@ aes-gcm = { workspace = true } async-trait = { workspace = true } backon = { workspace = true } anyhow = { workspace = true } -aws-config = { workspace = true } -aws-sdk-s3 = { workspace = true } axum = { workspace = true } base32 = { workspace = true } base64 = { workspace = true } @@ -55,7 +53,7 @@ multibase = { workspace = true } multihash = { workspace = true } p256 = { workspace = true } rand = { workspace = true } -redis = { workspace = true } +redis = { workspace = true, optional = true } regex = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } @@ -79,10 +77,15 @@ urlencoding = { workspace = true } uuid = { workspace = true } webauthn-rs = { workspace = true } zip = { workspace = true } +aws-config = { workspace = true, optional = true } +aws-sdk-s3 = { workspace = true, optional = true } [features] +default = ["s3", "valkey"] external-infra = [] -s3-storage = [] +s3-storage = ["tranquil-storage/s3", "dep:aws-config", "dep:aws-sdk-s3"] +s3 = ["s3-storage"] +valkey = ["tranquil-cache/valkey", "dep:redis"] [dev-dependencies] ciborium = { workspace = true } diff --git a/crates/tranquil-pds/src/api/repo/blob.rs b/crates/tranquil-pds/src/api/repo/blob.rs index 8126f97..81eac65 100644 --- a/crates/tranquil-pds/src/api/repo/blob.rs +++ b/crates/tranquil-pds/src/api/repo/blob.rs @@ -162,7 +162,10 @@ pub async fn upload_blob( if let Err(e) = state.blob_store.copy(&temp_key, &storage_key).await { let _ = state.blob_store.delete(&temp_key).await; if let Err(db_err) = state.blob_repo.delete_blob_by_cid(&cid_link).await { - error!("Failed to clean up orphaned blob record after copy failure: {:?}", db_err); + error!( + "Failed to clean up orphaned blob record after copy failure: {:?}", + db_err + ); } error!("Failed to copy blob to final location: {:?}", e); return Err(ApiError::InternalError(Some("Failed to store blob".into()))); @@ -170,8 +173,8 @@ pub async fn upload_blob( let _ = state.blob_store.delete(&temp_key).await; - if let Some(ref controller) = controller_did { - if let Err(e) = state + if let Some(ref controller) = controller_did + && let Err(e) = state .delegation_repo .log_delegation_action( &did, @@ -187,9 +190,8 @@ pub async fn upload_blob( None, ) .await - { - warn!("Failed to log delegation action for blob upload: {:?}", e); - } + { + warn!("Failed to log delegation action for blob upload: {:?}", e); } Ok(Json(json!({ diff --git a/crates/tranquil-pds/src/cache/mod.rs b/crates/tranquil-pds/src/cache/mod.rs index a15dbf2..c48865b 100644 --- a/crates/tranquil-pds/src/cache/mod.rs +++ b/crates/tranquil-pds/src/cache/mod.rs @@ -1,4 +1,6 @@ pub use tranquil_cache::{ - Cache, CacheError, DistributedRateLimiter, NoOpCache, NoOpRateLimiter, RedisRateLimiter, - ValkeyCache, create_cache, + Cache, CacheError, DistributedRateLimiter, NoOpCache, NoOpRateLimiter, create_cache, }; + +#[cfg(feature = "valkey")] +pub use tranquil_cache::{RedisRateLimiter, ValkeyCache}; diff --git a/crates/tranquil-pds/src/storage/mod.rs b/crates/tranquil-pds/src/storage/mod.rs index 3827d35..db45e73 100644 --- a/crates/tranquil-pds/src/storage/mod.rs +++ b/crates/tranquil-pds/src/storage/mod.rs @@ -1,5 +1,8 @@ pub use tranquil_storage::{ - BackupStorage, BlobStorage, FilesystemBackupStorage, FilesystemBlobStorage, S3BackupStorage, - S3BlobStorage, StorageError, StreamUploadResult, backup_interval_secs, backup_retention_count, - create_backup_storage, create_blob_storage, + BackupStorage, BlobStorage, FilesystemBackupStorage, FilesystemBlobStorage, StorageError, + StreamUploadResult, backup_interval_secs, backup_retention_count, create_backup_storage, + create_blob_storage, }; + +#[cfg(feature = "s3-storage")] +pub use tranquil_storage::{S3BackupStorage, S3BlobStorage}; diff --git a/crates/tranquil-pds/tests/common/mod.rs b/crates/tranquil-pds/tests/common/mod.rs index d2cd3df..79a350a 100644 --- a/crates/tranquil-pds/tests/common/mod.rs +++ b/crates/tranquil-pds/tests/common/mod.rs @@ -330,10 +330,8 @@ unsafe fn configure_external_storage_env() { ); std::env::set_var("S3_ENDPOINT", &s3_endpoint); } else { - let process_dir = std::env::temp_dir().join(format!( - "tranquil-pds-test-{}", - std::process::id() - )); + let process_dir = + std::env::temp_dir().join(format!("tranquil-pds-test-{}", std::process::id())); let blob_path = process_dir.join("blobs"); let backup_path = process_dir.join("backups"); std::fs::create_dir_all(&blob_path).expect("Failed to create blob directory"); @@ -715,7 +713,8 @@ async fn setup_cluster_external_infra() -> String { #[cfg(not(feature = "external-infra"))] async fn setup_cluster_testcontainers() -> String { - let temp_dir = std::env::temp_dir().join(format!("tranquil-pds-cluster-{}", uuid::Uuid::new_v4())); + let temp_dir = + std::env::temp_dir().join(format!("tranquil-pds-cluster-{}", uuid::Uuid::new_v4())); let blob_path = temp_dir.join("blobs"); let backup_path = temp_dir.join("backups"); std::fs::create_dir_all(&blob_path).expect("Failed to create blob temp directory"); diff --git a/crates/tranquil-pds/tests/rate_limit.rs b/crates/tranquil-pds/tests/rate_limit.rs index 4163380..a3ace16 100644 --- a/crates/tranquil-pds/tests/rate_limit.rs +++ b/crates/tranquil-pds/tests/rate_limit.rs @@ -118,6 +118,7 @@ async fn test_account_creation_rate_limiting() { ); } +#[cfg(feature = "valkey")] #[tokio::test] async fn test_valkey_connection() { if std::env::var("VALKEY_URL").is_err() { @@ -156,6 +157,7 @@ async fn test_valkey_connection() { .expect("DEL failed"); } +#[cfg(feature = "valkey")] #[tokio::test] async fn test_distributed_rate_limiter_directly() { if std::env::var("VALKEY_URL").is_err() { diff --git a/crates/tranquil-pds/tests/ripple_cluster.rs b/crates/tranquil-pds/tests/ripple_cluster.rs index 50e367b..c8e917f 100644 --- a/crates/tranquil-pds/tests/ripple_cluster.rs +++ b/crates/tranquil-pds/tests/ripple_cluster.rs @@ -45,21 +45,22 @@ async fn cluster_formation() { assert!(nodes.len() >= 3, "expected at least 3 cluster nodes"); let client = common::client(); - let results: Vec<_> = futures::future::join_all( - nodes.iter().map(|node| { - let client = client.clone(); - let url = node.url.clone(); - async move { - client - .get(format!("{url}/xrpc/com.atproto.server.describeServer")) - .send() - .await - } - }) - ).await; + let results: Vec<_> = futures::future::join_all(nodes.iter().map(|node| { + let client = client.clone(); + let url = node.url.clone(); + async move { + client + .get(format!("{url}/xrpc/com.atproto.server.describeServer")) + .send() + .await + } + })) + .await; results.iter().enumerate().for_each(|(i, result)| { - let resp = result.as_ref().unwrap_or_else(|e| panic!("node {i} unreachable: {e}")); + let resp = result + .as_ref() + .unwrap_or_else(|e| panic!("node {i} unreachable: {e}")); assert_eq!( resp.status(), StatusCode::OK, @@ -91,7 +92,10 @@ async fn cluster_any_node_access() { assert_eq!(create_res.status(), StatusCode::OK); let body: serde_json::Value = create_res.json().await.expect("invalid json"); let did = body["did"].as_str().expect("no did").to_string(); - let access_jwt = body["accessJwt"].as_str().expect("no accessJwt").to_string(); + let access_jwt = body["accessJwt"] + .as_str() + .expect("no accessJwt") + .to_string(); let pool = common::get_test_db_pool().await; let body_text: String = sqlx::query_scalar!( @@ -132,8 +136,10 @@ async fn cluster_any_node_access() { let token = match confirm_res.status() { StatusCode::OK => { - let confirm_body: serde_json::Value = - confirm_res.json().await.expect("invalid json from confirmSignup"); + let confirm_body: serde_json::Value = confirm_res + .json() + .await + .expect("invalid json from confirmSignup"); confirm_body["accessJwt"] .as_str() .unwrap_or(&access_jwt) @@ -164,14 +170,8 @@ async fn cluster_any_node_access() { async fn cache_convergence() { let nodes = common::cluster().await; - let cache_a = nodes[0] - .cache - .as_ref() - .expect("node 0 should have a cache"); - let cache_b = nodes[1] - .cache - .as_ref() - .expect("node 1 should have a cache"); + let cache_a = nodes[0].cache.as_ref().expect("node 0 should have a cache"); + let cache_b = nodes[1].cache.as_ref().expect("node 1 should have a cache"); let test_key = format!("ripple-test-{}", uuid::Uuid::new_v4()); let test_value = "converged-value"; @@ -407,14 +407,13 @@ async fn cluster_bulk_key_convergence() { }) .await; - let spot_checks: Vec> = futures::future::join_all( - [0, 99, 250, 499].iter().map(|&i| { + let spot_checks: Vec> = + futures::future::join_all([0, 99, 250, 499].iter().map(|&i| { let c = cache_2.clone(); let p = prefix.clone(); async move { c.get(&format!("{p}-{i}")).await } - }), - ) - .await; + })) + .await; spot_checks.iter().enumerate().for_each(|(idx, val)| { assert!( @@ -620,7 +619,10 @@ fn create_account_on_node<'a>( assert_eq!(create_res.status(), StatusCode::OK, "createAccount non-200"); let body: serde_json::Value = create_res.json().await.expect("invalid json"); let did = body["did"].as_str().expect("no did").to_string(); - let access_jwt = body["accessJwt"].as_str().expect("no accessJwt").to_string(); + let access_jwt = body["accessJwt"] + .as_str() + .expect("no accessJwt") + .to_string(); let pool = common::get_test_db_pool().await; let body_text: String = sqlx::query_scalar!( @@ -654,8 +656,10 @@ fn create_account_on_node<'a>( let token = match confirm_res.status() { StatusCode::OK => { - let confirm_body: serde_json::Value = - confirm_res.json().await.expect("invalid json from confirmSignup"); + let confirm_body: serde_json::Value = confirm_res + .json() + .await + .expect("invalid json from confirmSignup"); confirm_body["accessJwt"] .as_str() .unwrap_or(&access_jwt) @@ -744,7 +748,10 @@ async fn cross_node_handle_resolution_from_cache() { let cache_0 = cache_for(nodes, 0); let fake_handle = format!("cached-{}.test", uuid::Uuid::new_v4().simple()); - let fake_did = format!("did:plc:cached{}", &uuid::Uuid::new_v4().simple().to_string()[..16]); + let fake_did = format!( + "did:plc:cached{}", + &uuid::Uuid::new_v4().simple().to_string()[..16] + ); cache_0 .set( @@ -795,7 +802,10 @@ async fn cross_node_cache_delete_observable_via_http() { let cache_1 = cache_for(nodes, 1); let fake_handle = format!("deltest-{}.test", uuid::Uuid::new_v4().simple()); - let fake_did = format!("did:plc:del{}", &uuid::Uuid::new_v4().simple().to_string()[..16]); + let fake_did = format!( + "did:plc:del{}", + &uuid::Uuid::new_v4().simple().to_string()[..16] + ); let cache_key = format!("handle:{fake_handle}"); cache_0 diff --git a/crates/tranquil-pds/tests/sync_conformance.rs b/crates/tranquil-pds/tests/sync_conformance.rs index 8b7a15b..1db281f 100644 --- a/crates/tranquil-pds/tests/sync_conformance.rs +++ b/crates/tranquil-pds/tests/sync_conformance.rs @@ -190,7 +190,11 @@ async fn test_list_repos_shows_status_field() { assert!(takendown_repo.is_some(), "Takendown repo should be in list"); let repo = takendown_repo.unwrap(); assert_eq!(repo["active"], false, "repo should be inactive: {:?}", repo); - assert_eq!(repo["status"], "takendown", "repo status should be takendown: {:?}", repo); + assert_eq!( + repo["status"], "takendown", + "repo status should be takendown: {:?}", + repo + ); } #[tokio::test] diff --git a/crates/tranquil-pds/tests/whole_story.rs b/crates/tranquil-pds/tests/whole_story.rs index ce6cd5b..3272335 100644 --- a/crates/tranquil-pds/tests/whole_story.rs +++ b/crates/tranquil-pds/tests/whole_story.rs @@ -1400,10 +1400,16 @@ async fn test_scale_many_users_social_graph() { .iter() .enumerate() .flat_map(|(i, (follower_did, follower_jwt))| { - users.iter().enumerate() + users + .iter() + .enumerate() .filter(move |(j, _)| *j != i) .map(|(_, (followee_did, _))| { - (follower_did.clone(), follower_jwt.clone(), followee_did.clone()) + ( + follower_did.clone(), + follower_jwt.clone(), + followee_did.clone(), + ) }) .collect::>() }) diff --git a/crates/tranquil-ripple/src/config.rs b/crates/tranquil-ripple/src/config.rs index f770efa..014b71b 100644 --- a/crates/tranquil-ripple/src/config.rs +++ b/crates/tranquil-ripple/src/config.rs @@ -19,7 +19,11 @@ fn parse_env_with_warning(var_name: &str, raw: &str) -> Op match raw.parse::() { Ok(v) => Some(v), Err(_) => { - tracing::warn!(var = var_name, value = raw, "invalid env var value, using default"); + tracing::warn!( + var = var_name, + value = raw, + "invalid env var value, using default" + ); None } } diff --git a/crates/tranquil-ripple/src/crdt/delta.rs b/crates/tranquil-ripple/src/crdt/delta.rs index 1562e5a..c556c26 100644 --- a/crates/tranquil-ripple/src/crdt/delta.rs +++ b/crates/tranquil-ripple/src/crdt/delta.rs @@ -1,5 +1,5 @@ -use super::lww_map::LwwDelta; use super::g_counter::GCounterDelta; +use super::lww_map::LwwDelta; use serde::{Deserialize, Serialize}; const SCHEMA_VERSION: u8 = 1; @@ -21,7 +21,7 @@ impl CrdtDelta { pub fn is_empty(&self) -> bool { self.cache_delta .as_ref() - .map_or(true, |d| d.entries.is_empty()) + .is_none_or(|d| d.entries.is_empty()) && self.rate_limit_deltas.is_empty() } diff --git a/crates/tranquil-ripple/src/crdt/g_counter.rs b/crates/tranquil-ripple/src/crdt/g_counter.rs index 2a88647..37cac82 100644 --- a/crates/tranquil-ripple/src/crdt/g_counter.rs +++ b/crates/tranquil-ripple/src/crdt/g_counter.rs @@ -142,12 +142,10 @@ impl RateLimitStore { self.dirty .iter() .filter_map(|key| { - self.counters - .get(key) - .map(|counter| GCounterDelta { - key: key.clone(), - counter: counter.clone(), - }) + self.counters.get(key).map(|counter| GCounterDelta { + key: key.clone(), + counter: counter.clone(), + }) }) .collect() } @@ -167,7 +165,10 @@ impl RateLimitStore { return 0; } match self.counters.get(key) { - Some(counter) if counter.window_start_ms == Self::aligned_window_start(now_wall_ms, window_ms) => { + Some(counter) + if counter.window_start_ms + == Self::aligned_window_start(now_wall_ms, window_ms) => + { counter.total() } _ => 0, diff --git a/crates/tranquil-ripple/src/crdt/lww_map.rs b/crates/tranquil-ripple/src/crdt/lww_map.rs index cfea9eb..186618d 100644 --- a/crates/tranquil-ripple/src/crdt/lww_map.rs +++ b/crates/tranquil-ripple/src/crdt/lww_map.rs @@ -124,6 +124,12 @@ pub struct LwwMap { estimated_bytes: usize, } +impl Default for LwwMap { + fn default() -> Self { + Self::new() + } +} + impl LwwMap { pub fn new() -> Self { Self { @@ -141,7 +147,14 @@ impl LwwMap { entry.value.clone() } - pub fn set(&mut self, key: String, value: Vec, timestamp: HlcTimestamp, ttl_ms: u64, wall_ms_now: u64) { + pub fn set( + &mut self, + key: String, + value: Vec, + timestamp: HlcTimestamp, + ttl_ms: u64, + wall_ms_now: u64, + ) { let entry = LwwEntry { created_at_wall_ms: wall_ms_now, value: Some(value), @@ -248,6 +261,10 @@ impl LwwMap { self.entries.len() } + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + fn remove_estimated_bytes(&mut self, key: &str) { if let Some(existing) = self.entries.get(key) { let size = existing.entry_byte_size(key); diff --git a/crates/tranquil-ripple/src/crdt/mod.rs b/crates/tranquil-ripple/src/crdt/mod.rs index 1fa6853..808f245 100644 --- a/crates/tranquil-ripple/src/crdt/mod.rs +++ b/crates/tranquil-ripple/src/crdt/mod.rs @@ -1,13 +1,13 @@ pub mod delta; +pub mod g_counter; pub mod hlc; pub mod lww_map; -pub mod g_counter; use crate::config::fnv1a; use delta::CrdtDelta; +use g_counter::RateLimitStore; use hlc::{Hlc, HlcTimestamp}; use lww_map::{LwwDelta, LwwMap}; -use g_counter::RateLimitStore; use parking_lot::{Mutex, RwLock}; use std::time::{SystemTime, UNIX_EPOCH}; @@ -45,9 +45,8 @@ impl ShardedCrdtStore { let shards: Vec> = (0..SHARD_COUNT) .map(|_| RwLock::new(CrdtShard::new(node_id))) .collect(); - let promotions: Vec>> = (0..SHARD_COUNT) - .map(|_| Mutex::new(Vec::new())) - .collect(); + let promotions: Vec>> = + (0..SHARD_COUNT).map(|_| Mutex::new(Vec::new())).collect(); Self { hlc: Mutex::new(Hlc::new(node_id)), shards: shards.into_boxed_slice(), @@ -136,7 +135,9 @@ impl ShardedCrdtStore { let cache_delta = match cache_entries.is_empty() { true => None, - false => Some(LwwDelta { entries: cache_entries }), + false => Some(LwwDelta { + entries: cache_entries, + }), }; CrdtDelta { @@ -159,9 +160,8 @@ impl ShardedCrdtStore { }) .unwrap_or_default(); - let mut max_ts_per_shard: Vec> = (0..self.shards.len()) - .map(|_| None) - .collect(); + let mut max_ts_per_shard: Vec> = + (0..self.shards.len()).map(|_| None).collect(); cache_entries_by_shard.iter().for_each(|&(shard_idx, ts)| { let slot = &mut max_ts_per_shard[shard_idx]; @@ -177,37 +177,39 @@ impl ShardedCrdtStore { .map(|d| (d.key.as_str(), &d.counter)) .collect(); - let mut shard_rl_keys: Vec> = (0..self.shards.len()) - .map(|_| Vec::new()) - .collect(); + let mut shard_rl_keys: Vec> = + (0..self.shards.len()).map(|_| Vec::new()).collect(); rl_index.keys().for_each(|&key| { shard_rl_keys[self.shard_for(key)].push(key); }); - self.shards.iter().enumerate().for_each(|(idx, shard_lock)| { - let has_cache_update = max_ts_per_shard[idx].is_some(); - let has_rl_keys = !shard_rl_keys[idx].is_empty(); - if !has_cache_update && !has_rl_keys { - return; - } - let mut shard = shard_lock.write(); - if let Some(max_ts) = max_ts_per_shard[idx] { - shard.last_broadcast_ts = max_ts; - } - shard_rl_keys[idx].iter().for_each(|&key| { - let still_matches = shard - .rate_limits - .peek_dirty_counter(key) - .zip(rl_index.get(key)) - .is_some_and(|(current, committed)| { - current.window_start_ms == committed.window_start_ms - && current.total() == committed.total() - }); - if still_matches { - shard.rate_limits.clear_single_dirty(key); + self.shards + .iter() + .enumerate() + .for_each(|(idx, shard_lock)| { + let has_cache_update = max_ts_per_shard[idx].is_some(); + let has_rl_keys = !shard_rl_keys[idx].is_empty(); + if !has_cache_update && !has_rl_keys { + return; } + let mut shard = shard_lock.write(); + if let Some(max_ts) = max_ts_per_shard[idx] { + shard.last_broadcast_ts = max_ts; + } + shard_rl_keys[idx].iter().for_each(|&key| { + let still_matches = shard + .rate_limits + .peek_dirty_counter(key) + .zip(rl_index.get(key)) + .is_some_and(|(current, committed)| { + current.window_start_ms == committed.window_start_ms + && current.total() == committed.total() + }); + if still_matches { + shard.rate_limits.clear_single_dirty(key); + } + }); }); - }); } pub fn merge_delta(&self, delta: &CrdtDelta) -> bool { @@ -219,10 +221,10 @@ impl ShardedCrdtStore { return false; } - if let Some(ref cache_delta) = delta.cache_delta { - if let Some(max_ts) = cache_delta.entries.iter().map(|(_, e)| e.timestamp).max() { - let _ = self.hlc.lock().receive(max_ts); - } + if let Some(ref cache_delta) = delta.cache_delta + && let Some(max_ts) = cache_delta.entries.iter().map(|(_, e)| e.timestamp).max() + { + let _ = self.hlc.lock().receive(max_ts); } let mut changed = false; @@ -235,17 +237,20 @@ impl ShardedCrdtStore { entries_by_shard[self.shard_for(key)].push((key.clone(), entry.clone())); }); - entries_by_shard.into_iter().enumerate().for_each(|(idx, entries)| { - if entries.is_empty() { - return; - } - let mut shard = self.shards[idx].write(); - entries.into_iter().for_each(|(key, entry)| { - if shard.cache.merge_entry(key, entry) { - changed = true; + entries_by_shard + .into_iter() + .enumerate() + .for_each(|(idx, entries)| { + if entries.is_empty() { + return; } + let mut shard = self.shards[idx].write(); + entries.into_iter().for_each(|(key, entry)| { + if shard.cache.merge_entry(key, entry) { + changed = true; + } + }); }); - }); } if !delta.rate_limit_deltas.is_empty() { @@ -256,17 +261,20 @@ impl ShardedCrdtStore { rl_by_shard[self.shard_for(&rd.key)].push((rd.key.clone(), &rd.counter)); }); - rl_by_shard.into_iter().enumerate().for_each(|(idx, entries)| { - if entries.is_empty() { - return; - } - let mut shard = self.shards[idx].write(); - entries.into_iter().for_each(|(key, counter)| { - if shard.rate_limits.merge_counter(key, counter) { - changed = true; + rl_by_shard + .into_iter() + .enumerate() + .for_each(|(idx, entries)| { + if entries.is_empty() { + return; } + let mut shard = self.shards[idx].write(); + entries.into_iter().for_each(|(key, counter)| { + if shard.rate_limits.merge_counter(key, counter) { + changed = true; + } + }); }); - }); } changed @@ -274,14 +282,17 @@ impl ShardedCrdtStore { pub fn run_maintenance(&self) { let now = Self::wall_ms_now(); - self.shards.iter().enumerate().for_each(|(idx, shard_lock)| { - let pending: Vec = self.promotions[idx].lock().drain(..).collect(); - let mut shard = shard_lock.write(); - pending.iter().for_each(|key| shard.cache.touch(key)); - shard.cache.gc_tombstones(now); - shard.cache.gc_expired(now); - shard.rate_limits.gc_expired(now); - }); + self.shards + .iter() + .enumerate() + .for_each(|(idx, shard_lock)| { + let pending: Vec = self.promotions[idx].lock().drain(..).collect(); + let mut shard = shard_lock.write(); + pending.iter().for_each(|key| shard.cache.touch(key)); + shard.cache.gc_tombstones(now); + shard.cache.gc_expired(now); + shard.rate_limits.gc_expired(now); + }); } pub fn peek_full_state(&self) -> CrdtDelta { @@ -297,7 +308,9 @@ impl ShardedCrdtStore { let cache_delta = match cache_entries.is_empty() { true => None, - false => Some(LwwDelta { entries: cache_entries }), + false => Some(LwwDelta { + entries: cache_entries, + }), }; CrdtDelta { @@ -327,7 +340,10 @@ impl ShardedCrdtStore { .iter() .map(|s| { let shard = s.read(); - shard.cache.estimated_bytes().saturating_add(shard.rate_limits.estimated_bytes()) + shard + .cache + .estimated_bytes() + .saturating_add(shard.rate_limits.estimated_bytes()) }) .fold(0usize, usize::saturating_add) } @@ -335,7 +351,7 @@ impl ShardedCrdtStore { pub fn evict_lru_round_robin(&self, start_shard: usize) -> Option<(usize, usize)> { (0..self.shards.len()).find_map(|offset| { let idx = (start_shard + offset) & self.shard_mask; - let has_entries = self.shards[idx].read().cache.len() > 0; + let has_entries = !self.shards[idx].read().cache.is_empty(); match has_entries { true => { let mut shard = self.shards[idx].write(); diff --git a/crates/tranquil-ripple/src/engine.rs b/crates/tranquil-ripple/src/engine.rs index 32f3adb..ade193d 100644 --- a/crates/tranquil-ripple/src/engine.rs +++ b/crates/tranquil-ripple/src/engine.rs @@ -17,12 +17,14 @@ impl RippleEngine { pub async fn start( config: RippleConfig, shutdown: CancellationToken, - ) -> Result<(Arc, Arc, SocketAddr), RippleStartError> { + ) -> Result<(Arc, Arc, SocketAddr), RippleStartError> + { let store = Arc::new(ShardedCrdtStore::new(config.machine_id)); - let (transport, incoming_rx) = Transport::bind(config.bind_addr, config.machine_id, shutdown.clone()) - .await - .map_err(|e| RippleStartError::Bind(e.to_string()))?; + let (transport, incoming_rx) = + Transport::bind(config.bind_addr, config.machine_id, shutdown.clone()) + .await + .map_err(|e| RippleStartError::Bind(e.to_string()))?; let transport = Arc::new(transport); @@ -79,8 +81,7 @@ impl RippleEngine { }); let cache: Arc = Arc::new(RippleCache::new(store.clone())); - let rate_limiter: Arc = - Arc::new(RippleRateLimiter::new(store)); + let rate_limiter: Arc = Arc::new(RippleRateLimiter::new(store)); metrics::describe_metrics(); diff --git a/crates/tranquil-ripple/src/eviction.rs b/crates/tranquil-ripple/src/eviction.rs index 9bbbdd8..8d82e60 100644 --- a/crates/tranquil-ripple/src/eviction.rs +++ b/crates/tranquil-ripple/src/eviction.rs @@ -39,24 +39,22 @@ impl MemoryBudget { let mut remaining = total_bytes; let mut next_shard: usize = self.next_shard.load(std::sync::atomic::Ordering::Relaxed); let mut evicted: usize = 0; - (0..batch_size).try_for_each(|_| { - match remaining > max_bytes { - true => { - match store.evict_lru_round_robin(next_shard) { - Some((ns, freed)) => { - next_shard = ns; - remaining = remaining.saturating_sub(freed); - evicted += 1; - Ok(()) - } - None => Err(()), + (0..batch_size) + .try_for_each(|_| match remaining > max_bytes { + true => match store.evict_lru_round_robin(next_shard) { + Some((ns, freed)) => { + next_shard = ns; + remaining = remaining.saturating_sub(freed); + evicted += 1; + Ok(()) } - } + None => Err(()), + }, false => Err(()), - } - }) - .ok(); - self.next_shard.store(next_shard, std::sync::atomic::Ordering::Relaxed); + }) + .ok(); + self.next_shard + .store(next_shard, std::sync::atomic::Ordering::Relaxed); if evicted > 0 { metrics::record_evictions(evicted); let cache_bytes_after = store.cache_estimated_bytes(); @@ -92,11 +90,7 @@ mod tests { let store = ShardedCrdtStore::new(1); let budget = MemoryBudget::new(100); (0..50).for_each(|i| { - store.cache_set( - format!("key-{i}"), - vec![0u8; 64], - 60_000, - ); + store.cache_set(format!("key-{i}"), vec![0u8; 64], 60_000); }); budget.enforce(&store); assert!(store.total_estimated_bytes() <= 100); diff --git a/crates/tranquil-ripple/src/gossip.rs b/crates/tranquil-ripple/src/gossip.rs index 9e3bec4..defb2a0 100644 --- a/crates/tranquil-ripple/src/gossip.rs +++ b/crates/tranquil-ripple/src/gossip.rs @@ -1,11 +1,11 @@ +use crate::crdt::ShardedCrdtStore; use crate::crdt::delta::CrdtDelta; use crate::crdt::lww_map::LwwDelta; -use crate::crdt::ShardedCrdtStore; use crate::metrics; use crate::transport::{ChannelTag, IncomingFrame, Transport}; use foca::{Config, Foca, Notification, Runtime, Timer}; -use rand::rngs::StdRng; use rand::SeedableRng; +use rand::rngs::StdRng; use std::collections::HashSet; use std::fmt; use std::net::SocketAddr; @@ -135,12 +135,12 @@ impl Runtime for &mut BufferedRuntime { } fn send_to(&mut self, to: PeerId, data: &[u8]) { - self.actions - .push(RuntimeAction::SendTo(to, data.to_vec())); + self.actions.push(RuntimeAction::SendTo(to, data.to_vec())); } fn submit_after(&mut self, event: Timer, after: Duration) { - self.actions.push(RuntimeAction::ScheduleTimer(event, after)); + self.actions + .push(RuntimeAction::ScheduleTimer(event, after)); } } @@ -151,11 +151,7 @@ pub struct GossipEngine { } impl GossipEngine { - pub fn new( - transport: Arc, - store: Arc, - local_id: PeerId, - ) -> Self { + pub fn new(transport: Arc, store: Arc, local_id: PeerId) -> Self { Self { transport, store, @@ -208,10 +204,16 @@ impl GossipEngine { } }); - drain_runtime_actions(&mut runtime, &transport, &timer_tx, &mut members, &store, &shutdown); + drain_runtime_actions( + &mut runtime, + &transport, + &timer_tx, + &mut members, + &store, + &shutdown, + ); - let mut gossip_tick = - tokio::time::interval(Duration::from_millis(gossip_interval_ms)); + let mut gossip_tick = tokio::time::interval(Duration::from_millis(gossip_interval_ms)); gossip_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { @@ -401,16 +403,18 @@ fn drain_runtime_actions( metrics::set_gossip_peers(members.peer_count()); let snapshot = store.peek_full_state(); if !snapshot.is_empty() { - chunk_and_serialize(&snapshot).into_iter().for_each(|chunk| { - let t = transport.clone(); - let c = shutdown.clone(); - tokio::spawn(async move { - tokio::select! { - _ = c.cancelled() => {} - _ = t.send(addr, ChannelTag::CrdtSync, &chunk) => {} - } + chunk_and_serialize(&snapshot) + .into_iter() + .for_each(|chunk| { + let t = transport.clone(); + let c = shutdown.clone(); + tokio::spawn(async move { + tokio::select! { + _ = c.cancelled() => {} + _ = t.send(addr, ChannelTag::CrdtSync, &chunk) => {} + } + }); }); - }); } } RuntimeAction::MemberDown(addr) => { @@ -449,7 +453,9 @@ fn split_and_serialize(delta: CrdtDelta) -> Vec> { source_node, cache_delta: match cache_entries.is_empty() { true => None, - false => Some(LwwDelta { entries: cache_entries }), + false => Some(LwwDelta { + entries: cache_entries, + }), }, rate_limit_deltas: rl_deltas, }; diff --git a/crates/tranquil-ripple/src/metrics.rs b/crates/tranquil-ripple/src/metrics.rs index c49e629..34ae8f6 100644 --- a/crates/tranquil-ripple/src/metrics.rs +++ b/crates/tranquil-ripple/src/metrics.rs @@ -13,10 +13,7 @@ pub fn describe_metrics() { "tranquil_ripple_gossip_peers", "Number of active gossip peers" ); - metrics::describe_counter!( - "tranquil_ripple_cache_hits_total", - "Total cache read hits" - ); + metrics::describe_counter!("tranquil_ripple_cache_hits_total", "Total cache read hits"); metrics::describe_counter!( "tranquil_ripple_cache_misses_total", "Total cache read misses" diff --git a/crates/tranquil-ripple/src/transport.rs b/crates/tranquil-ripple/src/transport.rs index 97a6a38..511d6f2 100644 --- a/crates/tranquil-ripple/src/transport.rs +++ b/crates/tranquil-ripple/src/transport.rs @@ -222,7 +222,10 @@ impl Transport { let conn_gen = self.conn_generation.fetch_add(1, Ordering::Relaxed); self.connections.lock().insert( target, - ConnectionWriter { tx: write_tx.clone(), generation: conn_gen }, + ConnectionWriter { + tx: write_tx.clone(), + generation: conn_gen, + }, ); if let Some(frame) = encode_frame(tag, data) { let _ = write_tx.try_send(frame); @@ -412,7 +415,11 @@ fn decode_frame(buf: &mut BytesMut) -> DecodeResult { } let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; if len > MAX_FRAME_SIZE { - tracing::warn!(frame_len = len, max = MAX_FRAME_SIZE, "oversized frame, closing connection"); + tracing::warn!( + frame_len = len, + max = MAX_FRAME_SIZE, + "oversized frame, closing connection" + ); buf.clear(); return DecodeResult::Corrupt; } diff --git a/crates/tranquil-ripple/tests/two_node_convergence.rs b/crates/tranquil-ripple/tests/two_node_convergence.rs index 72e7d25..8410570 100644 --- a/crates/tranquil-ripple/tests/two_node_convergence.rs +++ b/crates/tranquil-ripple/tests/two_node_convergence.rs @@ -168,7 +168,10 @@ async fn two_node_lww_conflict_resolution() { let val_a = cache_a.get(&key).await.expect("A should have the key"); let val_b = cache_b.get(&key).await.expect("B should have the key"); - assert_eq!(val_a, val_b, "both nodes must agree on the same value after LWW resolution"); + assert_eq!( + val_a, val_b, + "both nodes must agree on the same value after LWW resolution" + ); shutdown.cancel(); } @@ -240,8 +243,14 @@ async fn two_node_ttl_expiration() { tokio::time::sleep(Duration::from_secs(3)).await; - assert!(cache_a.get(&key).await.is_none(), "A should have expired the key"); - assert!(cache_b.get(&key).await.is_none(), "B should have expired the key"); + assert!( + cache_a.get(&key).await.is_none(), + "A should have expired the key" + ); + assert!( + cache_b.get(&key).await.is_none(), + "B should have expired the key" + ); shutdown.cancel(); } @@ -715,50 +724,49 @@ async fn two_node_stress_concurrent_load() { let shutdown = CancellationToken::new(); let ((cache_a, rl_a), (cache_b, rl_b)) = spawn_pair(shutdown.clone()).await; - let tasks: Vec> = (0u32..8).map(|task_id| { - let cache = match task_id < 4 { - true => cache_a.clone(), - false => cache_b.clone(), - }; - let rl = match task_id < 4 { - true => rl_a.clone(), - false => rl_b.clone(), - }; - tokio::spawn(async move { - let value = vec![0xABu8; 1024]; - futures::future::join_all((0u32..500).map(|op| { - let cache = cache.clone(); - let rl = rl.clone(); - let value = value.clone(); - async move { - let key_idx = op % 100; - let key = format!("stress-{task_id}-{key_idx}"); - match op % 4 { - 0 | 1 => { - cache - .set_bytes(&key, &value, Duration::from_secs(120)) - .await - .expect("set_bytes failed"); - } - 2 => { - let _ = cache.get(&key).await; - } - _ => { - let _ = rl.check_rate_limit(&key, 1000, 60_000).await; + let tasks: Vec> = (0u32..8) + .map(|task_id| { + let cache = match task_id < 4 { + true => cache_a.clone(), + false => cache_b.clone(), + }; + let rl = match task_id < 4 { + true => rl_a.clone(), + false => rl_b.clone(), + }; + tokio::spawn(async move { + let value = vec![0xABu8; 1024]; + futures::future::join_all((0u32..500).map(|op| { + let cache = cache.clone(); + let rl = rl.clone(); + let value = value.clone(); + async move { + let key_idx = op % 100; + let key = format!("stress-{task_id}-{key_idx}"); + match op % 4 { + 0 | 1 => { + cache + .set_bytes(&key, &value, Duration::from_secs(120)) + .await + .expect("set_bytes failed"); + } + 2 => { + let _ = cache.get(&key).await; + } + _ => { + let _ = rl.check_rate_limit(&key, 1000, 60_000).await; + } } } - } - })) - .await; + })) + .await; + }) }) - }).collect(); + .collect(); - let results = tokio::time::timeout( - Duration::from_secs(30), - futures::future::join_all(tasks), - ) - .await - .expect("stress test timed out after 30s"); + let results = tokio::time::timeout(Duration::from_secs(30), futures::future::join_all(tasks)) + .await + .expect("stress test timed out after 30s"); results.into_iter().enumerate().for_each(|(i, r)| { r.unwrap_or_else(|e| panic!("task {i} panicked: {e}")); diff --git a/crates/tranquil-storage/Cargo.toml b/crates/tranquil-storage/Cargo.toml index 815236a..513415d 100644 --- a/crates/tranquil-storage/Cargo.toml +++ b/crates/tranquil-storage/Cargo.toml @@ -4,12 +4,16 @@ version.workspace = true edition.workspace = true license.workspace = true +[features] +default = [] +s3 = ["dep:aws-config", "dep:aws-sdk-s3"] + [dependencies] tranquil-infra = { workspace = true } async-trait = { workspace = true } -aws-config = { workspace = true } -aws-sdk-s3 = { workspace = true } +aws-config = { workspace = true, optional = true } +aws-sdk-s3 = { workspace = true, optional = true } bytes = { workspace = true } futures = { workspace = true } sha2 = { workspace = true } diff --git a/crates/tranquil-storage/src/lib.rs b/crates/tranquil-storage/src/lib.rs index b683688..2fc6dfd 100644 --- a/crates/tranquil-storage/src/lib.rs +++ b/crates/tranquil-storage/src/lib.rs @@ -4,12 +4,6 @@ pub use tranquil_infra::{ }; use async_trait::async_trait; -use aws_config::BehaviorVersion; -use aws_config::meta::region::RegionProviderChain; -use aws_sdk_s3::Client; -use aws_sdk_s3::primitives::ByteStream; -use aws_sdk_s3::types::CompletedMultipartUpload; -use aws_sdk_s3::types::CompletedPart; use bytes::Bytes; use futures::Stream; use sha2::{Digest, Sha256}; @@ -17,7 +11,6 @@ use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; -const MIN_PART_SIZE: usize = 5 * 1024 * 1024; const EXDEV: i32 = 18; const CID_SHARD_PREFIX_LEN: usize = 9; @@ -109,360 +102,385 @@ fn map_io_not_found(key: &str) -> impl FnOnce(std::io::Error) -> StorageError + } } -pub struct S3BlobStorage { - client: Client, - bucket: String, -} +#[cfg(feature = "s3")] +mod s3 { + use super::*; + use aws_config::BehaviorVersion; + use aws_config::meta::region::RegionProviderChain; + use aws_sdk_s3::Client; + use aws_sdk_s3::primitives::ByteStream; + use aws_sdk_s3::types::CompletedMultipartUpload; + use aws_sdk_s3::types::CompletedPart; -impl S3BlobStorage { - pub async fn new() -> Self { - let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set"); - let client = create_s3_client().await; - Self { client, bucket } + const MIN_PART_SIZE: usize = 5 * 1024 * 1024; + + pub struct S3BlobStorage { + client: Client, + bucket: String, } - pub async fn with_bucket(bucket: String) -> Self { - let client = create_s3_client().await; - Self { client, bucket } - } -} - -async fn create_s3_client() -> Client { - let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); - - let config = aws_config::defaults(BehaviorVersion::latest()) - .region(region_provider) - .load() - .await; - - std::env::var("S3_ENDPOINT").ok().map_or_else( - || Client::new(&config), - |endpoint| { - let s3_config = aws_sdk_s3::config::Builder::from(&config) - .endpoint_url(endpoint) - .force_path_style(true) - .build(); - Client::from_conf(s3_config) - }, - ) -} - -pub struct S3BackupStorage { - client: Client, - bucket: String, -} - -impl S3BackupStorage { - pub async fn new() -> Option { - let bucket = std::env::var("BACKUP_S3_BUCKET").ok()?; - let client = create_s3_client().await; - Some(Self { client, bucket }) - } -} - -#[async_trait] -impl BackupStorage for S3BackupStorage { - async fn put_backup(&self, did: &str, rev: &str, data: &[u8]) -> Result { - let key = format!("{}/{}.car", did, rev); - self.client - .put_object() - .bucket(&self.bucket) - .key(&key) - .body(ByteStream::from(Bytes::copy_from_slice(data))) - .send() - .await - .map_err(|e| StorageError::Backend(e.to_string()))?; - - Ok(key) - } - - async fn get_backup(&self, storage_key: &str) -> Result { - let resp = self - .client - .get_object() - .bucket(&self.bucket) - .key(storage_key) - .send() - .await - .map_err(|e| StorageError::Backend(e.to_string()))?; - - resp.body - .collect() - .await - .map(|agg| agg.into_bytes()) - .map_err(|e| StorageError::Backend(e.to_string())) - } - - async fn delete_backup(&self, storage_key: &str) -> Result<(), StorageError> { - self.client - .delete_object() - .bucket(&self.bucket) - .key(storage_key) - .send() - .await - .map_err(|e| StorageError::Backend(e.to_string()))?; - - Ok(()) - } -} - -#[async_trait] -impl BlobStorage for S3BlobStorage { - async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { - self.put_bytes(key, Bytes::copy_from_slice(data)).await - } - - async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { - self.client - .put_object() - .bucket(&self.bucket) - .key(key) - .body(ByteStream::from(data)) - .send() - .await - .map_err(|e| StorageError::Backend(e.to_string()))?; - - Ok(()) - } - - async fn get(&self, key: &str) -> Result, StorageError> { - self.get_bytes(key).await.map(|b| b.to_vec()) - } - - async fn get_bytes(&self, key: &str) -> Result { - let resp = self - .client - .get_object() - .bucket(&self.bucket) - .key(key) - .send() - .await - .map_err(|e| StorageError::Backend(e.to_string()))?; - - resp.body - .collect() - .await - .map(|agg| agg.into_bytes()) - .map_err(|e| StorageError::Backend(e.to_string())) - } - - async fn get_head(&self, key: &str, size: usize) -> Result { - let range = format!("bytes=0-{}", size.saturating_sub(1)); - let resp = self - .client - .get_object() - .bucket(&self.bucket) - .key(key) - .range(range) - .send() - .await - .map_err(|e| StorageError::Backend(e.to_string()))?; - - resp.body - .collect() - .await - .map(|agg| agg.into_bytes()) - .map_err(|e| StorageError::Backend(e.to_string())) - } - - async fn delete(&self, key: &str) -> Result<(), StorageError> { - self.client - .delete_object() - .bucket(&self.bucket) - .key(key) - .send() - .await - .map_err(|e| StorageError::Backend(e.to_string()))?; - - Ok(()) - } - - async fn put_stream( - &self, - key: &str, - stream: Pin> + Send>>, - ) -> Result { - use futures::StreamExt; - - let create_resp = self - .client - .create_multipart_upload() - .bucket(&self.bucket) - .key(key) - .send() - .await - .map_err(|e| { - StorageError::Backend(format!("Failed to create multipart upload: {}", e)) - })?; - - let upload_id = create_resp - .upload_id() - .ok_or_else(|| StorageError::Backend("No upload ID returned".to_string()))? - .to_string(); - - let upload_part = |client: &Client, - bucket: &str, - key: &str, - upload_id: &str, - part_num: i32, - data: Vec| - -> std::pin::Pin< - Box> + Send>, - > { - let client = client.clone(); - let bucket = bucket.to_string(); - let key = key.to_string(); - let upload_id = upload_id.to_string(); - Box::pin(async move { - let resp = client - .upload_part() - .bucket(&bucket) - .key(&key) - .upload_id(&upload_id) - .part_number(part_num) - .body(ByteStream::from(data)) - .send() - .await - .map_err(|e| StorageError::Backend(format!("Failed to upload part: {}", e)))?; - - let etag = resp - .e_tag() - .ok_or_else(|| StorageError::Backend("No ETag returned for part".to_string()))? - .to_string(); - - Ok(CompletedPart::builder() - .part_number(part_num) - .e_tag(etag) - .build()) - }) - }; - - struct UploadState { - hasher: Sha256, - total_size: u64, - part_number: i32, - completed_parts: Vec, - buffer: Vec, + impl S3BlobStorage { + pub async fn new() -> Self { + let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set"); + let client = create_s3_client().await; + Self { client, bucket } } - let initial_state = UploadState { - hasher: Sha256::new(), - total_size: 0, - part_number: 1, - completed_parts: Vec::new(), - buffer: Vec::with_capacity(MIN_PART_SIZE), - }; + pub async fn with_bucket(bucket: String) -> Self { + let client = create_s3_client().await; + Self { client, bucket } + } + } - let abort_upload = || async { - let _ = self + async fn create_s3_client() -> Client { + let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); + + let config = aws_config::defaults(BehaviorVersion::latest()) + .region(region_provider) + .load() + .await; + + std::env::var("S3_ENDPOINT").ok().map_or_else( + || Client::new(&config), + |endpoint| { + let s3_config = aws_sdk_s3::config::Builder::from(&config) + .endpoint_url(endpoint) + .force_path_style(true) + .build(); + Client::from_conf(s3_config) + }, + ) + } + + pub struct S3BackupStorage { + client: Client, + bucket: String, + } + + impl S3BackupStorage { + pub async fn new() -> Option { + let bucket = std::env::var("BACKUP_S3_BUCKET").ok()?; + let client = create_s3_client().await; + Some(Self { client, bucket }) + } + } + + #[async_trait] + impl BackupStorage for S3BackupStorage { + async fn put_backup( + &self, + did: &str, + rev: &str, + data: &[u8], + ) -> Result { + let key = format!("{}/{}.car", did, rev); + self.client + .put_object() + .bucket(&self.bucket) + .key(&key) + .body(ByteStream::from(Bytes::copy_from_slice(data))) + .send() + .await + .map_err(|e| StorageError::Backend(e.to_string()))?; + + Ok(key) + } + + async fn get_backup(&self, storage_key: &str) -> Result { + let resp = self .client - .abort_multipart_upload() + .get_object() + .bucket(&self.bucket) + .key(storage_key) + .send() + .await + .map_err(|e| StorageError::Backend(e.to_string()))?; + + resp.body + .collect() + .await + .map(|agg| agg.into_bytes()) + .map_err(|e| StorageError::Backend(e.to_string())) + } + + async fn delete_backup(&self, storage_key: &str) -> Result<(), StorageError> { + self.client + .delete_object() + .bucket(&self.bucket) + .key(storage_key) + .send() + .await + .map_err(|e| StorageError::Backend(e.to_string()))?; + + Ok(()) + } + } + + #[async_trait] + impl BlobStorage for S3BlobStorage { + async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { + self.put_bytes(key, Bytes::copy_from_slice(data)).await + } + + async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { + self.client + .put_object() + .bucket(&self.bucket) + .key(key) + .body(ByteStream::from(data)) + .send() + .await + .map_err(|e| StorageError::Backend(e.to_string()))?; + + Ok(()) + } + + async fn get(&self, key: &str) -> Result, StorageError> { + self.get_bytes(key).await.map(|b| b.to_vec()) + } + + async fn get_bytes(&self, key: &str) -> Result { + let resp = self + .client + .get_object() + .bucket(&self.bucket) + .key(key) + .send() + .await + .map_err(|e| StorageError::Backend(e.to_string()))?; + + resp.body + .collect() + .await + .map(|agg| agg.into_bytes()) + .map_err(|e| StorageError::Backend(e.to_string())) + } + + async fn get_head(&self, key: &str, size: usize) -> Result { + let range = format!("bytes=0-{}", size.saturating_sub(1)); + let resp = self + .client + .get_object() + .bucket(&self.bucket) + .key(key) + .range(range) + .send() + .await + .map_err(|e| StorageError::Backend(e.to_string()))?; + + resp.body + .collect() + .await + .map(|agg| agg.into_bytes()) + .map_err(|e| StorageError::Backend(e.to_string())) + } + + async fn delete(&self, key: &str) -> Result<(), StorageError> { + self.client + .delete_object() + .bucket(&self.bucket) + .key(key) + .send() + .await + .map_err(|e| StorageError::Backend(e.to_string()))?; + + Ok(()) + } + + async fn put_stream( + &self, + key: &str, + stream: Pin> + Send>>, + ) -> Result { + use futures::StreamExt; + + let create_resp = self + .client + .create_multipart_upload() + .bucket(&self.bucket) + .key(key) + .send() + .await + .map_err(|e| { + StorageError::Backend(format!("Failed to create multipart upload: {}", e)) + })?; + + let upload_id = create_resp + .upload_id() + .ok_or_else(|| StorageError::Backend("No upload ID returned".to_string()))? + .to_string(); + + let upload_part = |client: &Client, + bucket: &str, + key: &str, + upload_id: &str, + part_num: i32, + data: Vec| + -> std::pin::Pin< + Box> + Send>, + > { + let client = client.clone(); + let bucket = bucket.to_string(); + let key = key.to_string(); + let upload_id = upload_id.to_string(); + Box::pin(async move { + let resp = client + .upload_part() + .bucket(&bucket) + .key(&key) + .upload_id(&upload_id) + .part_number(part_num) + .body(ByteStream::from(data)) + .send() + .await + .map_err(|e| { + StorageError::Backend(format!("Failed to upload part: {}", e)) + })?; + + let etag = resp + .e_tag() + .ok_or_else(|| { + StorageError::Backend("No ETag returned for part".to_string()) + })? + .to_string(); + + Ok(CompletedPart::builder() + .part_number(part_num) + .e_tag(etag) + .build()) + }) + }; + + struct UploadState { + hasher: Sha256, + total_size: u64, + part_number: i32, + completed_parts: Vec, + buffer: Vec, + } + + let initial_state = UploadState { + hasher: Sha256::new(), + total_size: 0, + part_number: 1, + completed_parts: Vec::new(), + buffer: Vec::with_capacity(MIN_PART_SIZE), + }; + + let abort_upload = || async { + let _ = self + .client + .abort_multipart_upload() + .bucket(&self.bucket) + .key(key) + .upload_id(&upload_id) + .send() + .await; + }; + + let result: Result = { + let mut state = initial_state; + + let chunk_results: Vec> = stream.collect().await; + + for chunk_result in chunk_results { + match chunk_result { + Ok(chunk) => { + state.hasher.update(&chunk); + state.total_size += chunk.len() as u64; + state.buffer.extend_from_slice(&chunk); + + if state.buffer.len() >= MIN_PART_SIZE { + let part_data = std::mem::replace( + &mut state.buffer, + Vec::with_capacity(MIN_PART_SIZE), + ); + let part = upload_part( + &self.client, + &self.bucket, + key, + &upload_id, + state.part_number, + part_data, + ) + .await?; + state.completed_parts.push(part); + state.part_number += 1; + } + } + Err(e) => { + abort_upload().await; + return Err(StorageError::Io(e)); + } + } + } + + Ok(state) + }; + + let mut state = result?; + + if !state.buffer.is_empty() { + let part = upload_part( + &self.client, + &self.bucket, + key, + &upload_id, + state.part_number, + std::mem::take(&mut state.buffer), + ) + .await?; + state.completed_parts.push(part); + } + + if state.completed_parts.is_empty() { + abort_upload().await; + return Err(StorageError::Other("Empty upload".to_string())); + } + + let completed_upload = CompletedMultipartUpload::builder() + .set_parts(Some(state.completed_parts)) + .build(); + + self.client + .complete_multipart_upload() .bucket(&self.bucket) .key(key) .upload_id(&upload_id) + .multipart_upload(completed_upload) .send() - .await; - }; + .await + .map_err(|e| { + StorageError::Backend(format!("Failed to complete multipart upload: {}", e)) + })?; - let result: Result = { - let mut state = initial_state; - - let chunk_results: Vec> = stream.collect().await; - - for chunk_result in chunk_results { - match chunk_result { - Ok(chunk) => { - state.hasher.update(&chunk); - state.total_size += chunk.len() as u64; - state.buffer.extend_from_slice(&chunk); - - if state.buffer.len() >= MIN_PART_SIZE { - let part_data = std::mem::replace( - &mut state.buffer, - Vec::with_capacity(MIN_PART_SIZE), - ); - let part = upload_part( - &self.client, - &self.bucket, - key, - &upload_id, - state.part_number, - part_data, - ) - .await?; - state.completed_parts.push(part); - state.part_number += 1; - } - } - Err(e) => { - abort_upload().await; - return Err(StorageError::Io(e)); - } - } - } - - Ok(state) - }; - - let mut state = result?; - - if !state.buffer.is_empty() { - let part = upload_part( - &self.client, - &self.bucket, - key, - &upload_id, - state.part_number, - std::mem::take(&mut state.buffer), - ) - .await?; - state.completed_parts.push(part); + let hash: [u8; 32] = state.hasher.finalize().into(); + Ok(StreamUploadResult { + sha256_hash: hash, + size: state.total_size, + }) } - if state.completed_parts.is_empty() { - abort_upload().await; - return Err(StorageError::Other("Empty upload".to_string())); + async fn copy(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> { + let copy_source = format!("{}/{}", self.bucket, src_key); + + self.client + .copy_object() + .bucket(&self.bucket) + .copy_source(©_source) + .key(dst_key) + .send() + .await + .map_err(|e| StorageError::Backend(format!("Failed to copy object: {}", e)))?; + + Ok(()) } - - let completed_upload = CompletedMultipartUpload::builder() - .set_parts(Some(state.completed_parts)) - .build(); - - self.client - .complete_multipart_upload() - .bucket(&self.bucket) - .key(key) - .upload_id(&upload_id) - .multipart_upload(completed_upload) - .send() - .await - .map_err(|e| { - StorageError::Backend(format!("Failed to complete multipart upload: {}", e)) - })?; - - let hash: [u8; 32] = state.hasher.finalize().into(); - Ok(StreamUploadResult { - sha256_hash: hash, - size: state.total_size, - }) - } - - async fn copy(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> { - let copy_source = format!("{}/{}", self.bucket, src_key); - - self.client - .copy_object() - .bucket(&self.bucket) - .copy_source(©_source) - .key(dst_key) - .send() - .await - .map_err(|e| StorageError::Backend(format!("Failed to copy object: {}", e)))?; - - Ok(()) } } +#[cfg(feature = "s3")] +pub use s3::{S3BackupStorage, S3BlobStorage}; + pub struct FilesystemBlobStorage { base_path: PathBuf, tmp_path: PathBuf, @@ -686,10 +704,18 @@ pub async fn create_blob_storage() -> Arc { let backend = std::env::var("BLOB_STORAGE_BACKEND").unwrap_or_else(|_| "filesystem".into()); match backend.as_str() { + #[cfg(feature = "s3")] "s3" => { tracing::info!("Initializing S3 blob storage"); Arc::new(S3BlobStorage::new().await) } + #[cfg(not(feature = "s3"))] + "s3" => { + panic!( + "BLOB_STORAGE_BACKEND=s3 but binary was compiled without s3 feature. \ + Rebuild with --features s3 to enable S3 storage." + ); + } _ => { tracing::info!("Initializing filesystem blob storage"); FilesystemBlobStorage::from_env() @@ -719,6 +745,7 @@ pub async fn create_backup_storage() -> Option> { let backend = std::env::var("BACKUP_STORAGE_BACKEND").unwrap_or_else(|_| "filesystem".into()); match backend.as_str() { + #[cfg(feature = "s3")] "s3" => S3BackupStorage::new().await.map_or_else( || { tracing::error!( @@ -732,6 +759,14 @@ pub async fn create_backup_storage() -> Option> { Some(Arc::new(storage) as Arc) }, ), + #[cfg(not(feature = "s3"))] + "s3" => { + tracing::error!( + "BACKUP_STORAGE_BACKEND=s3 but binary was compiled without s3 feature. \ + Backups will be disabled." + ); + None + } _ => FilesystemBackupStorage::from_env().await.map_or_else( |e| { tracing::error!(