fix: smaller docker img

This commit is contained in:
lewis
2026-02-08 15:48:24 +02:00
committed by Tangled
parent ec36b8ddc7
commit ea27772a47
27 changed files with 855 additions and 767 deletions

83
Cargo.lock generated
View File

@@ -2744,22 +2744,6 @@ dependencies = [
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper 1.8.1",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.19"
@@ -3585,23 +3569,6 @@ dependencies = [
"web-time",
]
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework 2.11.1",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nom"
version = "7.1.3"
@@ -4498,12 +4465,10 @@ dependencies = [
"http-body-util",
"hyper 1.8.1",
"hyper-rustls 0.27.7",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"quinn",
@@ -4514,7 +4479,6 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.26.4",
"tower",
"tower-http",
@@ -4661,7 +4625,7 @@ dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
"security-framework 3.5.1",
"security-framework",
]
[[package]]
@@ -4800,19 +4764,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework"
version = "3.5.1"
@@ -5525,19 +5476,6 @@ dependencies = [
"libc",
]
[[package]]
name = "tempfile"
version = "3.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]]
name = "testcontainers"
version = "0.26.2"
@@ -5709,16 +5647,6 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
@@ -5758,10 +5686,12 @@ checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
dependencies = [
"futures-util",
"log",
"native-tls",
"rustls 0.23.35",
"rustls-pki-types",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.26.4",
"tungstenite",
"webpki-roots 0.26.11",
]
[[package]]
@@ -6274,8 +6204,9 @@ dependencies = [
"http 1.4.0",
"httparse",
"log",
"native-tls",
"rand 0.9.2",
"rustls 0.23.35",
"rustls-pki-types",
"sha1",
"thiserror 2.0.17",
"utf-8",

View File

@@ -81,7 +81,7 @@ p384 = { version = "0.13", features = ["ecdsa"] }
rand = "0.8"
redis = { version = "1.0", features = ["tokio-comp", "connection-manager"] }
regex = "1"
reqwest = { version = "0.12", features = ["json"] }
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-webpki-roots", "http2", "charset", "macos-system-configuration"] }
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11"
serde_ipld_dagcbor = "0.6"
@@ -91,9 +91,9 @@ sha2 = "0.10"
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "uuid", "chrono", "json"] }
subtle = "2.5"
thiserror = "2.0"
tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process"] }
tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process", "io-util", "fs"] }
tokio-util = "0.7.18"
tokio-tungstenite = { version = "0.28", features = ["native-tls"] }
tokio-tungstenite = { version = "0.28", features = ["rustls-tls-webpki-roots"] }
totp-rs = { version = "5", features = ["qr"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["cors"] }
@@ -111,3 +111,8 @@ ctor = "0.6"
testcontainers = "0.26"
testcontainers-modules = { version = "0.14", features = ["postgres"] }
wiremock = "0.6"
[profile.release]
lto = "thin"
strip = true
codegen-units = 1

View File

@@ -1,13 +1,18 @@
FROM rust:1.92-alpine AS builder
RUN apk add --no-cache ca-certificates openssl openssl-dev openssl-libs-static pkgconfig musl-dev
RUN apk add --no-cache ca-certificates musl-dev pkgconfig openssl-dev openssl-libs-static
WORKDIR /app
ARG SLIM="false"
COPY Cargo.toml Cargo.lock ./
COPY crates ./crates
COPY .sqlx ./.sqlx
COPY migrations ./crates/tranquil-pds/migrations
RUN --mount=type=cache,target=/usr/local/cargo/registry \
--mount=type=cache,target=/app/target \
SQLX_OFFLINE=true cargo build --release -p tranquil-pds && \
if [ "$SLIM" = "true" ]; then \
SQLX_OFFLINE=true cargo build --release -p tranquil-pds --no-default-features; \
else \
SQLX_OFFLINE=true cargo build --release -p tranquil-pds; \
fi && \
cp target/release/tranquil-pds /tmp/tranquil-pds
FROM alpine:3.23

View File

@@ -4,12 +4,16 @@ version.workspace = true
edition.workspace = true
license.workspace = true
[features]
default = []
valkey = ["dep:redis"]
[dependencies]
tranquil-infra = { workspace = true }
tranquil-ripple = { workspace = true }
async-trait = { workspace = true }
base64 = { workspace = true }
redis = { workspace = true }
redis = { workspace = true, optional = true }
tokio-util = { workspace = true }
tracing = { workspace = true }

View File

@@ -1,72 +1,135 @@
pub use tranquil_infra::{Cache, CacheError, DistributedRateLimiter};
use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct ValkeyCache {
conn: redis::aio::ConnectionManager,
}
#[cfg(feature = "valkey")]
mod valkey {
use super::*;
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
impl ValkeyCache {
pub async fn new(url: &str) -> Result<Self, CacheError> {
let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?;
let manager = client
.get_connection_manager()
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
Ok(Self { conn: manager })
#[derive(Clone)]
pub struct ValkeyCache {
conn: redis::aio::ConnectionManager,
}
pub fn connection(&self) -> redis::aio::ConnectionManager {
self.conn.clone()
impl ValkeyCache {
pub async fn new(url: &str) -> Result<Self, CacheError> {
let client =
redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?;
let manager = client
.get_connection_manager()
.await
.map_err(|e| CacheError::Connection(e.to_string()))?;
Ok(Self { conn: manager })
}
pub fn connection(&self) -> redis::aio::ConnectionManager {
self.conn.clone()
}
}
#[async_trait]
impl Cache for ValkeyCache {
async fn get(&self, key: &str) -> Option<String> {
let mut conn = self.conn.clone();
redis::cmd("GET")
.arg(key)
.query_async::<Option<String>>(&mut conn)
.await
.ok()
.flatten()
}
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
redis::cmd("SET")
.arg(key)
.arg(value)
.arg("PX")
.arg(ttl.as_millis().min(i64::MAX as u128) as i64)
.query_async::<()>(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
redis::cmd("DEL")
.arg(key)
.query_async::<()>(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
self.get(key).await.and_then(|s| BASE64.decode(&s).ok())
}
async fn set_bytes(
&self,
key: &str,
value: &[u8],
ttl: Duration,
) -> Result<(), CacheError> {
let encoded = BASE64.encode(value);
self.set(key, &encoded, ttl).await
}
}
#[derive(Clone)]
pub struct RedisRateLimiter {
conn: redis::aio::ConnectionManager,
}
impl RedisRateLimiter {
pub fn new(conn: redis::aio::ConnectionManager) -> Self {
Self { conn }
}
}
#[async_trait]
impl DistributedRateLimiter for RedisRateLimiter {
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
let mut conn = self.conn.clone();
let full_key = format!("rl:{}", key);
let window_secs = window_ms.div_ceil(1000).max(1) as i64;
let result: Result<i64, _> = redis::Script::new(
r"local c = redis.call('INCR', KEYS[1])
if c == 1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end
if redis.call('TTL', KEYS[1]) == -1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end
return c",
)
.key(&full_key)
.arg(window_secs)
.invoke_async(&mut conn)
.await;
match result {
Ok(count) => count <= limit as i64,
Err(e) => {
tracing::warn!(error = %e, "redis rate limit script failed, allowing request");
true
}
}
}
async fn peek_rate_limit_count(&self, key: &str, _window_ms: u64) -> u64 {
let mut conn = self.conn.clone();
let full_key = format!("rl:{}", key);
redis::cmd("GET")
.arg(&full_key)
.query_async::<Option<u64>>(&mut conn)
.await
.ok()
.flatten()
.unwrap_or(0)
}
}
}
#[async_trait]
impl Cache for ValkeyCache {
async fn get(&self, key: &str) -> Option<String> {
let mut conn = self.conn.clone();
redis::cmd("GET")
.arg(key)
.query_async::<Option<String>>(&mut conn)
.await
.ok()
.flatten()
}
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
redis::cmd("SET")
.arg(key)
.arg(value)
.arg("PX")
.arg(ttl.as_millis().min(i64::MAX as u128) as i64)
.query_async::<()>(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let mut conn = self.conn.clone();
redis::cmd("DEL")
.arg(key)
.query_async::<()>(&mut conn)
.await
.map_err(|e| CacheError::Connection(e.to_string()))
}
async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
self.get(key).await.and_then(|s| BASE64.decode(&s).ok())
}
async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> {
let encoded = BASE64.encode(value);
self.set(key, &encoded, ttl).await
}
}
#[cfg(feature = "valkey")]
pub use valkey::{RedisRateLimiter, ValkeyCache};
pub struct NoOpCache;
@@ -97,55 +160,6 @@ impl Cache for NoOpCache {
}
}
#[derive(Clone)]
pub struct RedisRateLimiter {
conn: redis::aio::ConnectionManager,
}
impl RedisRateLimiter {
pub fn new(conn: redis::aio::ConnectionManager) -> Self {
Self { conn }
}
}
#[async_trait]
impl DistributedRateLimiter for RedisRateLimiter {
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
let mut conn = self.conn.clone();
let full_key = format!("rl:{}", key);
let window_secs = window_ms.div_ceil(1000).max(1) as i64;
let result: Result<i64, _> = redis::Script::new(
r"local c = redis.call('INCR', KEYS[1])
if c == 1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end
if redis.call('TTL', KEYS[1]) == -1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end
return c"
)
.key(&full_key)
.arg(window_secs)
.invoke_async(&mut conn)
.await;
match result {
Ok(count) => count <= limit as i64,
Err(e) => {
tracing::warn!(error = %e, "redis rate limit script failed, allowing request");
true
}
}
}
async fn peek_rate_limit_count(&self, key: &str, _window_ms: u64) -> u64 {
let mut conn = self.conn.clone();
let full_key = format!("rl:{}", key);
redis::cmd("GET")
.arg(&full_key)
.query_async::<Option<u64>>(&mut conn)
.await
.ok()
.flatten()
.unwrap_or(0)
}
}
pub struct NoOpRateLimiter;
#[async_trait]
@@ -158,6 +172,7 @@ impl DistributedRateLimiter for NoOpRateLimiter {
pub async fn create_cache(
shutdown: tokio_util::sync::CancellationToken,
) -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
#[cfg(feature = "valkey")]
if let Ok(url) = std::env::var("VALKEY_URL") {
match ValkeyCache::new(&url).await {
Ok(cache) => {
@@ -171,6 +186,13 @@ pub async fn create_cache(
}
}
#[cfg(not(feature = "valkey"))]
if std::env::var("VALKEY_URL").is_ok() {
tracing::warn!(
"VALKEY_URL is set but binary was compiled without valkey feature. using ripple."
);
}
match tranquil_ripple::RippleConfig::from_env() {
Ok(config) => {
let peer_count = config.seed_peers.len();

View File

@@ -21,8 +21,6 @@ aes-gcm = { workspace = true }
async-trait = { workspace = true }
backon = { workspace = true }
anyhow = { workspace = true }
aws-config = { workspace = true }
aws-sdk-s3 = { workspace = true }
axum = { workspace = true }
base32 = { workspace = true }
base64 = { workspace = true }
@@ -55,7 +53,7 @@ multibase = { workspace = true }
multihash = { workspace = true }
p256 = { workspace = true }
rand = { workspace = true }
redis = { workspace = true }
redis = { workspace = true, optional = true }
regex = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
@@ -79,10 +77,15 @@ urlencoding = { workspace = true }
uuid = { workspace = true }
webauthn-rs = { workspace = true }
zip = { workspace = true }
aws-config = { workspace = true, optional = true }
aws-sdk-s3 = { workspace = true, optional = true }
[features]
default = ["s3", "valkey"]
external-infra = []
s3-storage = []
s3-storage = ["tranquil-storage/s3", "dep:aws-config", "dep:aws-sdk-s3"]
s3 = ["s3-storage"]
valkey = ["tranquil-cache/valkey", "dep:redis"]
[dev-dependencies]
ciborium = { workspace = true }

View File

@@ -162,7 +162,10 @@ pub async fn upload_blob(
if let Err(e) = state.blob_store.copy(&temp_key, &storage_key).await {
let _ = state.blob_store.delete(&temp_key).await;
if let Err(db_err) = state.blob_repo.delete_blob_by_cid(&cid_link).await {
error!("Failed to clean up orphaned blob record after copy failure: {:?}", db_err);
error!(
"Failed to clean up orphaned blob record after copy failure: {:?}",
db_err
);
}
error!("Failed to copy blob to final location: {:?}", e);
return Err(ApiError::InternalError(Some("Failed to store blob".into())));
@@ -170,8 +173,8 @@ pub async fn upload_blob(
let _ = state.blob_store.delete(&temp_key).await;
if let Some(ref controller) = controller_did {
if let Err(e) = state
if let Some(ref controller) = controller_did
&& let Err(e) = state
.delegation_repo
.log_delegation_action(
&did,
@@ -187,9 +190,8 @@ pub async fn upload_blob(
None,
)
.await
{
warn!("Failed to log delegation action for blob upload: {:?}", e);
}
{
warn!("Failed to log delegation action for blob upload: {:?}", e);
}
Ok(Json(json!({

View File

@@ -1,4 +1,6 @@
pub use tranquil_cache::{
Cache, CacheError, DistributedRateLimiter, NoOpCache, NoOpRateLimiter, RedisRateLimiter,
ValkeyCache, create_cache,
Cache, CacheError, DistributedRateLimiter, NoOpCache, NoOpRateLimiter, create_cache,
};
#[cfg(feature = "valkey")]
pub use tranquil_cache::{RedisRateLimiter, ValkeyCache};

View File

@@ -1,5 +1,8 @@
pub use tranquil_storage::{
BackupStorage, BlobStorage, FilesystemBackupStorage, FilesystemBlobStorage, S3BackupStorage,
S3BlobStorage, StorageError, StreamUploadResult, backup_interval_secs, backup_retention_count,
create_backup_storage, create_blob_storage,
BackupStorage, BlobStorage, FilesystemBackupStorage, FilesystemBlobStorage, StorageError,
StreamUploadResult, backup_interval_secs, backup_retention_count, create_backup_storage,
create_blob_storage,
};
#[cfg(feature = "s3-storage")]
pub use tranquil_storage::{S3BackupStorage, S3BlobStorage};

View File

@@ -330,10 +330,8 @@ unsafe fn configure_external_storage_env() {
);
std::env::set_var("S3_ENDPOINT", &s3_endpoint);
} else {
let process_dir = std::env::temp_dir().join(format!(
"tranquil-pds-test-{}",
std::process::id()
));
let process_dir =
std::env::temp_dir().join(format!("tranquil-pds-test-{}", std::process::id()));
let blob_path = process_dir.join("blobs");
let backup_path = process_dir.join("backups");
std::fs::create_dir_all(&blob_path).expect("Failed to create blob directory");
@@ -715,7 +713,8 @@ async fn setup_cluster_external_infra() -> String {
#[cfg(not(feature = "external-infra"))]
async fn setup_cluster_testcontainers() -> String {
let temp_dir = std::env::temp_dir().join(format!("tranquil-pds-cluster-{}", uuid::Uuid::new_v4()));
let temp_dir =
std::env::temp_dir().join(format!("tranquil-pds-cluster-{}", uuid::Uuid::new_v4()));
let blob_path = temp_dir.join("blobs");
let backup_path = temp_dir.join("backups");
std::fs::create_dir_all(&blob_path).expect("Failed to create blob temp directory");

View File

@@ -118,6 +118,7 @@ async fn test_account_creation_rate_limiting() {
);
}
#[cfg(feature = "valkey")]
#[tokio::test]
async fn test_valkey_connection() {
if std::env::var("VALKEY_URL").is_err() {
@@ -156,6 +157,7 @@ async fn test_valkey_connection() {
.expect("DEL failed");
}
#[cfg(feature = "valkey")]
#[tokio::test]
async fn test_distributed_rate_limiter_directly() {
if std::env::var("VALKEY_URL").is_err() {

View File

@@ -45,21 +45,22 @@ async fn cluster_formation() {
assert!(nodes.len() >= 3, "expected at least 3 cluster nodes");
let client = common::client();
let results: Vec<_> = futures::future::join_all(
nodes.iter().map(|node| {
let client = client.clone();
let url = node.url.clone();
async move {
client
.get(format!("{url}/xrpc/com.atproto.server.describeServer"))
.send()
.await
}
})
).await;
let results: Vec<_> = futures::future::join_all(nodes.iter().map(|node| {
let client = client.clone();
let url = node.url.clone();
async move {
client
.get(format!("{url}/xrpc/com.atproto.server.describeServer"))
.send()
.await
}
}))
.await;
results.iter().enumerate().for_each(|(i, result)| {
let resp = result.as_ref().unwrap_or_else(|e| panic!("node {i} unreachable: {e}"));
let resp = result
.as_ref()
.unwrap_or_else(|e| panic!("node {i} unreachable: {e}"));
assert_eq!(
resp.status(),
StatusCode::OK,
@@ -91,7 +92,10 @@ async fn cluster_any_node_access() {
assert_eq!(create_res.status(), StatusCode::OK);
let body: serde_json::Value = create_res.json().await.expect("invalid json");
let did = body["did"].as_str().expect("no did").to_string();
let access_jwt = body["accessJwt"].as_str().expect("no accessJwt").to_string();
let access_jwt = body["accessJwt"]
.as_str()
.expect("no accessJwt")
.to_string();
let pool = common::get_test_db_pool().await;
let body_text: String = sqlx::query_scalar!(
@@ -132,8 +136,10 @@ async fn cluster_any_node_access() {
let token = match confirm_res.status() {
StatusCode::OK => {
let confirm_body: serde_json::Value =
confirm_res.json().await.expect("invalid json from confirmSignup");
let confirm_body: serde_json::Value = confirm_res
.json()
.await
.expect("invalid json from confirmSignup");
confirm_body["accessJwt"]
.as_str()
.unwrap_or(&access_jwt)
@@ -164,14 +170,8 @@ async fn cluster_any_node_access() {
async fn cache_convergence() {
let nodes = common::cluster().await;
let cache_a = nodes[0]
.cache
.as_ref()
.expect("node 0 should have a cache");
let cache_b = nodes[1]
.cache
.as_ref()
.expect("node 1 should have a cache");
let cache_a = nodes[0].cache.as_ref().expect("node 0 should have a cache");
let cache_b = nodes[1].cache.as_ref().expect("node 1 should have a cache");
let test_key = format!("ripple-test-{}", uuid::Uuid::new_v4());
let test_value = "converged-value";
@@ -407,14 +407,13 @@ async fn cluster_bulk_key_convergence() {
})
.await;
let spot_checks: Vec<Option<String>> = futures::future::join_all(
[0, 99, 250, 499].iter().map(|&i| {
let spot_checks: Vec<Option<String>> =
futures::future::join_all([0, 99, 250, 499].iter().map(|&i| {
let c = cache_2.clone();
let p = prefix.clone();
async move { c.get(&format!("{p}-{i}")).await }
}),
)
.await;
}))
.await;
spot_checks.iter().enumerate().for_each(|(idx, val)| {
assert!(
@@ -620,7 +619,10 @@ fn create_account_on_node<'a>(
assert_eq!(create_res.status(), StatusCode::OK, "createAccount non-200");
let body: serde_json::Value = create_res.json().await.expect("invalid json");
let did = body["did"].as_str().expect("no did").to_string();
let access_jwt = body["accessJwt"].as_str().expect("no accessJwt").to_string();
let access_jwt = body["accessJwt"]
.as_str()
.expect("no accessJwt")
.to_string();
let pool = common::get_test_db_pool().await;
let body_text: String = sqlx::query_scalar!(
@@ -654,8 +656,10 @@ fn create_account_on_node<'a>(
let token = match confirm_res.status() {
StatusCode::OK => {
let confirm_body: serde_json::Value =
confirm_res.json().await.expect("invalid json from confirmSignup");
let confirm_body: serde_json::Value = confirm_res
.json()
.await
.expect("invalid json from confirmSignup");
confirm_body["accessJwt"]
.as_str()
.unwrap_or(&access_jwt)
@@ -744,7 +748,10 @@ async fn cross_node_handle_resolution_from_cache() {
let cache_0 = cache_for(nodes, 0);
let fake_handle = format!("cached-{}.test", uuid::Uuid::new_v4().simple());
let fake_did = format!("did:plc:cached{}", &uuid::Uuid::new_v4().simple().to_string()[..16]);
let fake_did = format!(
"did:plc:cached{}",
&uuid::Uuid::new_v4().simple().to_string()[..16]
);
cache_0
.set(
@@ -795,7 +802,10 @@ async fn cross_node_cache_delete_observable_via_http() {
let cache_1 = cache_for(nodes, 1);
let fake_handle = format!("deltest-{}.test", uuid::Uuid::new_v4().simple());
let fake_did = format!("did:plc:del{}", &uuid::Uuid::new_v4().simple().to_string()[..16]);
let fake_did = format!(
"did:plc:del{}",
&uuid::Uuid::new_v4().simple().to_string()[..16]
);
let cache_key = format!("handle:{fake_handle}");
cache_0

View File

@@ -190,7 +190,11 @@ async fn test_list_repos_shows_status_field() {
assert!(takendown_repo.is_some(), "Takendown repo should be in list");
let repo = takendown_repo.unwrap();
assert_eq!(repo["active"], false, "repo should be inactive: {:?}", repo);
assert_eq!(repo["status"], "takendown", "repo status should be takendown: {:?}", repo);
assert_eq!(
repo["status"], "takendown",
"repo status should be takendown: {:?}",
repo
);
}
#[tokio::test]

View File

@@ -1400,10 +1400,16 @@ async fn test_scale_many_users_social_graph() {
.iter()
.enumerate()
.flat_map(|(i, (follower_did, follower_jwt))| {
users.iter().enumerate()
users
.iter()
.enumerate()
.filter(move |(j, _)| *j != i)
.map(|(_, (followee_did, _))| {
(follower_did.clone(), follower_jwt.clone(), followee_did.clone())
(
follower_did.clone(),
follower_jwt.clone(),
followee_did.clone(),
)
})
.collect::<Vec<_>>()
})

View File

@@ -19,7 +19,11 @@ fn parse_env_with_warning<T: std::str::FromStr>(var_name: &str, raw: &str) -> Op
match raw.parse::<T>() {
Ok(v) => Some(v),
Err(_) => {
tracing::warn!(var = var_name, value = raw, "invalid env var value, using default");
tracing::warn!(
var = var_name,
value = raw,
"invalid env var value, using default"
);
None
}
}

View File

@@ -1,5 +1,5 @@
use super::lww_map::LwwDelta;
use super::g_counter::GCounterDelta;
use super::lww_map::LwwDelta;
use serde::{Deserialize, Serialize};
const SCHEMA_VERSION: u8 = 1;
@@ -21,7 +21,7 @@ impl CrdtDelta {
pub fn is_empty(&self) -> bool {
self.cache_delta
.as_ref()
.map_or(true, |d| d.entries.is_empty())
.is_none_or(|d| d.entries.is_empty())
&& self.rate_limit_deltas.is_empty()
}

View File

@@ -142,12 +142,10 @@ impl RateLimitStore {
self.dirty
.iter()
.filter_map(|key| {
self.counters
.get(key)
.map(|counter| GCounterDelta {
key: key.clone(),
counter: counter.clone(),
})
self.counters.get(key).map(|counter| GCounterDelta {
key: key.clone(),
counter: counter.clone(),
})
})
.collect()
}
@@ -167,7 +165,10 @@ impl RateLimitStore {
return 0;
}
match self.counters.get(key) {
Some(counter) if counter.window_start_ms == Self::aligned_window_start(now_wall_ms, window_ms) => {
Some(counter)
if counter.window_start_ms
== Self::aligned_window_start(now_wall_ms, window_ms) =>
{
counter.total()
}
_ => 0,

View File

@@ -124,6 +124,12 @@ pub struct LwwMap {
estimated_bytes: usize,
}
impl Default for LwwMap {
fn default() -> Self {
Self::new()
}
}
impl LwwMap {
pub fn new() -> Self {
Self {
@@ -141,7 +147,14 @@ impl LwwMap {
entry.value.clone()
}
pub fn set(&mut self, key: String, value: Vec<u8>, timestamp: HlcTimestamp, ttl_ms: u64, wall_ms_now: u64) {
pub fn set(
&mut self,
key: String,
value: Vec<u8>,
timestamp: HlcTimestamp,
ttl_ms: u64,
wall_ms_now: u64,
) {
let entry = LwwEntry {
created_at_wall_ms: wall_ms_now,
value: Some(value),
@@ -248,6 +261,10 @@ impl LwwMap {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn remove_estimated_bytes(&mut self, key: &str) {
if let Some(existing) = self.entries.get(key) {
let size = existing.entry_byte_size(key);

View File

@@ -1,13 +1,13 @@
pub mod delta;
pub mod g_counter;
pub mod hlc;
pub mod lww_map;
pub mod g_counter;
use crate::config::fnv1a;
use delta::CrdtDelta;
use g_counter::RateLimitStore;
use hlc::{Hlc, HlcTimestamp};
use lww_map::{LwwDelta, LwwMap};
use g_counter::RateLimitStore;
use parking_lot::{Mutex, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
@@ -45,9 +45,8 @@ impl ShardedCrdtStore {
let shards: Vec<RwLock<CrdtShard>> = (0..SHARD_COUNT)
.map(|_| RwLock::new(CrdtShard::new(node_id)))
.collect();
let promotions: Vec<Mutex<Vec<String>>> = (0..SHARD_COUNT)
.map(|_| Mutex::new(Vec::new()))
.collect();
let promotions: Vec<Mutex<Vec<String>>> =
(0..SHARD_COUNT).map(|_| Mutex::new(Vec::new())).collect();
Self {
hlc: Mutex::new(Hlc::new(node_id)),
shards: shards.into_boxed_slice(),
@@ -136,7 +135,9 @@ impl ShardedCrdtStore {
let cache_delta = match cache_entries.is_empty() {
true => None,
false => Some(LwwDelta { entries: cache_entries }),
false => Some(LwwDelta {
entries: cache_entries,
}),
};
CrdtDelta {
@@ -159,9 +160,8 @@ impl ShardedCrdtStore {
})
.unwrap_or_default();
let mut max_ts_per_shard: Vec<Option<HlcTimestamp>> = (0..self.shards.len())
.map(|_| None)
.collect();
let mut max_ts_per_shard: Vec<Option<HlcTimestamp>> =
(0..self.shards.len()).map(|_| None).collect();
cache_entries_by_shard.iter().for_each(|&(shard_idx, ts)| {
let slot = &mut max_ts_per_shard[shard_idx];
@@ -177,37 +177,39 @@ impl ShardedCrdtStore {
.map(|d| (d.key.as_str(), &d.counter))
.collect();
let mut shard_rl_keys: Vec<Vec<&str>> = (0..self.shards.len())
.map(|_| Vec::new())
.collect();
let mut shard_rl_keys: Vec<Vec<&str>> =
(0..self.shards.len()).map(|_| Vec::new()).collect();
rl_index.keys().for_each(|&key| {
shard_rl_keys[self.shard_for(key)].push(key);
});
self.shards.iter().enumerate().for_each(|(idx, shard_lock)| {
let has_cache_update = max_ts_per_shard[idx].is_some();
let has_rl_keys = !shard_rl_keys[idx].is_empty();
if !has_cache_update && !has_rl_keys {
return;
}
let mut shard = shard_lock.write();
if let Some(max_ts) = max_ts_per_shard[idx] {
shard.last_broadcast_ts = max_ts;
}
shard_rl_keys[idx].iter().for_each(|&key| {
let still_matches = shard
.rate_limits
.peek_dirty_counter(key)
.zip(rl_index.get(key))
.is_some_and(|(current, committed)| {
current.window_start_ms == committed.window_start_ms
&& current.total() == committed.total()
});
if still_matches {
shard.rate_limits.clear_single_dirty(key);
self.shards
.iter()
.enumerate()
.for_each(|(idx, shard_lock)| {
let has_cache_update = max_ts_per_shard[idx].is_some();
let has_rl_keys = !shard_rl_keys[idx].is_empty();
if !has_cache_update && !has_rl_keys {
return;
}
let mut shard = shard_lock.write();
if let Some(max_ts) = max_ts_per_shard[idx] {
shard.last_broadcast_ts = max_ts;
}
shard_rl_keys[idx].iter().for_each(|&key| {
let still_matches = shard
.rate_limits
.peek_dirty_counter(key)
.zip(rl_index.get(key))
.is_some_and(|(current, committed)| {
current.window_start_ms == committed.window_start_ms
&& current.total() == committed.total()
});
if still_matches {
shard.rate_limits.clear_single_dirty(key);
}
});
});
});
}
pub fn merge_delta(&self, delta: &CrdtDelta) -> bool {
@@ -219,10 +221,10 @@ impl ShardedCrdtStore {
return false;
}
if let Some(ref cache_delta) = delta.cache_delta {
if let Some(max_ts) = cache_delta.entries.iter().map(|(_, e)| e.timestamp).max() {
let _ = self.hlc.lock().receive(max_ts);
}
if let Some(ref cache_delta) = delta.cache_delta
&& let Some(max_ts) = cache_delta.entries.iter().map(|(_, e)| e.timestamp).max()
{
let _ = self.hlc.lock().receive(max_ts);
}
let mut changed = false;
@@ -235,17 +237,20 @@ impl ShardedCrdtStore {
entries_by_shard[self.shard_for(key)].push((key.clone(), entry.clone()));
});
entries_by_shard.into_iter().enumerate().for_each(|(idx, entries)| {
if entries.is_empty() {
return;
}
let mut shard = self.shards[idx].write();
entries.into_iter().for_each(|(key, entry)| {
if shard.cache.merge_entry(key, entry) {
changed = true;
entries_by_shard
.into_iter()
.enumerate()
.for_each(|(idx, entries)| {
if entries.is_empty() {
return;
}
let mut shard = self.shards[idx].write();
entries.into_iter().for_each(|(key, entry)| {
if shard.cache.merge_entry(key, entry) {
changed = true;
}
});
});
});
}
if !delta.rate_limit_deltas.is_empty() {
@@ -256,17 +261,20 @@ impl ShardedCrdtStore {
rl_by_shard[self.shard_for(&rd.key)].push((rd.key.clone(), &rd.counter));
});
rl_by_shard.into_iter().enumerate().for_each(|(idx, entries)| {
if entries.is_empty() {
return;
}
let mut shard = self.shards[idx].write();
entries.into_iter().for_each(|(key, counter)| {
if shard.rate_limits.merge_counter(key, counter) {
changed = true;
rl_by_shard
.into_iter()
.enumerate()
.for_each(|(idx, entries)| {
if entries.is_empty() {
return;
}
let mut shard = self.shards[idx].write();
entries.into_iter().for_each(|(key, counter)| {
if shard.rate_limits.merge_counter(key, counter) {
changed = true;
}
});
});
});
}
changed
@@ -274,14 +282,17 @@ impl ShardedCrdtStore {
pub fn run_maintenance(&self) {
let now = Self::wall_ms_now();
self.shards.iter().enumerate().for_each(|(idx, shard_lock)| {
let pending: Vec<String> = self.promotions[idx].lock().drain(..).collect();
let mut shard = shard_lock.write();
pending.iter().for_each(|key| shard.cache.touch(key));
shard.cache.gc_tombstones(now);
shard.cache.gc_expired(now);
shard.rate_limits.gc_expired(now);
});
self.shards
.iter()
.enumerate()
.for_each(|(idx, shard_lock)| {
let pending: Vec<String> = self.promotions[idx].lock().drain(..).collect();
let mut shard = shard_lock.write();
pending.iter().for_each(|key| shard.cache.touch(key));
shard.cache.gc_tombstones(now);
shard.cache.gc_expired(now);
shard.rate_limits.gc_expired(now);
});
}
pub fn peek_full_state(&self) -> CrdtDelta {
@@ -297,7 +308,9 @@ impl ShardedCrdtStore {
let cache_delta = match cache_entries.is_empty() {
true => None,
false => Some(LwwDelta { entries: cache_entries }),
false => Some(LwwDelta {
entries: cache_entries,
}),
};
CrdtDelta {
@@ -327,7 +340,10 @@ impl ShardedCrdtStore {
.iter()
.map(|s| {
let shard = s.read();
shard.cache.estimated_bytes().saturating_add(shard.rate_limits.estimated_bytes())
shard
.cache
.estimated_bytes()
.saturating_add(shard.rate_limits.estimated_bytes())
})
.fold(0usize, usize::saturating_add)
}
@@ -335,7 +351,7 @@ impl ShardedCrdtStore {
pub fn evict_lru_round_robin(&self, start_shard: usize) -> Option<(usize, usize)> {
(0..self.shards.len()).find_map(|offset| {
let idx = (start_shard + offset) & self.shard_mask;
let has_entries = self.shards[idx].read().cache.len() > 0;
let has_entries = !self.shards[idx].read().cache.is_empty();
match has_entries {
true => {
let mut shard = self.shards[idx].write();

View File

@@ -17,12 +17,14 @@ impl RippleEngine {
pub async fn start(
config: RippleConfig,
shutdown: CancellationToken,
) -> Result<(Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>, SocketAddr), RippleStartError> {
) -> Result<(Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>, SocketAddr), RippleStartError>
{
let store = Arc::new(ShardedCrdtStore::new(config.machine_id));
let (transport, incoming_rx) = Transport::bind(config.bind_addr, config.machine_id, shutdown.clone())
.await
.map_err(|e| RippleStartError::Bind(e.to_string()))?;
let (transport, incoming_rx) =
Transport::bind(config.bind_addr, config.machine_id, shutdown.clone())
.await
.map_err(|e| RippleStartError::Bind(e.to_string()))?;
let transport = Arc::new(transport);
@@ -79,8 +81,7 @@ impl RippleEngine {
});
let cache: Arc<dyn Cache> = Arc::new(RippleCache::new(store.clone()));
let rate_limiter: Arc<dyn DistributedRateLimiter> =
Arc::new(RippleRateLimiter::new(store));
let rate_limiter: Arc<dyn DistributedRateLimiter> = Arc::new(RippleRateLimiter::new(store));
metrics::describe_metrics();

View File

@@ -39,24 +39,22 @@ impl MemoryBudget {
let mut remaining = total_bytes;
let mut next_shard: usize = self.next_shard.load(std::sync::atomic::Ordering::Relaxed);
let mut evicted: usize = 0;
(0..batch_size).try_for_each(|_| {
match remaining > max_bytes {
true => {
match store.evict_lru_round_robin(next_shard) {
Some((ns, freed)) => {
next_shard = ns;
remaining = remaining.saturating_sub(freed);
evicted += 1;
Ok(())
}
None => Err(()),
(0..batch_size)
.try_for_each(|_| match remaining > max_bytes {
true => match store.evict_lru_round_robin(next_shard) {
Some((ns, freed)) => {
next_shard = ns;
remaining = remaining.saturating_sub(freed);
evicted += 1;
Ok(())
}
}
None => Err(()),
},
false => Err(()),
}
})
.ok();
self.next_shard.store(next_shard, std::sync::atomic::Ordering::Relaxed);
})
.ok();
self.next_shard
.store(next_shard, std::sync::atomic::Ordering::Relaxed);
if evicted > 0 {
metrics::record_evictions(evicted);
let cache_bytes_after = store.cache_estimated_bytes();
@@ -92,11 +90,7 @@ mod tests {
let store = ShardedCrdtStore::new(1);
let budget = MemoryBudget::new(100);
(0..50).for_each(|i| {
store.cache_set(
format!("key-{i}"),
vec![0u8; 64],
60_000,
);
store.cache_set(format!("key-{i}"), vec![0u8; 64], 60_000);
});
budget.enforce(&store);
assert!(store.total_estimated_bytes() <= 100);

View File

@@ -1,11 +1,11 @@
use crate::crdt::ShardedCrdtStore;
use crate::crdt::delta::CrdtDelta;
use crate::crdt::lww_map::LwwDelta;
use crate::crdt::ShardedCrdtStore;
use crate::metrics;
use crate::transport::{ChannelTag, IncomingFrame, Transport};
use foca::{Config, Foca, Notification, Runtime, Timer};
use rand::rngs::StdRng;
use rand::SeedableRng;
use rand::rngs::StdRng;
use std::collections::HashSet;
use std::fmt;
use std::net::SocketAddr;
@@ -135,12 +135,12 @@ impl Runtime<PeerId> for &mut BufferedRuntime {
}
fn send_to(&mut self, to: PeerId, data: &[u8]) {
self.actions
.push(RuntimeAction::SendTo(to, data.to_vec()));
self.actions.push(RuntimeAction::SendTo(to, data.to_vec()));
}
fn submit_after(&mut self, event: Timer<PeerId>, after: Duration) {
self.actions.push(RuntimeAction::ScheduleTimer(event, after));
self.actions
.push(RuntimeAction::ScheduleTimer(event, after));
}
}
@@ -151,11 +151,7 @@ pub struct GossipEngine {
}
impl GossipEngine {
pub fn new(
transport: Arc<Transport>,
store: Arc<ShardedCrdtStore>,
local_id: PeerId,
) -> Self {
pub fn new(transport: Arc<Transport>, store: Arc<ShardedCrdtStore>, local_id: PeerId) -> Self {
Self {
transport,
store,
@@ -208,10 +204,16 @@ impl GossipEngine {
}
});
drain_runtime_actions(&mut runtime, &transport, &timer_tx, &mut members, &store, &shutdown);
drain_runtime_actions(
&mut runtime,
&transport,
&timer_tx,
&mut members,
&store,
&shutdown,
);
let mut gossip_tick =
tokio::time::interval(Duration::from_millis(gossip_interval_ms));
let mut gossip_tick = tokio::time::interval(Duration::from_millis(gossip_interval_ms));
gossip_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
@@ -401,16 +403,18 @@ fn drain_runtime_actions(
metrics::set_gossip_peers(members.peer_count());
let snapshot = store.peek_full_state();
if !snapshot.is_empty() {
chunk_and_serialize(&snapshot).into_iter().for_each(|chunk| {
let t = transport.clone();
let c = shutdown.clone();
tokio::spawn(async move {
tokio::select! {
_ = c.cancelled() => {}
_ = t.send(addr, ChannelTag::CrdtSync, &chunk) => {}
}
chunk_and_serialize(&snapshot)
.into_iter()
.for_each(|chunk| {
let t = transport.clone();
let c = shutdown.clone();
tokio::spawn(async move {
tokio::select! {
_ = c.cancelled() => {}
_ = t.send(addr, ChannelTag::CrdtSync, &chunk) => {}
}
});
});
});
}
}
RuntimeAction::MemberDown(addr) => {
@@ -449,7 +453,9 @@ fn split_and_serialize(delta: CrdtDelta) -> Vec<Vec<u8>> {
source_node,
cache_delta: match cache_entries.is_empty() {
true => None,
false => Some(LwwDelta { entries: cache_entries }),
false => Some(LwwDelta {
entries: cache_entries,
}),
},
rate_limit_deltas: rl_deltas,
};

View File

@@ -13,10 +13,7 @@ pub fn describe_metrics() {
"tranquil_ripple_gossip_peers",
"Number of active gossip peers"
);
metrics::describe_counter!(
"tranquil_ripple_cache_hits_total",
"Total cache read hits"
);
metrics::describe_counter!("tranquil_ripple_cache_hits_total", "Total cache read hits");
metrics::describe_counter!(
"tranquil_ripple_cache_misses_total",
"Total cache read misses"

View File

@@ -222,7 +222,10 @@ impl Transport {
let conn_gen = self.conn_generation.fetch_add(1, Ordering::Relaxed);
self.connections.lock().insert(
target,
ConnectionWriter { tx: write_tx.clone(), generation: conn_gen },
ConnectionWriter {
tx: write_tx.clone(),
generation: conn_gen,
},
);
if let Some(frame) = encode_frame(tag, data) {
let _ = write_tx.try_send(frame);
@@ -412,7 +415,11 @@ fn decode_frame(buf: &mut BytesMut) -> DecodeResult {
}
let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if len > MAX_FRAME_SIZE {
tracing::warn!(frame_len = len, max = MAX_FRAME_SIZE, "oversized frame, closing connection");
tracing::warn!(
frame_len = len,
max = MAX_FRAME_SIZE,
"oversized frame, closing connection"
);
buf.clear();
return DecodeResult::Corrupt;
}

View File

@@ -168,7 +168,10 @@ async fn two_node_lww_conflict_resolution() {
let val_a = cache_a.get(&key).await.expect("A should have the key");
let val_b = cache_b.get(&key).await.expect("B should have the key");
assert_eq!(val_a, val_b, "both nodes must agree on the same value after LWW resolution");
assert_eq!(
val_a, val_b,
"both nodes must agree on the same value after LWW resolution"
);
shutdown.cancel();
}
@@ -240,8 +243,14 @@ async fn two_node_ttl_expiration() {
tokio::time::sleep(Duration::from_secs(3)).await;
assert!(cache_a.get(&key).await.is_none(), "A should have expired the key");
assert!(cache_b.get(&key).await.is_none(), "B should have expired the key");
assert!(
cache_a.get(&key).await.is_none(),
"A should have expired the key"
);
assert!(
cache_b.get(&key).await.is_none(),
"B should have expired the key"
);
shutdown.cancel();
}
@@ -715,50 +724,49 @@ async fn two_node_stress_concurrent_load() {
let shutdown = CancellationToken::new();
let ((cache_a, rl_a), (cache_b, rl_b)) = spawn_pair(shutdown.clone()).await;
let tasks: Vec<tokio::task::JoinHandle<()>> = (0u32..8).map(|task_id| {
let cache = match task_id < 4 {
true => cache_a.clone(),
false => cache_b.clone(),
};
let rl = match task_id < 4 {
true => rl_a.clone(),
false => rl_b.clone(),
};
tokio::spawn(async move {
let value = vec![0xABu8; 1024];
futures::future::join_all((0u32..500).map(|op| {
let cache = cache.clone();
let rl = rl.clone();
let value = value.clone();
async move {
let key_idx = op % 100;
let key = format!("stress-{task_id}-{key_idx}");
match op % 4 {
0 | 1 => {
cache
.set_bytes(&key, &value, Duration::from_secs(120))
.await
.expect("set_bytes failed");
}
2 => {
let _ = cache.get(&key).await;
}
_ => {
let _ = rl.check_rate_limit(&key, 1000, 60_000).await;
let tasks: Vec<tokio::task::JoinHandle<()>> = (0u32..8)
.map(|task_id| {
let cache = match task_id < 4 {
true => cache_a.clone(),
false => cache_b.clone(),
};
let rl = match task_id < 4 {
true => rl_a.clone(),
false => rl_b.clone(),
};
tokio::spawn(async move {
let value = vec![0xABu8; 1024];
futures::future::join_all((0u32..500).map(|op| {
let cache = cache.clone();
let rl = rl.clone();
let value = value.clone();
async move {
let key_idx = op % 100;
let key = format!("stress-{task_id}-{key_idx}");
match op % 4 {
0 | 1 => {
cache
.set_bytes(&key, &value, Duration::from_secs(120))
.await
.expect("set_bytes failed");
}
2 => {
let _ = cache.get(&key).await;
}
_ => {
let _ = rl.check_rate_limit(&key, 1000, 60_000).await;
}
}
}
}
}))
.await;
}))
.await;
})
})
}).collect();
.collect();
let results = tokio::time::timeout(
Duration::from_secs(30),
futures::future::join_all(tasks),
)
.await
.expect("stress test timed out after 30s");
let results = tokio::time::timeout(Duration::from_secs(30), futures::future::join_all(tasks))
.await
.expect("stress test timed out after 30s");
results.into_iter().enumerate().for_each(|(i, r)| {
r.unwrap_or_else(|e| panic!("task {i} panicked: {e}"));

View File

@@ -4,12 +4,16 @@ version.workspace = true
edition.workspace = true
license.workspace = true
[features]
default = []
s3 = ["dep:aws-config", "dep:aws-sdk-s3"]
[dependencies]
tranquil-infra = { workspace = true }
async-trait = { workspace = true }
aws-config = { workspace = true }
aws-sdk-s3 = { workspace = true }
aws-config = { workspace = true, optional = true }
aws-sdk-s3 = { workspace = true, optional = true }
bytes = { workspace = true }
futures = { workspace = true }
sha2 = { workspace = true }

View File

@@ -4,12 +4,6 @@ pub use tranquil_infra::{
};
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_s3::Client;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::CompletedMultipartUpload;
use aws_sdk_s3::types::CompletedPart;
use bytes::Bytes;
use futures::Stream;
use sha2::{Digest, Sha256};
@@ -17,7 +11,6 @@ use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
const MIN_PART_SIZE: usize = 5 * 1024 * 1024;
const EXDEV: i32 = 18;
const CID_SHARD_PREFIX_LEN: usize = 9;
@@ -109,360 +102,385 @@ fn map_io_not_found(key: &str) -> impl FnOnce(std::io::Error) -> StorageError +
}
}
pub struct S3BlobStorage {
client: Client,
bucket: String,
}
#[cfg(feature = "s3")]
mod s3 {
use super::*;
use aws_config::BehaviorVersion;
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_s3::Client;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::CompletedMultipartUpload;
use aws_sdk_s3::types::CompletedPart;
impl S3BlobStorage {
pub async fn new() -> Self {
let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set");
let client = create_s3_client().await;
Self { client, bucket }
const MIN_PART_SIZE: usize = 5 * 1024 * 1024;
pub struct S3BlobStorage {
client: Client,
bucket: String,
}
pub async fn with_bucket(bucket: String) -> Self {
let client = create_s3_client().await;
Self { client, bucket }
}
}
async fn create_s3_client() -> Client {
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
let config = aws_config::defaults(BehaviorVersion::latest())
.region(region_provider)
.load()
.await;
std::env::var("S3_ENDPOINT").ok().map_or_else(
|| Client::new(&config),
|endpoint| {
let s3_config = aws_sdk_s3::config::Builder::from(&config)
.endpoint_url(endpoint)
.force_path_style(true)
.build();
Client::from_conf(s3_config)
},
)
}
pub struct S3BackupStorage {
client: Client,
bucket: String,
}
impl S3BackupStorage {
pub async fn new() -> Option<Self> {
let bucket = std::env::var("BACKUP_S3_BUCKET").ok()?;
let client = create_s3_client().await;
Some(Self { client, bucket })
}
}
#[async_trait]
impl BackupStorage for S3BackupStorage {
async fn put_backup(&self, did: &str, rev: &str, data: &[u8]) -> Result<String, StorageError> {
let key = format!("{}/{}.car", did, rev);
self.client
.put_object()
.bucket(&self.bucket)
.key(&key)
.body(ByteStream::from(Bytes::copy_from_slice(data)))
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(key)
}
async fn get_backup(&self, storage_key: &str) -> Result<Bytes, StorageError> {
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(storage_key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
resp.body
.collect()
.await
.map(|agg| agg.into_bytes())
.map_err(|e| StorageError::Backend(e.to_string()))
}
async fn delete_backup(&self, storage_key: &str) -> Result<(), StorageError> {
self.client
.delete_object()
.bucket(&self.bucket)
.key(storage_key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(())
}
}
#[async_trait]
impl BlobStorage for S3BlobStorage {
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> {
self.put_bytes(key, Bytes::copy_from_slice(data)).await
}
async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> {
self.client
.put_object()
.bucket(&self.bucket)
.key(key)
.body(ByteStream::from(data))
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(())
}
async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> {
self.get_bytes(key).await.map(|b| b.to_vec())
}
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> {
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
resp.body
.collect()
.await
.map(|agg| agg.into_bytes())
.map_err(|e| StorageError::Backend(e.to_string()))
}
async fn get_head(&self, key: &str, size: usize) -> Result<Bytes, StorageError> {
let range = format!("bytes=0-{}", size.saturating_sub(1));
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.range(range)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
resp.body
.collect()
.await
.map(|agg| agg.into_bytes())
.map_err(|e| StorageError::Backend(e.to_string()))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
self.client
.delete_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(())
}
async fn put_stream(
&self,
key: &str,
stream: Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>,
) -> Result<StreamUploadResult, StorageError> {
use futures::StreamExt;
let create_resp = self
.client
.create_multipart_upload()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| {
StorageError::Backend(format!("Failed to create multipart upload: {}", e))
})?;
let upload_id = create_resp
.upload_id()
.ok_or_else(|| StorageError::Backend("No upload ID returned".to_string()))?
.to_string();
let upload_part = |client: &Client,
bucket: &str,
key: &str,
upload_id: &str,
part_num: i32,
data: Vec<u8>|
-> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<CompletedPart, StorageError>> + Send>,
> {
let client = client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
let upload_id = upload_id.to_string();
Box::pin(async move {
let resp = client
.upload_part()
.bucket(&bucket)
.key(&key)
.upload_id(&upload_id)
.part_number(part_num)
.body(ByteStream::from(data))
.send()
.await
.map_err(|e| StorageError::Backend(format!("Failed to upload part: {}", e)))?;
let etag = resp
.e_tag()
.ok_or_else(|| StorageError::Backend("No ETag returned for part".to_string()))?
.to_string();
Ok(CompletedPart::builder()
.part_number(part_num)
.e_tag(etag)
.build())
})
};
struct UploadState {
hasher: Sha256,
total_size: u64,
part_number: i32,
completed_parts: Vec<CompletedPart>,
buffer: Vec<u8>,
impl S3BlobStorage {
pub async fn new() -> Self {
let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set");
let client = create_s3_client().await;
Self { client, bucket }
}
let initial_state = UploadState {
hasher: Sha256::new(),
total_size: 0,
part_number: 1,
completed_parts: Vec::new(),
buffer: Vec::with_capacity(MIN_PART_SIZE),
};
pub async fn with_bucket(bucket: String) -> Self {
let client = create_s3_client().await;
Self { client, bucket }
}
}
let abort_upload = || async {
let _ = self
async fn create_s3_client() -> Client {
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
let config = aws_config::defaults(BehaviorVersion::latest())
.region(region_provider)
.load()
.await;
std::env::var("S3_ENDPOINT").ok().map_or_else(
|| Client::new(&config),
|endpoint| {
let s3_config = aws_sdk_s3::config::Builder::from(&config)
.endpoint_url(endpoint)
.force_path_style(true)
.build();
Client::from_conf(s3_config)
},
)
}
pub struct S3BackupStorage {
client: Client,
bucket: String,
}
impl S3BackupStorage {
pub async fn new() -> Option<Self> {
let bucket = std::env::var("BACKUP_S3_BUCKET").ok()?;
let client = create_s3_client().await;
Some(Self { client, bucket })
}
}
#[async_trait]
impl BackupStorage for S3BackupStorage {
async fn put_backup(
&self,
did: &str,
rev: &str,
data: &[u8],
) -> Result<String, StorageError> {
let key = format!("{}/{}.car", did, rev);
self.client
.put_object()
.bucket(&self.bucket)
.key(&key)
.body(ByteStream::from(Bytes::copy_from_slice(data)))
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(key)
}
async fn get_backup(&self, storage_key: &str) -> Result<Bytes, StorageError> {
let resp = self
.client
.abort_multipart_upload()
.get_object()
.bucket(&self.bucket)
.key(storage_key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
resp.body
.collect()
.await
.map(|agg| agg.into_bytes())
.map_err(|e| StorageError::Backend(e.to_string()))
}
async fn delete_backup(&self, storage_key: &str) -> Result<(), StorageError> {
self.client
.delete_object()
.bucket(&self.bucket)
.key(storage_key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(())
}
}
#[async_trait]
impl BlobStorage for S3BlobStorage {
async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> {
self.put_bytes(key, Bytes::copy_from_slice(data)).await
}
async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> {
self.client
.put_object()
.bucket(&self.bucket)
.key(key)
.body(ByteStream::from(data))
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(())
}
async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> {
self.get_bytes(key).await.map(|b| b.to_vec())
}
async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> {
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
resp.body
.collect()
.await
.map(|agg| agg.into_bytes())
.map_err(|e| StorageError::Backend(e.to_string()))
}
async fn get_head(&self, key: &str, size: usize) -> Result<Bytes, StorageError> {
let range = format!("bytes=0-{}", size.saturating_sub(1));
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.range(range)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
resp.body
.collect()
.await
.map(|agg| agg.into_bytes())
.map_err(|e| StorageError::Backend(e.to_string()))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
self.client
.delete_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| StorageError::Backend(e.to_string()))?;
Ok(())
}
async fn put_stream(
&self,
key: &str,
stream: Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>,
) -> Result<StreamUploadResult, StorageError> {
use futures::StreamExt;
let create_resp = self
.client
.create_multipart_upload()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| {
StorageError::Backend(format!("Failed to create multipart upload: {}", e))
})?;
let upload_id = create_resp
.upload_id()
.ok_or_else(|| StorageError::Backend("No upload ID returned".to_string()))?
.to_string();
let upload_part = |client: &Client,
bucket: &str,
key: &str,
upload_id: &str,
part_num: i32,
data: Vec<u8>|
-> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<CompletedPart, StorageError>> + Send>,
> {
let client = client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
let upload_id = upload_id.to_string();
Box::pin(async move {
let resp = client
.upload_part()
.bucket(&bucket)
.key(&key)
.upload_id(&upload_id)
.part_number(part_num)
.body(ByteStream::from(data))
.send()
.await
.map_err(|e| {
StorageError::Backend(format!("Failed to upload part: {}", e))
})?;
let etag = resp
.e_tag()
.ok_or_else(|| {
StorageError::Backend("No ETag returned for part".to_string())
})?
.to_string();
Ok(CompletedPart::builder()
.part_number(part_num)
.e_tag(etag)
.build())
})
};
struct UploadState {
hasher: Sha256,
total_size: u64,
part_number: i32,
completed_parts: Vec<CompletedPart>,
buffer: Vec<u8>,
}
let initial_state = UploadState {
hasher: Sha256::new(),
total_size: 0,
part_number: 1,
completed_parts: Vec::new(),
buffer: Vec::with_capacity(MIN_PART_SIZE),
};
let abort_upload = || async {
let _ = self
.client
.abort_multipart_upload()
.bucket(&self.bucket)
.key(key)
.upload_id(&upload_id)
.send()
.await;
};
let result: Result<UploadState, StorageError> = {
let mut state = initial_state;
let chunk_results: Vec<Result<Bytes, std::io::Error>> = stream.collect().await;
for chunk_result in chunk_results {
match chunk_result {
Ok(chunk) => {
state.hasher.update(&chunk);
state.total_size += chunk.len() as u64;
state.buffer.extend_from_slice(&chunk);
if state.buffer.len() >= MIN_PART_SIZE {
let part_data = std::mem::replace(
&mut state.buffer,
Vec::with_capacity(MIN_PART_SIZE),
);
let part = upload_part(
&self.client,
&self.bucket,
key,
&upload_id,
state.part_number,
part_data,
)
.await?;
state.completed_parts.push(part);
state.part_number += 1;
}
}
Err(e) => {
abort_upload().await;
return Err(StorageError::Io(e));
}
}
}
Ok(state)
};
let mut state = result?;
if !state.buffer.is_empty() {
let part = upload_part(
&self.client,
&self.bucket,
key,
&upload_id,
state.part_number,
std::mem::take(&mut state.buffer),
)
.await?;
state.completed_parts.push(part);
}
if state.completed_parts.is_empty() {
abort_upload().await;
return Err(StorageError::Other("Empty upload".to_string()));
}
let completed_upload = CompletedMultipartUpload::builder()
.set_parts(Some(state.completed_parts))
.build();
self.client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(key)
.upload_id(&upload_id)
.multipart_upload(completed_upload)
.send()
.await;
};
.await
.map_err(|e| {
StorageError::Backend(format!("Failed to complete multipart upload: {}", e))
})?;
let result: Result<UploadState, StorageError> = {
let mut state = initial_state;
let chunk_results: Vec<Result<Bytes, std::io::Error>> = stream.collect().await;
for chunk_result in chunk_results {
match chunk_result {
Ok(chunk) => {
state.hasher.update(&chunk);
state.total_size += chunk.len() as u64;
state.buffer.extend_from_slice(&chunk);
if state.buffer.len() >= MIN_PART_SIZE {
let part_data = std::mem::replace(
&mut state.buffer,
Vec::with_capacity(MIN_PART_SIZE),
);
let part = upload_part(
&self.client,
&self.bucket,
key,
&upload_id,
state.part_number,
part_data,
)
.await?;
state.completed_parts.push(part);
state.part_number += 1;
}
}
Err(e) => {
abort_upload().await;
return Err(StorageError::Io(e));
}
}
}
Ok(state)
};
let mut state = result?;
if !state.buffer.is_empty() {
let part = upload_part(
&self.client,
&self.bucket,
key,
&upload_id,
state.part_number,
std::mem::take(&mut state.buffer),
)
.await?;
state.completed_parts.push(part);
let hash: [u8; 32] = state.hasher.finalize().into();
Ok(StreamUploadResult {
sha256_hash: hash,
size: state.total_size,
})
}
if state.completed_parts.is_empty() {
abort_upload().await;
return Err(StorageError::Other("Empty upload".to_string()));
async fn copy(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> {
let copy_source = format!("{}/{}", self.bucket, src_key);
self.client
.copy_object()
.bucket(&self.bucket)
.copy_source(&copy_source)
.key(dst_key)
.send()
.await
.map_err(|e| StorageError::Backend(format!("Failed to copy object: {}", e)))?;
Ok(())
}
let completed_upload = CompletedMultipartUpload::builder()
.set_parts(Some(state.completed_parts))
.build();
self.client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(key)
.upload_id(&upload_id)
.multipart_upload(completed_upload)
.send()
.await
.map_err(|e| {
StorageError::Backend(format!("Failed to complete multipart upload: {}", e))
})?;
let hash: [u8; 32] = state.hasher.finalize().into();
Ok(StreamUploadResult {
sha256_hash: hash,
size: state.total_size,
})
}
async fn copy(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> {
let copy_source = format!("{}/{}", self.bucket, src_key);
self.client
.copy_object()
.bucket(&self.bucket)
.copy_source(&copy_source)
.key(dst_key)
.send()
.await
.map_err(|e| StorageError::Backend(format!("Failed to copy object: {}", e)))?;
Ok(())
}
}
#[cfg(feature = "s3")]
pub use s3::{S3BackupStorage, S3BlobStorage};
pub struct FilesystemBlobStorage {
base_path: PathBuf,
tmp_path: PathBuf,
@@ -686,10 +704,18 @@ pub async fn create_blob_storage() -> Arc<dyn BlobStorage> {
let backend = std::env::var("BLOB_STORAGE_BACKEND").unwrap_or_else(|_| "filesystem".into());
match backend.as_str() {
#[cfg(feature = "s3")]
"s3" => {
tracing::info!("Initializing S3 blob storage");
Arc::new(S3BlobStorage::new().await)
}
#[cfg(not(feature = "s3"))]
"s3" => {
panic!(
"BLOB_STORAGE_BACKEND=s3 but binary was compiled without s3 feature. \
Rebuild with --features s3 to enable S3 storage."
);
}
_ => {
tracing::info!("Initializing filesystem blob storage");
FilesystemBlobStorage::from_env()
@@ -719,6 +745,7 @@ pub async fn create_backup_storage() -> Option<Arc<dyn BackupStorage>> {
let backend = std::env::var("BACKUP_STORAGE_BACKEND").unwrap_or_else(|_| "filesystem".into());
match backend.as_str() {
#[cfg(feature = "s3")]
"s3" => S3BackupStorage::new().await.map_or_else(
|| {
tracing::error!(
@@ -732,6 +759,14 @@ pub async fn create_backup_storage() -> Option<Arc<dyn BackupStorage>> {
Some(Arc::new(storage) as Arc<dyn BackupStorage>)
},
),
#[cfg(not(feature = "s3"))]
"s3" => {
tracing::error!(
"BACKUP_STORAGE_BACKEND=s3 but binary was compiled without s3 feature. \
Backups will be disabled."
);
None
}
_ => FilesystemBackupStorage::from_env().await.map_or_else(
|e| {
tracing::error!(