From b66e4fe291b27989cc2f48d056488e666fa8d2d8 Mon Sep 17 00:00:00 2001 From: lewis Date: Fri, 12 Dec 2025 19:20:39 +0200 Subject: [PATCH] Remaining endpoints for MVP --- .env.example | 17 +- ...b9884e874735e76b50c42933a94d9fa70425e.json | 34 + ...415c8f3218755a4cd1730d5f800057b9b369f.json | 61 ++ ...e2c1ab276e486b4c2ebaaac299fa77414ef09.json | 22 + ...f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json | 30 - ...1ce88b9be04d6c7b4744891352a9a3bed4f3c.json | 30 - ...b7b0c241c734a5b7726019c2a59ae277caee6.json | 3 +- ...519ff631f160c97453e9e8aef04756cbc424e.json | 31 - ...582d20fdfa58a8e4c17c1628acc6b6f2ded15.json | 31 - ...2c1f84c5db138fdef9a8e8756771c30b66810.json | 76 ++ ...dc7b5fbd46d8b83c753319cba264ecf6d7df6.json | 22 + ...e748fa648c97f8109255120e969c957ff95bf.json | 3 +- ...5e395f71596a6d6a88d2be64ce86256a9860f.json | 46 ++ ...28771d527d4a683f4a273248e9cd91fdfd7a4.json | 14 + ...11ea0d6766d6a8840ae2b0589c42c0de79fc5.json | 14 + ...77c98667913225305df559559e36110516cfb.json | 40 + ...0d27c1afd00181ce495a5903a47fbfe243708.json | 58 ++ ...12b8d4cfb0b7555430eb45f16fe550fac4b43.json | 40 - ...5d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json | 23 + ...6bc0f5e42d3511f9daa00a0bb677ce48072ac.json | 12 + ...3adb44d1196296b7f93fad19b2d17548ed3de.json | 3 +- Cargo.lock | 207 ++++- Cargo.toml | 2 + README.md | 130 ++-- TODO.md | 86 ++- migrations/202512211700_add_2fa.sql | 16 + src/api/identity/account.rs | 67 +- src/api/identity/plc/sign.rs | 29 +- src/api/identity/plc/submit.rs | 45 +- src/api/mod.rs | 1 + src/api/repo/record/read.rs | 92 +-- src/api/repo/record/utils.rs | 28 + src/api/server/meta.rs | 8 + src/api/server/mod.rs | 2 +- src/api/server/password.rs | 32 +- src/api/server/session.rs | 31 + src/api/temp.rs | 48 ++ src/auth/token.rs | 98 +++ src/auth/verify.rs | 106 +++ src/circuit_breaker.rs | 307 ++++++++ src/crawlers.rs | 170 +++++ src/image/mod.rs | 304 ++++++++ src/lib.rs | 22 + src/main.rs | 43 +- src/notifications/mod.rs | 11 +- src/notifications/sender.rs | 297 +++++++- src/notifications/service.rs | 36 + src/notifications/types.rs | 1 + src/oauth/client.rs | 274 ++++++- src/oauth/db/device.rs | 62 ++ src/oauth/db/mod.rs | 9 +- src/oauth/db/two_factor.rs | 153 ++++ src/oauth/endpoints/authorize.rs | 713 ++++++++++++++++- src/oauth/endpoints/token/grants.rs | 2 +- src/oauth/endpoints/token/mod.rs | 24 + src/oauth/mod.rs | 2 + src/oauth/templates.rs | 719 ++++++++++++++++++ src/plc/mod.rs | 158 ++++ src/rate_limit.rs | 216 ++++++ src/state.rs | 18 + src/sync/crawl.rs | 4 - src/sync/deprecated.rs | 209 +++++ src/sync/mod.rs | 5 +- src/sync/relay_client.rs | 83 -- src/validation/mod.rs | 504 ++++++++++++ tests/image_processing.rs | 315 ++++++++ tests/import_with_verification.rs | 5 + tests/list_records_pagination.rs | 554 ++++++++++++++ tests/oauth.rs | 633 +++++++++++++++ tests/oauth_security.rs | 1 + tests/plc_migration.rs | 2 + tests/plc_validation.rs | 513 +++++++++++++ tests/record_validation.rs | 590 ++++++++++++++ tests/relay_client.rs | 86 --- tests/security_fixes.rs | 377 +++++++++ tests/sync_deprecated.rs | 307 ++++++++ 76 files changed, 8803 insertions(+), 564 deletions(-) create mode 100644 .sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json create mode 100644 .sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json create mode 100644 .sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json delete mode 100644 .sqlx/query-243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json delete mode 100644 .sqlx/query-2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c.json delete mode 100644 .sqlx/query-347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e.json delete mode 100644 .sqlx/query-4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15.json create mode 100644 .sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json create mode 100644 .sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json create mode 100644 .sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json create mode 100644 .sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json create mode 100644 .sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json create mode 100644 .sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json create mode 100644 .sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json delete mode 100644 .sqlx/query-91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43.json create mode 100644 .sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json create mode 100644 .sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json create mode 100644 migrations/202512211700_add_2fa.sql create mode 100644 src/api/temp.rs create mode 100644 src/circuit_breaker.rs create mode 100644 src/crawlers.rs create mode 100644 src/image/mod.rs create mode 100644 src/oauth/db/two_factor.rs create mode 100644 src/oauth/templates.rs create mode 100644 src/rate_limit.rs create mode 100644 src/sync/deprecated.rs delete mode 100644 src/sync/relay_client.rs create mode 100644 src/validation/mod.rs create mode 100644 tests/image_processing.rs create mode 100644 tests/list_records_pagination.rs create mode 100644 tests/plc_validation.rs create mode 100644 tests/record_validation.rs delete mode 100644 tests/relay_client.rs create mode 100644 tests/security_fixes.rs create mode 100644 tests/sync_deprecated.rs diff --git a/.env.example b/.env.example index 942de84..44bcedd 100644 --- a/.env.example +++ b/.env.example @@ -13,26 +13,27 @@ AWS_SECRET_ACCESS_KEY=minioadmin PDS_HOSTNAME=localhost:3000 PLC_URL=plc.directory -# A comma-separated list of WebSocket URLs for firehose relays to push updates to. -# e.g., RELAYS=wss://relay.bsky.social,wss://another-relay.com -RELAYS= +# A comma-separated list of relay URLs to notify via requestCrawl when we have updates. +# e.g., CRAWLERS=https://bsky.network +CRAWLERS= # Notification Service Configuration # At least one notification channel should be configured for user notifications to work. + # Email notifications (via sendmail/msmtp) # MAIL_FROM_ADDRESS=noreply@example.com # MAIL_FROM_NAME=My PDS # SENDMAIL_PATH=/usr/sbin/sendmail -# Discord notifications (not yet implemented) -# DISCORD_BOT_TOKEN=your-bot-token +# Discord notifications (via webhook) +# DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/... -# Telegram notifications (not yet implemented) +# Telegram notifications (via bot) # TELEGRAM_BOT_TOKEN=your-bot-token -# Signal notifications (not yet implemented) +# Signal notifications (via signal-cli) # SIGNAL_CLI_PATH=/usr/local/bin/signal-cli -# SIGNAL_PHONE_NUMBER=+1234567890 +# SIGNAL_SENDER_NUMBER=+1234567890 CARGO_MOMMYS_LITTLE=mister CARGO_MOMMYS_PRONOUNS=his diff --git a/.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json b/.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json new file mode 100644 index 0000000..74512ad --- /dev/null +++ b/.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json @@ -0,0 +1,34 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT preferred_notification_channel as \"channel: NotificationChannel\" FROM users WHERE did = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "channel: NotificationChannel", + "type_info": { + "Custom": { + "name": "notification_channel", + "kind": { + "Enum": [ + "email", + "discord", + "telegram", + "signal" + ] + } + } + } + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false + ] + }, + "hash": "0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e" +} diff --git a/.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json b/.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json new file mode 100644 index 0000000..873befe --- /dev/null +++ b/.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json @@ -0,0 +1,61 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)\n VALUES ($1, $2, $3, $4)\n RETURNING id, did, request_uri, code, attempts, created_at, expires_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "did", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "request_uri", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "code", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "attempts", + "type_info": "Int4" + }, + { + "ordinal": 5, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "expires_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text", + "Text", + "Text", + "Timestamptz" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f" +} diff --git a/.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json b/.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json new file mode 100644 index 0000000..5ddc8f3 --- /dev/null +++ b/.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT two_factor_enabled\n FROM users\n WHERE did = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "two_factor_enabled", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false + ] + }, + "hash": "180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09" +} diff --git a/.sqlx/query-243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json b/.sqlx/query-243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json deleted file mode 100644 index 2765d3e..0000000 --- a/.sqlx/query-243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "rkey", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "record_cid", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Int8" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1" -} diff --git a/.sqlx/query-2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c.json b/.sqlx/query-2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c.json deleted file mode 100644 index aa58028..0000000 --- a/.sqlx/query-2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "rkey", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "record_cid", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Int8" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c" -} diff --git a/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json b/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json index 4c7cf54..866b7e7 100644 --- a/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json +++ b/.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json @@ -36,7 +36,8 @@ "email_update", "account_deletion", "admin_email", - "plc_operation" + "plc_operation", + "two_factor_code" ] } } diff --git a/.sqlx/query-347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e.json b/.sqlx/query-347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e.json deleted file mode 100644 index cb26727..0000000 --- a/.sqlx/query-347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "rkey", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "record_cid", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Int8" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e" -} diff --git a/.sqlx/query-4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15.json b/.sqlx/query-4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15.json deleted file mode 100644 index aeab538..0000000 --- a/.sqlx/query-4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "rkey", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "record_cid", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Int8" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15" -} diff --git a/.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json b/.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json new file mode 100644 index 0000000..7601810 --- /dev/null +++ b/.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json @@ -0,0 +1,76 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, did, email, password_hash, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\",\n deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "did", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "email", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "password_hash", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "two_factor_enabled", + "type_info": "Bool" + }, + { + "ordinal": 5, + "name": "preferred_notification_channel: NotificationChannel", + "type_info": { + "Custom": { + "name": "notification_channel", + "kind": { + "Enum": [ + "email", + "discord", + "telegram", + "signal" + ] + } + } + } + }, + { + "ordinal": 6, + "name": "deactivated_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "takedown_ref", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + true + ] + }, + "hash": "458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810" +} diff --git a/.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json b/.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json new file mode 100644 index 0000000..027f158 --- /dev/null +++ b/.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE oauth_2fa_challenge\n SET attempts = attempts + 1\n WHERE id = $1\n RETURNING attempts\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "attempts", + "type_info": "Int4" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6" +} diff --git a/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json b/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json index 8e9d0e4..f8bf0e5 100644 --- a/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json +++ b/.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json @@ -36,7 +36,8 @@ "email_update", "account_deletion", "admin_email", - "plc_operation" + "plc_operation", + "two_factor_code" ] } } diff --git a/.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json b/.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json new file mode 100644 index 0000000..4b9b1db --- /dev/null +++ b/.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\"\n FROM users\n WHERE did = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "two_factor_enabled", + "type_info": "Bool" + }, + { + "ordinal": 2, + "name": "preferred_notification_channel: NotificationChannel", + "type_info": { + "Custom": { + "name": "notification_channel", + "kind": { + "Enum": [ + "email", + "discord", + "telegram", + "signal" + ] + } + } + } + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f" +} diff --git a/.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json b/.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json new file mode 100644 index 0000000..cfbc788 --- /dev/null +++ b/.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_2fa_challenge WHERE request_uri = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [] + }, + "hash": "6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4" +} diff --git a/.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json b/.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json new file mode 100644 index 0000000..37c529e --- /dev/null +++ b/.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_2fa_challenge WHERE id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5" +} diff --git a/.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json b/.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json new file mode 100644 index 0000000..04c9297 --- /dev/null +++ b/.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ORDER BY ad.updated_at DESC\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "did", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "handle", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "email", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "last_used_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb" +} diff --git a/.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json b/.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json new file mode 100644 index 0000000..1874789 --- /dev/null +++ b/.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json @@ -0,0 +1,58 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, did, request_uri, code, attempts, created_at, expires_at\n FROM oauth_2fa_challenge\n WHERE request_uri = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "did", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "request_uri", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "code", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "attempts", + "type_info": "Int4" + }, + { + "ordinal": 5, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "expires_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708" +} diff --git a/.sqlx/query-91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43.json b/.sqlx/query-91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43.json deleted file mode 100644 index 7a20bbb..0000000 --- a/.sqlx/query-91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT did, password_hash, deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "did", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "password_hash", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "deactivated_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 3, - "name": "takedown_ref", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Text" - ] - }, - "nullable": [ - false, - false, - true, - true - ] - }, - "hash": "91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43" -} diff --git a/.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json b/.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json new file mode 100644 index 0000000..6e152d7 --- /dev/null +++ b/.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT 1 as exists\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND ad.did = $2\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Int4" + } + ], + "parameters": { + "Left": [ + "Text", + "Text" + ] + }, + "nullable": [ + null + ] + }, + "hash": "a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b" +} diff --git a/.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json b/.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json new file mode 100644 index 0000000..289e029 --- /dev/null +++ b/.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac" +} diff --git a/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json b/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json index 5724dfb..c004720 100644 --- a/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json +++ b/.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json @@ -44,7 +44,8 @@ "email_update", "account_deletion", "admin_email", - "plc_operation" + "plc_operation", + "two_factor_code" ] } } diff --git a/Cargo.lock b/Cargo.lock index eb3e3fd..6806171 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -915,8 +915,10 @@ dependencies = [ "dotenvy", "ed25519-dalek", "futures", + "governor", "hkdf", "hmac", + "image", "ipld-core", "iroh-car", "jacquard", @@ -985,12 +987,24 @@ version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" +[[package]] +name = "bytemuck" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" + [[package]] name = "byteorder" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + [[package]] name = "bytes" version = "1.11.0" @@ -1157,6 +1171,12 @@ dependencies = [ "cc", ] +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + [[package]] name = "compression-codecs" version = "0.4.33" @@ -1819,6 +1839,15 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + [[package]] name = "ferroid" version = "0.8.7" @@ -1907,6 +1936,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -2055,6 +2090,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -2135,6 +2176,16 @@ dependencies = [ "polyval", ] +[[package]] +name = "gif" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5df2ba84018d80c213569363bdcd0c64e6933c67fe4c1d60ecf822971a3c35e" +dependencies = [ + "color_quant", + "weezl", +] + [[package]] name = "glob" version = "0.3.3" @@ -2169,6 +2220,29 @@ dependencies = [ "web-sys", ] +[[package]] +name = "governor" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e23d5986fd4364c2fb7498523540618b4b8d92eec6c36a02e565f66748e2f79" +dependencies = [ + "cfg-if", + "dashmap 6.1.0", + "futures-sink", + "futures-timer", + "futures-util", + "getrandom 0.3.4", + "hashbrown 0.16.1", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.9.2", + "smallvec", + "spinning_top", + "web-time", +] + [[package]] name = "group" version = "0.12.1" @@ -2260,7 +2334,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -2268,6 +2342,11 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", +] [[package]] name = "hashlink" @@ -2759,6 +2838,34 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "image" +version = "0.25.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a" +dependencies = [ + "bytemuck", + "byteorder-lite", + "color_quant", + "gif", + "image-webp", + "moxcms", + "num-traits", + "png", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-webp" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" +dependencies = [ + "byteorder-lite", + "quick-error", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -3476,6 +3583,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "moxcms" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80986bbbcf925ebd3be54c26613d861255284584501595cf418320c078945608" +dependencies = [ + "num-traits", + "pxfm", +] + [[package]] name = "multibase" version = "0.9.2" @@ -3553,6 +3670,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -3975,6 +4098,19 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "png" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0" +dependencies = [ + "bitflags", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "polyval" version = "0.6.2" @@ -4131,6 +4267,36 @@ dependencies = [ "unicase", ] +[[package]] +name = "pxfm" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8" +dependencies = [ + "num-traits", +] + +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + [[package]] name = "quinn" version = "0.11.9" @@ -4266,6 +4432,15 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d20581732dd76fa913c7dff1a2412b714afe3573e94d41c34719de73337cc8ab" +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -5033,6 +5208,15 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.6.0" @@ -6220,6 +6404,12 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + [[package]] name = "whoami" version = "1.6.1" @@ -6831,3 +7021,18 @@ dependencies = [ "quote", "syn 2.0.111", ] + +[[package]] +name = "zune-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "111f7d9820f05fd715df3144e254d6fc02ee4088b0644c0ffd0efc9e6d9d2773" + +[[package]] +name = "zune-jpeg" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f520eebad972262a1dde0ec455bce4f8b298b1e5154513de58c114c4c54303e8" +dependencies = [ + "zune-core", +] diff --git a/Cargo.toml b/Cargo.toml index ae9e1da..336722f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ chrono = { version = "0.4.42", features = ["serde"] } cid = "0.11.1" dotenvy = "0.15.7" futures = "0.3.30" +governor = "0.10" hkdf = "0.12" hmac = "0.12" aes-gcm = "0.10" @@ -47,6 +48,7 @@ tokio-tungstenite = { version = "0.28.0", features = ["native-tls"] } urlencoding = "2.1" uuid = { version = "1.19.0", features = ["v4", "fast-rng"] } iroh-car = "0.5.1" +image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } [features] external-infra = [] diff --git a/README.md b/README.md index 3ae31b3..547bc26 100644 --- a/README.md +++ b/README.md @@ -1,75 +1,103 @@ -# Lewis' BS PDS Sandbox +# BSPDS, a Personal Data Server -When I'm actually done then yeah let's make this into a proper official-looking repo perhaps under an official-looking account or something. +A production-grade Personal Data Server (PDS) implementation for the AT Protocol. -This project implements a Personal Data Server (PDS) implementation for the AT Protocol. +Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and is designed to be a complete drop-in replacement for Bluesky's reference PDS implementation. -Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and aims to be a complete drop-in replacement for Bluesky's reference PDS implementation. +## Features -In fact I aim to also implement a plugin system soon, so that we can add things onto our own PDSes on top of the default BS. +- Full AT Protocol support, all `com.atproto.*` endpoints implemented +- OAuth 2.1 Provider. PKCE, DPoP, Pushed Authorization Requests +- PostgreSQL, prod-ready database backend +- S3-compatible object storage for blobs; works with AWS S3, UpCloud object storage, self-hosted MinIO, etc. +- WebSocket `subscribeRepos` endpoint for real-time sync +- Crawler notifications via `requestCrawl` +- Multi-channel notifications: email, discord, telegram, signal +- Per-IP rate limiting on sensitive endpoints -I'm also taking ideas on what other PDSes lack, such as an on-PDS webpage that users can access to manage their records and preferences. +## Running Locally -:3 +Requires Rust installed locally. -# Running locally +Run PostgreSQL and S3-compatible object store (e.g., with podman/docker): -The reader will need rust installed locally. +```bash +podman compose up db objsto -d +``` -I personally run the postgres db, and an S3-compatible object store with podman compose up db objsto -d. +Run the PDS: -Run the PDS directly: +```bash +just run +``` - just run +## Configuration -Configuration is via environment variables: +### Required - DATABASE_URL postgres connection string - S3_BUCKET blob storage bucket name - S3_ENDPOINT S3 endpoint URL (for MinIO etc) - AWS_ACCESS_KEY_ID S3 credentials - AWS_SECRET_ACCESS_KEY - AWS_REGION - PDS_HOSTNAME public hostname of this PDS - APPVIEW_URL appview to proxy unimplemented endpoints to - RELAYS comma-separated list of relay WebSocket URLs +| Variable | Description | +|----------|-------------| +| `DATABASE_URL` | PostgreSQL connection string | +| `S3_BUCKET` | Blob storage bucket name | +| `S3_ENDPOINT` | S3 endpoint URL (for MinIO, etc.) | +| `AWS_ACCESS_KEY_ID` | S3 credentials | +| `AWS_SECRET_ACCESS_KEY` | S3 credentials | +| `AWS_REGION` | S3 region | +| `PDS_HOSTNAME` | Public hostname of this PDS | +| `JWT_SECRET` | Secret for OAuth token signing (HS256) | +| `KEY_ENCRYPTION_KEY` | Key for encrypting user signing keys (AES-256-GCM) | -Optional email stuff: +### Optional - MAIL_FROM_ADDRESS sender address (enables email notifications) - MAIL_FROM_NAME sender name (default: BSPDS) - SENDMAIL_PATH path to sendmail binary +| Variable | Description | +|----------|-------------| +| `APPVIEW_URL` | Appview URL to proxy unimplemented endpoints to | +| `CRAWLERS` | Comma-separated list of relay URLs to notify via `requestCrawl` | -Development +### Notifications - just shows available commands - just test run tests (spins up postgres and minio via testcontainers) - just lint clippy + fmt check - just db-reset drop and recreate local database +At least one channel should be configured for user notifications (password reset, email verification, etc.): -The test suite uses testcontainers so you don't need to set up anything manually for running tests. +| Variable | Description | +|----------|-------------| +| `MAIL_FROM_ADDRESS` | Email sender address (enables email via sendmail) | +| `MAIL_FROM_NAME` | Email sender name (default: "BSPDS") | +| `SENDMAIL_PATH` | Path to sendmail binary (default: /usr/sbin/sendmail) | +| `DISCORD_WEBHOOK_URL` | Discord webhook URL for notifications | +| `TELEGRAM_BOT_TOKEN` | Telegram bot token for notifications | +| `SIGNAL_CLI_PATH` | Path to signal-cli binary | +| `SIGNAL_SENDER_NUMBER` | Signal sender phone number (+1234567890 format) | -## What's implemented +## Development -Most of the com.atproto.* namespace is done. Server endpoints, repo operations, sync, identity, admin, moderation. The firehose websocket works. OAuth is not done yet. +```bash +just # Show available commands +just test # Run tests (auto-starts postgres/minio, runs nextest) +just lint # Clippy + fmt check +just db-reset # Drop and recreate local database +``` -See TODO.md for the full breakdown of what's done and what's left. +## Project Structure -Structure +``` +src/ + main.rs Server entrypoint + lib.rs Router setup + state.rs AppState (db pool, stores, rate limiters, circuit breakers) + api/ XRPC handlers organized by namespace + auth/ JWT authentication (ES256K per-user keys) + oauth/ OAuth 2.1 provider (HS256 server-wide) + repo/ PostgreSQL block store + storage/ S3 blob storage + sync/ Firehose, CAR export, crawler notifications + notifications/ Multi-channel notification service + plc/ PLC directory client + circuit_breaker/ Circuit breaker for external services + rate_limit/ Per-IP rate limiting +tests/ Integration tests +migrations/ SQLx migrations +``` - src/ - main.rs server entrypoint - lib.rs router setup - state.rs app state (db pool, stores) - api/ XRPC handlers organized by namespace - auth/ JWT handling - repo/ postgres block store - storage/ S3 blob storage - sync/ firehose, relay clients - notifications/ email service - tests/ integration tests - migrations/ sqlx migrations +## License -License - -idk +TBD diff --git a/TODO.md b/TODO.md index da4cad6..5fb306d 100644 --- a/TODO.md +++ b/TODO.md @@ -81,6 +81,9 @@ Lewis' corrected big boy todofile - [x] Implement `com.atproto.sync.listBlobs`. - [x] Crawler Interaction - [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us). +- [x] Deprecated Sync Endpoints (for compatibility) + - [x] Implement `com.atproto.sync.getCheckout` (deprecated). + - [x] Implement `com.atproto.sync.getHead` (deprecated). ## Identity (`com.atproto.identity`) - [x] Resolution @@ -108,14 +111,17 @@ Lewis' corrected big boy todofile - [x] Implement `com.atproto.moderation.createReport`. ## Temp Namespace (`com.atproto.temp`) -- [ ] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups). +- [x] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups). + +## Misc HTTP Endpoints +- [x] Implement `/robots.txt` endpoint. ## OAuth 2.1 Support Full OAuth 2.1 provider for ATProto native app authentication. - [x] OAuth Provider Core - [x] Implement `/.well-known/oauth-protected-resource` metadata endpoint. - [x] Implement `/.well-known/oauth-authorization-server` metadata endpoint. - - [x] Implement `/oauth/authorize` authorization endpoint (headless JSON mode). + - [x] Implement `/oauth/authorize` authorization endpoint (with login UI). - [x] Implement `/oauth/par` Pushed Authorization Request endpoint. - [x] Implement `/oauth/token` token endpoint (authorization_code + refresh_token grants). - [x] Implement `/oauth/jwks` JSON Web Key Set endpoint. @@ -132,12 +138,13 @@ Full OAuth 2.1 provider for ATProto native app authentication. - [x] Client metadata fetching and validation. - [x] PKCE (S256) enforcement. - [x] OAuth token verification extractor for protected resources. -- [ ] Authorization UI templates (currently headless-only, returns JSON for programmatic flows). -- [ ] Implement `private_key_jwt` signature verification (currently rejects with clear error). +- [x] Authorization UI templates (HTML login form). +- [x] Implement `private_key_jwt` signature verification with async JWKS fetching. +- [x] HS256 JWT support (matches reference PDS). ## OAuth Security Notes -I've tried to ensure that this codebase is not vulnerable to the following: +Security measures implemented: - Constant-time comparison for signature verification (prevents timing attacks) - HMAC-SHA256 for access token signing with configurable secret @@ -151,12 +158,12 @@ I've tried to ensure that this codebase is not vulnerable to the following: - All database queries use parameterized statements (no SQL injection) - Deactivated/taken-down accounts blocked from OAuth authorization - Client ID validation on token exchange (defense-in-depth against cross-client attacks) +- HTML escaping in OAuth templates (XSS prevention) ### Auth Notes -- Algorithm choice: Using ES256K (secp256k1 ECDSA) with per-user keys. Ref PDS uses HS256 (HMAC) with single server key. Our approach provides better key isolation but differs from reference implementation. - - [ ] Support the ref PDS HS256 system too. -- Token storage: Now storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks. -- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from MASTER_KEY environment variable. Migration-safe: supports both encrypted (version 1) and plaintext (version 0) keys. +- Dual algorithm support: ES256K (secp256k1 ECDSA) with per-user keys AND HS256 (HMAC) for compatibility with reference PDS. +- Token storage: Storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks. +- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from KEY_ENCRYPTION_KEY environment variable. ## PDS-Level App Endpoints These endpoints need to be implemented at the PDS level (not just proxied to appview). @@ -178,22 +185,6 @@ These are implemented at PDS level to enable local-first reads (read-after-write ### Notification (`app.bsky.notification`) - [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied). -## Deprecated Sync Endpoints (for compatibility) -- [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility). -- [ ] Implement `com.atproto.sync.getHead` (deprecated, still needed for compatibility). - -## Misc HTTP Endpoints -- [ ] Implement `/robots.txt` endpoint. - -## Record Schema Validation -- [ ] Handle this generically. - -## Preference Storage -User preferences (for app.bsky.actor.getPreferences/putPreferences): -- [x] Create preferences table for storing user app preferences. -- [x] Implement `app.bsky.actor.getPreferences` handler (read from postgres, proxy fallback). -- [x] Implement `app.bsky.actor.putPreferences` handler (write to postgres). - ## Infrastructure & Core Components - [x] Sequencer (Event Log) - [x] Implement a `Sequencer` (backed by `repo_seq` table). @@ -206,32 +197,53 @@ User preferences (for app.bsky.actor.getPreferences/putPreferences): - [x] Manage Repo Root in `repos` table. - [x] Implement Atomic Repo Transactions. - [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction. - - [ ] Implement concurrency control (row-level locking on `repos` table) to prevent concurrent writes to the same repo. + - [x] Implement concurrency control (row-level locking via FOR UPDATE). - [ ] DID Cache - [ ] Implement caching layer for DID resolution (Redis or in-memory). - [ ] Handle cache invalidation/expiry. -- [ ] Background Jobs - - [ ] Implement `Crawlers` service (debounce notifications to relays). +- [x] Crawlers Service + - [x] Implement `Crawlers` service (debounce notifications to relays). + - [x] 20-minute notification debounce. + - [x] Circuit breaker for relay failures. - [x] Notification Service - [x] Queue-based notification system with database table - [x] Background worker polling for pending notifications - [x] Extensible sender trait for multiple channels - [x] Email sender via OS sendmail/msmtp - - [ ] Discord bot sender - - [ ] Telegram bot sender - - [ ] Signal bot sender + - [x] Discord webhook sender + - [x] Telegram bot sender + - [x] Signal CLI sender - [x] Helper functions for common notification types (welcome, password reset, email verification, etc.) - [x] Respect user's `preferred_notification_channel` setting for non-email-specific notifications -- [ ] Image Processing - - [ ] Implement image resize/formatting pipeline (for blob uploads). +- [x] Image Processing + - [x] Implement image resize/formatting pipeline (for blob uploads). + - [x] WebP conversion for thumbnails. + - [x] EXIF stripping. + - [x] File size limits (10MB default). - [x] IPLD & MST - [x] Implement Merkle Search Tree logic for repo signing. - [x] Implement CAR (Content Addressable Archive) encoding/decoding. -- [ ] Validation - - [ ] DID PLC Operations (Sign rotation keys). -- [ ] Fix any remaining TODOs in the code, everywhere, full stop. + - [x] Cycle detection in CAR export. +- [x] Rate Limiting + - [x] Per-IP rate limiting on login (10/min). + - [x] Per-IP rate limiting on OAuth token endpoint (30/min). + - [x] Per-IP rate limiting on password reset (5/hour). + - [x] Per-IP rate limiting on account creation (10/hour). +- [x] Circuit Breakers + - [x] PLC directory circuit breaker (5 failures → open, 60s timeout). + - [x] Relay notification circuit breaker (10 failures → open, 30s timeout). +- [x] Security Hardening + - [x] Email header injection prevention (CRLF sanitization). + - [x] Signal command injection prevention (phone number validation). + - [x] Constant-time signature comparison. + - [x] SSRF protection for outbound requests. -## Web Management UI +## Lewis' fabulous mini-list of remaining TODOs +- [ ] DID resolution caching (valkey). +- [ ] Record schema validation (generic validation framework). +- [ ] Fix any remaining TODOs in the code. + +## Future: Web Management UI A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers. ### Architecture diff --git a/migrations/202512211700_add_2fa.sql b/migrations/202512211700_add_2fa.sql new file mode 100644 index 0000000..8a6c332 --- /dev/null +++ b/migrations/202512211700_add_2fa.sql @@ -0,0 +1,16 @@ +ALTER TABLE users ADD COLUMN two_factor_enabled BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TYPE notification_type ADD VALUE 'two_factor_code'; + +CREATE TABLE oauth_2fa_challenge ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE, + request_uri TEXT NOT NULL, + code TEXT NOT NULL, + attempts INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes' +); + +CREATE INDEX idx_oauth_2fa_challenge_request_uri ON oauth_2fa_challenge(request_uri); +CREATE INDEX idx_oauth_2fa_challenge_expires ON oauth_2fa_challenge(expires_at); diff --git a/src/api/identity/account.rs b/src/api/identity/account.rs index be166b2..0dbf59f 100644 --- a/src/api/identity/account.rs +++ b/src/api/identity/account.rs @@ -3,7 +3,7 @@ use crate::state::AppState; use axum::{ Json, extract::State, - http::StatusCode, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; use bcrypt::{DEFAULT_COST, hash}; @@ -16,6 +16,22 @@ use serde_json::json; use std::sync::Arc; use tracing::{error, info, warn}; +fn extract_client_ip(headers: &HeaderMap) -> String { + if let Some(forwarded) = headers.get("x-forwarded-for") { + if let Ok(value) = forwarded.to_str() { + if let Some(first_ip) = value.split(',').next() { + return first_ip.trim().to_string(); + } + } + } + if let Some(real_ip) = headers.get("x-real-ip") { + if let Ok(value) = real_ip.to_str() { + return value.trim().to_string(); + } + } + "unknown".to_string() +} + #[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct CreateAccountInput { @@ -38,9 +54,24 @@ pub struct CreateAccountOutput { pub async fn create_account( State(state): State, + headers: HeaderMap, Json(input): Json, ) -> Response { info!("create_account called"); + + let client_ip = extract_client_ip(&headers); + if state.rate_limiters.account_creation.check_key(&client_ip).is_err() { + warn!(ip = %client_ip, "Account creation rate limit exceeded"); + return ( + StatusCode::TOO_MANY_REQUESTS, + Json(json!({ + "error": "RateLimitExceeded", + "message": "Too many account creation attempts. Please try again later." + })), + ) + .into_response(); + } + if input.handle.contains('!') || input.handle.contains('@') { return ( StatusCode::BAD_REQUEST, @@ -184,8 +215,40 @@ pub async fn create_account( let user_id = match user_insert { Ok(row) => row.id, Err(e) => { + if let Some(db_err) = e.as_database_error() { + if db_err.code().as_deref() == Some("23505") { + let constraint = db_err.constraint().unwrap_or(""); + if constraint.contains("handle") || constraint.contains("users_handle") { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "HandleNotAvailable", + "message": "Handle already taken" + })), + ) + .into_response(); + } else if constraint.contains("email") || constraint.contains("users_email") { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "InvalidEmail", + "message": "Email already registered" + })), + ) + .into_response(); + } else if constraint.contains("did") || constraint.contains("users_did") { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "AccountAlreadyExists", + "message": "An account with this DID already exists" + })), + ) + .into_response(); + } + } + } error!("Error inserting user: {:?}", e); - // TODO: Check for unique constraint violation on email/did specifically return ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"})), diff --git a/src/api/identity/plc/sign.rs b/src/api/identity/plc/sign.rs index efb57aa..06d5cf1 100644 --- a/src/api/identity/plc/sign.rs +++ b/src/api/identity/plc/sign.rs @@ -1,6 +1,7 @@ use crate::api::ApiError; +use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError}; use crate::plc::{ - create_update_op, sign_operation, PlcClient, PlcError, PlcService, + create_update_op, sign_operation, PlcClient, PlcError, PlcOpOrTombstone, PlcService, }; use crate::state::AppState; use axum::{ @@ -14,7 +15,7 @@ use k256::ecdsa::SigningKey; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashMap; -use tracing::{error, info}; +use tracing::{error, info, warn}; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -166,9 +167,27 @@ pub async fn sign_plc_operation( }; let plc_client = PlcClient::new(None); - let last_op = match plc_client.get_last_op(did).await { + let did_clone = did.clone(); + let result: Result> = with_circuit_breaker( + &state.circuit_breakers.plc_directory, + || async { plc_client.get_last_op(&did_clone).await }, + ) + .await; + + let last_op = match result { Ok(op) => op, - Err(PlcError::NotFound) => { + Err(CircuitBreakerError::CircuitOpen(e)) => { + warn!("PLC directory circuit breaker open: {}", e); + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ + "error": "ServiceUnavailable", + "message": "PLC directory service temporarily unavailable" + })), + ) + .into_response(); + } + Err(CircuitBreakerError::OperationFailed(PlcError::NotFound)) => { return ( StatusCode::NOT_FOUND, Json(json!({ @@ -178,7 +197,7 @@ pub async fn sign_plc_operation( ) .into_response(); } - Err(e) => { + Err(CircuitBreakerError::OperationFailed(e)) => { error!("Failed to fetch PLC operation: {:?}", e); return ( StatusCode::BAD_GATEWAY, diff --git a/src/api/identity/plc/submit.rs b/src/api/identity/plc/submit.rs index d91f4c5..6aedc47 100644 --- a/src/api/identity/plc/submit.rs +++ b/src/api/identity/plc/submit.rs @@ -1,5 +1,6 @@ use crate::api::ApiError; -use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient}; +use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError}; +use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient, PlcError}; use crate::state::AppState; use axum::{ extract::State, @@ -183,16 +184,38 @@ pub async fn submit_plc_operation( } let plc_client = PlcClient::new(None); - if let Err(e) = plc_client.send_operation(did, &input.operation).await { - error!("Failed to submit PLC operation: {:?}", e); - return ( - StatusCode::BAD_GATEWAY, - Json(json!({ - "error": "UpstreamError", - "message": format!("Failed to submit to PLC directory: {}", e) - })), - ) - .into_response(); + let operation_clone = input.operation.clone(); + let did_clone = did.clone(); + let result: Result<(), CircuitBreakerError> = with_circuit_breaker( + &state.circuit_breakers.plc_directory, + || async { plc_client.send_operation(&did_clone, &operation_clone).await }, + ) + .await; + + match result { + Ok(()) => {} + Err(CircuitBreakerError::CircuitOpen(e)) => { + warn!("PLC directory circuit breaker open: {}", e); + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ + "error": "ServiceUnavailable", + "message": "PLC directory service temporarily unavailable" + })), + ) + .into_response(); + } + Err(CircuitBreakerError::OperationFailed(e)) => { + error!("Failed to submit PLC operation: {:?}", e); + return ( + StatusCode::BAD_GATEWAY, + Json(json!({ + "error": "UpstreamError", + "message": format!("Failed to submit to PLC directory: {}", e) + })), + ) + .into_response(); + } } if let Err(e) = sqlx::query!( diff --git a/src/api/mod.rs b/src/api/mod.rs index 61bba8e..018c563 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -10,6 +10,7 @@ pub mod proxy_client; pub mod read_after_write; pub mod repo; pub mod server; +pub mod temp; pub mod validation; pub use error::ApiError; diff --git a/src/api/repo/record/read.rs b/src/api/repo/record/read.rs index e1cd268..7f0701e 100644 --- a/src/api/repo/record/read.rs +++ b/src/api/repo/record/read.rs @@ -167,57 +167,57 @@ pub async fn list_records( let limit = input.limit.unwrap_or(50).clamp(1, 100); let reverse = input.reverse.unwrap_or(false); - - // Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination - // TODO: Implement rkeyStart/End and correct cursor logic - let limit_i64 = limit as i64; - let rows_res = if let Some(cursor) = &input.cursor { - if reverse { - sqlx::query!( - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4", - user_id, - input.collection, - cursor, - limit_i64 - ) + let order = if reverse { "ASC" } else { "DESC" }; + + let rows_res: Result, sqlx::Error> = if let Some(cursor) = &input.cursor { + let comparator = if reverse { ">" } else { "<" }; + let query = format!( + "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey {} $3 ORDER BY rkey {} LIMIT $4", + comparator, order + ); + sqlx::query_as(&query) + .bind(user_id) + .bind(&input.collection) + .bind(cursor) + .bind(limit_i64) .fetch_all(&state.db) .await - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::>()) - } else { - sqlx::query!( - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4", - user_id, - input.collection, - cursor, - limit_i64 - ) - .fetch_all(&state.db) - .await - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::>()) - } } else { - if reverse { - sqlx::query!( - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3", - user_id, - input.collection, - limit_i64 - ) - .fetch_all(&state.db) - .await - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::>()) - } else { - sqlx::query!( - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3", - user_id, - input.collection, - limit_i64 - ) - .fetch_all(&state.db) - .await - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::>()) + let mut conditions = vec!["repo_id = $1", "collection = $2"]; + let mut param_idx = 3; + + if input.rkey_start.is_some() { + conditions.push("rkey > $3"); + param_idx += 1; } + + if input.rkey_end.is_some() { + conditions.push(if param_idx == 3 { "rkey < $3" } else { "rkey < $4" }); + param_idx += 1; + } + + let limit_idx = param_idx; + + let query = format!( + "SELECT rkey, record_cid FROM records WHERE {} ORDER BY rkey {} LIMIT ${}", + conditions.join(" AND "), + order, + limit_idx + ); + + let mut query_builder = sqlx::query_as::<_, (String, String)>(&query) + .bind(user_id) + .bind(&input.collection); + + if let Some(start) = &input.rkey_start { + query_builder = query_builder.bind(start); + } + if let Some(end) = &input.rkey_end { + query_builder = query_builder.bind(end); + } + + query_builder.bind(limit_i64).fetch_all(&state.db).await }; let rows = match rows_res { diff --git a/src/api/repo/record/utils.rs b/src/api/repo/record/utils.rs index 5dd8cb5..d74d7ff 100644 --- a/src/api/repo/record/utils.rs +++ b/src/api/repo/record/utils.rs @@ -58,6 +58,34 @@ pub async fn commit_and_log( let mut tx = state.db.begin().await .map_err(|e| format!("Failed to begin transaction: {}", e))?; + let lock_result = sqlx::query!( + "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT", + user_id + ) + .fetch_optional(&mut *tx) + .await; + + match lock_result { + Err(e) => { + if let Some(db_err) = e.as_database_error() { + if db_err.code().as_deref() == Some("55P03") { + return Err("ConcurrentModification: Another request is modifying this repo".to_string()); + } + } + return Err(format!("Failed to acquire repo lock: {}", e)); + } + Ok(Some(row)) => { + if let Some(expected_root) = ¤t_root_cid { + if row.repo_root_cid != expected_root.to_string() { + return Err("ConcurrentModification: Repo has been modified since last read".to_string()); + } + } + } + Ok(None) => { + return Err("Repo not found".to_string()); + } + } + sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id) .execute(&mut *tx) .await diff --git a/src/api/server/meta.rs b/src/api/server/meta.rs index 9781dd2..6e52070 100644 --- a/src/api/server/meta.rs +++ b/src/api/server/meta.rs @@ -4,6 +4,14 @@ use serde_json::json; use tracing::error; +pub async fn robots_txt() -> impl IntoResponse { + ( + StatusCode::OK, + [("content-type", "text/plain")], + "# Hello!\n\n# Crawling the public API is allowed\nUser-agent: *\nAllow: /\n", + ) +} + pub async fn describe_server() -> impl IntoResponse { let domains_str = std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string()); diff --git a/src/api/server/mod.rs b/src/api/server/mod.rs index 64cdb1d..475010e 100644 --- a/src/api/server/mod.rs +++ b/src/api/server/mod.rs @@ -15,7 +15,7 @@ pub use account_status::{ pub use app_password::{create_app_password, list_app_passwords, revoke_app_password}; pub use email::{confirm_email, request_email_update, update_email}; pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes}; -pub use meta::{describe_server, health}; +pub use meta::{describe_server, health, robots_txt}; pub use password::{request_password_reset, reset_password}; pub use service_auth::get_service_auth; pub use session::{create_session, delete_session, get_session, refresh_session}; diff --git a/src/api/server/password.rs b/src/api/server/password.rs index 52339f2..1ae645d 100644 --- a/src/api/server/password.rs +++ b/src/api/server/password.rs @@ -2,7 +2,7 @@ use crate::state::AppState; use axum::{ Json, extract::State, - http::StatusCode, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; use bcrypt::{hash, DEFAULT_COST}; @@ -15,6 +15,22 @@ fn generate_reset_code() -> String { crate::util::generate_token_code() } +fn extract_client_ip(headers: &HeaderMap) -> String { + if let Some(forwarded) = headers.get("x-forwarded-for") { + if let Ok(value) = forwarded.to_str() { + if let Some(first_ip) = value.split(',').next() { + return first_ip.trim().to_string(); + } + } + } + if let Some(real_ip) = headers.get("x-real-ip") { + if let Ok(value) = real_ip.to_str() { + return value.trim().to_string(); + } + } + "unknown".to_string() +} + #[derive(Deserialize)] pub struct RequestPasswordResetInput { pub email: String, @@ -22,8 +38,22 @@ pub struct RequestPasswordResetInput { pub async fn request_password_reset( State(state): State, + headers: HeaderMap, Json(input): Json, ) -> Response { + let client_ip = extract_client_ip(&headers); + if state.rate_limiters.password_reset.check_key(&client_ip).is_err() { + warn!(ip = %client_ip, "Password reset rate limit exceeded"); + return ( + StatusCode::TOO_MANY_REQUESTS, + Json(json!({ + "error": "RateLimitExceeded", + "message": "Too many password reset requests. Please try again later." + })), + ) + .into_response(); + } + let email = input.email.trim().to_lowercase(); if email.is_empty() { return ( diff --git a/src/api/server/session.rs b/src/api/server/session.rs index 56f27c1..f61f748 100644 --- a/src/api/server/session.rs +++ b/src/api/server/session.rs @@ -4,6 +4,7 @@ use crate::state::AppState; use axum::{ Json, extract::State, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; use bcrypt::verify; @@ -11,6 +12,22 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use tracing::{error, info, warn}; +fn extract_client_ip(headers: &HeaderMap) -> String { + if let Some(forwarded) = headers.get("x-forwarded-for") { + if let Ok(value) = forwarded.to_str() { + if let Some(first_ip) = value.split(',').next() { + return first_ip.trim().to_string(); + } + } + } + if let Some(real_ip) = headers.get("x-real-ip") { + if let Ok(value) = real_ip.to_str() { + return value.trim().to_string(); + } + } + "unknown".to_string() +} + #[derive(Deserialize)] pub struct CreateSessionInput { pub identifier: String, @@ -28,10 +45,24 @@ pub struct CreateSessionOutput { pub async fn create_session( State(state): State, + headers: HeaderMap, Json(input): Json, ) -> Response { info!("create_session called"); + let client_ip = extract_client_ip(&headers); + if state.rate_limiters.login.check_key(&client_ip).is_err() { + warn!(ip = %client_ip, "Login rate limit exceeded"); + return ( + StatusCode::TOO_MANY_REQUESTS, + Json(json!({ + "error": "RateLimitExceeded", + "message": "Too many login attempts. Please try again later." + })), + ) + .into_response(); + } + let row = match sqlx::query!( "SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1", input.identifier diff --git a/src/api/temp.rs b/src/api/temp.rs new file mode 100644 index 0000000..9af4a52 --- /dev/null +++ b/src/api/temp.rs @@ -0,0 +1,48 @@ +use axum::{ + Json, + extract::State, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; +use serde::Serialize; +use serde_json::json; + +use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; +use crate::state::AppState; + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CheckSignupQueueOutput { + pub activated: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub place_in_queue: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub estimated_time_ms: Option, +} + +pub async fn check_signup_queue( + State(state): State, + headers: HeaderMap, +) -> Response { + if let Some(token) = extract_bearer_token_from_header( + headers.get("Authorization").and_then(|h| h.to_str().ok()) + ) { + if let Ok(user) = validate_bearer_token(&state.db, &token).await { + if user.is_oauth { + return ( + StatusCode::FORBIDDEN, + Json(json!({ + "error": "Forbidden", + "message": "OAuth credentials are not supported for this endpoint" + })), + ).into_response(); + } + } + } + + Json(CheckSignupQueueOutput { + activated: true, + place_in_queue: None, + estimated_time_ms: None, + }).into_response() +} diff --git a/src/auth/token.rs b/src/auth/token.rs index 0914b69..53eb857 100644 --- a/src/auth/token.rs +++ b/src/auth/token.rs @@ -3,9 +3,13 @@ use anyhow::Result; use base64::Engine as _; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use chrono::{DateTime, Duration, Utc}; +use hmac::{Hmac, Mac}; use k256::ecdsa::{Signature, SigningKey, signature::Signer}; +use sha2::Sha256; use uuid; +type HmacSha256 = Hmac; + pub const TOKEN_TYPE_ACCESS: &str = "at+jwt"; pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt"; pub const TOKEN_TYPE_SERVICE: &str = "jwt"; @@ -118,3 +122,97 @@ fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result< Ok(format!("{}.{}", message, signature_b64)) } + +pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result { + Ok(create_access_token_hs256_with_metadata(did, secret)?.token) +} + +pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result { + Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token) +} + +pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result { + create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120)) +} + +pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result { + create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90)) +} + +pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result { + let expiration = Utc::now() + .checked_add_signed(Duration::seconds(60)) + .expect("valid timestamp") + .timestamp(); + + let claims = Claims { + iss: did.to_owned(), + sub: did.to_owned(), + aud: aud.to_owned(), + exp: expiration as usize, + iat: Utc::now().timestamp() as usize, + scope: None, + lxm: Some(lxm.to_string()), + jti: uuid::Uuid::new_v4().to_string(), + }; + + sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret) +} + +fn create_hs256_token_with_metadata( + did: &str, + scope: &str, + typ: &str, + secret: &[u8], + duration: Duration, +) -> Result { + let expires_at = Utc::now() + .checked_add_signed(duration) + .expect("valid timestamp"); + let expiration = expires_at.timestamp(); + let jti = uuid::Uuid::new_v4().to_string(); + + let claims = Claims { + iss: did.to_owned(), + sub: did.to_owned(), + aud: format!( + "did:web:{}", + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) + ), + exp: expiration as usize, + iat: Utc::now().timestamp() as usize, + scope: Some(scope.to_string()), + lxm: None, + jti: jti.clone(), + }; + + let token = sign_claims_hs256(claims, typ, secret)?; + Ok(TokenWithMetadata { + token, + jti, + expires_at, + }) +} + +fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result { + let header = Header { + alg: "HS256".to_string(), + typ: typ.to_string(), + }; + + let header_json = serde_json::to_string(&header)?; + let claims_json = serde_json::to_string(&claims)?; + + let header_b64 = URL_SAFE_NO_PAD.encode(header_json); + let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); + + let message = format!("{}.{}", header_b64, claims_b64); + + let mut mac = HmacSha256::new_from_slice(secret) + .map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?; + mac.update(message.as_bytes()); + let signature = mac.finalize().into_bytes(); + let signature_b64 = URL_SAFE_NO_PAD.encode(signature); + + Ok(format!("{}.{}", message, signature_b64)) +} diff --git a/src/auth/verify.rs b/src/auth/verify.rs index 41e1a0c..51917e1 100644 --- a/src/auth/verify.rs +++ b/src/auth/verify.rs @@ -4,7 +4,12 @@ use anyhow::{Context, Result, anyhow}; use base64::Engine as _; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use chrono::Utc; +use hmac::{Hmac, Mac}; use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; +use sha2::Sha256; +use subtle::ConstantTimeEq; + +type HmacSha256 = Hmac; pub fn get_did_from_token(token: &str) -> Result { let parts: Vec<&str> = token.split('.').collect(); @@ -63,6 +68,24 @@ pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result Result> { + verify_token_hs256_internal( + token, + secret, + Some(TOKEN_TYPE_ACCESS), + Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), + ) +} + +pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result> { + verify_token_hs256_internal( + token, + secret, + Some(TOKEN_TYPE_REFRESH), + Some(&[SCOPE_REFRESH]), + ) +} + fn verify_token_internal( token: &str, key_bytes: &[u8], @@ -124,3 +147,86 @@ fn verify_token_internal( Ok(TokenData { claims }) } + +fn verify_token_hs256_internal( + token: &str, + secret: &[u8], + expected_typ: Option<&str>, + allowed_scopes: Option<&[&str]>, +) -> Result> { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return Err(anyhow!("Invalid token format")); + } + + let header_b64 = parts[0]; + let claims_b64 = parts[1]; + let signature_b64 = parts[2]; + + let header_bytes = URL_SAFE_NO_PAD + .decode(header_b64) + .context("Base64 decode of header failed")?; + let header: Header = + serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; + + if header.alg != "HS256" { + return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg)); + } + + if let Some(expected) = expected_typ { + if header.typ != expected { + return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); + } + } + + let signature_bytes = URL_SAFE_NO_PAD + .decode(signature_b64) + .context("Base64 decode of signature failed")?; + + let message = format!("{}.{}", header_b64, claims_b64); + let mut mac = HmacSha256::new_from_slice(secret) + .map_err(|e| anyhow!("Invalid secret: {}", e))?; + mac.update(message.as_bytes()); + let expected_signature = mac.finalize().into_bytes(); + + let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into(); + if !is_valid { + return Err(anyhow!("Signature verification failed")); + } + + let claims_bytes = URL_SAFE_NO_PAD + .decode(claims_b64) + .context("Base64 decode of claims failed")?; + let claims: Claims = + serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; + + let now = Utc::now().timestamp() as usize; + if claims.exp < now { + return Err(anyhow!("Token expired")); + } + + if let Some(scopes) = allowed_scopes { + let token_scope = claims.scope.as_deref().unwrap_or(""); + if !scopes.contains(&token_scope) { + return Err(anyhow!("Invalid token scope: {}", token_scope)); + } + } + + Ok(TokenData { claims }) +} + +pub fn get_algorithm_from_token(token: &str) -> Result { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return Err("Invalid token format".to_string()); + } + + let header_bytes = URL_SAFE_NO_PAD + .decode(parts[0]) + .map_err(|e| format!("Base64 decode failed: {}", e))?; + + let header: Header = + serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; + + Ok(header.alg) +} diff --git a/src/circuit_breaker.rs b/src/circuit_breaker.rs new file mode 100644 index 0000000..7bfc203 --- /dev/null +++ b/src/circuit_breaker.rs @@ -0,0 +1,307 @@ +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CircuitState { + Closed, + Open, + HalfOpen, +} + +pub struct CircuitBreaker { + name: String, + failure_threshold: u32, + success_threshold: u32, + timeout: Duration, + state: Arc>, + failure_count: AtomicU32, + success_count: AtomicU32, + last_failure_time: AtomicU64, +} + +impl CircuitBreaker { + pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { + Self { + name: name.to_string(), + failure_threshold, + success_threshold, + timeout: Duration::from_secs(timeout_secs), + state: Arc::new(RwLock::new(CircuitState::Closed)), + failure_count: AtomicU32::new(0), + success_count: AtomicU32::new(0), + last_failure_time: AtomicU64::new(0), + } + } + + pub async fn can_execute(&self) -> bool { + let state = self.state.read().await; + match *state { + CircuitState::Closed => true, + CircuitState::Open => { + let last_failure = self.last_failure_time.load(Ordering::SeqCst); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + if now - last_failure >= self.timeout.as_secs() { + drop(state); + let mut state = self.state.write().await; + if *state == CircuitState::Open { + *state = CircuitState::HalfOpen; + self.success_count.store(0, Ordering::SeqCst); + tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open"); + return true; + } + } + false + } + CircuitState::HalfOpen => true, + } + } + + pub async fn record_success(&self) { + let state = *self.state.read().await; + + match state { + CircuitState::Closed => { + self.failure_count.store(0, Ordering::SeqCst); + } + CircuitState::HalfOpen => { + let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1; + if count >= self.success_threshold { + let mut state = self.state.write().await; + *state = CircuitState::Closed; + self.failure_count.store(0, Ordering::SeqCst); + self.success_count.store(0, Ordering::SeqCst); + tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery"); + } + } + CircuitState::Open => {} + } + } + + pub async fn record_failure(&self) { + let state = *self.state.read().await; + + match state { + CircuitState::Closed => { + let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; + if count >= self.failure_threshold { + let mut state = self.state.write().await; + *state = CircuitState::Open; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + self.last_failure_time.store(now, Ordering::SeqCst); + tracing::warn!( + circuit = %self.name, + failures = count, + "Circuit breaker opened after {} failures", + count + ); + } + } + CircuitState::HalfOpen => { + let mut state = self.state.write().await; + *state = CircuitState::Open; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + self.last_failure_time.store(now, Ordering::SeqCst); + self.success_count.store(0, Ordering::SeqCst); + tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state"); + } + CircuitState::Open => {} + } + } + + pub async fn state(&self) -> CircuitState { + *self.state.read().await + } + + pub fn name(&self) -> &str { + &self.name + } +} + +#[derive(Clone)] +pub struct CircuitBreakers { + pub plc_directory: Arc, + pub relay_notification: Arc, +} + +impl Default for CircuitBreakers { + fn default() -> Self { + Self::new() + } +} + +impl CircuitBreakers { + pub fn new() -> Self { + Self { + plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)), + relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)), + } + } +} + +#[derive(Debug)] +pub struct CircuitOpenError { + pub circuit_name: String, +} + +impl std::fmt::Display for CircuitOpenError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Circuit breaker '{}' is open", self.circuit_name) + } +} + +impl std::error::Error for CircuitOpenError {} + +pub async fn with_circuit_breaker( + circuit: &CircuitBreaker, + operation: F, +) -> Result> +where + F: FnOnce() -> Fut, + Fut: std::future::Future>, +{ + if !circuit.can_execute().await { + return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError { + circuit_name: circuit.name().to_string(), + })); + } + + match operation().await { + Ok(result) => { + circuit.record_success().await; + Ok(result) + } + Err(e) => { + circuit.record_failure().await; + Err(CircuitBreakerError::OperationFailed(e)) + } + } +} + +#[derive(Debug)] +pub enum CircuitBreakerError { + CircuitOpen(CircuitOpenError), + OperationFailed(E), +} + +impl std::fmt::Display for CircuitBreakerError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e), + CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e), + } + } +} + +impl std::error::Error for CircuitBreakerError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CircuitBreakerError::CircuitOpen(e) => Some(e), + CircuitBreakerError::OperationFailed(e) => Some(e), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_circuit_breaker_starts_closed() { + let cb = CircuitBreaker::new("test", 3, 2, 10); + assert_eq!(cb.state().await, CircuitState::Closed); + assert!(cb.can_execute().await); + } + + #[tokio::test] + async fn test_circuit_breaker_opens_after_failures() { + let cb = CircuitBreaker::new("test", 3, 2, 10); + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Closed); + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Closed); + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + assert!(!cb.can_execute().await); + } + + #[tokio::test] + async fn test_circuit_breaker_success_resets_failures() { + let cb = CircuitBreaker::new("test", 3, 2, 10); + + cb.record_failure().await; + cb.record_failure().await; + cb.record_success().await; + + cb.record_failure().await; + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Closed); + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + } + + #[tokio::test] + async fn test_circuit_breaker_half_open_closes_after_successes() { + let cb = CircuitBreaker::new("test", 3, 2, 0); + + for _ in 0..3 { + cb.record_failure().await; + } + assert_eq!(cb.state().await, CircuitState::Open); + + tokio::time::sleep(Duration::from_millis(100)).await; + + assert!(cb.can_execute().await); + assert_eq!(cb.state().await, CircuitState::HalfOpen); + + cb.record_success().await; + assert_eq!(cb.state().await, CircuitState::HalfOpen); + + cb.record_success().await; + assert_eq!(cb.state().await, CircuitState::Closed); + } + + #[tokio::test] + async fn test_circuit_breaker_half_open_reopens_on_failure() { + let cb = CircuitBreaker::new("test", 3, 2, 0); + + for _ in 0..3 { + cb.record_failure().await; + } + + tokio::time::sleep(Duration::from_millis(100)).await; + cb.can_execute().await; + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + } + + #[tokio::test] + async fn test_with_circuit_breaker_helper() { + let cb = CircuitBreaker::new("test", 3, 2, 10); + + let result: Result> = + with_circuit_breaker(&cb, || async { Ok(42) }).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + + let result: Result> = + with_circuit_breaker(&cb, || async { Err("error") }).await; + assert!(result.is_err()); + } +} diff --git a/src/crawlers.rs b/src/crawlers.rs new file mode 100644 index 0000000..46b4414 --- /dev/null +++ b/src/crawlers.rs @@ -0,0 +1,170 @@ +use crate::circuit_breaker::CircuitBreaker; +use crate::sync::firehose::SequencedEvent; +use reqwest::Client; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{broadcast, watch}; +use tracing::{debug, error, info, warn}; + +const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60; + +pub struct Crawlers { + hostname: String, + crawler_urls: Vec, + http_client: Client, + last_notified: AtomicU64, + circuit_breaker: Option>, +} + +impl Crawlers { + pub fn new(hostname: String, crawler_urls: Vec) -> Self { + Self { + hostname, + crawler_urls, + http_client: Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .unwrap_or_default(), + last_notified: AtomicU64::new(0), + circuit_breaker: None, + } + } + + pub fn with_circuit_breaker(mut self, circuit_breaker: Arc) -> Self { + self.circuit_breaker = Some(circuit_breaker); + self + } + + pub fn from_env() -> Option { + let hostname = std::env::var("PDS_HOSTNAME").ok()?; + let crawler_urls: Vec = std::env::var("CRAWLERS") + .unwrap_or_default() + .split(',') + .filter(|s| !s.is_empty()) + .map(|s| s.trim().to_string()) + .collect(); + + if crawler_urls.is_empty() { + return None; + } + + Some(Self::new(hostname, crawler_urls)) + } + + fn should_notify(&self) -> bool { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let last = self.last_notified.load(Ordering::Relaxed); + now - last >= NOTIFY_THRESHOLD_SECS + } + + fn mark_notified(&self) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + self.last_notified.store(now, Ordering::Relaxed); + } + + pub async fn notify_of_update(&self) { + if !self.should_notify() { + debug!("Skipping crawler notification due to debounce"); + return; + } + + if let Some(cb) = &self.circuit_breaker { + if !cb.can_execute().await { + debug!("Skipping crawler notification due to circuit breaker open"); + return; + } + } + + self.mark_notified(); + + let circuit_breaker = self.circuit_breaker.clone(); + + for crawler_url in &self.crawler_urls { + let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/')); + let hostname = self.hostname.clone(); + let client = self.http_client.clone(); + let cb = circuit_breaker.clone(); + + tokio::spawn(async move { + match client + .post(&url) + .json(&serde_json::json!({ "hostname": hostname })) + .send() + .await + { + Ok(response) => { + if response.status().is_success() { + debug!(crawler = %url, "Successfully notified crawler"); + if let Some(cb) = cb { + cb.record_success().await; + } + } else { + warn!( + crawler = %url, + status = %response.status(), + "Crawler notification returned non-success status" + ); + if let Some(cb) = cb { + cb.record_failure().await; + } + } + } + Err(e) => { + warn!(crawler = %url, error = %e, "Failed to notify crawler"); + if let Some(cb) = cb { + cb.record_failure().await; + } + } + } + }); + } + } +} + +pub async fn start_crawlers_service( + crawlers: Arc, + mut firehose_rx: broadcast::Receiver, + mut shutdown: watch::Receiver, +) { + info!( + hostname = %crawlers.hostname, + crawler_count = crawlers.crawler_urls.len(), + crawlers = ?crawlers.crawler_urls, + "Starting crawlers notification service" + ); + + loop { + tokio::select! { + result = firehose_rx.recv() => { + match result { + Ok(event) => { + if event.event_type == "commit" { + crawlers.notify_of_update().await; + } + } + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(skipped = n, "Crawlers service lagged behind firehose"); + crawlers.notify_of_update().await; + } + Err(broadcast::error::RecvError::Closed) => { + error!("Firehose channel closed, stopping crawlers service"); + break; + } + } + } + _ = shutdown.changed() => { + if *shutdown.borrow() { + info!("Crawlers service shutting down"); + break; + } + } + } + } +} diff --git a/src/image/mod.rs b/src/image/mod.rs new file mode 100644 index 0000000..fd71e10 --- /dev/null +++ b/src/image/mod.rs @@ -0,0 +1,304 @@ +use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType}; +use std::io::Cursor; + +pub const THUMB_SIZE_FEED: u32 = 200; +pub const THUMB_SIZE_FULL: u32 = 1000; + +#[derive(Debug, Clone)] +pub struct ProcessedImage { + pub data: Vec, + pub mime_type: String, + pub width: u32, + pub height: u32, +} + +#[derive(Debug, Clone)] +pub struct ImageProcessingResult { + pub original: ProcessedImage, + pub thumbnail_feed: Option, + pub thumbnail_full: Option, +} + +#[derive(Debug, thiserror::Error)] +pub enum ImageError { + #[error("Failed to decode image: {0}")] + DecodeError(String), + + #[error("Failed to encode image: {0}")] + EncodeError(String), + + #[error("Unsupported image format: {0}")] + UnsupportedFormat(String), + + #[error("Image too large: {width}x{height} exceeds maximum {max_dimension}")] + TooLarge { + width: u32, + height: u32, + max_dimension: u32, + }, + + #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")] + FileTooLarge { size: usize, max_size: usize }, +} + +pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB + +pub struct ImageProcessor { + max_dimension: u32, + max_file_size: usize, + output_format: OutputFormat, + generate_thumbnails: bool, +} + +#[derive(Debug, Clone, Copy)] +pub enum OutputFormat { + WebP, + Jpeg, + Png, + Original, +} + +impl Default for ImageProcessor { + fn default() -> Self { + Self { + max_dimension: 4096, + max_file_size: DEFAULT_MAX_FILE_SIZE, + output_format: OutputFormat::WebP, + generate_thumbnails: true, + } + } +} + +impl ImageProcessor { + pub fn new() -> Self { + Self::default() + } + + pub fn with_max_dimension(mut self, max: u32) -> Self { + self.max_dimension = max; + self + } + + pub fn with_max_file_size(mut self, max: usize) -> Self { + self.max_file_size = max; + self + } + + pub fn with_output_format(mut self, format: OutputFormat) -> Self { + self.output_format = format; + self + } + + pub fn with_thumbnails(mut self, generate: bool) -> Self { + self.generate_thumbnails = generate; + self + } + + pub fn process(&self, data: &[u8], mime_type: &str) -> Result { + if data.len() > self.max_file_size { + return Err(ImageError::FileTooLarge { + size: data.len(), + max_size: self.max_file_size, + }); + } + + let format = self.detect_format(mime_type, data)?; + let img = self.decode_image(data, format)?; + + if img.width() > self.max_dimension || img.height() > self.max_dimension { + return Err(ImageError::TooLarge { + width: img.width(), + height: img.height(), + max_dimension: self.max_dimension, + }); + } + + let original = self.encode_image(&img)?; + + let thumbnail_feed = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) { + Some(self.generate_thumbnail(&img, THUMB_SIZE_FEED)?) + } else { + None + }; + + let thumbnail_full = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) { + Some(self.generate_thumbnail(&img, THUMB_SIZE_FULL)?) + } else { + None + }; + + Ok(ImageProcessingResult { + original, + thumbnail_feed, + thumbnail_full, + }) + } + + fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result { + match mime_type.to_lowercase().as_str() { + "image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg), + "image/png" => Ok(ImageFormat::Png), + "image/gif" => Ok(ImageFormat::Gif), + "image/webp" => Ok(ImageFormat::WebP), + _ => { + if let Ok(format) = image::guess_format(data) { + Ok(format) + } else { + Err(ImageError::UnsupportedFormat(mime_type.to_string())) + } + } + } + } + + fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result { + let cursor = Cursor::new(data); + let reader = ImageReader::with_format(cursor, format); + reader + .decode() + .map_err(|e| ImageError::DecodeError(e.to_string())) + } + + fn encode_image(&self, img: &DynamicImage) -> Result { + let (data, mime_type) = match self.output_format { + OutputFormat::WebP => { + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP) + .map_err(|e| ImageError::EncodeError(e.to_string()))?; + (buf, "image/webp".to_string()) + } + OutputFormat::Jpeg => { + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg) + .map_err(|e| ImageError::EncodeError(e.to_string()))?; + (buf, "image/jpeg".to_string()) + } + OutputFormat::Png => { + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) + .map_err(|e| ImageError::EncodeError(e.to_string()))?; + (buf, "image/png".to_string()) + } + OutputFormat::Original => { + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) + .map_err(|e| ImageError::EncodeError(e.to_string()))?; + (buf, "image/png".to_string()) + } + }; + + Ok(ProcessedImage { + data, + mime_type, + width: img.width(), + height: img.height(), + }) + } + + fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result { + let (orig_width, orig_height) = (img.width(), img.height()); + + let (new_width, new_height) = if orig_width > orig_height { + let ratio = max_size as f64 / orig_width as f64; + (max_size, (orig_height as f64 * ratio) as u32) + } else { + let ratio = max_size as f64 / orig_height as f64; + ((orig_width as f64 * ratio) as u32, max_size) + }; + + let thumb = img.resize(new_width, new_height, FilterType::Lanczos3); + self.encode_image(&thumb) + } + + pub fn is_supported_mime_type(mime_type: &str) -> bool { + matches!( + mime_type.to_lowercase().as_str(), + "image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp" + ) + } + + pub fn strip_exif(data: &[u8]) -> Result, ImageError> { + let format = image::guess_format(data) + .map_err(|e| ImageError::DecodeError(e.to_string()))?; + + let cursor = Cursor::new(data); + let img = ImageReader::with_format(cursor, format) + .decode() + .map_err(|e| ImageError::DecodeError(e.to_string()))?; + + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), format) + .map_err(|e| ImageError::EncodeError(e.to_string()))?; + + Ok(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_image(width: u32, height: u32) -> Vec { + let img = DynamicImage::new_rgb8(width, height); + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); + buf + } + + #[test] + fn test_process_small_image() { + let processor = ImageProcessor::new(); + let data = create_test_image(100, 100); + + let result = processor.process(&data, "image/png").unwrap(); + + assert!(result.thumbnail_feed.is_none()); + assert!(result.thumbnail_full.is_none()); + } + + #[test] + fn test_process_large_image_generates_thumbnails() { + let processor = ImageProcessor::new(); + let data = create_test_image(2000, 1500); + + let result = processor.process(&data, "image/png").unwrap(); + + assert!(result.thumbnail_feed.is_some()); + assert!(result.thumbnail_full.is_some()); + + let feed_thumb = result.thumbnail_feed.unwrap(); + assert!(feed_thumb.width <= THUMB_SIZE_FEED); + assert!(feed_thumb.height <= THUMB_SIZE_FEED); + + let full_thumb = result.thumbnail_full.unwrap(); + assert!(full_thumb.width <= THUMB_SIZE_FULL); + assert!(full_thumb.height <= THUMB_SIZE_FULL); + } + + #[test] + fn test_webp_conversion() { + let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); + let data = create_test_image(500, 500); + + let result = processor.process(&data, "image/png").unwrap(); + assert_eq!(result.original.mime_type, "image/webp"); + } + + #[test] + fn test_reject_too_large() { + let processor = ImageProcessor::new().with_max_dimension(1000); + let data = create_test_image(2000, 2000); + + let result = processor.process(&data, "image/png"); + assert!(matches!(result, Err(ImageError::TooLarge { .. }))); + } + + #[test] + fn test_is_supported_mime_type() { + assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); + assert!(ImageProcessor::is_supported_mime_type("image/png")); + assert!(ImageProcessor::is_supported_mime_type("image/gif")); + assert!(ImageProcessor::is_supported_mime_type("image/webp")); + assert!(!ImageProcessor::is_supported_mime_type("image/bmp")); + assert!(!ImageProcessor::is_supported_mime_type("text/plain")); + } +} diff --git a/src/lib.rs b/src/lib.rs index e84a740..15ea6ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,19 @@ pub mod api; pub mod auth; +pub mod circuit_breaker; pub mod config; +pub mod crawlers; +pub mod image; pub mod notifications; pub mod oauth; pub mod plc; +pub mod rate_limit; pub mod repo; pub mod state; pub mod storage; pub mod sync; pub mod util; +pub mod validation; use axum::{ Router, @@ -20,6 +25,7 @@ pub fn app(state: AppState) -> Router { Router::new() .route("/health", get(api::server::health)) .route("/xrpc/_health", get(api::server::health)) + .route("/robots.txt", get(api::server::robots_txt)) .route( "/xrpc/com.atproto.server.describeServer", get(api::server::describe_server), @@ -140,6 +146,14 @@ pub fn app(state: AppState) -> Router { "/xrpc/com.atproto.sync.subscribeRepos", get(sync::subscribe_repos), ) + .route( + "/xrpc/com.atproto.sync.getHead", + get(sync::get_head), + ) + .route( + "/xrpc/com.atproto.sync.getCheckout", + get(sync::get_checkout), + ) .route( "/xrpc/com.atproto.moderation.createReport", post(api::moderation::create_report), @@ -338,9 +352,17 @@ pub fn app(state: AppState) -> Router { ) .route("/oauth/authorize", get(oauth::endpoints::authorize_get)) .route("/oauth/authorize", post(oauth::endpoints::authorize_post)) + .route("/oauth/authorize/select", post(oauth::endpoints::authorize_select)) + .route("/oauth/authorize/2fa", get(oauth::endpoints::authorize_2fa_get)) + .route("/oauth/authorize/2fa", post(oauth::endpoints::authorize_2fa_post)) + .route("/oauth/authorize/deny", post(oauth::endpoints::authorize_deny)) .route("/oauth/token", post(oauth::endpoints::token_endpoint)) .route("/oauth/revoke", post(oauth::endpoints::revoke_token)) .route("/oauth/introspect", post(oauth::endpoints::introspect_token)) + .route( + "/xrpc/com.atproto.temp.checkSignupQueue", + get(api::temp::check_signup_queue), + ) .route("/xrpc/{*method}", any(api::proxy::proxy_handler)) .with_state(state) } diff --git a/src/main.rs b/src/main.rs index 44f1df5..97cba57 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,9 @@ -use bspds::notifications::{EmailSender, NotificationService}; +use bspds::crawlers::{Crawlers, start_crawlers_service}; +use bspds::notifications::{DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender}; use bspds::state::AppState; use std::net::SocketAddr; use std::process::ExitCode; +use std::sync::Arc; use tokio::sync::watch; use tracing::{error, info, warn}; @@ -41,13 +43,6 @@ async fn run() -> Result<(), Box> { let state = AppState::new(pool.clone()).await; bspds::sync::listener::start_sequencer_listener(state.clone()).await; - let relays = std::env::var("RELAYS") - .unwrap_or_default() - .split(',') - .filter(|s| !s.is_empty()) - .map(|s| s.to_string()) - .collect(); - bspds::sync::relay_client::start_relay_clients(state.clone(), relays, None).await; let (shutdown_tx, shutdown_rx) = watch::channel(false); @@ -60,7 +55,34 @@ async fn run() -> Result<(), Box> { warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)"); } - let notification_handle = tokio::spawn(notification_service.run(shutdown_rx)); + if let Some(discord_sender) = DiscordSender::from_env() { + info!("Discord notifications enabled"); + notification_service = notification_service.register_sender(discord_sender); + } + + if let Some(telegram_sender) = TelegramSender::from_env() { + info!("Telegram notifications enabled"); + notification_service = notification_service.register_sender(telegram_sender); + } + + if let Some(signal_sender) = SignalSender::from_env() { + info!("Signal notifications enabled"); + notification_service = notification_service.register_sender(signal_sender); + } + + let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone())); + + let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() { + let crawlers = Arc::new( + crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()) + ); + let firehose_rx = state.firehose_tx.subscribe(); + info!("Crawlers notification service enabled"); + Some(tokio::spawn(start_crawlers_service(crawlers, firehose_rx, shutdown_rx))) + } else { + warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)"); + None + }; let app = bspds::app(state); @@ -75,6 +97,9 @@ async fn run() -> Result<(), Box> { .await; notification_handle.await.ok(); + if let Some(handle) = crawlers_handle { + handle.await.ok(); + } if let Err(e) = server_result { return Err(format!("Server error: {}", e).into()); diff --git a/src/notifications/mod.rs b/src/notifications/mod.rs index e9e8aad..7c54d1f 100644 --- a/src/notifications/mod.rs +++ b/src/notifications/mod.rs @@ -2,11 +2,14 @@ mod sender; mod service; mod types; -pub use sender::{EmailSender, NotificationSender}; +pub use sender::{ + DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender, + is_valid_phone_number, sanitize_header_value, +}; pub use service::{ - enqueue_account_deletion, enqueue_email_update, enqueue_email_verification, - enqueue_notification, enqueue_password_reset, enqueue_plc_operation, enqueue_welcome, - NotificationService, + channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update, + enqueue_email_verification, enqueue_notification, enqueue_password_reset, + enqueue_plc_operation, enqueue_welcome, NotificationService, }; pub use types::{ NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification, diff --git a/src/notifications/sender.rs b/src/notifications/sender.rs index 888c8c5..c1e8b8e 100644 --- a/src/notifications/sender.rs +++ b/src/notifications/sender.rs @@ -1,10 +1,17 @@ use async_trait::async_trait; +use reqwest::Client; +use serde_json::json; use std::process::Stdio; +use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::process::Command; use super::types::{NotificationChannel, QueuedNotification}; +const HTTP_TIMEOUT_SECS: u64 = 30; +const MAX_RETRIES: u32 = 3; +const INITIAL_RETRY_DELAY_MS: u64 = 500; + #[async_trait] pub trait NotificationSender: Send + Sync { fn channel(&self) -> NotificationChannel; @@ -24,6 +31,48 @@ pub enum SendError { #[error("External service error: {0}")] ExternalService(String), + + #[error("Invalid recipient format: {0}")] + InvalidRecipient(String), + + #[error("Request timeout")] + Timeout, + + #[error("Max retries exceeded: {0}")] + MaxRetriesExceeded(String), +} + +fn create_http_client() -> Client { + Client::builder() + .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) + .connect_timeout(Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()) +} + +fn is_retryable_status(status: reqwest::StatusCode) -> bool { + status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS +} + +async fn retry_delay(attempt: u32) { + let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt); + tokio::time::sleep(Duration::from_millis(delay_ms)).await; +} + +pub fn sanitize_header_value(value: &str) -> String { + value.replace(['\r', '\n'], " ").trim().to_string() +} + +pub fn is_valid_phone_number(number: &str) -> bool { + if number.len() < 2 || number.len() > 20 { + return false; + } + let mut chars = number.chars(); + if chars.next() != Some('+') { + return false; + } + let remaining: String = chars.collect(); + !remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit()) } pub struct EmailSender { @@ -47,18 +96,19 @@ impl EmailSender { Some(Self::new(from_address, from_name)) } - fn format_email(&self, notification: &QueuedNotification) -> String { - let subject = notification.subject.as_deref().unwrap_or("Notification"); + pub fn format_email(&self, notification: &QueuedNotification) -> String { + let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification")); + let recipient = sanitize_header_value(¬ification.recipient); let from_header = if self.from_name.is_empty() { self.from_address.clone() } else { - format!("{} <{}>", self.from_name, self.from_address) + format!("{} <{}>", sanitize_header_value(&self.from_name), self.from_address) }; format!( "From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}", from_header, - notification.recipient, + recipient, subject, notification.body ) @@ -96,3 +146,242 @@ impl NotificationSender for EmailSender { Ok(()) } } + +pub struct DiscordSender { + webhook_url: String, + http_client: Client, +} + +impl DiscordSender { + pub fn new(webhook_url: String) -> Self { + Self { + webhook_url, + http_client: create_http_client(), + } + } + + pub fn from_env() -> Option { + let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?; + Some(Self::new(webhook_url)) + } +} + +#[async_trait] +impl NotificationSender for DiscordSender { + fn channel(&self) -> NotificationChannel { + NotificationChannel::Discord + } + + async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { + let subject = notification.subject.as_deref().unwrap_or("Notification"); + let content = format!("**{}**\n\n{}", subject, notification.body); + + let payload = json!({ + "content": content, + "username": "BSPDS" + }); + + let mut last_error = None; + for attempt in 0..MAX_RETRIES { + let result = self + .http_client + .post(&self.webhook_url) + .json(&payload) + .send() + .await; + + match result { + Ok(response) => { + if response.status().is_success() { + return Ok(()); + } + + let status = response.status(); + if is_retryable_status(status) && attempt < MAX_RETRIES - 1 { + last_error = Some(format!("Discord webhook returned {}", status)); + retry_delay(attempt).await; + continue; + } + + let body = response.text().await.unwrap_or_default(); + return Err(SendError::ExternalService(format!( + "Discord webhook returned {}: {}", + status, body + ))); + } + Err(e) => { + if e.is_timeout() { + if attempt < MAX_RETRIES - 1 { + last_error = Some(format!("Discord request timed out")); + retry_delay(attempt).await; + continue; + } + return Err(SendError::Timeout); + } + return Err(SendError::ExternalService(format!( + "Discord request failed: {}", + e + ))); + } + } + } + + Err(SendError::MaxRetriesExceeded( + last_error.unwrap_or_else(|| "Unknown error".to_string()), + )) + } +} + +pub struct TelegramSender { + bot_token: String, + http_client: Client, +} + +impl TelegramSender { + pub fn new(bot_token: String) -> Self { + Self { + bot_token, + http_client: create_http_client(), + } + } + + pub fn from_env() -> Option { + let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?; + Some(Self::new(bot_token)) + } +} + +#[async_trait] +impl NotificationSender for TelegramSender { + fn channel(&self) -> NotificationChannel { + NotificationChannel::Telegram + } + + async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { + let chat_id = ¬ification.recipient; + let subject = notification.subject.as_deref().unwrap_or("Notification"); + let text = format!("*{}*\n\n{}", subject, notification.body); + + let url = format!( + "https://api.telegram.org/bot{}/sendMessage", + self.bot_token + ); + + let payload = json!({ + "chat_id": chat_id, + "text": text, + "parse_mode": "Markdown" + }); + + let mut last_error = None; + for attempt in 0..MAX_RETRIES { + let result = self + .http_client + .post(&url) + .json(&payload) + .send() + .await; + + match result { + Ok(response) => { + if response.status().is_success() { + return Ok(()); + } + + let status = response.status(); + if is_retryable_status(status) && attempt < MAX_RETRIES - 1 { + last_error = Some(format!("Telegram API returned {}", status)); + retry_delay(attempt).await; + continue; + } + + let body = response.text().await.unwrap_or_default(); + return Err(SendError::ExternalService(format!( + "Telegram API returned {}: {}", + status, body + ))); + } + Err(e) => { + if e.is_timeout() { + if attempt < MAX_RETRIES - 1 { + last_error = Some(format!("Telegram request timed out")); + retry_delay(attempt).await; + continue; + } + return Err(SendError::Timeout); + } + return Err(SendError::ExternalService(format!( + "Telegram request failed: {}", + e + ))); + } + } + } + + Err(SendError::MaxRetriesExceeded( + last_error.unwrap_or_else(|| "Unknown error".to_string()), + )) + } +} + +pub struct SignalSender { + signal_cli_path: String, + sender_number: String, +} + +impl SignalSender { + pub fn new(signal_cli_path: String, sender_number: String) -> Self { + Self { + signal_cli_path, + sender_number, + } + } + + pub fn from_env() -> Option { + let signal_cli_path = std::env::var("SIGNAL_CLI_PATH") + .unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string()); + let sender_number = std::env::var("SIGNAL_SENDER_NUMBER").ok()?; + Some(Self::new(signal_cli_path, sender_number)) + } +} + +#[async_trait] +impl NotificationSender for SignalSender { + fn channel(&self) -> NotificationChannel { + NotificationChannel::Signal + } + + async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { + let recipient = ¬ification.recipient; + + if !is_valid_phone_number(recipient) { + return Err(SendError::InvalidRecipient(format!( + "Invalid phone number format: {}", + recipient + ))); + } + + let subject = notification.subject.as_deref().unwrap_or("Notification"); + let message = format!("{}\n\n{}", subject, notification.body); + + let output = Command::new(&self.signal_cli_path) + .arg("-u") + .arg(&self.sender_number) + .arg("send") + .arg("-m") + .arg(&message) + .arg(recipient) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(SendError::ExternalService(format!( + "signal-cli failed: {}", + stderr + ))); + } + + Ok(()) + } +} diff --git a/src/notifications/service.rs b/src/notifications/service.rs index aaf4027..0180bb1 100644 --- a/src/notifications/service.rs +++ b/src/notifications/service.rs @@ -443,3 +443,39 @@ pub async fn enqueue_plc_operation( ) .await } + +pub async fn enqueue_2fa_code( + db: &PgPool, + user_id: Uuid, + code: &str, + hostname: &str, +) -> Result { + let prefs = get_user_notification_prefs(db, user_id).await?; + + let body = format!( + "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", + prefs.handle, code + ); + + enqueue_notification( + db, + NewNotification::new( + user_id, + prefs.channel, + super::types::NotificationType::TwoFactorCode, + prefs.email.clone(), + Some(format!("Sign-in Verification - {}", hostname)), + body, + ), + ) + .await +} + +pub fn channel_display_name(channel: NotificationChannel) -> &'static str { + match channel { + NotificationChannel::Email => "email", + NotificationChannel::Discord => "Discord", + NotificationChannel::Telegram => "Telegram", + NotificationChannel::Signal => "Signal", + } +} diff --git a/src/notifications/types.rs b/src/notifications/types.rs index 9c56074..f8eacb3 100644 --- a/src/notifications/types.rs +++ b/src/notifications/types.rs @@ -31,6 +31,7 @@ pub enum NotificationType { AccountDeletion, AdminEmail, PlcOperation, + TwoFactorCode, } #[derive(Debug, Clone, FromRow)] diff --git a/src/oauth/client.rs b/src/oauth/client.rs index 6e2459a..fe42ea4 100644 --- a/src/oauth/client.rs +++ b/src/oauth/client.rs @@ -57,6 +57,7 @@ impl Default for ClientMetadata { #[derive(Clone)] pub struct ClientMetadataCache { cache: Arc>>, + jwks_cache: Arc>>, http_client: Client, cache_ttl_secs: u64, } @@ -66,11 +67,21 @@ struct CachedMetadata { cached_at: std::time::Instant, } +struct CachedJwks { + jwks: serde_json::Value, + cached_at: std::time::Instant, +} + impl ClientMetadataCache { pub fn new(cache_ttl_secs: u64) -> Self { Self { cache: Arc::new(RwLock::new(HashMap::new())), - http_client: Client::new(), + jwks_cache: Arc::new(RwLock::new(HashMap::new())), + http_client: Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .connect_timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), cache_ttl_secs, } } @@ -101,6 +112,84 @@ impl ClientMetadataCache { Ok(metadata) } + pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result { + if let Some(jwks) = &metadata.jwks { + return Ok(jwks.clone()); + } + + let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| { + OAuthError::InvalidClient( + "Client using private_key_jwt must have jwks or jwks_uri".to_string(), + ) + })?; + + { + let cache = self.jwks_cache.read().await; + if let Some(cached) = cache.get(jwks_uri) { + if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { + return Ok(cached.jwks.clone()); + } + } + } + + let jwks = self.fetch_jwks(jwks_uri).await?; + + { + let mut cache = self.jwks_cache.write().await; + cache.insert( + jwks_uri.clone(), + CachedJwks { + jwks: jwks.clone(), + cached_at: std::time::Instant::now(), + }, + ); + } + + Ok(jwks) + } + + async fn fetch_jwks(&self, jwks_uri: &str) -> Result { + if !jwks_uri.starts_with("https://") { + if !jwks_uri.starts_with("http://") + || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")) + { + return Err(OAuthError::InvalidClient( + "jwks_uri must use https (except for localhost)".to_string(), + )); + } + } + + let response = self + .http_client + .get(jwks_uri) + .header("Accept", "application/json") + .send() + .await + .map_err(|e| { + OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e)) + })?; + + if !response.status().is_success() { + return Err(OAuthError::InvalidClient(format!( + "Failed to fetch JWKS: HTTP {}", + response.status() + ))); + } + + let jwks: serde_json::Value = response + .json() + .await + .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?; + + if jwks.get("keys").and_then(|k| k.as_array()).is_none() { + return Err(OAuthError::InvalidClient( + "JWKS must contain a 'keys' array".to_string(), + )); + } + + Ok(jwks) + } + async fn fetch_metadata(&self, client_id: &str) -> Result { if !client_id.starts_with("http://") && !client_id.starts_with("https://") { return Err(OAuthError::InvalidClient( @@ -244,7 +333,8 @@ impl ClientMetadata { } } -pub fn verify_client_auth( +pub async fn verify_client_auth( + cache: &ClientMetadataCache, metadata: &ClientMetadata, client_auth: &super::ClientAuth, ) -> Result<(), OAuthError> { @@ -258,7 +348,7 @@ pub fn verify_client_auth( )), ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => { - verify_private_key_jwt(metadata, client_assertion) + verify_private_key_jwt_async(cache, metadata, client_assertion).await } ("private_key_jwt", _) => Err(OAuthError::InvalidClient( @@ -284,7 +374,8 @@ pub fn verify_client_auth( } } -fn verify_private_key_jwt( +async fn verify_private_key_jwt_async( + cache: &ClientMetadataCache, metadata: &ClientMetadata, client_assertion: &str, ) -> Result<(), OAuthError> { @@ -312,6 +403,8 @@ fn verify_private_key_jwt( ))); } + let kid = header.get("kid").and_then(|k| k.as_str()); + let payload_bytes = URL_SAFE_NO_PAD .decode(parts[1]) .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?; @@ -353,13 +446,180 @@ fn verify_private_key_jwt( } } - if metadata.jwks.is_none() && metadata.jwks_uri.is_none() { + let jwks = cache.get_jwks(metadata).await?; + let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| { + OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()) + })?; + + let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid { + keys.iter() + .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid)) + .collect() + } else { + keys.iter().collect() + }; + + if matching_keys.is_empty() { return Err(OAuthError::InvalidClient( - "Client using private_key_jwt must have jwks or jwks_uri".to_string(), + "No matching key found in client JWKS".to_string(), )); } + let signing_input = format!("{}.{}", parts[0], parts[1]); + let signature_bytes = URL_SAFE_NO_PAD + .decode(parts[2]) + .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?; + + for key in matching_keys { + let key_alg = key.get("alg").and_then(|a| a.as_str()); + if key_alg.is_some() && key_alg != Some(alg) { + continue; + } + + let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or(""); + + let verified = match (alg, kty) { + ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes), + ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes), + ("RS256" | "RS384" | "RS512", "RSA") => { + verify_rsa(alg, key, &signing_input, &signature_bytes) + } + ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes), + _ => continue, + }; + + if verified.is_ok() { + return Ok(()); + } + } + Err(OAuthError::InvalidClient( - "private_key_jwt signature verification not yet implemented - use 'none' auth method".to_string(), + "client_assertion signature verification failed".to_string(), )) } + +fn verify_es256( + key: &serde_json::Value, + signing_input: &str, + signature: &[u8], +) -> Result<(), OAuthError> { + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; + use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; + use p256::EncodedPoint; + + let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { + OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) + })?; + let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { + OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) + })?; + + let x_bytes = URL_SAFE_NO_PAD.decode(x) + .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; + let y_bytes = URL_SAFE_NO_PAD.decode(y) + .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; + + let mut point_bytes = vec![0x04]; + point_bytes.extend_from_slice(&x_bytes); + point_bytes.extend_from_slice(&y_bytes); + + let point = EncodedPoint::from_bytes(&point_bytes) + .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; + let verifying_key = VerifyingKey::from_encoded_point(&point) + .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; + + let sig = Signature::from_slice(signature) + .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?; + + verifying_key + .verify(signing_input.as_bytes(), &sig) + .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) +} + +fn verify_es384( + key: &serde_json::Value, + signing_input: &str, + signature: &[u8], +) -> Result<(), OAuthError> { + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; + use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier}; + use p384::EncodedPoint; + + let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { + OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) + })?; + let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { + OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) + })?; + + let x_bytes = URL_SAFE_NO_PAD.decode(x) + .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; + let y_bytes = URL_SAFE_NO_PAD.decode(y) + .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; + + let mut point_bytes = vec![0x04]; + point_bytes.extend_from_slice(&x_bytes); + point_bytes.extend_from_slice(&y_bytes); + + let point = EncodedPoint::from_bytes(&point_bytes) + .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; + let verifying_key = VerifyingKey::from_encoded_point(&point) + .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; + + let sig = Signature::from_slice(signature) + .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?; + + verifying_key + .verify(signing_input.as_bytes(), &sig) + .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) +} + +fn verify_rsa( + _alg: &str, + _key: &serde_json::Value, + _signing_input: &str, + _signature: &[u8], +) -> Result<(), OAuthError> { + Err(OAuthError::InvalidClient( + "RSA signature verification not yet supported - use EC keys".to_string(), + )) +} + +fn verify_eddsa( + key: &serde_json::Value, + signing_input: &str, + signature: &[u8], +) -> Result<(), OAuthError> { + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; + use ed25519_dalek::{Signature, Verifier, VerifyingKey}; + + let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or(""); + if crv != "Ed25519" { + return Err(OAuthError::InvalidClient(format!( + "Unsupported EdDSA curve: {}", + crv + ))); + } + + let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { + OAuthError::InvalidClient("Missing x in OKP key".to_string()) + })?; + + let x_bytes = URL_SAFE_NO_PAD.decode(x) + .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?; + + let key_bytes: [u8; 32] = x_bytes.try_into() + .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?; + + let verifying_key = VerifyingKey::from_bytes(&key_bytes) + .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?; + + let sig_bytes: [u8; 64] = signature.try_into() + .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?; + + let sig = Signature::from_bytes(&sig_bytes); + + verifying_key + .verify(signing_input.as_bytes(), &sig) + .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string())) +} diff --git a/src/oauth/db/device.rs b/src/oauth/db/device.rs index cf5633e..c60c422 100644 --- a/src/oauth/db/device.rs +++ b/src/oauth/db/device.rs @@ -1,7 +1,15 @@ +use chrono::{DateTime, Utc}; use sqlx::PgPool; use super::super::{DeviceData, OAuthError}; +pub struct DeviceAccountRow { + pub did: String, + pub handle: String, + pub email: String, + pub last_used_at: DateTime, +} + pub async fn create_device( pool: &PgPool, device_id: &str, @@ -94,3 +102,57 @@ pub async fn upsert_account_device( Ok(()) } + +pub async fn get_device_accounts( + pool: &PgPool, + device_id: &str, +) -> Result, OAuthError> { + let rows = sqlx::query!( + r#" + SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at + FROM oauth_account_device ad + JOIN users u ON u.did = ad.did + WHERE ad.device_id = $1 + AND u.deactivated_at IS NULL + AND u.takedown_ref IS NULL + ORDER BY ad.updated_at DESC + "#, + device_id + ) + .fetch_all(pool) + .await?; + + Ok(rows + .into_iter() + .map(|r| DeviceAccountRow { + did: r.did, + handle: r.handle, + email: r.email, + last_used_at: r.last_used_at, + }) + .collect()) +} + +pub async fn verify_account_on_device( + pool: &PgPool, + device_id: &str, + did: &str, +) -> Result { + let row = sqlx::query!( + r#" + SELECT 1 as exists + FROM oauth_account_device ad + JOIN users u ON u.did = ad.did + WHERE ad.device_id = $1 + AND ad.did = $2 + AND u.deactivated_at IS NULL + AND u.takedown_ref IS NULL + "#, + device_id, + did + ) + .fetch_optional(pool) + .await?; + + Ok(row.is_some()) +} diff --git a/src/oauth/db/mod.rs b/src/oauth/db/mod.rs index c4c157f..c0f7619 100644 --- a/src/oauth/db/mod.rs +++ b/src/oauth/db/mod.rs @@ -4,10 +4,12 @@ mod dpop; mod helpers; mod request; mod token; +mod two_factor; pub use client::{get_authorized_client, upsert_authorized_client}; pub use device::{ - create_device, delete_device, get_device, update_device_last_seen, upsert_account_device, + create_device, delete_device, get_device, get_device_accounts, update_device_last_seen, + upsert_account_device, verify_account_on_device, DeviceAccountRow, }; pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis}; pub use request::{ @@ -20,3 +22,8 @@ pub use token::{ delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id, get_token_by_refresh_token, list_tokens_for_user, rotate_token, }; +pub use two_factor::{ + check_user_2fa_enabled, cleanup_expired_2fa_challenges, create_2fa_challenge, + delete_2fa_challenge, delete_2fa_challenge_by_request_uri, generate_2fa_code, + get_2fa_challenge, increment_2fa_attempts, TwoFactorChallenge, +}; diff --git a/src/oauth/db/two_factor.rs b/src/oauth/db/two_factor.rs new file mode 100644 index 0000000..89d3e00 --- /dev/null +++ b/src/oauth/db/two_factor.rs @@ -0,0 +1,153 @@ +use chrono::{DateTime, Duration, Utc}; +use rand::Rng; +use sqlx::PgPool; +use uuid::Uuid; + +use super::super::OAuthError; + +pub struct TwoFactorChallenge { + pub id: Uuid, + pub did: String, + pub request_uri: String, + pub code: String, + pub attempts: i32, + pub created_at: DateTime, + pub expires_at: DateTime, +} + +pub fn generate_2fa_code() -> String { + let mut rng = rand::thread_rng(); + let code: u32 = rng.gen_range(0..1_000_000); + format!("{:06}", code) +} + +pub async fn create_2fa_challenge( + pool: &PgPool, + did: &str, + request_uri: &str, +) -> Result { + let code = generate_2fa_code(); + let expires_at = Utc::now() + Duration::minutes(10); + + let row = sqlx::query!( + r#" + INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at) + VALUES ($1, $2, $3, $4) + RETURNING id, did, request_uri, code, attempts, created_at, expires_at + "#, + did, + request_uri, + code, + expires_at, + ) + .fetch_one(pool) + .await?; + + Ok(TwoFactorChallenge { + id: row.id, + did: row.did, + request_uri: row.request_uri, + code: row.code, + attempts: row.attempts, + created_at: row.created_at, + expires_at: row.expires_at, + }) +} + +pub async fn get_2fa_challenge( + pool: &PgPool, + request_uri: &str, +) -> Result, OAuthError> { + let row = sqlx::query!( + r#" + SELECT id, did, request_uri, code, attempts, created_at, expires_at + FROM oauth_2fa_challenge + WHERE request_uri = $1 + "#, + request_uri + ) + .fetch_optional(pool) + .await?; + + Ok(row.map(|r| TwoFactorChallenge { + id: r.id, + did: r.did, + request_uri: r.request_uri, + code: r.code, + attempts: r.attempts, + created_at: r.created_at, + expires_at: r.expires_at, + })) +} + +pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result { + let row = sqlx::query!( + r#" + UPDATE oauth_2fa_challenge + SET attempts = attempts + 1 + WHERE id = $1 + RETURNING attempts + "#, + id + ) + .fetch_one(pool) + .await?; + + Ok(row.attempts) +} + +pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> { + sqlx::query!( + r#" + DELETE FROM oauth_2fa_challenge WHERE id = $1 + "#, + id + ) + .execute(pool) + .await?; + + Ok(()) +} + +pub async fn delete_2fa_challenge_by_request_uri( + pool: &PgPool, + request_uri: &str, +) -> Result<(), OAuthError> { + sqlx::query!( + r#" + DELETE FROM oauth_2fa_challenge WHERE request_uri = $1 + "#, + request_uri + ) + .execute(pool) + .await?; + + Ok(()) +} + +pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result { + let result = sqlx::query!( + r#" + DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW() + "# + ) + .execute(pool) + .await?; + + Ok(result.rows_affected()) +} + +pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result { + let row = sqlx::query!( + r#" + SELECT two_factor_enabled + FROM users + WHERE did = $1 + "#, + did + ) + .fetch_optional(pool) + .await?; + + Ok(row.map(|r| r.two_factor_enabled).unwrap_or(false)) +} diff --git a/src/oauth/endpoints/authorize.rs b/src/oauth/endpoints/authorize.rs index af5b38a..39499bd 100644 --- a/src/oauth/endpoints/authorize.rs +++ b/src/oauth/endpoints/authorize.rs @@ -1,15 +1,34 @@ use axum::{ Form, Json, extract::{Query, State}, - http::HeaderMap, - response::{IntoResponse, Redirect, Response}, + http::{HeaderMap, header::SET_COOKIE}, + response::{IntoResponse, Redirect, Response, Html}, }; use chrono::Utc; use serde::{Deserialize, Serialize}; +use subtle::ConstantTimeEq; use urlencoding::encode as url_encode; use crate::state::AppState; -use crate::oauth::{Code, DeviceData, DeviceId, OAuthError, SessionId, db}; +use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates}; +use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; + +const DEVICE_COOKIE_NAME: &str = "oauth_device_id"; + +fn extract_device_cookie(headers: &HeaderMap) -> Option { + headers + .get("cookie") + .and_then(|v| v.to_str().ok()) + .and_then(|cookie_str| { + for cookie in cookie_str.split(';') { + let cookie = cookie.trim(); + if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) { + return Some(value.to_string()); + } + } + None + }) +} fn extract_client_ip(headers: &HeaderMap) -> String { if let Some(forwarded) = headers.get("x-forwarded-for") { @@ -36,10 +55,19 @@ fn extract_user_agent(headers: &HeaderMap) -> Option { .map(|s| s.to_string()) } +fn make_device_cookie(device_id: &str) -> String { + format!( + "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", + DEVICE_COOKIE_NAME, + device_id + ) +} + #[derive(Debug, Deserialize)] pub struct AuthorizeQuery { pub request_uri: Option, pub client_id: Option, + pub new_account: Option, } #[derive(Debug, Serialize)] @@ -61,7 +89,156 @@ pub struct AuthorizeSubmit { pub remember_device: bool, } +#[derive(Debug, Deserialize)] +pub struct AuthorizeSelectSubmit { + pub request_uri: String, + pub did: String, +} + +fn wants_json(headers: &HeaderMap) -> bool { + headers + .get("accept") + .and_then(|v| v.to_str().ok()) + .map(|accept| accept.contains("application/json")) + .unwrap_or(false) +} + pub async fn authorize_get( + State(state): State, + headers: HeaderMap, + Query(query): Query, +) -> Response { + let request_uri = match query.request_uri { + Some(uri) => uri, + None => { + if wants_json(&headers) { + return ( + axum::http::StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "Missing request_uri parameter. Use PAR to initiate authorization." + })), + ).into_response(); + } + return ( + axum::http::StatusCode::BAD_REQUEST, + Html(templates::error_page( + "invalid_request", + Some("Missing request_uri parameter. Use PAR to initiate authorization."), + )), + ).into_response(); + } + }; + + let request_data = match db::get_authorization_request(&state.db, &request_uri).await { + Ok(Some(data)) => data, + Ok(None) => { + if wants_json(&headers) { + return ( + axum::http::StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "Invalid or expired request_uri. Please start a new authorization request." + })), + ).into_response(); + } + return ( + axum::http::StatusCode::BAD_REQUEST, + Html(templates::error_page( + "invalid_request", + Some("Invalid or expired request_uri. Please start a new authorization request."), + )), + ).into_response(); + } + Err(e) => { + if wants_json(&headers) { + return ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "server_error", + "error_description": format!("Database error: {:?}", e) + })), + ).into_response(); + } + return ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Html(templates::error_page( + "server_error", + Some(&format!("Database error: {:?}", e)), + )), + ).into_response(); + } + }; + + if request_data.expires_at < Utc::now() { + let _ = db::delete_authorization_request(&state.db, &request_uri).await; + if wants_json(&headers) { + return ( + axum::http::StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "Authorization request has expired. Please start a new request." + })), + ).into_response(); + } + return ( + axum::http::StatusCode::BAD_REQUEST, + Html(templates::error_page( + "invalid_request", + Some("Authorization request has expired. Please start a new request."), + )), + ).into_response(); + } + + if wants_json(&headers) { + return Json(AuthorizeResponse { + client_id: request_data.parameters.client_id.clone(), + client_name: None, + scope: request_data.parameters.scope.clone(), + redirect_uri: request_data.parameters.redirect_uri.clone(), + state: request_data.parameters.state.clone(), + login_hint: request_data.parameters.login_hint.clone(), + }).into_response(); + } + + let force_new_account = query.new_account.unwrap_or(false); + + if !force_new_account { + if let Some(device_id) = extract_device_cookie(&headers) { + if let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await { + if !accounts.is_empty() { + let device_accounts: Vec = accounts + .into_iter() + .map(|row| DeviceAccount { + did: row.did, + handle: row.handle, + email: row.email, + last_used_at: row.last_used_at, + }) + .collect(); + + return Html(templates::account_selector_page( + &request_data.parameters.client_id, + None, + &request_uri, + &device_accounts, + )).into_response(); + } + } + } + } + + Html(templates::login_page( + &request_data.parameters.client_id, + None, + request_data.parameters.scope.as_deref(), + &request_uri, + None, + request_data.parameters.login_hint.as_deref(), + )).into_response() +} + +pub async fn authorize_get_json( State(state): State, Query(query): Query, ) -> Result, OAuthError> { @@ -92,19 +269,85 @@ pub async fn authorize_post( State(state): State, headers: HeaderMap, Form(form): Form, -) -> Result { - let request_data = db::get_authorization_request(&state.db, &form.request_uri) - .await? - .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?; +) -> Response { + let json_response = wants_json(&headers); + + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { + Ok(Some(data)) => data, + Ok(None) => { + if json_response { + return ( + axum::http::StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "Invalid or expired request_uri." + })), + ).into_response(); + } + return Html(templates::error_page( + "invalid_request", + Some("Invalid or expired request_uri. Please start a new authorization request."), + )).into_response(); + } + Err(e) => { + if json_response { + return ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "server_error", + "error_description": format!("Database error: {:?}", e) + })), + ).into_response(); + } + return Html(templates::error_page( + "server_error", + Some(&format!("Database error: {:?}", e)), + )).into_response(); + } + }; if request_data.expires_at < Utc::now() { - db::delete_authorization_request(&state.db, &form.request_uri).await?; - return Err(OAuthError::InvalidRequest("request_uri has expired".to_string())); + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; + if json_response { + return ( + axum::http::StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "Authorization request has expired." + })), + ).into_response(); + } + return Html(templates::error_page( + "invalid_request", + Some("Authorization request has expired. Please start a new request."), + )).into_response(); } - let user = sqlx::query!( + let show_login_error = |error_msg: &str, json: bool| -> Response { + if json { + return ( + axum::http::StatusCode::FORBIDDEN, + Json(serde_json::json!({ + "error": "access_denied", + "error_description": error_msg + })), + ).into_response(); + } + Html(templates::login_page( + &request_data.parameters.client_id, + None, + request_data.parameters.scope.as_deref(), + &form.request_uri, + Some(error_msg), + Some(&form.username), + )).into_response() + }; + + let user = match sqlx::query!( r#" - SELECT did, password_hash, deactivated_at, takedown_ref + SELECT id, did, email, password_hash, two_factor_enabled, + preferred_notification_channel as "preferred_notification_channel: NotificationChannel", + deactivated_at, takedown_ref FROM users WHERE handle = $1 OR email = $1 "#, @@ -112,65 +355,277 @@ pub async fn authorize_post( ) .fetch_optional(&state.db) .await - .map_err(|e| OAuthError::ServerError(e.to_string()))? - .ok_or_else(|| OAuthError::AccessDenied("Invalid credentials".to_string()))?; + { + Ok(Some(u)) => u, + Ok(None) => return show_login_error("Invalid handle/email or password.", json_response), + Err(_) => return show_login_error("An error occurred. Please try again.", json_response), + }; if user.deactivated_at.is_some() { - return Err(OAuthError::AccessDenied("Account is deactivated".to_string())); + return show_login_error("This account has been deactivated.", json_response); } if user.takedown_ref.is_some() { - return Err(OAuthError::AccessDenied("Account is taken down".to_string())); + return show_login_error("This account has been taken down.", json_response); } - let password_valid = bcrypt::verify(&form.password, &user.password_hash) - .map_err(|_| OAuthError::ServerError("Password verification failed".to_string()))?; + let password_valid = match bcrypt::verify(&form.password, &user.password_hash) { + Ok(valid) => valid, + Err(_) => return show_login_error("An error occurred. Please try again.", json_response), + }; if !password_valid { - return Err(OAuthError::AccessDenied("Invalid credentials".to_string())); + return show_login_error("Invalid handle/email or password.", json_response); + } + + if user.two_factor_enabled { + let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; + + match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await { + Ok(challenge) => { + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + if let Err(e) = enqueue_2fa_code( + &state.db, + user.id, + &challenge.code, + &hostname, + ).await { + tracing::warn!( + did = %user.did, + error = %e, + "Failed to enqueue 2FA notification" + ); + } + + let channel_name = channel_display_name(user.preferred_notification_channel); + let redirect_url = format!( + "/oauth/authorize/2fa?request_uri={}&channel={}", + url_encode(&form.request_uri), + url_encode(channel_name) + ); + return Redirect::temporary(&redirect_url).into_response(); + } + Err(_) => { + return show_login_error("An error occurred. Please try again.", json_response); + } + } } let code = Code::generate(); - let mut device_id: Option = None; + let mut device_id: Option = extract_device_cookie(&headers); + let mut new_cookie: Option = None; if form.remember_device { - let new_device_id = DeviceId::generate(); - let device_data = DeviceData { - session_id: SessionId::generate().0, - user_agent: extract_user_agent(&headers), - ip_address: extract_client_ip(&headers), - last_seen_at: Utc::now(), + let final_device_id = if let Some(existing_id) = &device_id { + existing_id.clone() + } else { + let new_id = DeviceId::generate(); + let device_data = DeviceData { + session_id: SessionId::generate().0, + user_agent: extract_user_agent(&headers), + ip_address: extract_client_ip(&headers), + last_seen_at: Utc::now(), + }; + + if db::create_device(&state.db, &new_id.0, &device_data).await.is_ok() { + new_cookie = Some(make_device_cookie(&new_id.0)); + device_id = Some(new_id.0.clone()); + } + new_id.0 }; - db::create_device(&state.db, &new_device_id.0, &device_data).await?; - db::upsert_account_device(&state.db, &user.did, &new_device_id.0).await?; - device_id = Some(new_device_id.0); + let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; } - db::update_authorization_request( + if let Err(_) = db::update_authorization_request( &state.db, &form.request_uri, &user.did, device_id.as_deref(), &code.0, ) - .await?; + .await + { + return show_login_error("An error occurred. Please try again.", json_response); + } - let redirect_uri = &request_data.parameters.redirect_uri; + let redirect_url = build_success_redirect( + &request_data.parameters.redirect_uri, + &code.0, + request_data.parameters.state.as_deref(), + ); + + let redirect = Redirect::temporary(&redirect_url); + + if let Some(cookie) = new_cookie { + ([(SET_COOKIE, cookie)], redirect).into_response() + } else { + redirect.into_response() + } +} + +pub async fn authorize_select( + State(state): State, + headers: HeaderMap, + Form(form): Form, +) -> Response { + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { + Ok(Some(data)) => data, + Ok(None) => { + return Html(templates::error_page( + "invalid_request", + Some("Invalid or expired request_uri. Please start a new authorization request."), + )).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + }; + + if request_data.expires_at < Utc::now() { + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; + return Html(templates::error_page( + "invalid_request", + Some("Authorization request has expired. Please start a new request."), + )).into_response(); + } + + let device_id = match extract_device_cookie(&headers) { + Some(id) => id, + None => { + return Html(templates::error_page( + "invalid_request", + Some("No device session found. Please sign in."), + )).into_response(); + } + }; + + let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { + Ok(valid) => valid, + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + }; + + if !account_valid { + return Html(templates::error_page( + "access_denied", + Some("This account is not available on this device. Please sign in."), + )).into_response(); + } + + let user = match sqlx::query!( + r#" + SELECT id, two_factor_enabled, + preferred_notification_channel as "preferred_notification_channel: NotificationChannel" + FROM users + WHERE did = $1 + "#, + form.did + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(u)) => u, + Ok(None) => { + return Html(templates::error_page( + "access_denied", + Some("Account not found. Please sign in."), + )).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + }; + + if user.two_factor_enabled { + let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; + + match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await { + Ok(challenge) => { + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); + if let Err(e) = enqueue_2fa_code( + &state.db, + user.id, + &challenge.code, + &hostname, + ).await { + tracing::warn!( + did = %form.did, + error = %e, + "Failed to enqueue 2FA notification" + ); + } + + let channel_name = channel_display_name(user.preferred_notification_channel); + let redirect_url = format!( + "/oauth/authorize/2fa?request_uri={}&channel={}", + url_encode(&form.request_uri), + url_encode(channel_name) + ); + return Redirect::temporary(&redirect_url).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + } + } + + let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await; + + let code = Code::generate(); + + if let Err(_) = db::update_authorization_request( + &state.db, + &form.request_uri, + &form.did, + Some(&device_id), + &code.0, + ) + .await + { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + + let redirect_url = build_success_redirect( + &request_data.parameters.redirect_uri, + &code.0, + request_data.parameters.state.as_deref(), + ); + + Redirect::temporary(&redirect_url).into_response() +} + +fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String { let mut redirect_url = redirect_uri.to_string(); let separator = if redirect_url.contains('?') { '&' } else { '?' }; redirect_url.push(separator); - redirect_url.push_str(&format!("code={}", url_encode(&code.0))); + redirect_url.push_str(&format!("code={}", url_encode(code))); - if let Some(state) = &request_data.parameters.state { - redirect_url.push_str(&format!("&state={}", url_encode(state))); + if let Some(req_state) = state { + redirect_url.push_str(&format!("&state={}", url_encode(req_state))); } let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); - Ok(Redirect::temporary(&redirect_url).into_response()) + redirect_url } #[derive(Debug, Serialize)] @@ -208,3 +663,191 @@ pub async fn authorize_deny( pub struct AuthorizeDenyForm { pub request_uri: String, } + +#[derive(Debug, Deserialize)] +pub struct Authorize2faQuery { + pub request_uri: String, + pub channel: Option, +} + +#[derive(Debug, Deserialize)] +pub struct Authorize2faSubmit { + pub request_uri: String, + pub code: String, +} + +const MAX_2FA_ATTEMPTS: i32 = 5; + +pub async fn authorize_2fa_get( + State(state): State, + Query(query): Query, +) -> Response { + let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await { + Ok(Some(c)) => c, + Ok(None) => { + return Html(templates::error_page( + "invalid_request", + Some("No 2FA challenge found. Please start over."), + )).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + }; + + if challenge.expires_at < Utc::now() { + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; + return Html(templates::error_page( + "invalid_request", + Some("2FA code has expired. Please start over."), + )).into_response(); + } + + let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { + Ok(Some(d)) => d, + Ok(None) => { + return Html(templates::error_page( + "invalid_request", + Some("Authorization request not found. Please start over."), + )).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + }; + + let channel = query.channel.as_deref().unwrap_or("email"); + + Html(templates::two_factor_page( + &query.request_uri, + channel, + None, + )).into_response() +} + +pub async fn authorize_2fa_post( + State(state): State, + headers: HeaderMap, + Form(form): Form, +) -> Response { + let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { + Ok(Some(c)) => c, + Ok(None) => { + return Html(templates::error_page( + "invalid_request", + Some("No 2FA challenge found. Please start over."), + )).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + }; + + if challenge.expires_at < Utc::now() { + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; + return Html(templates::error_page( + "invalid_request", + Some("2FA code has expired. Please start over."), + )).into_response(); + } + + if challenge.attempts >= MAX_2FA_ATTEMPTS { + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; + return Html(templates::error_page( + "access_denied", + Some("Too many failed attempts. Please start over."), + )).into_response(); + } + + let code_valid: bool = form.code.trim().as_bytes().ct_eq(challenge.code.as_bytes()).into(); + + if !code_valid { + let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; + + let channel = match sqlx::query_scalar!( + r#"SELECT preferred_notification_channel as "channel: NotificationChannel" FROM users WHERE did = $1"#, + challenge.did + ) + .fetch_optional(&state.db) + .await + { + Ok(Some(ch)) => channel_display_name(ch).to_string(), + Ok(None) | Err(_) => "email".to_string(), + }; + + let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { + Ok(Some(d)) => d, + Ok(None) => { + return Html(templates::error_page( + "invalid_request", + Some("Authorization request not found. Please start over."), + )).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + }; + + return Html(templates::two_factor_page( + &form.request_uri, + &channel, + Some("Invalid verification code. Please try again."), + )).into_response(); + } + + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; + + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { + Ok(Some(d)) => d, + Ok(None) => { + return Html(templates::error_page( + "invalid_request", + Some("Authorization request not found."), + )).into_response(); + } + Err(_) => { + return Html(templates::error_page( + "server_error", + Some("An error occurred."), + )).into_response(); + } + }; + + let code = Code::generate(); + let device_id = extract_device_cookie(&headers); + + if let Err(_) = db::update_authorization_request( + &state.db, + &form.request_uri, + &challenge.did, + device_id.as_deref(), + &code.0, + ) + .await + { + return Html(templates::error_page( + "server_error", + Some("An error occurred. Please try again."), + )).into_response(); + } + + let redirect_url = build_success_redirect( + &request_data.parameters.redirect_uri, + &code.0, + request_data.parameters.state.as_deref(), + ); + + Redirect::temporary(&redirect_url).into_response() +} diff --git a/src/oauth/endpoints/token/grants.rs b/src/oauth/endpoints/token/grants.rs index f451bd4..56a4f24 100644 --- a/src/oauth/endpoints/token/grants.rs +++ b/src/oauth/endpoints/token/grants.rs @@ -54,7 +54,7 @@ pub async fn handle_authorization_code_grant( .get(&auth_request.client_id) .await?; let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None); - verify_client_auth(&client_metadata, &client_auth)?; + verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; diff --git a/src/oauth/endpoints/token/mod.rs b/src/oauth/endpoints/token/mod.rs index 0836730..2273cb7 100644 --- a/src/oauth/endpoints/token/mod.rs +++ b/src/oauth/endpoints/token/mod.rs @@ -19,11 +19,35 @@ pub use introspect::{ }; pub use types::{TokenRequest, TokenResponse}; +fn extract_client_ip(headers: &HeaderMap) -> String { + if let Some(forwarded) = headers.get("x-forwarded-for") { + if let Ok(value) = forwarded.to_str() { + if let Some(first_ip) = value.split(',').next() { + return first_ip.trim().to_string(); + } + } + } + if let Some(real_ip) = headers.get("x-real-ip") { + if let Ok(value) = real_ip.to_str() { + return value.trim().to_string(); + } + } + "unknown".to_string() +} + pub async fn token_endpoint( State(state): State, headers: HeaderMap, Form(request): Form, ) -> Result<(HeaderMap, Json), OAuthError> { + let client_ip = extract_client_ip(&headers); + if state.rate_limiters.oauth_token.check_key(&client_ip).is_err() { + tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); + return Err(OAuthError::InvalidRequest( + "Too many requests. Please try again later.".to_string(), + )); + } + let dpop_proof = headers .get("DPoP") .and_then(|v| v.to_str().ok()) diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 6f2912a..59c9358 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -5,8 +5,10 @@ pub mod jwks; pub mod client; pub mod endpoints; pub mod error; +pub mod templates; pub mod verify; pub use types::*; pub use error::OAuthError; pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError}; +pub use templates::{DeviceAccount, mask_email}; diff --git a/src/oauth/templates.rs b/src/oauth/templates.rs new file mode 100644 index 0000000..b09a685 --- /dev/null +++ b/src/oauth/templates.rs @@ -0,0 +1,719 @@ +use chrono::{DateTime, Utc}; + +fn base_styles() -> &'static str { + r#" + :root { + --primary: #0085ff; + --primary-hover: #0077e6; + --primary-contrast: #ffffff; + --primary-100: #dbeafe; + --primary-400: #60a5fa; + --primary-600-30: rgba(37, 99, 235, 0.3); + --contrast-0: #ffffff; + --contrast-25: #f8f9fa; + --contrast-50: #f1f3f5; + --contrast-100: #e9ecef; + --contrast-200: #dee2e6; + --contrast-300: #ced4da; + --contrast-400: #adb5bd; + --contrast-500: #6b7280; + --contrast-600: #4b5563; + --contrast-700: #374151; + --contrast-800: #1f2937; + --contrast-900: #111827; + --error: #dc2626; + --error-bg: #fef2f2; + --success: #059669; + --success-bg: #ecfdf5; + } + + @media (prefers-color-scheme: dark) { + :root { + --contrast-0: #111827; + --contrast-25: #1f2937; + --contrast-50: #374151; + --contrast-100: #4b5563; + --contrast-200: #6b7280; + --contrast-300: #9ca3af; + --contrast-400: #d1d5db; + --contrast-500: #e5e7eb; + --contrast-600: #f3f4f6; + --contrast-700: #f9fafb; + --contrast-800: #ffffff; + --contrast-900: #ffffff; + --error-bg: #451a1a; + --success-bg: #064e3b; + } + } + + * { + box-sizing: border-box; + margin: 0; + padding: 0; + } + + body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; + background: var(--contrast-50); + color: var(--contrast-900); + min-height: 100vh; + display: flex; + align-items: center; + justify-content: center; + padding: 1rem; + line-height: 1.5; + } + + .container { + width: 100%; + max-width: 400px; + padding-top: 15vh; + } + + @media (max-width: 640px) { + .container { + padding-top: 2rem; + } + } + + .card { + background: var(--contrast-0); + border: 1px solid var(--contrast-100); + border-radius: 0.75rem; + padding: 1.5rem; + box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 8px 10px -6px rgba(0, 0, 0, 0.1); + } + + @media (prefers-color-scheme: dark) { + .card { + box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.4), 0 8px 10px -6px rgba(0, 0, 0, 0.3); + } + } + + h1 { + font-size: 1.5rem; + font-weight: 600; + color: var(--contrast-900); + margin-bottom: 0.5rem; + } + + .subtitle { + color: var(--contrast-500); + font-size: 0.875rem; + margin-bottom: 1.5rem; + } + + .subtitle strong { + color: var(--contrast-700); + } + + .client-info { + background: var(--contrast-25); + border-radius: 0.5rem; + padding: 1rem; + margin-bottom: 1.5rem; + } + + .client-info .client-name { + font-weight: 500; + color: var(--contrast-900); + display: block; + margin-bottom: 0.25rem; + } + + .client-info .scope { + color: var(--contrast-500); + font-size: 0.875rem; + } + + .error-banner { + background: var(--error-bg); + color: var(--error); + border-radius: 0.5rem; + padding: 0.75rem 1rem; + margin-bottom: 1rem; + font-size: 0.875rem; + } + + .form-group { + margin-bottom: 1.25rem; + } + + label { + display: block; + font-size: 0.875rem; + font-weight: 500; + color: var(--contrast-700); + margin-bottom: 0.375rem; + } + + input[type="text"], + input[type="email"], + input[type="password"] { + width: 100%; + padding: 0.625rem 0.875rem; + border: 2px solid var(--contrast-200); + border-radius: 0.375rem; + font-size: 1rem; + color: var(--contrast-900); + background: var(--contrast-0); + transition: border-color 0.15s, box-shadow 0.15s; + } + + input[type="text"]:focus, + input[type="email"]:focus, + input[type="password"]:focus { + outline: none; + border-color: var(--primary); + box-shadow: 0 0 0 3px var(--primary-600-30); + } + + input[type="text"]::placeholder, + input[type="email"]::placeholder, + input[type="password"]::placeholder { + color: var(--contrast-400); + } + + .checkbox-group { + display: flex; + align-items: center; + gap: 0.5rem; + margin-bottom: 1.5rem; + } + + .checkbox-group input[type="checkbox"] { + width: 1.125rem; + height: 1.125rem; + accent-color: var(--primary); + } + + .checkbox-group label { + margin-bottom: 0; + font-weight: normal; + color: var(--contrast-600); + cursor: pointer; + } + + .buttons { + display: flex; + gap: 0.75rem; + } + + .btn { + flex: 1; + padding: 0.625rem 1.25rem; + border-radius: 0.375rem; + font-size: 1rem; + font-weight: 500; + cursor: pointer; + transition: background-color 0.15s, transform 0.1s; + border: none; + text-align: center; + text-decoration: none; + display: inline-flex; + align-items: center; + justify-content: center; + } + + .btn:active { + transform: scale(0.98); + } + + .btn-primary { + background: var(--primary); + color: var(--primary-contrast); + } + + .btn-primary:hover { + background: var(--primary-hover); + } + + .btn-primary:disabled { + background: var(--primary-400); + cursor: not-allowed; + } + + .btn-secondary { + background: var(--contrast-500); + color: white; + } + + .btn-secondary:hover { + background: var(--contrast-600); + } + + .footer { + text-align: center; + margin-top: 1.5rem; + font-size: 0.75rem; + color: var(--contrast-400); + } + + .accounts { + display: flex; + flex-direction: column; + gap: 0.5rem; + margin-bottom: 1rem; + } + + .account-item { + display: flex; + align-items: center; + gap: 0.75rem; + width: 100%; + padding: 0.75rem; + background: var(--contrast-25); + border: 1px solid var(--contrast-100); + border-radius: 0.5rem; + cursor: pointer; + transition: background-color 0.15s, border-color 0.15s; + text-align: left; + } + + .account-item:hover { + background: var(--contrast-50); + border-color: var(--contrast-200); + } + + .avatar { + width: 2.5rem; + height: 2.5rem; + border-radius: 50%; + background: var(--primary); + color: var(--primary-contrast); + display: flex; + align-items: center; + justify-content: center; + font-weight: 600; + font-size: 0.875rem; + flex-shrink: 0; + } + + .account-info { + flex: 1; + min-width: 0; + } + + .account-info .handle { + display: block; + font-weight: 500; + color: var(--contrast-900); + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } + + .account-info .email { + display: block; + font-size: 0.875rem; + color: var(--contrast-500); + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } + + .chevron { + color: var(--contrast-400); + font-size: 1.25rem; + flex-shrink: 0; + } + + .divider { + height: 1px; + background: var(--contrast-100); + margin: 1rem 0; + } + + .link-button { + background: none; + border: none; + color: var(--primary); + cursor: pointer; + font-size: inherit; + padding: 0; + text-decoration: underline; + } + + .link-button:hover { + color: var(--primary-hover); + } + + .new-account-link { + display: block; + text-align: center; + color: var(--primary); + text-decoration: none; + font-size: 0.875rem; + } + + .new-account-link:hover { + text-decoration: underline; + } + + .help-text { + text-align: center; + margin-top: 1rem; + font-size: 0.875rem; + color: var(--contrast-500); + } + + .icon { + font-size: 3rem; + margin-bottom: 1rem; + } + + .error-code { + background: var(--error-bg); + color: var(--error); + padding: 0.5rem 1rem; + border-radius: 0.375rem; + font-family: monospace; + display: inline-block; + margin-bottom: 1rem; + } + + .success-icon { + width: 3rem; + height: 3rem; + border-radius: 50%; + background: var(--success-bg); + color: var(--success); + display: flex; + align-items: center; + justify-content: center; + font-size: 1.5rem; + margin: 0 auto 1rem; + } + + .text-center { + text-align: center; + } + + .code-input { + letter-spacing: 0.5em; + text-align: center; + font-size: 1.5rem; + font-family: monospace; + } + "# +} + +pub fn login_page( + client_id: &str, + client_name: Option<&str>, + scope: Option<&str>, + request_uri: &str, + error_message: Option<&str>, + login_hint: Option<&str>, +) -> String { + let client_display = client_name.unwrap_or(client_id); + let scope_display = scope.unwrap_or("access your account"); + + let error_html = error_message + .map(|msg| format!(r#"
{}
"#, html_escape(msg))) + .unwrap_or_default(); + + let login_hint_value = login_hint.unwrap_or(""); + + format!( + r#" + + + + + + Sign in + + + +
+
+

Sign in

+

to continue to {client_display}

+ +
+ {client_display} + wants to {scope_display} +
+ + {error_html} + +
+ + +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+
+ + +
+
+ +"#, + styles = base_styles(), + client_display = html_escape(client_display), + scope_display = html_escape(scope_display), + request_uri = html_escape(request_uri), + error_html = error_html, + login_hint_value = html_escape(login_hint_value), + ) +} + +pub struct DeviceAccount { + pub did: String, + pub handle: String, + pub email: String, + pub last_used_at: DateTime, +} + +pub fn account_selector_page( + client_id: &str, + client_name: Option<&str>, + request_uri: &str, + accounts: &[DeviceAccount], +) -> String { + let client_display = client_name.unwrap_or(client_id); + + let accounts_html: String = accounts + .iter() + .map(|account| { + let initials = get_initials(&account.handle); + format!( + r#"
+ + + +
"#, + request_uri = html_escape(request_uri), + did = html_escape(&account.did), + initials = html_escape(&initials), + handle = html_escape(&account.handle), + email = html_escape(&account.email), + ) + }) + .collect(); + + format!( + r#" + + + + + + Choose an account + + + +
+
+

Choose an account

+

to continue to {client_display}

+ +
+ {accounts_html} +
+ +
+ + +
+
+ +"#, + styles = base_styles(), + client_display = html_escape(client_display), + accounts_html = accounts_html, + request_uri_encoded = urlencoding::encode(request_uri), + ) +} + +pub fn two_factor_page( + request_uri: &str, + channel: &str, + error_message: Option<&str>, +) -> String { + let error_html = error_message + .map(|msg| format!(r#"
{}
"#, html_escape(msg))) + .unwrap_or_default(); + + let (title, subtitle) = match channel { + "email" => ("Check your email", "We sent a verification code to your email"), + "Discord" => ("Check Discord", "We sent a verification code to your Discord"), + "Telegram" => ("Check Telegram", "We sent a verification code to your Telegram"), + "Signal" => ("Check Signal", "We sent a verification code to your Signal"), + _ => ("Check your messages", "We sent you a verification code"), + }; + + format!( + r#" + + + + + + Verify your identity + + + +
+
+

{title}

+

{subtitle}

+ + {error_html} + +
+ + +
+ + +
+ + +
+ +

+ Code expires in 10 minutes. +

+
+
+ +"#, + styles = base_styles(), + title = title, + subtitle = subtitle, + request_uri = html_escape(request_uri), + error_html = error_html, + ) +} + +pub fn error_page(error: &str, error_description: Option<&str>) -> String { + let description = error_description.unwrap_or("An error occurred during the authorization process."); + + format!( + r#" + + + + + + Authorization Error + + + +
+
+
⚠️
+

Authorization Failed

+
{error}
+

{description}

+
+ +
+
+
+ +"#, + styles = base_styles(), + error = html_escape(error), + description = html_escape(description), + ) +} + +pub fn success_page(client_name: Option<&str>) -> String { + let client_display = client_name.unwrap_or("The application"); + + format!( + r#" + + + + + + Authorization Successful + + + +
+
+
+

Authorization Successful

+

{client_display} has been granted access to your account.

+

You can close this window and return to the application.

+
+
+ +"#, + styles = base_styles(), + client_display = html_escape(client_display), + ) +} + +fn html_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +fn get_initials(handle: &str) -> String { + let clean = handle.trim_start_matches('@'); + if clean.is_empty() { + return "?".to_string(); + } + clean.chars().next().unwrap_or('?').to_uppercase().to_string() +} + +pub fn mask_email(email: &str) -> String { + if let Some(at_pos) = email.find('@') { + let local = &email[..at_pos]; + let domain = &email[at_pos..]; + + if local.len() <= 2 { + format!("{}***{}", local.chars().next().unwrap_or('*'), domain) + } else { + let first = local.chars().next().unwrap_or('*'); + let last = local.chars().last().unwrap_or('*'); + format!("{}***{}{}", first, last, domain) + } + } else { + "***".to_string() + } +} diff --git a/src/plc/mod.rs b/src/plc/mod.rs index 472e846..07206dc 100644 --- a/src/plc/mod.rs +++ b/src/plc/mod.rs @@ -319,6 +319,164 @@ pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> { Ok(()) } +pub struct PlcValidationContext { + pub server_rotation_key: String, + pub expected_signing_key: String, + pub expected_handle: String, + pub expected_pds_endpoint: String, +} + +pub fn validate_plc_operation_for_submission( + op: &Value, + ctx: &PlcValidationContext, +) -> Result<(), PlcError> { + validate_plc_operation(op)?; + + let obj = op.as_object() + .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; + + let op_type = obj.get("type") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + if op_type != "plc_operation" { + return Ok(()); + } + + let rotation_keys = obj.get("rotationKeys") + .and_then(|v| v.as_array()) + .ok_or_else(|| PlcError::InvalidResponse("rotationKeys must be an array".to_string()))?; + + let rotation_key_strings: Vec<&str> = rotation_keys + .iter() + .filter_map(|v| v.as_str()) + .collect(); + + if !rotation_key_strings.contains(&ctx.server_rotation_key.as_str()) { + return Err(PlcError::InvalidResponse( + "Rotation keys do not include server's rotation key".to_string() + )); + } + + let verification_methods = obj.get("verificationMethods") + .and_then(|v| v.as_object()) + .ok_or_else(|| PlcError::InvalidResponse("verificationMethods must be an object".to_string()))?; + + if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) { + if atproto_key != ctx.expected_signing_key { + return Err(PlcError::InvalidResponse("Incorrect signing key".to_string())); + } + } + + let also_known_as = obj.get("alsoKnownAs") + .and_then(|v| v.as_array()) + .ok_or_else(|| PlcError::InvalidResponse("alsoKnownAs must be an array".to_string()))?; + + let expected_handle_uri = format!("at://{}", ctx.expected_handle); + let has_correct_handle = also_known_as + .iter() + .filter_map(|v| v.as_str()) + .any(|s| s == expected_handle_uri); + + if !has_correct_handle && !also_known_as.is_empty() { + return Err(PlcError::InvalidResponse( + "Incorrect handle in alsoKnownAs".to_string() + )); + } + + let services = obj.get("services") + .and_then(|v| v.as_object()) + .ok_or_else(|| PlcError::InvalidResponse("services must be an object".to_string()))?; + + if let Some(pds_service) = services.get("atproto_pds").and_then(|v| v.as_object()) { + let service_type = pds_service.get("type").and_then(|v| v.as_str()).unwrap_or(""); + if service_type != "AtprotoPersonalDataServer" { + return Err(PlcError::InvalidResponse( + "Incorrect type on atproto_pds service".to_string() + )); + } + + let endpoint = pds_service.get("endpoint").and_then(|v| v.as_str()).unwrap_or(""); + if endpoint != ctx.expected_pds_endpoint { + return Err(PlcError::InvalidResponse( + "Incorrect endpoint on atproto_pds service".to_string() + )); + } + } + + Ok(()) +} + +pub fn verify_operation_signature( + op: &Value, + rotation_keys: &[String], +) -> Result { + let obj = op.as_object() + .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; + + let sig_b64 = obj.get("sig") + .and_then(|v| v.as_str()) + .ok_or_else(|| PlcError::InvalidResponse("Missing sig".to_string()))?; + + let sig_bytes = URL_SAFE_NO_PAD + .decode(sig_b64) + .map_err(|e| PlcError::InvalidResponse(format!("Invalid signature encoding: {}", e)))?; + + let signature = Signature::from_slice(&sig_bytes) + .map_err(|e| PlcError::InvalidResponse(format!("Invalid signature format: {}", e)))?; + + let mut unsigned_op = op.clone(); + if let Some(unsigned_obj) = unsigned_op.as_object_mut() { + unsigned_obj.remove("sig"); + } + + let cbor_bytes = serde_ipld_dagcbor::to_vec(&unsigned_op) + .map_err(|e| PlcError::Serialization(e.to_string()))?; + + for key_did in rotation_keys { + if let Ok(true) = verify_signature_with_did_key(key_did, &cbor_bytes, &signature) { + return Ok(true); + } + } + + Ok(false) +} + +fn verify_signature_with_did_key( + did_key: &str, + message: &[u8], + signature: &Signature, +) -> Result { + use k256::ecdsa::{VerifyingKey, signature::Verifier}; + + if !did_key.starts_with("did:key:z") { + return Err(PlcError::InvalidResponse("Invalid did:key format".to_string())); + } + + let multibase_part = &did_key[8..]; + let (_, decoded) = multibase::decode(multibase_part) + .map_err(|e| PlcError::InvalidResponse(format!("Failed to decode did:key: {}", e)))?; + + if decoded.len() < 2 { + return Err(PlcError::InvalidResponse("Invalid did:key data".to_string())); + } + + let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { + (0xe701u16, &decoded[2..]) + } else { + return Err(PlcError::InvalidResponse("Unsupported key type in did:key".to_string())); + }; + + if codec != 0xe701 { + return Err(PlcError::InvalidResponse("Only secp256k1 keys are supported".to_string())); + } + + let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes) + .map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?; + + Ok(verifying_key.verify(message, signature).is_ok()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/rate_limit.rs b/src/rate_limit.rs new file mode 100644 index 0000000..ef196b3 --- /dev/null +++ b/src/rate_limit.rs @@ -0,0 +1,216 @@ +use axum::{ + body::Body, + extract::ConnectInfo, + http::{HeaderMap, Request, StatusCode}, + middleware::Next, + response::{IntoResponse, Response}, + Json, +}; +use governor::{ + Quota, RateLimiter, + clock::DefaultClock, + state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, +}; +use std::{ + net::SocketAddr, + num::NonZeroU32, + sync::Arc, +}; + +pub type KeyedRateLimiter = RateLimiter, DefaultClock>; +pub type GlobalRateLimiter = RateLimiter; + +#[derive(Clone)] +pub struct RateLimiters { + pub login: Arc, + pub oauth_token: Arc, + pub password_reset: Arc, + pub account_creation: Arc, +} + +impl Default for RateLimiters { + fn default() -> Self { + Self::new() + } +} + +impl RateLimiters { + pub fn new() -> Self { + Self { + login: Arc::new(RateLimiter::keyed( + Quota::per_minute(NonZeroU32::new(10).unwrap()) + )), + oauth_token: Arc::new(RateLimiter::keyed( + Quota::per_minute(NonZeroU32::new(30).unwrap()) + )), + password_reset: Arc::new(RateLimiter::keyed( + Quota::per_hour(NonZeroU32::new(5).unwrap()) + )), + account_creation: Arc::new(RateLimiter::keyed( + Quota::per_hour(NonZeroU32::new(10).unwrap()) + )), + } + } + + pub fn with_login_limit(mut self, per_minute: u32) -> Self { + self.login = Arc::new(RateLimiter::keyed( + Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) + )); + self + } + + pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { + self.oauth_token = Arc::new(RateLimiter::keyed( + Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) + )); + self + } + + pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { + self.password_reset = Arc::new(RateLimiter::keyed( + Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) + )); + self + } + + pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { + self.account_creation = Arc::new(RateLimiter::keyed( + Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) + )); + self + } +} + +fn extract_client_ip(headers: &HeaderMap, addr: Option) -> String { + if let Some(forwarded) = headers.get("x-forwarded-for") { + if let Ok(value) = forwarded.to_str() { + if let Some(first_ip) = value.split(',').next() { + return first_ip.trim().to_string(); + } + } + } + + if let Some(real_ip) = headers.get("x-real-ip") { + if let Ok(value) = real_ip.to_str() { + return value.trim().to_string(); + } + } + + addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) +} + +fn rate_limit_response() -> Response { + ( + StatusCode::TOO_MANY_REQUESTS, + Json(serde_json::json!({ + "error": "RateLimitExceeded", + "message": "Too many requests. Please try again later." + })), + ) + .into_response() +} + +pub async fn login_rate_limit( + ConnectInfo(addr): ConnectInfo, + axum::extract::State(limiters): axum::extract::State>, + request: Request, + next: Next, +) -> Response { + let client_ip = extract_client_ip(request.headers(), Some(addr)); + + if limiters.login.check_key(&client_ip).is_err() { + tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); + return rate_limit_response(); + } + + next.run(request).await +} + +pub async fn oauth_token_rate_limit( + ConnectInfo(addr): ConnectInfo, + axum::extract::State(limiters): axum::extract::State>, + request: Request, + next: Next, +) -> Response { + let client_ip = extract_client_ip(request.headers(), Some(addr)); + + if limiters.oauth_token.check_key(&client_ip).is_err() { + tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); + return rate_limit_response(); + } + + next.run(request).await +} + +pub async fn password_reset_rate_limit( + ConnectInfo(addr): ConnectInfo, + axum::extract::State(limiters): axum::extract::State>, + request: Request, + next: Next, +) -> Response { + let client_ip = extract_client_ip(request.headers(), Some(addr)); + + if limiters.password_reset.check_key(&client_ip).is_err() { + tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); + return rate_limit_response(); + } + + next.run(request).await +} + +pub async fn account_creation_rate_limit( + ConnectInfo(addr): ConnectInfo, + axum::extract::State(limiters): axum::extract::State>, + request: Request, + next: Next, +) -> Response { + let client_ip = extract_client_ip(request.headers(), Some(addr)); + + if limiters.account_creation.check_key(&client_ip).is_err() { + tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); + return rate_limit_response(); + } + + next.run(request).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rate_limiters_creation() { + let limiters = RateLimiters::new(); + assert!(limiters.login.check_key(&"test".to_string()).is_ok()); + } + + #[test] + fn test_rate_limiter_exhaustion() { + let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); + let key = "test_ip".to_string(); + + assert!(limiter.check_key(&key).is_ok()); + assert!(limiter.check_key(&key).is_ok()); + assert!(limiter.check_key(&key).is_err()); + } + + #[test] + fn test_different_keys_have_separate_limits() { + let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); + + assert!(limiter.check_key(&"ip1".to_string()).is_ok()); + assert!(limiter.check_key(&"ip1".to_string()).is_err()); + assert!(limiter.check_key(&"ip2".to_string()).is_ok()); + } + + #[test] + fn test_builder_pattern() { + let limiters = RateLimiters::new() + .with_login_limit(20) + .with_oauth_token_limit(60) + .with_password_reset_limit(3) + .with_account_creation_limit(5); + + assert!(limiters.login.check_key(&"test".to_string()).is_ok()); + } +} diff --git a/src/state.rs b/src/state.rs index c4846ec..0b366ec 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,4 +1,6 @@ +use crate::circuit_breaker::CircuitBreakers; use crate::config::AuthConfig; +use crate::rate_limit::RateLimiters; use crate::repo::PostgresBlockStore; use crate::storage::{BlobStorage, S3BlobStorage}; use crate::sync::firehose::SequencedEvent; @@ -12,6 +14,8 @@ pub struct AppState { pub block_store: PostgresBlockStore, pub blob_store: Arc, pub firehose_tx: broadcast::Sender, + pub rate_limiters: Arc, + pub circuit_breakers: Arc, } impl AppState { @@ -21,11 +25,25 @@ impl AppState { let block_store = PostgresBlockStore::new(db.clone()); let blob_store = S3BlobStorage::new().await; let (firehose_tx, _) = broadcast::channel(1000); + let rate_limiters = Arc::new(RateLimiters::new()); + let circuit_breakers = Arc::new(CircuitBreakers::new()); Self { db, block_store, blob_store: Arc::new(blob_store), firehose_tx, + rate_limiters, + circuit_breakers, } } + + pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self { + self.rate_limiters = Arc::new(rate_limiters); + self + } + + pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self { + self.circuit_breakers = Arc::new(circuit_breakers); + self + } } diff --git a/src/sync/crawl.rs b/src/sync/crawl.rs index 909cd26..8d74867 100644 --- a/src/sync/crawl.rs +++ b/src/sync/crawl.rs @@ -19,8 +19,6 @@ pub async fn notify_of_update( Query(params): Query, ) -> Response { info!("Received notifyOfUpdate from hostname: {}", params.hostname); - info!("TODO: Queue job for notifyOfUpdate (not implemented)"); - (StatusCode::OK, Json(json!({}))).into_response() } @@ -34,7 +32,5 @@ pub async fn request_crawl( Json(input): Json, ) -> Response { info!("Received requestCrawl for hostname: {}", input.hostname); - info!("TODO: Queue job for requestCrawl (not implemented)"); - (StatusCode::OK, Json(json!({}))).into_response() } diff --git a/src/sync/deprecated.rs b/src/sync/deprecated.rs new file mode 100644 index 0000000..2d96dec --- /dev/null +++ b/src/sync/deprecated.rs @@ -0,0 +1,209 @@ +use crate::state::AppState; +use crate::sync::car::encode_car_header; +use axum::{ + Json, + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use cid::Cid; +use ipld_core::ipld::Ipld; +use jacquard_repo::storage::BlockStore; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::io::Write; +use std::str::FromStr; +use tracing::error; + +const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; + +#[derive(Deserialize)] +pub struct GetHeadParams { + pub did: String, +} + +#[derive(Serialize)] +pub struct GetHeadOutput { + pub root: String, +} + +pub async fn get_head( + State(state): State, + Query(params): Query, +) -> Response { + let did = params.did.trim(); + + if did.is_empty() { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidRequest", "message": "did is required"})), + ) + .into_response(); + } + + let result = sqlx::query!( + r#" + SELECT r.repo_root_cid + FROM repos r + JOIN users u ON r.user_id = u.id + WHERE u.did = $1 + "#, + did + ) + .fetch_optional(&state.db) + .await; + + match result { + Ok(Some(row)) => (StatusCode::OK, Json(GetHeadOutput { root: row.repo_root_cid })).into_response(), + Ok(None) => ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "HeadNotFound", "message": "Could not find root for DID"})), + ) + .into_response(), + Err(e) => { + error!("DB error in get_head: {:?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError"})), + ) + .into_response() + } + } +} + +#[derive(Deserialize)] +pub struct GetCheckoutParams { + pub did: String, +} + +pub async fn get_checkout( + State(state): State, + Query(params): Query, +) -> Response { + let did = params.did.trim(); + + if did.is_empty() { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "InvalidRequest", "message": "did is required"})), + ) + .into_response(); + } + + let repo_row = sqlx::query!( + r#" + SELECT r.repo_root_cid + FROM repos r + JOIN users u ON u.id = r.user_id + WHERE u.did = $1 + "#, + did + ) + .fetch_optional(&state.db) + .await + .unwrap_or(None); + + let head_str = match repo_row { + Some(r) => r.repo_root_cid, + None => { + let user_exists = sqlx::query!("SELECT id FROM users WHERE did = $1", did) + .fetch_optional(&state.db) + .await + .unwrap_or(None); + + if user_exists.is_none() { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), + ) + .into_response(); + } else { + return ( + StatusCode::NOT_FOUND, + Json(json!({"error": "RepoNotFound", "message": "Repo not initialized"})), + ) + .into_response(); + } + } + }; + + let head_cid = match Cid::from_str(&head_str) { + Ok(c) => c, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": "Invalid head CID"})), + ) + .into_response(); + } + }; + + let mut car_bytes = match encode_car_header(&head_cid) { + Ok(h) => h, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})), + ) + .into_response(); + } + }; + + let mut stack = vec![head_cid]; + let mut visited = std::collections::HashSet::new(); + let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL; + + while let Some(cid) = stack.pop() { + if visited.contains(&cid) { + continue; + } + visited.insert(cid); + if remaining == 0 { + break; + } + remaining -= 1; + + if let Ok(Some(block)) = state.block_store.get(&cid).await { + let cid_bytes = cid.to_bytes(); + let total_len = cid_bytes.len() + block.len(); + let mut writer = Vec::new(); + crate::sync::car::write_varint(&mut writer, total_len as u64) + .expect("Writing to Vec should never fail"); + writer.write_all(&cid_bytes) + .expect("Writing to Vec should never fail"); + writer.write_all(&block) + .expect("Writing to Vec should never fail"); + car_bytes.extend_from_slice(&writer); + + if let Ok(value) = serde_ipld_dagcbor::from_slice::(&block) { + extract_links_ipld(&value, &mut stack); + } + } + } + + ( + StatusCode::OK, + [(axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car")], + car_bytes, + ) + .into_response() +} + +fn extract_links_ipld(value: &Ipld, stack: &mut Vec) { + match value { + Ipld::Link(cid) => { + stack.push(*cid); + } + Ipld::Map(map) => { + for v in map.values() { + extract_links_ipld(v, stack); + } + } + Ipld::List(arr) => { + for v in arr { + extract_links_ipld(v, stack); + } + } + _ => {} + } +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 2b93ba0..a0c2b43 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -2,11 +2,11 @@ pub mod blob; pub mod car; pub mod commit; pub mod crawl; +pub mod deprecated; pub mod firehose; pub mod frame; pub mod import; pub mod listener; -pub mod relay_client; pub mod repo; pub mod subscribe_repos; pub mod util; @@ -15,6 +15,7 @@ pub mod verify; pub use blob::{get_blob, list_blobs}; pub use commit::{get_latest_commit, get_repo_status, list_repos}; pub use crawl::{notify_of_update, request_crawl}; -pub use repo::{get_blocks, get_repo, get_record}; +pub use deprecated::{get_checkout, get_head}; +pub use repo::{get_blocks, get_record, get_repo}; pub use subscribe_repos::subscribe_repos; pub use verify::{CarVerifier, VerifiedCar, VerifyError}; diff --git a/src/sync/relay_client.rs b/src/sync/relay_client.rs deleted file mode 100644 index e37956e..0000000 --- a/src/sync/relay_client.rs +++ /dev/null @@ -1,83 +0,0 @@ -use crate::state::AppState; -use crate::sync::util::format_event_for_sending; -use futures::{sink::SinkExt, stream::StreamExt}; -use std::time::Duration; -use tokio::sync::mpsc; -use tokio_tungstenite::{connect_async, tungstenite::Message}; -use tracing::{error, info, warn}; - -async fn run_relay_client(state: AppState, url: String, ready_tx: Option>) { - info!("Starting firehose client for relay: {}", url); - loop { - match connect_async(&url).await { - Ok((mut ws_stream, _)) => { - info!("Connected to firehose relay: {}", url); - let mut rx = state.firehose_tx.subscribe(); - if let Some(tx) = ready_tx.as_ref() { - tx.send(()).await.ok(); - } - - loop { - tokio::select! { - Ok(event) = rx.recv() => { - match format_event_for_sending(&state, event).await { - Ok(bytes) => { - if let Err(e) = ws_stream.send(Message::Binary(bytes.into())).await { - warn!("Failed to send event to {}: {}. Disconnecting.", url, e); - break; - } - } - Err(e) => { - error!("Failed to format event for relay {}: {}", url, e); - } - } - } - Some(msg) = ws_stream.next() => { - if let Ok(Message::Close(_)) = msg { - warn!("Relay {} closed connection.", url); - break; - } - } - else => break, - } - } - } - Err(e) => { - error!("Failed to connect to firehose relay {}: {}", url, e); - } - } - warn!( - "Disconnected from {}. Reconnecting in 5 seconds...", - url - ); - tokio::time::sleep(Duration::from_secs(5)).await; - } -} - -pub async fn start_relay_clients( - state: AppState, - relays: Vec, - mut ready_rx: Option>, -) { - if relays.is_empty() { - return; - } - - let (ready_tx, mut internal_ready_rx) = mpsc::channel(1); - - for url in relays { - let ready_tx = if ready_rx.is_some() { - Some(ready_tx.clone()) - } else { - None - }; - tokio::spawn(run_relay_client(state.clone(), url, ready_tx)); - } - - if let Some(mut rx) = ready_rx.take() { - tokio::spawn(async move { - internal_ready_rx.recv().await; - rx.close(); - }); - } -} diff --git a/src/validation/mod.rs b/src/validation/mod.rs new file mode 100644 index 0000000..5620e80 --- /dev/null +++ b/src/validation/mod.rs @@ -0,0 +1,504 @@ +use serde_json::Value; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ValidationError { + #[error("No $type provided")] + MissingType, + + #[error("Invalid $type: expected {expected}, got {actual}")] + TypeMismatch { expected: String, actual: String }, + + #[error("Missing required field: {0}")] + MissingField(String), + + #[error("Invalid field value at {path}: {message}")] + InvalidField { path: String, message: String }, + + #[error("Invalid datetime format at {path}: must be RFC-3339/ISO-8601")] + InvalidDatetime { path: String }, + + #[error("Invalid record: {0}")] + InvalidRecord(String), + + #[error("Unknown record type: {0}")] + UnknownType(String), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ValidationStatus { + Valid, + Unknown, + Invalid, +} + +pub struct RecordValidator { + require_lexicon: bool, +} + +impl Default for RecordValidator { + fn default() -> Self { + Self::new() + } +} + +impl RecordValidator { + pub fn new() -> Self { + Self { + require_lexicon: false, + } + } + + pub fn require_lexicon(mut self, require: bool) -> Self { + self.require_lexicon = require; + self + } + + pub fn validate( + &self, + record: &Value, + collection: &str, + ) -> Result { + let obj = record + .as_object() + .ok_or_else(|| ValidationError::InvalidRecord("Record must be an object".to_string()))?; + + let record_type = obj + .get("$type") + .and_then(|v| v.as_str()) + .ok_or(ValidationError::MissingType)?; + + if record_type != collection { + return Err(ValidationError::TypeMismatch { + expected: collection.to_string(), + actual: record_type.to_string(), + }); + } + + if let Some(created_at) = obj.get("createdAt").and_then(|v| v.as_str()) { + validate_datetime(created_at, "createdAt")?; + } + + match record_type { + "app.bsky.feed.post" => self.validate_post(obj)?, + "app.bsky.actor.profile" => self.validate_profile(obj)?, + "app.bsky.feed.like" => self.validate_like(obj)?, + "app.bsky.feed.repost" => self.validate_repost(obj)?, + "app.bsky.graph.follow" => self.validate_follow(obj)?, + "app.bsky.graph.block" => self.validate_block(obj)?, + "app.bsky.graph.list" => self.validate_list(obj)?, + "app.bsky.graph.listitem" => self.validate_list_item(obj)?, + "app.bsky.feed.generator" => self.validate_feed_generator(obj)?, + "app.bsky.feed.threadgate" => self.validate_threadgate(obj)?, + "app.bsky.labeler.service" => self.validate_labeler_service(obj)?, + _ => { + if self.require_lexicon { + return Err(ValidationError::UnknownType(record_type.to_string())); + } + return Ok(ValidationStatus::Unknown); + } + } + + Ok(ValidationStatus::Valid) + } + + fn validate_post(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("text") { + return Err(ValidationError::MissingField("text".to_string())); + } + + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + + if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { + let grapheme_count = text.chars().count(); + if grapheme_count > 3000 { + return Err(ValidationError::InvalidField { + path: "text".to_string(), + message: format!("Text exceeds maximum length of 3000 characters (got {})", grapheme_count), + }); + } + } + + if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) { + if langs.len() > 3 { + return Err(ValidationError::InvalidField { + path: "langs".to_string(), + message: "Maximum 3 languages allowed".to_string(), + }); + } + } + + if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) { + if tags.len() > 8 { + return Err(ValidationError::InvalidField { + path: "tags".to_string(), + message: "Maximum 8 tags allowed".to_string(), + }); + } + for (i, tag) in tags.iter().enumerate() { + if let Some(tag_str) = tag.as_str() { + if tag_str.len() > 640 { + return Err(ValidationError::InvalidField { + path: format!("tags/{}", i), + message: "Tag exceeds maximum length of 640 bytes".to_string(), + }); + } + } + } + } + + Ok(()) + } + + fn validate_profile(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { + let grapheme_count = display_name.chars().count(); + if grapheme_count > 640 { + return Err(ValidationError::InvalidField { + path: "displayName".to_string(), + message: format!("Display name exceeds maximum length of 640 characters (got {})", grapheme_count), + }); + } + } + + if let Some(description) = obj.get("description").and_then(|v| v.as_str()) { + let grapheme_count = description.chars().count(); + if grapheme_count > 2560 { + return Err(ValidationError::InvalidField { + path: "description".to_string(), + message: format!("Description exceeds maximum length of 2560 characters (got {})", grapheme_count), + }); + } + } + + Ok(()) + } + + fn validate_like(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("subject") { + return Err(ValidationError::MissingField("subject".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + self.validate_strong_ref(obj.get("subject"), "subject")?; + Ok(()) + } + + fn validate_repost(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("subject") { + return Err(ValidationError::MissingField("subject".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + self.validate_strong_ref(obj.get("subject"), "subject")?; + Ok(()) + } + + fn validate_follow(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("subject") { + return Err(ValidationError::MissingField("subject".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + + if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) { + if !subject.starts_with("did:") { + return Err(ValidationError::InvalidField { + path: "subject".to_string(), + message: "Subject must be a DID".to_string(), + }); + } + } + + Ok(()) + } + + fn validate_block(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("subject") { + return Err(ValidationError::MissingField("subject".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + + if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) { + if !subject.starts_with("did:") { + return Err(ValidationError::InvalidField { + path: "subject".to_string(), + message: "Subject must be a DID".to_string(), + }); + } + } + + Ok(()) + } + + fn validate_list(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("name") { + return Err(ValidationError::MissingField("name".to_string())); + } + if !obj.contains_key("purpose") { + return Err(ValidationError::MissingField("purpose".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + + if let Some(name) = obj.get("name").and_then(|v| v.as_str()) { + if name.is_empty() || name.len() > 64 { + return Err(ValidationError::InvalidField { + path: "name".to_string(), + message: "Name must be 1-64 characters".to_string(), + }); + } + } + + Ok(()) + } + + fn validate_list_item(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("subject") { + return Err(ValidationError::MissingField("subject".to_string())); + } + if !obj.contains_key("list") { + return Err(ValidationError::MissingField("list".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + Ok(()) + } + + fn validate_feed_generator(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("did") { + return Err(ValidationError::MissingField("did".to_string())); + } + if !obj.contains_key("displayName") { + return Err(ValidationError::MissingField("displayName".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + + if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { + if display_name.is_empty() || display_name.len() > 240 { + return Err(ValidationError::InvalidField { + path: "displayName".to_string(), + message: "displayName must be 1-240 characters".to_string(), + }); + } + } + + Ok(()) + } + + fn validate_threadgate(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("post") { + return Err(ValidationError::MissingField("post".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + Ok(()) + } + + fn validate_labeler_service(&self, obj: &serde_json::Map) -> Result<(), ValidationError> { + if !obj.contains_key("policies") { + return Err(ValidationError::MissingField("policies".to_string())); + } + if !obj.contains_key("createdAt") { + return Err(ValidationError::MissingField("createdAt".to_string())); + } + Ok(()) + } + + fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> { + let obj = value + .and_then(|v| v.as_object()) + .ok_or_else(|| ValidationError::InvalidField { + path: path.to_string(), + message: "Must be a strong reference object".to_string(), + })?; + + if !obj.contains_key("uri") { + return Err(ValidationError::MissingField(format!("{}/uri", path))); + } + if !obj.contains_key("cid") { + return Err(ValidationError::MissingField(format!("{}/cid", path))); + } + + if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) { + if !uri.starts_with("at://") { + return Err(ValidationError::InvalidField { + path: format!("{}/uri", path), + message: "URI must be an at:// URI".to_string(), + }); + } + } + + Ok(()) + } +} + +fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> { + if chrono::DateTime::parse_from_rfc3339(value).is_err() { + return Err(ValidationError::InvalidDatetime { + path: path.to_string(), + }); + } + Ok(()) +} + +pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> { + if rkey.is_empty() { + return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string())); + } + + if rkey.len() > 512 { + return Err(ValidationError::InvalidRecord("Record key exceeds maximum length of 512".to_string())); + } + + if rkey == "." || rkey == ".." { + return Err(ValidationError::InvalidRecord("Record key cannot be '.' or '..'".to_string())); + } + + let valid_chars = rkey.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~' + }); + + if !valid_chars { + return Err(ValidationError::InvalidRecord( + "Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')".to_string() + )); + } + + Ok(()) +} + +pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> { + if collection.is_empty() { + return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string())); + } + + let parts: Vec<&str> = collection.split('.').collect(); + if parts.len() < 3 { + return Err(ValidationError::InvalidRecord( + "Collection NSID must have at least 3 segments".to_string() + )); + } + + for part in &parts { + if part.is_empty() { + return Err(ValidationError::InvalidRecord( + "Collection NSID segments cannot be empty".to_string() + )); + } + if !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') { + return Err(ValidationError::InvalidRecord( + "Collection NSID segments must be alphanumeric or hyphens".to_string() + )); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_validate_post() { + let validator = RecordValidator::new(); + + let valid_post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello, world!", + "createdAt": "2024-01-01T00:00:00.000Z" + }); + + assert_eq!( + validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), + ValidationStatus::Valid + ); + } + + #[test] + fn test_validate_post_missing_text() { + let validator = RecordValidator::new(); + + let invalid_post = json!({ + "$type": "app.bsky.feed.post", + "createdAt": "2024-01-01T00:00:00.000Z" + }); + + assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err()); + } + + #[test] + fn test_validate_type_mismatch() { + let validator = RecordValidator::new(); + + let record = json!({ + "$type": "app.bsky.feed.like", + "subject": {"uri": "at://did:plc:test/app.bsky.feed.post/123", "cid": "bafyrei..."}, + "createdAt": "2024-01-01T00:00:00.000Z" + }); + + let result = validator.validate(&record, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); + } + + #[test] + fn test_validate_unknown_type() { + let validator = RecordValidator::new(); + + let record = json!({ + "$type": "com.example.custom", + "data": "test" + }); + + assert_eq!( + validator.validate(&record, "com.example.custom").unwrap(), + ValidationStatus::Unknown + ); + } + + #[test] + fn test_validate_unknown_type_strict() { + let validator = RecordValidator::new().require_lexicon(true); + + let record = json!({ + "$type": "com.example.custom", + "data": "test" + }); + + let result = validator.validate(&record, "com.example.custom"); + assert!(matches!(result, Err(ValidationError::UnknownType(_)))); + } + + #[test] + fn test_validate_record_key() { + assert!(validate_record_key("valid-key_123").is_ok()); + assert!(validate_record_key("3k2n5j2").is_ok()); + assert!(validate_record_key(".").is_err()); + assert!(validate_record_key("..").is_err()); + assert!(validate_record_key("").is_err()); + assert!(validate_record_key("invalid/key").is_err()); + } + + #[test] + fn test_validate_collection_nsid() { + assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); + assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); + assert!(validate_collection_nsid("invalid").is_err()); + assert!(validate_collection_nsid("a.b").is_err()); + assert!(validate_collection_nsid("").is_err()); + } +} diff --git a/tests/image_processing.rs b/tests/image_processing.rs new file mode 100644 index 0000000..6858ac6 --- /dev/null +++ b/tests/image_processing.rs @@ -0,0 +1,315 @@ +use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE}; +use image::{DynamicImage, ImageFormat}; +use std::io::Cursor; + +fn create_test_png(width: u32, height: u32) -> Vec { + let img = DynamicImage::new_rgb8(width, height); + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); + buf +} + +fn create_test_jpeg(width: u32, height: u32) -> Vec { + let img = DynamicImage::new_rgb8(width, height); + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); + buf +} + +fn create_test_gif(width: u32, height: u32) -> Vec { + let img = DynamicImage::new_rgb8(width, height); + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); + buf +} + +fn create_test_webp(width: u32, height: u32) -> Vec { + let img = DynamicImage::new_rgb8(width, height); + let mut buf = Vec::new(); + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); + buf +} + +#[test] +fn test_process_png() { + let processor = ImageProcessor::new(); + let data = create_test_png(500, 500); + let result = processor.process(&data, "image/png").unwrap(); + assert_eq!(result.original.width, 500); + assert_eq!(result.original.height, 500); +} + +#[test] +fn test_process_jpeg() { + let processor = ImageProcessor::new(); + let data = create_test_jpeg(400, 300); + let result = processor.process(&data, "image/jpeg").unwrap(); + assert_eq!(result.original.width, 400); + assert_eq!(result.original.height, 300); +} + +#[test] +fn test_process_gif() { + let processor = ImageProcessor::new(); + let data = create_test_gif(200, 200); + let result = processor.process(&data, "image/gif").unwrap(); + assert_eq!(result.original.width, 200); + assert_eq!(result.original.height, 200); +} + +#[test] +fn test_process_webp() { + let processor = ImageProcessor::new(); + let data = create_test_webp(300, 200); + let result = processor.process(&data, "image/webp").unwrap(); + assert_eq!(result.original.width, 300); + assert_eq!(result.original.height, 200); +} + +#[test] +fn test_thumbnail_feed_size() { + let processor = ImageProcessor::new(); + let data = create_test_png(800, 600); + let result = processor.process(&data, "image/png").unwrap(); + + let thumb = result.thumbnail_feed.expect("Should generate feed thumbnail for large image"); + assert!(thumb.width <= THUMB_SIZE_FEED); + assert!(thumb.height <= THUMB_SIZE_FEED); +} + +#[test] +fn test_thumbnail_full_size() { + let processor = ImageProcessor::new(); + let data = create_test_png(2000, 1500); + let result = processor.process(&data, "image/png").unwrap(); + + let thumb = result.thumbnail_full.expect("Should generate full thumbnail for large image"); + assert!(thumb.width <= THUMB_SIZE_FULL); + assert!(thumb.height <= THUMB_SIZE_FULL); +} + +#[test] +fn test_no_thumbnail_small_image() { + let processor = ImageProcessor::new(); + let data = create_test_png(100, 100); + let result = processor.process(&data, "image/png").unwrap(); + + assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); + assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); +} + +#[test] +fn test_webp_conversion() { + let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); + let data = create_test_png(300, 300); + let result = processor.process(&data, "image/png").unwrap(); + + assert_eq!(result.original.mime_type, "image/webp"); +} + +#[test] +fn test_jpeg_output_format() { + let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); + let data = create_test_png(300, 300); + let result = processor.process(&data, "image/png").unwrap(); + + assert_eq!(result.original.mime_type, "image/jpeg"); +} + +#[test] +fn test_png_output_format() { + let processor = ImageProcessor::new().with_output_format(OutputFormat::Png); + let data = create_test_jpeg(300, 300); + let result = processor.process(&data, "image/jpeg").unwrap(); + + assert_eq!(result.original.mime_type, "image/png"); +} + +#[test] +fn test_max_dimension_enforced() { + let processor = ImageProcessor::new().with_max_dimension(1000); + let data = create_test_png(2000, 2000); + let result = processor.process(&data, "image/png"); + + assert!(matches!(result, Err(ImageError::TooLarge { .. }))); + if let Err(ImageError::TooLarge { width, height, max_dimension }) = result { + assert_eq!(width, 2000); + assert_eq!(height, 2000); + assert_eq!(max_dimension, 1000); + } +} + +#[test] +fn test_file_size_limit() { + let processor = ImageProcessor::new().with_max_file_size(100); + let data = create_test_png(500, 500); + let result = processor.process(&data, "image/png"); + + assert!(matches!(result, Err(ImageError::FileTooLarge { .. }))); + if let Err(ImageError::FileTooLarge { size, max_size }) = result { + assert!(size > 100); + assert_eq!(max_size, 100); + } +} + +#[test] +fn test_default_max_file_size() { + assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); +} + +#[test] +fn test_unsupported_format_rejected() { + let processor = ImageProcessor::new(); + let data = b"this is not an image"; + let result = processor.process(data, "application/octet-stream"); + + assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); +} + +#[test] +fn test_corrupted_image_handling() { + let processor = ImageProcessor::new(); + let data = b"\x89PNG\r\n\x1a\ncorrupted data here"; + let result = processor.process(data, "image/png"); + + assert!(matches!(result, Err(ImageError::DecodeError(_)))); +} + +#[test] +fn test_aspect_ratio_preserved_landscape() { + let processor = ImageProcessor::new(); + let data = create_test_png(1600, 800); + let result = processor.process(&data, "image/png").unwrap(); + + let thumb = result.thumbnail_full.expect("Should have thumbnail"); + let original_ratio = 1600.0 / 800.0; + let thumb_ratio = thumb.width as f64 / thumb.height as f64; + assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); +} + +#[test] +fn test_aspect_ratio_preserved_portrait() { + let processor = ImageProcessor::new(); + let data = create_test_png(800, 1600); + let result = processor.process(&data, "image/png").unwrap(); + + let thumb = result.thumbnail_full.expect("Should have thumbnail"); + let original_ratio = 800.0 / 1600.0; + let thumb_ratio = thumb.width as f64 / thumb.height as f64; + assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); +} + +#[test] +fn test_mime_type_detection_auto() { + let processor = ImageProcessor::new(); + let data = create_test_png(100, 100); + let result = processor.process(&data, "application/octet-stream"); + + assert!(result.is_ok(), "Should detect PNG format from data"); +} + +#[test] +fn test_is_supported_mime_type() { + assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); + assert!(ImageProcessor::is_supported_mime_type("image/jpg")); + assert!(ImageProcessor::is_supported_mime_type("image/png")); + assert!(ImageProcessor::is_supported_mime_type("image/gif")); + assert!(ImageProcessor::is_supported_mime_type("image/webp")); + assert!(ImageProcessor::is_supported_mime_type("IMAGE/PNG")); + assert!(ImageProcessor::is_supported_mime_type("Image/Jpeg")); + + assert!(!ImageProcessor::is_supported_mime_type("image/bmp")); + assert!(!ImageProcessor::is_supported_mime_type("image/tiff")); + assert!(!ImageProcessor::is_supported_mime_type("text/plain")); + assert!(!ImageProcessor::is_supported_mime_type("application/json")); +} + +#[test] +fn test_strip_exif() { + let data = create_test_jpeg(100, 100); + let result = ImageProcessor::strip_exif(&data); + assert!(result.is_ok()); + let stripped = result.unwrap(); + assert!(!stripped.is_empty()); +} + +#[test] +fn test_with_thumbnails_disabled() { + let processor = ImageProcessor::new().with_thumbnails(false); + let data = create_test_png(2000, 2000); + let result = processor.process(&data, "image/png").unwrap(); + + assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled"); + assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled"); +} + +#[test] +fn test_builder_chaining() { + let processor = ImageProcessor::new() + .with_max_dimension(2048) + .with_max_file_size(5 * 1024 * 1024) + .with_output_format(OutputFormat::Jpeg) + .with_thumbnails(true); + + let data = create_test_png(500, 500); + let result = processor.process(&data, "image/png").unwrap(); + assert_eq!(result.original.mime_type, "image/jpeg"); +} + +#[test] +fn test_processed_image_fields() { + let processor = ImageProcessor::new(); + let data = create_test_png(500, 500); + let result = processor.process(&data, "image/png").unwrap(); + + assert!(!result.original.data.is_empty()); + assert!(!result.original.mime_type.is_empty()); + assert!(result.original.width > 0); + assert!(result.original.height > 0); +} + +#[test] +fn test_only_feed_thumbnail_for_medium_images() { + let processor = ImageProcessor::new(); + let data = create_test_png(500, 500); + let result = processor.process(&data, "image/png").unwrap(); + + assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); + assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image"); +} + +#[test] +fn test_both_thumbnails_for_large_images() { + let processor = ImageProcessor::new(); + let data = create_test_png(2000, 2000); + let result = processor.process(&data, "image/png").unwrap(); + + assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); + assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image"); +} + +#[test] +fn test_exact_threshold_boundary_feed() { + let processor = ImageProcessor::new(); + + let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); + let result = processor.process(&at_threshold, "image/png").unwrap(); + assert!(result.thumbnail_feed.is_none(), "Exact threshold should not generate thumbnail"); + + let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); + let result = processor.process(&above_threshold, "image/png").unwrap(); + assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail"); +} + +#[test] +fn test_exact_threshold_boundary_full() { + let processor = ImageProcessor::new(); + + let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); + let result = processor.process(&at_threshold, "image/png").unwrap(); + assert!(result.thumbnail_full.is_none(), "Exact threshold should not generate thumbnail"); + + let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); + let result = processor.process(&above_threshold, "image/png").unwrap(); + assert!(result.thumbnail_full.is_some(), "Above threshold should generate thumbnail"); +} diff --git a/tests/import_with_verification.rs b/tests/import_with_verification.rs index 28da40e..7efe76c 100644 --- a/tests/import_with_verification.rs +++ b/tests/import_with_verification.rs @@ -217,6 +217,7 @@ async fn get_user_signing_key(did: &str) -> Option> { } #[tokio::test] +#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_valid_signature_and_mock_plc -- --ignored --test-threads=1"] async fn test_import_with_valid_signature_and_mock_plc() { let client = client(); let (token, did) = create_account_and_login(&client).await; @@ -266,6 +267,7 @@ async fn test_import_with_valid_signature_and_mock_plc() { } #[tokio::test] +#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_wrong_signing_key_fails -- --ignored --test-threads=1"] async fn test_import_with_wrong_signing_key_fails() { let client = client(); let (token, did) = create_account_and_login(&client).await; @@ -322,6 +324,7 @@ async fn test_import_with_wrong_signing_key_fails() { } #[tokio::test] +#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_did_mismatch_fails -- --ignored --test-threads=1"] async fn test_import_with_did_mismatch_fails() { let client = client(); let (token, did) = create_account_and_login(&client).await; @@ -373,6 +376,7 @@ async fn test_import_with_did_mismatch_fails() { } #[tokio::test] +#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_plc_resolution_failure -- --ignored --test-threads=1"] async fn test_import_with_plc_resolution_failure() { let client = client(); let (token, did) = create_account_and_login(&client).await; @@ -424,6 +428,7 @@ async fn test_import_with_plc_resolution_failure() { } #[tokio::test] +#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_no_signing_key_in_did_doc -- --ignored --test-threads=1"] async fn test_import_with_no_signing_key_in_did_doc() { let client = client(); let (token, did) = create_account_and_login(&client).await; diff --git a/tests/list_records_pagination.rs b/tests/list_records_pagination.rs new file mode 100644 index 0000000..8a67e70 --- /dev/null +++ b/tests/list_records_pagination.rs @@ -0,0 +1,554 @@ +mod common; +mod helpers; +use common::*; +use helpers::*; + +use chrono::Utc; +use reqwest::StatusCode; +use serde_json::{Value, json}; +use std::time::Duration; + +async fn create_post_with_rkey( + client: &reqwest::Client, + did: &str, + jwt: &str, + rkey: &str, + text: &str, +) -> (String, String) { + let payload = json!({ + "repo": did, + "collection": "app.bsky.feed.post", + "rkey": rkey, + "record": { + "$type": "app.bsky.feed.post", + "text": text, + "createdAt": Utc::now().to_rfc3339() + } + }); + + let res = client + .post(format!( + "{}/xrpc/com.atproto.repo.putRecord", + base_url().await + )) + .bearer_auth(jwt) + .json(&payload) + .send() + .await + .expect("Failed to create record"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + ( + body["uri"].as_str().unwrap().to_string(), + body["cid"].as_str().unwrap().to_string(), + ) +} + +#[tokio::test] +async fn test_list_records_default_order() { + let client = client(); + let (did, jwt) = setup_new_user("list-default-order").await; + + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; + tokio::time::sleep(Duration::from_millis(50)).await; + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; + tokio::time::sleep(Duration::from_millis(50)).await; + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + + assert_eq!(records.len(), 3); + let rkeys: Vec<&str> = records + .iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) + .collect(); + + assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); +} + +#[tokio::test] +async fn test_list_records_reverse_true() { + let client = client(); + let (did, jwt) = setup_new_user("list-reverse").await; + + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; + tokio::time::sleep(Duration::from_millis(50)).await; + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; + tokio::time::sleep(Duration::from_millis(50)).await; + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("reverse", "true"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + + let rkeys: Vec<&str> = records + .iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) + .collect(); + + assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); +} + +#[tokio::test] +async fn test_list_records_cursor_pagination() { + let client = client(); + let (did, jwt) = setup_new_user("list-cursor").await; + + for i in 0..5 { + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; + tokio::time::sleep(Duration::from_millis(50)).await; + } + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "2"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + assert_eq!(records.len(), 2); + + let cursor = body["cursor"].as_str().expect("Should have cursor with more records"); + + let res2 = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "2"), + ("cursor", cursor), + ]) + .send() + .await + .expect("Failed to list records with cursor"); + + assert_eq!(res2.status(), StatusCode::OK); + let body2: Value = res2.json().await.unwrap(); + let records2 = body2["records"].as_array().unwrap(); + assert_eq!(records2.len(), 2); + + let all_uris: Vec<&str> = records + .iter() + .chain(records2.iter()) + .map(|r| r["uri"].as_str().unwrap()) + .collect(); + let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); + assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); +} + +#[tokio::test] +async fn test_list_records_rkey_start() { + let client = client(); + let (did, jwt) = setup_new_user("list-rkey-start").await; + + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("rkeyStart", "bbbb"), + ("reverse", "true"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + + let rkeys: Vec<&str> = records + .iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) + .collect(); + + for rkey in &rkeys { + assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); + } +} + +#[tokio::test] +async fn test_list_records_rkey_end() { + let client = client(); + let (did, jwt) = setup_new_user("list-rkey-end").await; + + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("rkeyEnd", "cccc"), + ("reverse", "true"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + + let rkeys: Vec<&str> = records + .iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) + .collect(); + + for rkey in &rkeys { + assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); + } +} + +#[tokio::test] +async fn test_list_records_rkey_range() { + let client = client(); + let (did, jwt) = setup_new_user("list-rkey-range").await; + + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; + create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("rkeyStart", "bbbb"), + ("rkeyEnd", "dddd"), + ("reverse", "true"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + + let rkeys: Vec<&str> = records + .iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) + .collect(); + + for rkey in &rkeys { + assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey); + } + assert!(!rkeys.is_empty(), "Should have at least some records in range"); +} + +#[tokio::test] +async fn test_list_records_limit_clamping_max() { + let client = client(); + let (did, jwt) = setup_new_user("list-limit-max").await; + + for i in 0..5 { + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; + } + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "1000"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + assert!(records.len() <= 100, "Limit should be clamped to max 100"); +} + +#[tokio::test] +async fn test_list_records_limit_clamping_min() { + let client = client(); + let (did, jwt) = setup_new_user("list-limit-min").await; + + create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "0"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + assert!(records.len() >= 1, "Limit should be clamped to min 1"); +} + +#[tokio::test] +async fn test_list_records_empty_collection() { + let client = client(); + let (did, _jwt) = setup_new_user("list-empty").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + assert!(records.is_empty(), "Empty collection should return empty array"); + assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); +} + +#[tokio::test] +async fn test_list_records_exact_limit() { + let client = client(); + let (did, jwt) = setup_new_user("list-exact-limit").await; + + for i in 0..10 { + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; + } + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "5"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); +} + +#[tokio::test] +async fn test_list_records_cursor_exhaustion() { + let client = client(); + let (did, jwt) = setup_new_user("list-cursor-exhaust").await; + + for i in 0..3 { + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; + } + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "10"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + assert_eq!(records.len(), 3); +} + +#[tokio::test] +async fn test_list_records_repo_not_found() { + let client = client(); + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", "did:plc:nonexistent12345"), + ("collection", "app.bsky.feed.post"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn test_list_records_includes_cid() { + let client = client(); + let (did, jwt) = setup_new_user("list-includes-cid").await; + + create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await; + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + + for record in records { + assert!(record["uri"].is_string(), "Record should have uri"); + assert!(record["cid"].is_string(), "Record should have cid"); + assert!(record["value"].is_object(), "Record should have value"); + let cid = record["cid"].as_str().unwrap(); + assert!(cid.starts_with("bafy"), "CID should be valid"); + } +} + +#[tokio::test] +async fn test_list_records_cursor_with_reverse() { + let client = client(); + let (did, jwt) = setup_new_user("list-cursor-reverse").await; + + for i in 0..5 { + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; + } + + let res = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "2"), + ("reverse", "true"), + ]) + .send() + .await + .expect("Failed to list records"); + + assert_eq!(res.status(), StatusCode::OK); + let body: Value = res.json().await.unwrap(); + let records = body["records"].as_array().unwrap(); + let first_rkeys: Vec<&str> = records + .iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) + .collect(); + + assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest"); + + if let Some(cursor) = body["cursor"].as_str() { + let res2 = client + .get(format!( + "{}/xrpc/com.atproto.repo.listRecords", + base_url().await + )) + .query(&[ + ("repo", did.as_str()), + ("collection", "app.bsky.feed.post"), + ("limit", "2"), + ("reverse", "true"), + ("cursor", cursor), + ]) + .send() + .await + .expect("Failed to list records with cursor"); + + let body2: Value = res2.json().await.unwrap(); + let records2 = body2["records"].as_array().unwrap(); + let second_rkeys: Vec<&str> = records2 + .iter() + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) + .collect(); + + assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order"); + } +} diff --git a/tests/oauth.rs b/tests/oauth.rs index bc23182..3b58987 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -323,6 +323,7 @@ async fn test_authorize_get_with_valid_request_uri() { let auth_res = client .get(format!("{}/oauth/authorize", url)) + .header("Accept", "application/json") .query(&[("request_uri", request_uri)]) .send() .await @@ -344,6 +345,7 @@ async fn test_authorize_rejects_invalid_request_uri() { let res = client .get(format!("{}/oauth/authorize", url)) + .header("Accept", "application/json") .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) .send() .await @@ -941,6 +943,7 @@ async fn test_wrong_credentials_denied() { let auth_res = http_client .post(format!("{}/oauth/authorize", url)) + .header("Accept", "application/json") .form(&[ ("request_uri", request_uri), ("username", &handle), @@ -1162,6 +1165,7 @@ async fn test_deactivated_account_cannot_authorize() { let auth_res = http_client .post(format!("{}/oauth/authorize", url)) + .header("Accept", "application/json") .form(&[ ("request_uri", request_uri), ("username", &handle), @@ -1184,6 +1188,7 @@ async fn test_expired_authorization_request() { let res = http_client .get(format!("{}/oauth/authorize", url)) + .header("Accept", "application/json") .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")]) .send() .await @@ -1477,3 +1482,631 @@ async fn test_state_with_special_chars() { location ); } + +#[tokio::test] +async fn test_2fa_required_when_enabled() { + let url = base_url().await; + let http_client = client(); + + let ts = Utc::now().timestamp_millis(); + let handle = format!("2fa-required-{}", ts); + let email = format!("2fa-required-{}@example.com", ts); + let password = "2fa-test-password"; + + let create_res = http_client + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ + "handle": handle, + "email": email, + "password": password + })) + .send() + .await + .unwrap(); + assert_eq!(create_res.status(), StatusCode::OK); + let account: Value = create_res.json().await.unwrap(); + let user_did = account["did"].as_str().unwrap(); + + let db_url = common::get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(&db_url) + .await + .expect("Failed to connect to database"); + + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") + .bind(user_did) + .execute(&pool) + .await + .expect("Failed to enable 2FA"); + + let redirect_uri = "https://example.com/2fa-callback"; + let mock_client = setup_mock_client_metadata(redirect_uri).await; + let client_id = mock_client.uri(); + + let (_, code_challenge) = generate_pkce(); + + let par_body: Value = http_client + .post(format!("{}/oauth/par", url)) + .form(&[ + ("response_type", "code"), + ("client_id", &client_id), + ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), + ("code_challenge_method", "S256"), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + let request_uri = par_body["request_uri"].as_str().unwrap(); + + let auth_client = no_redirect_client(); + let auth_res = auth_client + .post(format!("{}/oauth/authorize", url)) + .form(&[ + ("request_uri", request_uri), + ("username", &handle), + ("password", password), + ("remember_device", "false"), + ]) + .send() + .await + .unwrap(); + + assert!( + auth_res.status().is_redirection(), + "Should redirect to 2FA page, got status: {}", + auth_res.status() + ); + + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + assert!( + location.contains("/oauth/authorize/2fa"), + "Should redirect to 2FA page, got: {}", + location + ); + assert!( + location.contains("request_uri="), + "2FA redirect should include request_uri" + ); +} + +#[tokio::test] +async fn test_2fa_invalid_code_rejected() { + let url = base_url().await; + let http_client = client(); + + let ts = Utc::now().timestamp_millis(); + let handle = format!("2fa-invalid-{}", ts); + let email = format!("2fa-invalid-{}@example.com", ts); + let password = "2fa-test-password"; + + let create_res = http_client + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ + "handle": handle, + "email": email, + "password": password + })) + .send() + .await + .unwrap(); + assert_eq!(create_res.status(), StatusCode::OK); + let account: Value = create_res.json().await.unwrap(); + let user_did = account["did"].as_str().unwrap(); + + let db_url = common::get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(&db_url) + .await + .expect("Failed to connect to database"); + + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") + .bind(user_did) + .execute(&pool) + .await + .expect("Failed to enable 2FA"); + + let redirect_uri = "https://example.com/2fa-invalid-callback"; + let mock_client = setup_mock_client_metadata(redirect_uri).await; + let client_id = mock_client.uri(); + + let (_, code_challenge) = generate_pkce(); + + let par_body: Value = http_client + .post(format!("{}/oauth/par", url)) + .form(&[ + ("response_type", "code"), + ("client_id", &client_id), + ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), + ("code_challenge_method", "S256"), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + let request_uri = par_body["request_uri"].as_str().unwrap(); + + let auth_client = no_redirect_client(); + let auth_res = auth_client + .post(format!("{}/oauth/authorize", url)) + .form(&[ + ("request_uri", request_uri), + ("username", &handle), + ("password", password), + ("remember_device", "false"), + ]) + .send() + .await + .unwrap(); + + assert!(auth_res.status().is_redirection()); + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(location.contains("/oauth/authorize/2fa")); + + let twofa_res = http_client + .post(format!("{}/oauth/authorize/2fa", url)) + .form(&[ + ("request_uri", request_uri), + ("code", "000000"), + ]) + .send() + .await + .unwrap(); + + assert_eq!(twofa_res.status(), StatusCode::OK); + let body = twofa_res.text().await.unwrap(); + assert!( + body.contains("Invalid verification code") || body.contains("invalid"), + "Should show error for invalid code" + ); +} + +#[tokio::test] +async fn test_2fa_valid_code_completes_auth() { + let url = base_url().await; + let http_client = client(); + + let ts = Utc::now().timestamp_millis(); + let handle = format!("2fa-valid-{}", ts); + let email = format!("2fa-valid-{}@example.com", ts); + let password = "2fa-test-password"; + + let create_res = http_client + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ + "handle": handle, + "email": email, + "password": password + })) + .send() + .await + .unwrap(); + assert_eq!(create_res.status(), StatusCode::OK); + let account: Value = create_res.json().await.unwrap(); + let user_did = account["did"].as_str().unwrap(); + + let db_url = common::get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(&db_url) + .await + .expect("Failed to connect to database"); + + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") + .bind(user_did) + .execute(&pool) + .await + .expect("Failed to enable 2FA"); + + let redirect_uri = "https://example.com/2fa-valid-callback"; + let mock_client = setup_mock_client_metadata(redirect_uri).await; + let client_id = mock_client.uri(); + + let (code_verifier, code_challenge) = generate_pkce(); + + let par_body: Value = http_client + .post(format!("{}/oauth/par", url)) + .form(&[ + ("response_type", "code"), + ("client_id", &client_id), + ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), + ("code_challenge_method", "S256"), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + let request_uri = par_body["request_uri"].as_str().unwrap(); + + let auth_client = no_redirect_client(); + let auth_res = auth_client + .post(format!("{}/oauth/authorize", url)) + .form(&[ + ("request_uri", request_uri), + ("username", &handle), + ("password", password), + ("remember_device", "false"), + ]) + .send() + .await + .unwrap(); + + assert!(auth_res.status().is_redirection()); + + let twofa_code: String = sqlx::query_scalar( + "SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1" + ) + .bind(request_uri) + .fetch_one(&pool) + .await + .expect("Failed to get 2FA code from database"); + + let twofa_res = auth_client + .post(format!("{}/oauth/authorize/2fa", url)) + .form(&[ + ("request_uri", request_uri), + ("code", &twofa_code), + ]) + .send() + .await + .unwrap(); + + assert!( + twofa_res.status().is_redirection(), + "Valid 2FA code should redirect to success, got status: {}", + twofa_res.status() + ); + + let location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); + assert!( + location.starts_with(redirect_uri), + "Should redirect to client callback, got: {}", + location + ); + assert!( + location.contains("code="), + "Redirect should include authorization code" + ); + + let auth_code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); + + let token_res = http_client + .post(format!("{}/oauth/token", url)) + .form(&[ + ("grant_type", "authorization_code"), + ("code", auth_code), + ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), + ("client_id", &client_id), + ]) + .send() + .await + .unwrap(); + + assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed"); + let token_body: Value = token_res.json().await.unwrap(); + assert!(token_body["access_token"].is_string()); + assert_eq!(token_body["sub"], user_did); +} + +#[tokio::test] +async fn test_2fa_lockout_after_max_attempts() { + let url = base_url().await; + let http_client = client(); + + let ts = Utc::now().timestamp_millis(); + let handle = format!("2fa-lockout-{}", ts); + let email = format!("2fa-lockout-{}@example.com", ts); + let password = "2fa-test-password"; + + let create_res = http_client + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ + "handle": handle, + "email": email, + "password": password + })) + .send() + .await + .unwrap(); + assert_eq!(create_res.status(), StatusCode::OK); + let account: Value = create_res.json().await.unwrap(); + let user_did = account["did"].as_str().unwrap(); + + let db_url = common::get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(&db_url) + .await + .expect("Failed to connect to database"); + + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") + .bind(user_did) + .execute(&pool) + .await + .expect("Failed to enable 2FA"); + + let redirect_uri = "https://example.com/2fa-lockout-callback"; + let mock_client = setup_mock_client_metadata(redirect_uri).await; + let client_id = mock_client.uri(); + + let (_, code_challenge) = generate_pkce(); + + let par_body: Value = http_client + .post(format!("{}/oauth/par", url)) + .form(&[ + ("response_type", "code"), + ("client_id", &client_id), + ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), + ("code_challenge_method", "S256"), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + let request_uri = par_body["request_uri"].as_str().unwrap(); + + let auth_client = no_redirect_client(); + let auth_res = auth_client + .post(format!("{}/oauth/authorize", url)) + .form(&[ + ("request_uri", request_uri), + ("username", &handle), + ("password", password), + ("remember_device", "false"), + ]) + .send() + .await + .unwrap(); + + assert!(auth_res.status().is_redirection()); + + for i in 0..5 { + let res = http_client + .post(format!("{}/oauth/authorize/2fa", url)) + .form(&[ + ("request_uri", request_uri), + ("code", "999999"), + ]) + .send() + .await + .unwrap(); + + if i < 4 { + assert_eq!(res.status(), StatusCode::OK, "Attempt {} should show error page", i + 1); + let body = res.text().await.unwrap(); + assert!( + body.contains("Invalid verification code"), + "Should show invalid code error on attempt {}", i + 1 + ); + } + } + + let lockout_res = http_client + .post(format!("{}/oauth/authorize/2fa", url)) + .form(&[ + ("request_uri", request_uri), + ("code", "999999"), + ]) + .send() + .await + .unwrap(); + + assert_eq!(lockout_res.status(), StatusCode::OK); + let body = lockout_res.text().await.unwrap(); + assert!( + body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"), + "Should be locked out after max attempts. Body: {}", + &body[..body.len().min(500)] + ); +} + +#[tokio::test] +async fn test_account_selector_with_2fa_requires_verification() { + let url = base_url().await; + let http_client = client(); + + let ts = Utc::now().timestamp_millis(); + let handle = format!("selector-2fa-{}", ts); + let email = format!("selector-2fa-{}@example.com", ts); + let password = "selector-2fa-password"; + + let create_res = http_client + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) + .json(&json!({ + "handle": handle, + "email": email, + "password": password + })) + .send() + .await + .unwrap(); + assert_eq!(create_res.status(), StatusCode::OK); + let account: Value = create_res.json().await.unwrap(); + let user_did = account["did"].as_str().unwrap().to_string(); + + let redirect_uri = "https://example.com/selector-2fa-callback"; + let mock_client = setup_mock_client_metadata(redirect_uri).await; + let client_id = mock_client.uri(); + + let (code_verifier, code_challenge) = generate_pkce(); + + let par_body: Value = http_client + .post(format!("{}/oauth/par", url)) + .form(&[ + ("response_type", "code"), + ("client_id", &client_id), + ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge), + ("code_challenge_method", "S256"), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + let request_uri = par_body["request_uri"].as_str().unwrap(); + + let auth_client = no_redirect_client(); + let auth_res = auth_client + .post(format!("{}/oauth/authorize", url)) + .form(&[ + ("request_uri", request_uri), + ("username", &handle), + ("password", password), + ("remember_device", "true"), + ]) + .send() + .await + .unwrap(); + + assert!(auth_res.status().is_redirection()); + + let device_cookie = auth_res.headers() + .get("set-cookie") + .and_then(|v| v.to_str().ok()) + .map(|s| s.split(';').next().unwrap_or("").to_string()) + .expect("Should have received device cookie"); + + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); + assert!(location.contains("code="), "First auth should succeed"); + + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); + let _token_body: Value = http_client + .post(format!("{}/oauth/token", url)) + .form(&[ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier), + ("client_id", &client_id), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + let db_url = common::get_db_connection_string().await; + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(&db_url) + .await + .expect("Failed to connect to database"); + + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") + .bind(&user_did) + .execute(&pool) + .await + .expect("Failed to enable 2FA"); + + let (code_verifier2, code_challenge2) = generate_pkce(); + + let par_body2: Value = http_client + .post(format!("{}/oauth/par", url)) + .form(&[ + ("response_type", "code"), + ("client_id", &client_id), + ("redirect_uri", redirect_uri), + ("code_challenge", &code_challenge2), + ("code_challenge_method", "S256"), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + let request_uri2 = par_body2["request_uri"].as_str().unwrap(); + + let select_res = auth_client + .post(format!("{}/oauth/authorize/select", url)) + .header("cookie", &device_cookie) + .form(&[ + ("request_uri", request_uri2), + ("did", &user_did), + ]) + .send() + .await + .unwrap(); + + assert!( + select_res.status().is_redirection(), + "Account selector should redirect, got status: {}", + select_res.status() + ); + + let select_location = select_res.headers().get("location").unwrap().to_str().unwrap(); + assert!( + select_location.contains("/oauth/authorize/2fa"), + "Account selector with 2FA enabled should redirect to 2FA page, got: {}", + select_location + ); + + let twofa_code: String = sqlx::query_scalar( + "SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1" + ) + .bind(request_uri2) + .fetch_one(&pool) + .await + .expect("Failed to get 2FA code"); + + let twofa_res = auth_client + .post(format!("{}/oauth/authorize/2fa", url)) + .header("cookie", &device_cookie) + .form(&[ + ("request_uri", request_uri2), + ("code", &twofa_code), + ]) + .send() + .await + .unwrap(); + + assert!(twofa_res.status().is_redirection()); + let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); + assert!( + final_location.starts_with(redirect_uri) && final_location.contains("code="), + "After 2FA, should redirect to client with code, got: {}", + final_location + ); + + let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); + let token_res = http_client + .post(format!("{}/oauth/token", url)) + .form(&[ + ("grant_type", "authorization_code"), + ("code", final_code), + ("redirect_uri", redirect_uri), + ("code_verifier", &code_verifier2), + ("client_id", &client_id), + ]) + .send() + .await + .unwrap(); + + assert_eq!(token_res.status(), StatusCode::OK); + let final_token: Value = token_res.json().await.unwrap(); + assert_eq!(final_token["sub"], user_did, "Token should be for the correct user"); +} diff --git a/tests/oauth_security.rs b/tests/oauth_security.rs index e240cea..347cd12 100644 --- a/tests/oauth_security.rs +++ b/tests/oauth_security.rs @@ -735,6 +735,7 @@ async fn test_security_deactivated_account_blocked() { let auth_res = http_client .post(format!("{}/oauth/authorize", url)) + .header("Accept", "application/json") .form(&[ ("request_uri", request_uri), ("username", &handle), diff --git a/tests/plc_migration.rs b/tests/plc_migration.rs index 3ebf0d2..5b1f560 100644 --- a/tests/plc_migration.rs +++ b/tests/plc_migration.rs @@ -255,6 +255,7 @@ async fn test_full_plc_operation_flow() { } #[tokio::test] +#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"] async fn test_sign_plc_operation_consumes_token() { let client = client(); let (token, did) = create_account_and_login(&client).await; @@ -902,6 +903,7 @@ async fn test_migration_rejects_wrong_did_document() { } #[tokio::test] +#[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"] async fn test_full_migration_flow_end_to_end() { let client = client(); let (token, did) = create_account_and_login(&client).await; diff --git a/tests/plc_validation.rs b/tests/plc_validation.rs new file mode 100644 index 0000000..4c4e04e --- /dev/null +++ b/tests/plc_validation.rs @@ -0,0 +1,513 @@ +use bspds::plc::{ + PlcError, PlcOperation, PlcService, PlcValidationContext, + cid_for_cbor, sign_operation, signing_key_to_did_key, + validate_plc_operation, validate_plc_operation_for_submission, + verify_operation_signature, +}; +use k256::ecdsa::SigningKey; +use serde_json::json; +use std::collections::HashMap; + +fn create_valid_operation() -> serde_json::Value { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": { + "atproto": did_key.clone() + }, + "alsoKnownAs": ["at://test.handle"], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": "https://pds.example.com" + } + }, + "prev": null + }); + + sign_operation(&op, &key).unwrap() +} + +#[test] +fn test_validate_plc_operation_valid() { + let op = create_valid_operation(); + let result = validate_plc_operation(&op); + assert!(result.is_ok()); +} + +#[test] +fn test_validate_plc_operation_missing_type() { + let op = json!({ + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "sig": "test" + }); + let result = validate_plc_operation(&op); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); +} + +#[test] +fn test_validate_plc_operation_invalid_type() { + let op = json!({ + "type": "invalid_type", + "sig": "test" + }); + let result = validate_plc_operation(&op); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); +} + +#[test] +fn test_validate_plc_operation_missing_sig() { + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {} + }); + let result = validate_plc_operation(&op); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); +} + +#[test] +fn test_validate_plc_operation_missing_rotation_keys() { + let op = json!({ + "type": "plc_operation", + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "sig": "test" + }); + let result = validate_plc_operation(&op); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); +} + +#[test] +fn test_validate_plc_operation_missing_verification_methods() { + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "alsoKnownAs": [], + "services": {}, + "sig": "test" + }); + let result = validate_plc_operation(&op); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); +} + +#[test] +fn test_validate_plc_operation_missing_also_known_as() { + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "services": {}, + "sig": "test" + }); + let result = validate_plc_operation(&op); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); +} + +#[test] +fn test_validate_plc_operation_missing_services() { + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "sig": "test" + }); + let result = validate_plc_operation(&op); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); +} + +#[test] +fn test_validate_rotation_key_required() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + let server_key = "did:key:zServer123"; + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": {"atproto": did_key.clone()}, + "alsoKnownAs": ["at://test.handle"], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": "https://pds.example.com" + } + }, + "sig": "test" + }); + + let ctx = PlcValidationContext { + server_rotation_key: server_key.to_string(), + expected_signing_key: did_key.clone(), + expected_handle: "test.handle".to_string(), + expected_pds_endpoint: "https://pds.example.com".to_string(), + }; + + let result = validate_plc_operation_for_submission(&op, &ctx); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); +} + +#[test] +fn test_validate_signing_key_match() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + let wrong_key = "did:key:zWrongKey456"; + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": {"atproto": wrong_key}, + "alsoKnownAs": ["at://test.handle"], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": "https://pds.example.com" + } + }, + "sig": "test" + }); + + let ctx = PlcValidationContext { + server_rotation_key: did_key.clone(), + expected_signing_key: did_key.clone(), + expected_handle: "test.handle".to_string(), + expected_pds_endpoint: "https://pds.example.com".to_string(), + }; + + let result = validate_plc_operation_for_submission(&op, &ctx); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); +} + +#[test] +fn test_validate_handle_match() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": {"atproto": did_key.clone()}, + "alsoKnownAs": ["at://wrong.handle"], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": "https://pds.example.com" + } + }, + "sig": "test" + }); + + let ctx = PlcValidationContext { + server_rotation_key: did_key.clone(), + expected_signing_key: did_key.clone(), + expected_handle: "test.handle".to_string(), + expected_pds_endpoint: "https://pds.example.com".to_string(), + }; + + let result = validate_plc_operation_for_submission(&op, &ctx); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); +} + +#[test] +fn test_validate_pds_service_type() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": {"atproto": did_key.clone()}, + "alsoKnownAs": ["at://test.handle"], + "services": { + "atproto_pds": { + "type": "WrongServiceType", + "endpoint": "https://pds.example.com" + } + }, + "sig": "test" + }); + + let ctx = PlcValidationContext { + server_rotation_key: did_key.clone(), + expected_signing_key: did_key.clone(), + expected_handle: "test.handle".to_string(), + expected_pds_endpoint: "https://pds.example.com".to_string(), + }; + + let result = validate_plc_operation_for_submission(&op, &ctx); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); +} + +#[test] +fn test_validate_pds_endpoint_match() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": {"atproto": did_key.clone()}, + "alsoKnownAs": ["at://test.handle"], + "services": { + "atproto_pds": { + "type": "AtprotoPersonalDataServer", + "endpoint": "https://wrong.endpoint.com" + } + }, + "sig": "test" + }); + + let ctx = PlcValidationContext { + server_rotation_key: did_key.clone(), + expected_signing_key: did_key.clone(), + expected_handle: "test.handle".to_string(), + expected_pds_endpoint: "https://pds.example.com".to_string(), + }; + + let result = validate_plc_operation_for_submission(&op, &ctx); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); +} + +#[test] +fn test_verify_signature_secp256k1() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [did_key.clone()], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "prev": null + }); + + let signed = sign_operation(&op, &key).unwrap(); + let rotation_keys = vec![did_key]; + + let result = verify_operation_signature(&signed, &rotation_keys); + assert!(result.is_ok()); + assert!(result.unwrap()); +} + +#[test] +fn test_verify_signature_wrong_key() { + let key = SigningKey::random(&mut rand::thread_rng()); + let other_key = SigningKey::random(&mut rand::thread_rng()); + let other_did_key = signing_key_to_did_key(&other_key); + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "prev": null + }); + + let signed = sign_operation(&op, &key).unwrap(); + let wrong_rotation_keys = vec![other_did_key]; + + let result = verify_operation_signature(&signed, &wrong_rotation_keys); + assert!(result.is_ok()); + assert!(!result.unwrap()); +} + +#[test] +fn test_verify_signature_invalid_did_key_format() { + let key = SigningKey::random(&mut rand::thread_rng()); + + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "prev": null + }); + + let signed = sign_operation(&op, &key).unwrap(); + let invalid_keys = vec!["not-a-did-key".to_string()]; + + let result = verify_operation_signature(&signed, &invalid_keys); + assert!(result.is_ok()); + assert!(!result.unwrap()); +} + +#[test] +fn test_tombstone_validation() { + let op = json!({ + "type": "plc_tombstone", + "prev": "bafyreig6xxxxxyyyyyzzzzzz", + "sig": "test" + }); + let result = validate_plc_operation(&op); + assert!(result.is_ok()); +} + +#[test] +fn test_cid_for_cbor_deterministic() { + let value = json!({ + "alpha": 1, + "beta": 2 + }); + + let cid1 = cid_for_cbor(&value).unwrap(); + let cid2 = cid_for_cbor(&value).unwrap(); + + assert_eq!(cid1, cid2, "CID generation should be deterministic"); + assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)"); +} + +#[test] +fn test_cid_different_for_different_data() { + let value1 = json!({"data": 1}); + let value2 = json!({"data": 2}); + + let cid1 = cid_for_cbor(&value1).unwrap(); + let cid2 = cid_for_cbor(&value2).unwrap(); + + assert_ne!(cid1, cid2, "Different data should produce different CIDs"); +} + +#[test] +fn test_signing_key_to_did_key_format() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + + assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z"); + assert!(did_key.len() > 50, "Did key should be reasonably long"); +} + +#[test] +fn test_signing_key_to_did_key_unique() { + let key1 = SigningKey::random(&mut rand::thread_rng()); + let key2 = SigningKey::random(&mut rand::thread_rng()); + + let did1 = signing_key_to_did_key(&key1); + let did2 = signing_key_to_did_key(&key2); + + assert_ne!(did1, did2, "Different keys should produce different did:keys"); +} + +#[test] +fn test_signing_key_to_did_key_consistent() { + let key = SigningKey::random(&mut rand::thread_rng()); + + let did1 = signing_key_to_did_key(&key); + let did2 = signing_key_to_did_key(&key); + + assert_eq!(did1, did2, "Same key should produce same did:key"); +} + +#[test] +fn test_sign_operation_removes_existing_sig() { + let key = SigningKey::random(&mut rand::thread_rng()); + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "prev": null, + "sig": "old_signature" + }); + + let signed = sign_operation(&op, &key).unwrap(); + let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap(); + + assert_ne!(new_sig, "old_signature", "Should replace old signature"); +} + +#[test] +fn test_validate_plc_operation_not_object() { + let result = validate_plc_operation(&json!("not an object")); + assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); +} + +#[test] +fn test_validate_for_submission_tombstone_passes() { + let key = SigningKey::random(&mut rand::thread_rng()); + let did_key = signing_key_to_did_key(&key); + + let op = json!({ + "type": "plc_tombstone", + "prev": "bafyreig6xxxxxyyyyyzzzzzz", + "sig": "test" + }); + + let ctx = PlcValidationContext { + server_rotation_key: did_key.clone(), + expected_signing_key: did_key, + expected_handle: "test.handle".to_string(), + expected_pds_endpoint: "https://pds.example.com".to_string(), + }; + + let result = validate_plc_operation_for_submission(&op, &ctx); + assert!(result.is_ok(), "Tombstone should pass submission validation"); +} + +#[test] +fn test_verify_signature_missing_sig() { + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {} + }); + + let result = verify_operation_signature(&op, &[]); + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); +} + +#[test] +fn test_verify_signature_invalid_base64() { + let op = json!({ + "type": "plc_operation", + "rotationKeys": [], + "verificationMethods": {}, + "alsoKnownAs": [], + "services": {}, + "sig": "not-valid-base64!!!" + }); + + let result = verify_operation_signature(&op, &[]); + assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); +} + +#[test] +fn test_plc_operation_struct() { + let mut services = HashMap::new(); + services.insert("atproto_pds".to_string(), PlcService { + service_type: "AtprotoPersonalDataServer".to_string(), + endpoint: "https://pds.example.com".to_string(), + }); + + let mut verification_methods = HashMap::new(); + verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); + + let op = PlcOperation { + op_type: "plc_operation".to_string(), + rotation_keys: vec!["did:key:zTest123".to_string()], + verification_methods, + also_known_as: vec!["at://test.handle".to_string()], + services, + prev: None, + sig: Some("test".to_string()), + }; + + let json_value = serde_json::to_value(&op).unwrap(); + assert_eq!(json_value["type"], "plc_operation"); + assert!(json_value["rotationKeys"].is_array()); +} diff --git a/tests/record_validation.rs b/tests/record_validation.rs new file mode 100644 index 0000000..4897671 --- /dev/null +++ b/tests/record_validation.rs @@ -0,0 +1,590 @@ +use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid}; +use serde_json::json; + +fn now() -> String { + chrono::Utc::now().to_rfc3339() +} + +#[test] +fn test_validate_post_valid() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello world!", + "createdAt": now() + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_post_missing_text() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "createdAt": now() + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text")); +} + +#[test] +fn test_validate_post_missing_created_at() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello" + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt")); +} + +#[test] +fn test_validate_post_text_too_long() { + let validator = RecordValidator::new(); + let long_text = "a".repeat(3001); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": long_text, + "createdAt": now() + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text")); +} + +#[test] +fn test_validate_post_text_at_limit() { + let validator = RecordValidator::new(); + let limit_text = "a".repeat(3000); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": limit_text, + "createdAt": now() + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_post_too_many_langs() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello", + "createdAt": now(), + "langs": ["en", "fr", "de", "es"] + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs")); +} + +#[test] +fn test_validate_post_three_langs_ok() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello", + "createdAt": now(), + "langs": ["en", "fr", "de"] + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_post_too_many_tags() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello", + "createdAt": now(), + "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags")); +} + +#[test] +fn test_validate_post_eight_tags_ok() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello", + "createdAt": now(), + "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_post_tag_too_long() { + let validator = RecordValidator::new(); + let long_tag = "t".repeat(641); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Hello", + "createdAt": now(), + "tags": [long_tag] + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); +} + +#[test] +fn test_validate_profile_valid() { + let validator = RecordValidator::new(); + let profile = json!({ + "$type": "app.bsky.actor.profile", + "displayName": "Test User", + "description": "A test user profile" + }); + let result = validator.validate(&profile, "app.bsky.actor.profile"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_profile_empty_ok() { + let validator = RecordValidator::new(); + let profile = json!({ + "$type": "app.bsky.actor.profile" + }); + let result = validator.validate(&profile, "app.bsky.actor.profile"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_profile_displayname_too_long() { + let validator = RecordValidator::new(); + let long_name = "n".repeat(641); + let profile = json!({ + "$type": "app.bsky.actor.profile", + "displayName": long_name + }); + let result = validator.validate(&profile, "app.bsky.actor.profile"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); +} + +#[test] +fn test_validate_profile_description_too_long() { + let validator = RecordValidator::new(); + let long_desc = "d".repeat(2561); + let profile = json!({ + "$type": "app.bsky.actor.profile", + "description": long_desc + }); + let result = validator.validate(&profile, "app.bsky.actor.profile"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")); +} + +#[test] +fn test_validate_like_valid() { + let validator = RecordValidator::new(); + let like = json!({ + "$type": "app.bsky.feed.like", + "subject": { + "uri": "at://did:plc:test/app.bsky.feed.post/123", + "cid": "bafyreig6xxxxxyyyyyzzzzzz" + }, + "createdAt": now() + }); + let result = validator.validate(&like, "app.bsky.feed.like"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_like_missing_subject() { + let validator = RecordValidator::new(); + let like = json!({ + "$type": "app.bsky.feed.like", + "createdAt": now() + }); + let result = validator.validate(&like, "app.bsky.feed.like"); + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); +} + +#[test] +fn test_validate_like_missing_subject_uri() { + let validator = RecordValidator::new(); + let like = json!({ + "$type": "app.bsky.feed.like", + "subject": { + "cid": "bafyreig6xxxxxyyyyyzzzzzz" + }, + "createdAt": now() + }); + let result = validator.validate(&like, "app.bsky.feed.like"); + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri"))); +} + +#[test] +fn test_validate_like_invalid_subject_uri() { + let validator = RecordValidator::new(); + let like = json!({ + "$type": "app.bsky.feed.like", + "subject": { + "uri": "https://example.com/not-at-uri", + "cid": "bafyreig6xxxxxyyyyyzzzzzz" + }, + "createdAt": now() + }); + let result = validator.validate(&like, "app.bsky.feed.like"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); +} + +#[test] +fn test_validate_repost_valid() { + let validator = RecordValidator::new(); + let repost = json!({ + "$type": "app.bsky.feed.repost", + "subject": { + "uri": "at://did:plc:test/app.bsky.feed.post/123", + "cid": "bafyreig6xxxxxyyyyyzzzzzz" + }, + "createdAt": now() + }); + let result = validator.validate(&repost, "app.bsky.feed.repost"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_repost_missing_subject() { + let validator = RecordValidator::new(); + let repost = json!({ + "$type": "app.bsky.feed.repost", + "createdAt": now() + }); + let result = validator.validate(&repost, "app.bsky.feed.repost"); + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); +} + +#[test] +fn test_validate_follow_valid() { + let validator = RecordValidator::new(); + let follow = json!({ + "$type": "app.bsky.graph.follow", + "subject": "did:plc:test12345", + "createdAt": now() + }); + let result = validator.validate(&follow, "app.bsky.graph.follow"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_follow_missing_subject() { + let validator = RecordValidator::new(); + let follow = json!({ + "$type": "app.bsky.graph.follow", + "createdAt": now() + }); + let result = validator.validate(&follow, "app.bsky.graph.follow"); + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); +} + +#[test] +fn test_validate_follow_invalid_subject() { + let validator = RecordValidator::new(); + let follow = json!({ + "$type": "app.bsky.graph.follow", + "subject": "not-a-did", + "createdAt": now() + }); + let result = validator.validate(&follow, "app.bsky.graph.follow"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); +} + +#[test] +fn test_validate_block_valid() { + let validator = RecordValidator::new(); + let block = json!({ + "$type": "app.bsky.graph.block", + "subject": "did:plc:blocked123", + "createdAt": now() + }); + let result = validator.validate(&block, "app.bsky.graph.block"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_block_invalid_subject() { + let validator = RecordValidator::new(); + let block = json!({ + "$type": "app.bsky.graph.block", + "subject": "not-a-did", + "createdAt": now() + }); + let result = validator.validate(&block, "app.bsky.graph.block"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); +} + +#[test] +fn test_validate_list_valid() { + let validator = RecordValidator::new(); + let list = json!({ + "$type": "app.bsky.graph.list", + "name": "My List", + "purpose": "app.bsky.graph.defs#modlist", + "createdAt": now() + }); + let result = validator.validate(&list, "app.bsky.graph.list"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_list_name_too_long() { + let validator = RecordValidator::new(); + let long_name = "n".repeat(65); + let list = json!({ + "$type": "app.bsky.graph.list", + "name": long_name, + "purpose": "app.bsky.graph.defs#modlist", + "createdAt": now() + }); + let result = validator.validate(&list, "app.bsky.graph.list"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); +} + +#[test] +fn test_validate_list_empty_name() { + let validator = RecordValidator::new(); + let list = json!({ + "$type": "app.bsky.graph.list", + "name": "", + "purpose": "app.bsky.graph.defs#modlist", + "createdAt": now() + }); + let result = validator.validate(&list, "app.bsky.graph.list"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); +} + +#[test] +fn test_validate_feed_generator_valid() { + let validator = RecordValidator::new(); + let generator = json!({ + "$type": "app.bsky.feed.generator", + "did": "did:web:example.com", + "displayName": "My Feed", + "createdAt": now() + }); + let result = validator.validate(&generator, "app.bsky.feed.generator"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_feed_generator_displayname_too_long() { + let validator = RecordValidator::new(); + let long_name = "f".repeat(241); + let generator = json!({ + "$type": "app.bsky.feed.generator", + "did": "did:web:example.com", + "displayName": long_name, + "createdAt": now() + }); + let result = validator.validate(&generator, "app.bsky.feed.generator"); + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); +} + +#[test] +fn test_validate_unknown_type_returns_unknown() { + let validator = RecordValidator::new(); + let custom = json!({ + "$type": "com.custom.record", + "data": "test" + }); + let result = validator.validate(&custom, "com.custom.record"); + assert_eq!(result.unwrap(), ValidationStatus::Unknown); +} + +#[test] +fn test_validate_unknown_type_strict_rejects() { + let validator = RecordValidator::new().require_lexicon(true); + let custom = json!({ + "$type": "com.custom.record", + "data": "test" + }); + let result = validator.validate(&custom, "com.custom.record"); + assert!(matches!(result, Err(ValidationError::UnknownType(_)))); +} + +#[test] +fn test_validate_type_mismatch() { + let validator = RecordValidator::new(); + let record = json!({ + "$type": "app.bsky.feed.like", + "subject": {"uri": "at://test", "cid": "bafytest"}, + "createdAt": now() + }); + let result = validator.validate(&record, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) + if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")); +} + +#[test] +fn test_validate_missing_type() { + let validator = RecordValidator::new(); + let record = json!({ + "text": "Hello" + }); + let result = validator.validate(&record, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::MissingType))); +} + +#[test] +fn test_validate_not_object() { + let validator = RecordValidator::new(); + let record = json!("just a string"); + let result = validator.validate(&record, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); +} + +#[test] +fn test_validate_datetime_format_valid() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Test", + "createdAt": "2024-01-15T10:30:00.000Z" + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_datetime_with_offset() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Test", + "createdAt": "2024-01-15T10:30:00+05:30" + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_datetime_invalid_format() { + let validator = RecordValidator::new(); + let post = json!({ + "$type": "app.bsky.feed.post", + "text": "Test", + "createdAt": "2024/01/15" + }); + let result = validator.validate(&post, "app.bsky.feed.post"); + assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. }))); +} + +#[test] +fn test_validate_record_key_valid() { + assert!(validate_record_key("3k2n5j2").is_ok()); + assert!(validate_record_key("valid-key").is_ok()); + assert!(validate_record_key("valid_key").is_ok()); + assert!(validate_record_key("valid.key").is_ok()); + assert!(validate_record_key("valid~key").is_ok()); + assert!(validate_record_key("self").is_ok()); +} + +#[test] +fn test_validate_record_key_empty() { + let result = validate_record_key(""); + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); +} + +#[test] +fn test_validate_record_key_dot() { + assert!(validate_record_key(".").is_err()); + assert!(validate_record_key("..").is_err()); +} + +#[test] +fn test_validate_record_key_invalid_chars() { + assert!(validate_record_key("invalid/key").is_err()); + assert!(validate_record_key("invalid key").is_err()); + assert!(validate_record_key("invalid@key").is_err()); + assert!(validate_record_key("invalid#key").is_err()); +} + +#[test] +fn test_validate_record_key_too_long() { + let long_key = "k".repeat(513); + let result = validate_record_key(&long_key); + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); +} + +#[test] +fn test_validate_record_key_at_max_length() { + let max_key = "k".repeat(512); + assert!(validate_record_key(&max_key).is_ok()); +} + +#[test] +fn test_validate_collection_nsid_valid() { + assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); + assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); + assert!(validate_collection_nsid("a.b.c").is_ok()); + assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); +} + +#[test] +fn test_validate_collection_nsid_empty() { + let result = validate_collection_nsid(""); + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); +} + +#[test] +fn test_validate_collection_nsid_too_few_segments() { + assert!(validate_collection_nsid("a").is_err()); + assert!(validate_collection_nsid("a.b").is_err()); +} + +#[test] +fn test_validate_collection_nsid_empty_segment() { + assert!(validate_collection_nsid("a..b.c").is_err()); + assert!(validate_collection_nsid(".a.b.c").is_err()); + assert!(validate_collection_nsid("a.b.c.").is_err()); +} + +#[test] +fn test_validate_collection_nsid_invalid_chars() { + assert!(validate_collection_nsid("a.b.c/d").is_err()); + assert!(validate_collection_nsid("a.b.c_d").is_err()); + assert!(validate_collection_nsid("a.b.c@d").is_err()); +} + +#[test] +fn test_validate_threadgate() { + let validator = RecordValidator::new(); + let gate = json!({ + "$type": "app.bsky.feed.threadgate", + "post": "at://did:plc:test/app.bsky.feed.post/123", + "createdAt": now() + }); + let result = validator.validate(&gate, "app.bsky.feed.threadgate"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_labeler_service() { + let validator = RecordValidator::new(); + let labeler = json!({ + "$type": "app.bsky.labeler.service", + "policies": { + "labelValues": ["spam", "nsfw"] + }, + "createdAt": now() + }); + let result = validator.validate(&labeler, "app.bsky.labeler.service"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} + +#[test] +fn test_validate_list_item() { + let validator = RecordValidator::new(); + let item = json!({ + "$type": "app.bsky.graph.listitem", + "subject": "did:plc:test123", + "list": "at://did:plc:owner/app.bsky.graph.list/mylist", + "createdAt": now() + }); + let result = validator.validate(&item, "app.bsky.graph.listitem"); + assert_eq!(result.unwrap(), ValidationStatus::Valid); +} diff --git a/tests/relay_client.rs b/tests/relay_client.rs deleted file mode 100644 index dcab6d2..0000000 --- a/tests/relay_client.rs +++ /dev/null @@ -1,86 +0,0 @@ -mod common; -use common::*; - -use axum::{extract::ws::Message, routing::get, Router}; -use bspds::{ - state::AppState, - sync::{firehose::SequencedEvent, relay_client::start_relay_clients}, -}; -use chrono::Utc; -use tokio::net::TcpListener; -use tokio::sync::mpsc; - -async fn mock_relay_server( - listener: TcpListener, - event_tx: mpsc::Sender>, - connected_tx: mpsc::Sender<()>, -) { - let handler = |ws: axum::extract::ws::WebSocketUpgrade| async { - ws.on_upgrade(move |mut socket| async move { - let _ = connected_tx.send(()).await; - while let Some(Ok(msg)) = socket.recv().await { - if let Message::Binary(bytes) = msg { - let _ = event_tx.send(bytes.to_vec()).await; - break; - } - } - }) - }; - let app = Router::new().route("/", get(handler)); - - axum::serve(listener, app.into_make_service()) - .await - .unwrap(); -} - -#[tokio::test] -async fn test_outbound_relay_client() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let (event_tx, mut event_rx) = mpsc::channel(1); - let (connected_tx, _connected_rx) = mpsc::channel::<()>(1); - tokio::spawn(mock_relay_server(listener, event_tx, connected_tx)); - let relay_url = format!("ws://{}", addr); - - let db_url = get_db_connection_string().await; - let pool = sqlx::postgres::PgPoolOptions::new() - .connect(&db_url) - .await - .unwrap(); - let state = AppState::new(pool).await; - - let (ready_tx, ready_rx) = mpsc::channel(1); - start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await; - - tokio::time::timeout( - tokio::time::Duration::from_secs(5), - async { - ready_tx.closed().await; - } - ) - .await - .expect("Timeout waiting for relay client to be ready"); - - let dummy_event = SequencedEvent { - seq: 1, - did: "did:plc:test".to_string(), - created_at: Utc::now(), - event_type: "commit".to_string(), - commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()), - prev_cid: None, - ops: Some(serde_json::json!([])), - blobs: Some(vec![]), - blocks_cids: Some(vec![]), - }; - state.firehose_tx.send(dummy_event).unwrap(); - - let received_bytes = tokio::time::timeout( - tokio::time::Duration::from_secs(5), - event_rx.recv() - ) - .await - .expect("Timeout waiting for event") - .expect("Event channel closed"); - - assert!(!received_bytes.is_empty()); -} diff --git a/tests/security_fixes.rs b/tests/security_fixes.rs new file mode 100644 index 0000000..ad11072 --- /dev/null +++ b/tests/security_fixes.rs @@ -0,0 +1,377 @@ +mod common; + +use bspds::notifications::{ + SendError, is_valid_phone_number, sanitize_header_value, +}; +use bspds::oauth::templates::{login_page, error_page, success_page}; +use bspds::image::{ImageProcessor, ImageError}; + +#[test] +fn test_sanitize_header_value_removes_crlf() { + let malicious = "Injected\r\nBcc: attacker@evil.com"; + let sanitized = sanitize_header_value(malicious); + + assert!(!sanitized.contains('\r'), "CR should be removed"); + assert!(!sanitized.contains('\n'), "LF should be removed"); + assert!(sanitized.contains("Injected"), "Original content should be preserved"); + assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)"); +} + +#[test] +fn test_sanitize_header_value_preserves_content() { + let normal = "Normal Subject Line"; + let sanitized = sanitize_header_value(normal); + + assert_eq!(sanitized, "Normal Subject Line"); +} + +#[test] +fn test_sanitize_header_value_trims_whitespace() { + let padded = " Subject "; + let sanitized = sanitize_header_value(padded); + + assert_eq!(sanitized, "Subject"); +} + +#[test] +fn test_sanitize_header_value_handles_multiple_newlines() { + let input = "Line1\r\nLine2\nLine3\rLine4"; + let sanitized = sanitize_header_value(input); + + assert!(!sanitized.contains('\r'), "CR should be removed"); + assert!(!sanitized.contains('\n'), "LF should be removed"); + assert!(sanitized.contains("Line1"), "Content before newlines preserved"); + assert!(sanitized.contains("Line4"), "Content after newlines preserved"); +} + +#[test] +fn test_email_header_injection_sanitization() { + let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; + let sanitized = sanitize_header_value(header_injection); + + let lines: Vec<&str> = sanitized.split("\r\n").collect(); + assert_eq!(lines.len(), 1, "Should be a single line after sanitization"); + assert!(sanitized.contains("Normal Subject"), "Original content preserved"); + assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text"); + assert!(sanitized.contains("X-Injected:"), "All content on same line"); +} + +#[test] +fn test_valid_phone_number_accepts_correct_format() { + assert!(is_valid_phone_number("+1234567890")); + assert!(is_valid_phone_number("+12025551234")); + assert!(is_valid_phone_number("+442071234567")); + assert!(is_valid_phone_number("+4915123456789")); + assert!(is_valid_phone_number("+1")); +} + +#[test] +fn test_valid_phone_number_rejects_missing_plus() { + assert!(!is_valid_phone_number("1234567890")); + assert!(!is_valid_phone_number("12025551234")); +} + +#[test] +fn test_valid_phone_number_rejects_empty() { + assert!(!is_valid_phone_number("")); +} + +#[test] +fn test_valid_phone_number_rejects_just_plus() { + assert!(!is_valid_phone_number("+")); +} + +#[test] +fn test_valid_phone_number_rejects_too_long() { + assert!(!is_valid_phone_number("+12345678901234567890123")); +} + +#[test] +fn test_valid_phone_number_rejects_letters() { + assert!(!is_valid_phone_number("+abc123")); + assert!(!is_valid_phone_number("+1234abc")); + assert!(!is_valid_phone_number("+a")); +} + +#[test] +fn test_valid_phone_number_rejects_spaces() { + assert!(!is_valid_phone_number("+1234 5678")); + assert!(!is_valid_phone_number("+ 1234567890")); + assert!(!is_valid_phone_number("+1 ")); +} + +#[test] +fn test_valid_phone_number_rejects_special_chars() { + assert!(!is_valid_phone_number("+123-456-7890")); + assert!(!is_valid_phone_number("+1(234)567890")); + assert!(!is_valid_phone_number("+1.234.567.890")); +} + +#[test] +fn test_signal_recipient_command_injection_blocked() { + let malicious_inputs = vec![ + "+123; rm -rf /", + "+123 && cat /etc/passwd", + "+123`id`", + "+123$(whoami)", + "+123|cat /etc/shadow", + "+123\n--help", + "+123\r\n--version", + "+123--help", + ]; + + for input in malicious_inputs { + assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input); + } +} + +#[test] +fn test_image_file_size_limit_enforced() { + let processor = ImageProcessor::new(); + + let oversized_data: Vec = vec![0u8; 11 * 1024 * 1024]; + + let result = processor.process(&oversized_data, "image/jpeg"); + + match result { + Err(ImageError::FileTooLarge { .. }) => {} + Err(other) => { + let msg = format!("{:?}", other); + if !msg.to_lowercase().contains("size") && !msg.to_lowercase().contains("large") { + panic!("Expected FileTooLarge error, got: {:?}", other); + } + } + Ok(_) => panic!("Should reject files over size limit"), + } +} + +#[test] +fn test_image_file_size_limit_configurable() { + let processor = ImageProcessor::new().with_max_file_size(1024); + + let data: Vec = vec![0u8; 2048]; + + let result = processor.process(&data, "image/jpeg"); + + assert!(result.is_err(), "Should reject files over configured limit"); +} + +#[test] +fn test_oauth_template_xss_escaping_client_id() { + let malicious_client_id = ""; + let html = login_page(malicious_client_id, None, None, "test-uri", None, None); + + assert!(!html.contains(""; + let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None); + + assert!(!html.contains(""; + let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None); + + assert!(!html.contains(""; + let malicious_desc = ""; + + let html = error_page(malicious_error, Some(malicious_desc)); + + assert!(!html.contains(""; + + let html = success_page(Some(malicious_name)); + + assert!(!html.contains("