From e3efcb43ba1484cb2d6e6fd119a742d0bf260b22 Mon Sep 17 00:00:00 2001 From: Lewis Date: Sat, 21 Mar 2026 13:17:48 +0200 Subject: [PATCH] test(signal): add protocol store integration tests --- crates/tranquil-signal/src/tests.rs | 466 ++++++++++++++++++++++++++++ 1 file changed, 466 insertions(+) create mode 100644 crates/tranquil-signal/src/tests.rs diff --git a/crates/tranquil-signal/src/tests.rs b/crates/tranquil-signal/src/tests.rs new file mode 100644 index 0000000..9df74f5 --- /dev/null +++ b/crates/tranquil-signal/src/tests.rs @@ -0,0 +1,466 @@ +use presage::libsignal_service::{ + pre_keys::{KyberPreKeyStoreExt, PreKeysStore}, + prelude::{ProfileKey, SessionStoreExt}, + protocol::{ + DeviceId, Direction, GenericSignedPreKey, IdentityKeyPair, IdentityKeyStore, KeyPair, + KyberPreKeyId, KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore, + ProtocolAddress, SenderKeyStore, ServiceId, SessionRecord, SessionStore, SignedPreKeyId, + SignedPreKeyRecord, SignedPreKeyStore, Timestamp, + }, +}; +use presage::store::{ContentsStore, StateStore, Store}; +use sqlx::postgres::PgPoolOptions; +use uuid::Uuid; + +use crate::store::{IdentityType, PgProtocolStore, PgSignalStore}; + +async fn test_store() -> PgSignalStore { + let url = std::env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://postgres:postgres@127.0.0.1:5432/postgres".into()); + + let pool = PgPoolOptions::new() + .max_connections(5) + .connect(&url) + .await + .unwrap(); + + sqlx::query("DELETE FROM signal_kv") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_base_keys_seen") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_sender_keys") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_sessions") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_identities") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_kyber_pre_keys") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_signed_pre_keys") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_pre_keys") + .execute(&pool) + .await + .ok(); + sqlx::query("DELETE FROM signal_profile_keys") + .execute(&pool) + .await + .ok(); + + PgSignalStore::new(pool) +} + +fn protocol_store(store: &PgSignalStore, identity: IdentityType) -> PgProtocolStore { + PgProtocolStore::new(store.clone(), identity) +} + +#[tokio::test] +async fn state_store_registration_empty() { + let store = test_store().await; + + assert!(store.load_registration_data().await.unwrap().is_none()); + assert!(!store.is_registered().await); +} + +#[tokio::test] +async fn state_store_kv_roundtrip() { + let store = test_store().await; + + let value = b"test-data".to_vec(); + sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('test_key', $1)") + .bind(&value) + .execute(&store.db) + .await + .unwrap(); + + let loaded: Vec = sqlx::query_scalar("SELECT value FROM signal_kv WHERE key = 'test_key'") + .fetch_one(&store.db) + .await + .unwrap(); + assert_eq!(loaded, value); +} + +#[tokio::test] +async fn state_store_identity_keypairs() { + let store = test_store().await; + + let aci_pair = IdentityKeyPair::generate(&mut rand::rng()); + let pni_pair = IdentityKeyPair::generate(&mut rand::rng()); + + store.set_aci_identity_key_pair(aci_pair).await.unwrap(); + store.set_pni_identity_key_pair(pni_pair).await.unwrap(); + + let aci_store = protocol_store(&store, IdentityType::Aci); + let pni_store = protocol_store(&store, IdentityType::Pni); + + let loaded_aci = aci_store.get_identity_key_pair().await.unwrap(); + let loaded_pni = pni_store.get_identity_key_pair().await.unwrap(); + + assert_eq!(loaded_aci.serialize(), aci_pair.serialize()); + assert_eq!(loaded_pni.serialize(), pni_pair.serialize()); +} + +#[tokio::test] +async fn state_store_sender_certificate_roundtrip() { + let store = test_store().await; + assert!(store.sender_certificate().await.unwrap().is_none()); +} + +#[tokio::test] +async fn state_store_clear_registration() { + let mut store = test_store().await; + + sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('registration', $1)") + .bind(b"dummy-data".as_slice()) + .execute(&store.db) + .await + .unwrap(); + + let mut ps = protocol_store(&store, IdentityType::Aci); + let keypair = KeyPair::generate(&mut rand::rng()); + let record = PreKeyRecord::new(PreKeyId::from(1u32), &keypair); + ps.save_pre_key(PreKeyId::from(1u32), &record) + .await + .unwrap(); + + store.clear_registration().await.unwrap(); + + let remaining: Option> = + sqlx::query_scalar("SELECT value FROM signal_kv WHERE key = 'registration'") + .fetch_optional(&store.db) + .await + .unwrap(); + assert!(remaining.is_none()); + + assert!(ps.get_pre_key(PreKeyId::from(1u32)).await.is_err()); +} + +#[tokio::test] +async fn session_store_crud() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let addr = ProtocolAddress::new("test-uuid".into(), DeviceId::new(1).unwrap()); + assert!(ps.load_session(&addr).await.unwrap().is_none()); + + let record = SessionRecord::new_fresh(); + ps.store_session(&addr, &record).await.unwrap(); + + let loaded = ps.load_session(&addr).await.unwrap(); + assert!(loaded.is_some()); + + ps.store_session(&addr, &record).await.unwrap(); + let loaded2 = ps.load_session(&addr).await.unwrap(); + assert!(loaded2.is_some()); +} + +#[tokio::test] +async fn session_store_sub_devices() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let uuid = Uuid::new_v4(); + let service_id: ServiceId = presage::libsignal_service::protocol::Aci::from(uuid).into(); + let addr1 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(1).unwrap()); + let addr2 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(2).unwrap()); + let addr3 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(3).unwrap()); + + let record = SessionRecord::new_fresh(); + ps.store_session(&addr1, &record).await.unwrap(); + ps.store_session(&addr2, &record).await.unwrap(); + ps.store_session(&addr3, &record).await.unwrap(); + + let sub_devices = ps.get_sub_device_sessions(&service_id).await.unwrap(); + assert_eq!(sub_devices.len(), 2); + + let deleted = ps.delete_all_sessions(&service_id).await.unwrap(); + assert_eq!(deleted, 3); + + let sub_devices = ps.get_sub_device_sessions(&service_id).await.unwrap(); + assert!(sub_devices.is_empty()); +} + +#[tokio::test] +async fn pre_key_store_crud() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let keypair = KeyPair::generate(&mut rand::rng()); + let id = PreKeyId::from(42u32); + let record = PreKeyRecord::new(id, &keypair); + + ps.save_pre_key(id, &record).await.unwrap(); + let loaded = ps.get_pre_key(id).await.unwrap(); + assert_eq!(loaded.serialize().unwrap(), record.serialize().unwrap()); + + ps.remove_pre_key(id).await.unwrap(); + assert!(ps.get_pre_key(id).await.is_err()); +} + +#[tokio::test] +async fn pre_key_store_next_ids() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + assert_eq!(ps.next_pre_key_id().await.unwrap(), 1); + + let keypair = KeyPair::generate(&mut rand::rng()); + let record = PreKeyRecord::new(PreKeyId::from(5u32), &keypair); + ps.save_pre_key(PreKeyId::from(5u32), &record) + .await + .unwrap(); + + assert_eq!(ps.next_pre_key_id().await.unwrap(), 6); +} + +#[tokio::test] +async fn signed_pre_key_store_crud() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let keypair = KeyPair::generate(&mut rand::rng()); + let id = SignedPreKeyId::from(1u32); + let signature = keypair + .private_key + .calculate_signature(&keypair.public_key.serialize(), &mut rand::rng()) + .unwrap(); + let record = + SignedPreKeyRecord::new(id, Timestamp::from_epoch_millis(1000), &keypair, &signature); + + ps.save_signed_pre_key(id, &record).await.unwrap(); + let loaded = ps.get_signed_pre_key(id).await.unwrap(); + assert_eq!(loaded.serialize().unwrap(), record.serialize().unwrap()); + + assert_eq!(ps.signed_pre_keys_count().await.unwrap(), 1); + assert_eq!(ps.next_signed_pre_key_id().await.unwrap(), 2); +} + +#[tokio::test] +async fn kyber_pre_key_one_time_mark_used_deletes() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let keypair = KeyPair::generate(&mut rand::rng()); + let id = KyberPreKeyId::from(1u32); + let record = KyberPreKeyRecord::generate( + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, + id, + &keypair.private_key, + ) + .unwrap(); + + ps.save_kyber_pre_key(id, &record).await.unwrap(); + assert!(ps.get_kyber_pre_key(id).await.is_ok()); + + let ec_prekey_id = SignedPreKeyId::from(1u32); + ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) + .await + .unwrap(); + + assert!(ps.get_kyber_pre_key(id).await.is_err()); +} + +#[tokio::test] +async fn kyber_pre_key_last_resort_survives_mark_used() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let keypair = KeyPair::generate(&mut rand::rng()); + let id = KyberPreKeyId::from(1u32); + let record = KyberPreKeyRecord::generate( + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, + id, + &keypair.private_key, + ) + .unwrap(); + + ps.store_last_resort_kyber_pre_key(id, &record) + .await + .unwrap(); + assert!(ps.get_kyber_pre_key(id).await.is_ok()); + + let ec_prekey_id = SignedPreKeyId::from(1u32); + ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) + .await + .unwrap(); + + assert!(ps.get_kyber_pre_key(id).await.is_ok()); +} + +#[tokio::test] +async fn kyber_pre_key_last_resort_rejects_replayed_base_key() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let keypair = KeyPair::generate(&mut rand::rng()); + let id = KyberPreKeyId::from(1u32); + let record = KyberPreKeyRecord::generate( + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, + id, + &keypair.private_key, + ) + .unwrap(); + + ps.store_last_resort_kyber_pre_key(id, &record) + .await + .unwrap(); + + let ec_prekey_id = SignedPreKeyId::from(1u32); + ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) + .await + .unwrap(); + + let replay_result = ps + .mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) + .await; + assert!(replay_result.is_err()); +} + +#[tokio::test] +async fn kyber_pre_key_last_resort_list() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let keypair = KeyPair::generate(&mut rand::rng()); + let id = KyberPreKeyId::from(1u32); + let record = KyberPreKeyRecord::generate( + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, + id, + &keypair.private_key, + ) + .unwrap(); + + assert!( + ps.load_last_resort_kyber_pre_keys() + .await + .unwrap() + .is_empty() + ); + + ps.store_last_resort_kyber_pre_key(id, &record) + .await + .unwrap(); + + let last_resorts = ps.load_last_resort_kyber_pre_keys().await.unwrap(); + assert_eq!(last_resorts.len(), 1); +} + +#[tokio::test] +async fn identity_store_crud() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let addr = ProtocolAddress::new("test-addr".into(), DeviceId::new(1).unwrap()); + let keypair = IdentityKeyPair::generate(&mut rand::rng()); + let identity_key = keypair.identity_key(); + + assert!(ps.get_identity(&addr).await.unwrap().is_none()); + + ps.save_identity(&addr, identity_key).await.unwrap(); + let loaded = ps.get_identity(&addr).await.unwrap().unwrap(); + assert_eq!(loaded.serialize(), identity_key.serialize()); + + assert!( + ps.is_trusted_identity(&addr, identity_key, Direction::Receiving) + .await + .unwrap() + ); +} + +#[tokio::test] +async fn identity_store_aci_pni_isolation() { + let store = test_store().await; + let mut aci_store = protocol_store(&store, IdentityType::Aci); + let pni_store = protocol_store(&store, IdentityType::Pni); + + let addr = ProtocolAddress::new("same-addr".into(), DeviceId::new(1).unwrap()); + let keypair = IdentityKeyPair::generate(&mut rand::rng()); + + aci_store + .save_identity(&addr, keypair.identity_key()) + .await + .unwrap(); + + assert!(aci_store.get_identity(&addr).await.unwrap().is_some()); + assert!(pni_store.get_identity(&addr).await.unwrap().is_none()); +} + +#[tokio::test] +async fn sender_key_store_load_missing() { + let store = test_store().await; + let mut ps = protocol_store(&store, IdentityType::Aci); + + let sender = ProtocolAddress::new("sender-uuid".into(), DeviceId::new(1).unwrap()); + let dist_id = Uuid::new_v4(); + + assert!( + ps.load_sender_key(&sender, dist_id) + .await + .unwrap() + .is_none() + ); +} + +#[tokio::test] +async fn profile_key_store_roundtrip() { + let mut store = test_store().await; + + let uuid = Uuid::new_v4(); + let service_id: ServiceId = presage::libsignal_service::protocol::Aci::from(uuid).into(); + let key = ProfileKey { bytes: [42u8; 32] }; + + assert!(store.profile_key(&service_id).await.unwrap().is_none()); + + store.upsert_profile_key(&uuid, key).await.unwrap(); + + let loaded = store.profile_key(&service_id).await.unwrap().unwrap(); + assert_eq!(loaded.bytes, key.bytes); +} + +#[tokio::test] +async fn client_from_pool_returns_none_without_registration() { + let store = test_store().await; + let pool = store.db.clone(); + + let client = + crate::SignalClient::from_pool(&pool, tokio_util::sync::CancellationToken::new()).await; + assert!(client.is_none()); +} + +#[tokio::test] +async fn store_clear_removes_kv() { + let mut store = test_store().await; + + store + .set_aci_identity_key_pair(IdentityKeyPair::generate(&mut rand::rng())) + .await + .unwrap(); + + sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('registration', $1)") + .bind(b"dummy".as_slice()) + .execute(&store.db) + .await + .unwrap(); + + store.clear().await.unwrap(); + + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM signal_kv") + .fetch_one(&store.db) + .await + .unwrap(); + assert_eq!(count, 0); +}