mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-09 05:40:09 +00:00
530 lines
18 KiB
Rust
530 lines
18 KiB
Rust
use aws_config::BehaviorVersion;
|
|
use aws_sdk_s3::Client as S3Client;
|
|
use aws_sdk_s3::config::Credentials;
|
|
use chrono::Utc;
|
|
use reqwest::{Client, StatusCode, header};
|
|
use serde_json::{Value, json};
|
|
use sqlx::postgres::PgPoolOptions;
|
|
#[allow(unused_imports)]
|
|
use std::collections::HashMap;
|
|
use std::sync::OnceLock;
|
|
#[allow(unused_imports)]
|
|
use std::time::Duration;
|
|
use tokio::net::TcpListener;
|
|
use tranquil_pds::state::AppState;
|
|
use wiremock::matchers::{method, path};
|
|
use wiremock::{Mock, MockServer, ResponseTemplate};
|
|
|
|
static SERVER_URL: OnceLock<String> = OnceLock::new();
|
|
static APP_PORT: OnceLock<u16> = OnceLock::new();
|
|
static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
|
|
|
|
#[cfg(not(feature = "external-infra"))]
|
|
use testcontainers::core::ContainerPort;
|
|
#[cfg(not(feature = "external-infra"))]
|
|
use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner};
|
|
#[cfg(not(feature = "external-infra"))]
|
|
use testcontainers_modules::postgres::Postgres;
|
|
#[cfg(not(feature = "external-infra"))]
|
|
static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
|
|
#[cfg(not(feature = "external-infra"))]
|
|
static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
|
|
|
|
#[allow(dead_code)]
|
|
pub const AUTH_TOKEN: &str = "test-token";
|
|
#[allow(dead_code)]
|
|
pub const BAD_AUTH_TOKEN: &str = "bad-token";
|
|
#[allow(dead_code)]
|
|
pub const AUTH_DID: &str = "did:plc:fake";
|
|
#[allow(dead_code)]
|
|
pub const TARGET_DID: &str = "did:plc:target";
|
|
|
|
fn has_external_infra() -> bool {
|
|
std::env::var("TRANQUIL_PDS_TEST_INFRA_READY").is_ok()
|
|
|| (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
|
|
}
|
|
#[cfg(test)]
|
|
#[ctor::dtor]
|
|
fn cleanup() {
|
|
if has_external_infra() {
|
|
return;
|
|
}
|
|
if std::env::var("XDG_RUNTIME_DIR").is_ok() {
|
|
let _ = std::process::Command::new("podman")
|
|
.args(&["rm", "-f", "--filter", "label=tranquil_pds_test=true"])
|
|
.output();
|
|
}
|
|
let _ = std::process::Command::new("docker")
|
|
.args(&[
|
|
"container",
|
|
"prune",
|
|
"-f",
|
|
"--filter",
|
|
"label=tranquil_pds_test=true",
|
|
])
|
|
.output();
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub fn client() -> Client {
|
|
Client::new()
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub fn app_port() -> u16 {
|
|
*APP_PORT.get().expect("APP_PORT not initialized")
|
|
}
|
|
|
|
pub async fn base_url() -> &'static str {
|
|
SERVER_URL.get_or_init(|| {
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
std::thread::spawn(move || {
|
|
unsafe {
|
|
std::env::set_var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS", "1");
|
|
}
|
|
if std::env::var("DOCKER_HOST").is_err() {
|
|
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
|
|
let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
|
|
if podman_sock.exists() {
|
|
unsafe {
|
|
std::env::set_var(
|
|
"DOCKER_HOST",
|
|
format!("unix://{}", podman_sock.display()),
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
|
rt.block_on(async move {
|
|
if has_external_infra() {
|
|
let url = setup_with_external_infra().await;
|
|
tx.send(url).unwrap();
|
|
} else {
|
|
let url = setup_with_testcontainers().await;
|
|
tx.send(url).unwrap();
|
|
}
|
|
std::future::pending::<()>().await;
|
|
});
|
|
});
|
|
rx.recv().expect("Failed to start test server")
|
|
})
|
|
}
|
|
|
|
async fn setup_with_external_infra() -> String {
|
|
let database_url =
|
|
std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra");
|
|
let s3_endpoint =
|
|
std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra");
|
|
unsafe {
|
|
std::env::set_var(
|
|
"S3_BUCKET",
|
|
std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()),
|
|
);
|
|
std::env::set_var(
|
|
"AWS_ACCESS_KEY_ID",
|
|
std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()),
|
|
);
|
|
std::env::set_var(
|
|
"AWS_SECRET_ACCESS_KEY",
|
|
std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()),
|
|
);
|
|
std::env::set_var(
|
|
"AWS_REGION",
|
|
std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()),
|
|
);
|
|
std::env::set_var("S3_ENDPOINT", &s3_endpoint);
|
|
}
|
|
let mock_server = MockServer::start().await;
|
|
setup_mock_appview(&mock_server).await;
|
|
let mock_uri = mock_server.uri();
|
|
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
|
|
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
|
|
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
|
|
MOCK_APPVIEW.set(mock_server).ok();
|
|
spawn_app(database_url).await
|
|
}
|
|
|
|
#[cfg(not(feature = "external-infra"))]
|
|
async fn setup_with_testcontainers() -> String {
|
|
let s3_container = GenericImage::new("minio/minio", "latest")
|
|
.with_exposed_port(ContainerPort::Tcp(9000))
|
|
.with_env_var("MINIO_ROOT_USER", "minioadmin")
|
|
.with_env_var("MINIO_ROOT_PASSWORD", "minioadmin")
|
|
.with_cmd(vec!["server".to_string(), "/data".to_string()])
|
|
.with_label("tranquil_pds_test", "true")
|
|
.start()
|
|
.await
|
|
.expect("Failed to start MinIO");
|
|
let s3_port = s3_container
|
|
.get_host_port_ipv4(9000)
|
|
.await
|
|
.expect("Failed to get S3 port");
|
|
let s3_endpoint = format!("http://127.0.0.1:{}", s3_port);
|
|
unsafe {
|
|
std::env::set_var("S3_BUCKET", "test-bucket");
|
|
std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
|
|
std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
|
|
std::env::set_var("AWS_REGION", "us-east-1");
|
|
std::env::set_var("S3_ENDPOINT", &s3_endpoint);
|
|
}
|
|
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
|
|
.region("us-east-1")
|
|
.endpoint_url(&s3_endpoint)
|
|
.credentials_provider(Credentials::new(
|
|
"minioadmin",
|
|
"minioadmin",
|
|
None,
|
|
None,
|
|
"test",
|
|
))
|
|
.load()
|
|
.await;
|
|
let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config)
|
|
.force_path_style(true)
|
|
.build();
|
|
let s3_client = S3Client::from_conf(s3_config);
|
|
let _ = s3_client.create_bucket().bucket("test-bucket").send().await;
|
|
let mock_server = MockServer::start().await;
|
|
setup_mock_appview(&mock_server).await;
|
|
let mock_uri = mock_server.uri();
|
|
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
|
|
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
|
|
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
|
|
MOCK_APPVIEW.set(mock_server).ok();
|
|
S3_CONTAINER.set(s3_container).ok();
|
|
let container = Postgres::default()
|
|
.with_tag("18-alpine")
|
|
.with_label("tranquil_pds_test", "true")
|
|
.start()
|
|
.await
|
|
.expect("Failed to start Postgres");
|
|
let connection_string = format!(
|
|
"postgres://postgres:postgres@127.0.0.1:{}",
|
|
container
|
|
.get_host_port_ipv4(5432)
|
|
.await
|
|
.expect("Failed to get port")
|
|
);
|
|
DB_CONTAINER.set(container).ok();
|
|
spawn_app(connection_string).await
|
|
}
|
|
|
|
#[cfg(feature = "external-infra")]
|
|
async fn setup_with_testcontainers() -> String {
|
|
panic!(
|
|
"Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."
|
|
);
|
|
}
|
|
|
|
async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_endpoint: &str) {
|
|
Mock::given(method("GET"))
|
|
.and(path("/.well-known/did.json"))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
|
"id": did,
|
|
"service": [{
|
|
"id": "#atproto_appview",
|
|
"type": "AtprotoAppView",
|
|
"serviceEndpoint": service_endpoint
|
|
}]
|
|
})))
|
|
.mount(mock_server)
|
|
.await;
|
|
}
|
|
|
|
async fn setup_mock_appview(_mock_server: &MockServer) {}
|
|
|
|
async fn spawn_app(database_url: String) -> String {
|
|
use tranquil_pds::rate_limit::RateLimiters;
|
|
let pool = PgPoolOptions::new()
|
|
.max_connections(50)
|
|
.connect(&database_url)
|
|
.await
|
|
.expect("Failed to connect to Postgres. Make sure the database is running.");
|
|
sqlx::migrate!("./migrations")
|
|
.run(&pool)
|
|
.await
|
|
.expect("Failed to run migrations");
|
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
APP_PORT.set(addr.port()).ok();
|
|
unsafe {
|
|
std::env::set_var("PDS_HOSTNAME", addr.to_string());
|
|
}
|
|
let rate_limiters = RateLimiters::new()
|
|
.with_login_limit(10000)
|
|
.with_account_creation_limit(10000)
|
|
.with_password_reset_limit(10000)
|
|
.with_email_update_limit(10000)
|
|
.with_oauth_authorize_limit(10000)
|
|
.with_oauth_token_limit(10000);
|
|
let state = AppState::new(pool).await.with_rate_limiters(rate_limiters);
|
|
tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
|
|
let app = tranquil_pds::app(state);
|
|
tokio::spawn(async move {
|
|
axum::serve(listener, app).await.unwrap();
|
|
});
|
|
format!("http://{}", addr)
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub async fn get_db_connection_string() -> String {
|
|
base_url().await;
|
|
if has_external_infra() {
|
|
std::env::var("DATABASE_URL").expect("DATABASE_URL not set")
|
|
} else {
|
|
#[cfg(not(feature = "external-infra"))]
|
|
{
|
|
let container = DB_CONTAINER.get().expect("DB container not initialized");
|
|
let port = container
|
|
.get_host_port_ipv4(5432)
|
|
.await
|
|
.expect("Failed to get port");
|
|
format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port)
|
|
}
|
|
#[cfg(feature = "external-infra")]
|
|
{
|
|
panic!("DATABASE_URL must be set with external-infra feature");
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub async fn verify_new_account(client: &Client, did: &str) -> String {
|
|
let conn_str = get_db_connection_string().await;
|
|
let pool = sqlx::postgres::PgPoolOptions::new()
|
|
.max_connections(2)
|
|
.connect(&conn_str)
|
|
.await
|
|
.expect("Failed to connect to test database");
|
|
let body_text: String = sqlx::query_scalar!(
|
|
"SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
|
|
did
|
|
)
|
|
.fetch_one(&pool)
|
|
.await
|
|
.expect("Failed to get verification code");
|
|
|
|
let lines: Vec<&str> = body_text.lines().collect();
|
|
let verification_code = lines
|
|
.iter()
|
|
.enumerate()
|
|
.find(|(_, line)| {
|
|
line.contains("verification code is:") || line.contains("code is:")
|
|
})
|
|
.and_then(|(i, _)| lines.get(i + 1).map(|s| s.trim().to_string()))
|
|
.or_else(|| {
|
|
body_text
|
|
.split_whitespace()
|
|
.find(|word| {
|
|
word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3
|
|
})
|
|
.map(|s| s.to_string())
|
|
})
|
|
.unwrap_or_else(|| body_text.clone());
|
|
|
|
let confirm_payload = json!({
|
|
"did": did,
|
|
"verificationCode": verification_code
|
|
});
|
|
let confirm_res = client
|
|
.post(format!(
|
|
"{}/xrpc/com.atproto.server.confirmSignup",
|
|
base_url().await
|
|
))
|
|
.json(&confirm_payload)
|
|
.send()
|
|
.await
|
|
.expect("confirmSignup request failed");
|
|
assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed");
|
|
let confirm_body: Value = confirm_res
|
|
.json()
|
|
.await
|
|
.expect("Invalid JSON from confirmSignup");
|
|
confirm_body["accessJwt"]
|
|
.as_str()
|
|
.expect("No accessJwt in confirmSignup response")
|
|
.to_string()
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
|
|
let res = client
|
|
.post(format!(
|
|
"{}/xrpc/com.atproto.repo.uploadBlob",
|
|
base_url().await
|
|
))
|
|
.header(header::CONTENT_TYPE, mime)
|
|
.bearer_auth(AUTH_TOKEN)
|
|
.body(data)
|
|
.send()
|
|
.await
|
|
.expect("Failed to send uploadBlob request");
|
|
assert_eq!(res.status(), StatusCode::OK, "Failed to upload blob");
|
|
let body: Value = res.json().await.expect("Blob upload response was not JSON");
|
|
body["blob"].clone()
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub async fn create_test_post(
|
|
client: &Client,
|
|
text: &str,
|
|
reply_to: Option<Value>,
|
|
) -> (String, String, String) {
|
|
let collection = "app.bsky.feed.post";
|
|
let mut record = json!({
|
|
"$type": collection,
|
|
"text": text,
|
|
"createdAt": Utc::now().to_rfc3339()
|
|
});
|
|
if let Some(reply_obj) = reply_to {
|
|
record["reply"] = reply_obj;
|
|
}
|
|
let payload = json!({
|
|
"repo": AUTH_DID,
|
|
"collection": collection,
|
|
"record": record
|
|
});
|
|
let res = client
|
|
.post(format!(
|
|
"{}/xrpc/com.atproto.repo.createRecord",
|
|
base_url().await
|
|
))
|
|
.bearer_auth(AUTH_TOKEN)
|
|
.json(&payload)
|
|
.send()
|
|
.await
|
|
.expect("Failed to send createRecord");
|
|
assert_eq!(res.status(), StatusCode::OK, "Failed to create post record");
|
|
let body: Value = res
|
|
.json()
|
|
.await
|
|
.expect("createRecord response was not JSON");
|
|
let uri = body["uri"]
|
|
.as_str()
|
|
.expect("Response had no URI")
|
|
.to_string();
|
|
let cid = body["cid"]
|
|
.as_str()
|
|
.expect("Response had no CID")
|
|
.to_string();
|
|
let rkey = uri
|
|
.split('/')
|
|
.last()
|
|
.expect("URI was malformed")
|
|
.to_string();
|
|
(uri, cid, rkey)
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub async fn create_account_and_login(client: &Client) -> (String, String) {
|
|
create_account_and_login_internal(client, false).await
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub async fn create_admin_account_and_login(client: &Client) -> (String, String) {
|
|
create_account_and_login_internal(client, true).await
|
|
}
|
|
|
|
async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) {
|
|
let mut last_error = String::new();
|
|
for attempt in 0..3 {
|
|
if attempt > 0 {
|
|
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
|
|
}
|
|
let handle = format!("user_{}", uuid::Uuid::new_v4());
|
|
let payload = json!({
|
|
"handle": handle,
|
|
"email": format!("{}@example.com", handle),
|
|
"password": "Testpass123!"
|
|
});
|
|
let res = match client
|
|
.post(format!(
|
|
"{}/xrpc/com.atproto.server.createAccount",
|
|
base_url().await
|
|
))
|
|
.json(&payload)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
last_error = format!("Request failed: {}", e);
|
|
continue;
|
|
}
|
|
};
|
|
if res.status() == StatusCode::OK {
|
|
let body: Value = res.json().await.expect("Invalid JSON");
|
|
let did = body["did"].as_str().expect("No did").to_string();
|
|
let conn_str = get_db_connection_string().await;
|
|
let pool = sqlx::postgres::PgPoolOptions::new()
|
|
.max_connections(2)
|
|
.connect(&conn_str)
|
|
.await
|
|
.expect("Failed to connect to test database");
|
|
if make_admin {
|
|
sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
|
|
.execute(&pool)
|
|
.await
|
|
.expect("Failed to mark user as admin");
|
|
}
|
|
if let Some(access_jwt) = body["accessJwt"].as_str() {
|
|
return (access_jwt.to_string(), did);
|
|
}
|
|
let body_text: String = sqlx::query_scalar!(
|
|
"SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
|
|
&did
|
|
)
|
|
.fetch_one(&pool)
|
|
.await
|
|
.expect("Failed to get verification from comms_queue");
|
|
let lines: Vec<&str> = body_text.lines().collect();
|
|
let verification_code = lines
|
|
.iter()
|
|
.enumerate()
|
|
.find(|(_, line)| {
|
|
line.contains("verification code is:") || line.contains("code is:")
|
|
})
|
|
.and_then(|(i, _)| lines.get(i + 1).map(|s| s.trim().to_string()))
|
|
.or_else(|| {
|
|
body_text
|
|
.split_whitespace()
|
|
.find(|word| {
|
|
word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3
|
|
})
|
|
.map(|s| s.to_string())
|
|
})
|
|
.unwrap_or_else(|| body_text.clone());
|
|
|
|
let confirm_payload = json!({
|
|
"did": did,
|
|
"verificationCode": verification_code
|
|
});
|
|
let confirm_res = client
|
|
.post(format!(
|
|
"{}/xrpc/com.atproto.server.confirmSignup",
|
|
base_url().await
|
|
))
|
|
.json(&confirm_payload)
|
|
.send()
|
|
.await
|
|
.expect("confirmSignup request failed");
|
|
if confirm_res.status() == StatusCode::OK {
|
|
let confirm_body: Value = confirm_res
|
|
.json()
|
|
.await
|
|
.expect("Invalid JSON from confirmSignup");
|
|
let access_jwt = confirm_body["accessJwt"]
|
|
.as_str()
|
|
.expect("No accessJwt in confirmSignup response")
|
|
.to_string();
|
|
return (access_jwt, did);
|
|
}
|
|
last_error = format!("confirmSignup failed: {:?}", confirm_res.text().await);
|
|
continue;
|
|
}
|
|
last_error = format!("Status {}: {:?}", res.status(), res.text().await);
|
|
}
|
|
panic!("Failed to create account after 3 attempts: {}", last_error);
|
|
}
|