From dea6c09aa06c4f16d58c48e9fe8d489e87352c58 Mon Sep 17 00:00:00 2001 From: lewis Date: Tue, 16 Dec 2025 18:28:20 +0200 Subject: [PATCH] Add admin functionality --- ...b8146d95da0a0ce8ce31dc9a2f930a26cc9ce.json | 22 ------ ...79d0f93bb3458487d79b64bd40ae9accd522.json} | 12 +++- ...40e95264ded8021cfc7490b578a96fb8db3c.json} | 12 +++- ...891936b3f906c2d4cdb4bf203dd6a4c9aa060.json | 28 -------- ...d48045a47ca5841eef8c8889894f2c2452f6.json} | 12 +++- ...5b27a89ea481427753a860215d6ee85f8dcf8.json | 46 ++++++++++++ ...0ad5183b647eaaff90decbce15e10d83c7538.json | 20 ++++++ migrations/20251218_add_is_admin.sql | 1 + src/api/admin/account/delete.rs | 11 +-- src/api/admin/account/email.rs | 11 +-- src/api/admin/account/info.rs | 21 +----- src/api/admin/account/profile.rs | 23 +----- src/api/admin/account/update.rs | 31 ++------ src/api/admin/invite.rs | 41 ++--------- src/api/admin/server_stats.rs | 14 +--- src/api/admin/status.rs | 21 +----- src/api/identity/account.rs | 11 ++- src/auth/extractor.rs | 38 ++++++++++ src/auth/mod.rs | 71 +++++++++---------- tests/admin_email.rs | 10 +-- tests/admin_invite.rs | 16 ++--- tests/admin_moderation.rs | 45 ++++++------ tests/admin_stats.rs | 6 +- tests/common/mod.rs | 22 ++++-- 24 files changed, 257 insertions(+), 288 deletions(-) delete mode 100644 .sqlx/query-1add22e111d5eff8beadbd832b4b8146d95da0a0ce8ce31dc9a2f930a26cc9ce.json rename .sqlx/{query-04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a.json => query-225c3844ce6962121e5cc0aa544c79d0f93bb3458487d79b64bd40ae9accd522.json} (57%) rename .sqlx/{query-bee4276cbb537512cced16f7017d8f7c068d30f319ef965fa9ec9fb1a3490151.json => query-49cd5f335121f5eb4f578f6ca3af40e95264ded8021cfc7490b578a96fb8db3c.json} (65%) delete mode 100644 .sqlx/query-90bcc8fb97f73a0b5f427971aca891936b3f906c2d4cdb4bf203dd6a4c9aa060.json rename .sqlx/{query-6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761.json => query-cc68023c320bc4376925c2cd921cd48045a47ca5841eef8c8889894f2c2452f6.json} (66%) create mode 100644 .sqlx/query-e6077393f797f94d6048f01edd45b27a89ea481427753a860215d6ee85f8dcf8.json create mode 100644 .sqlx/query-fd64104d130b93dd5fc9414b8710ad5183b647eaaff90decbce15e10d83c7538.json create mode 100644 migrations/20251218_add_is_admin.sql diff --git a/.sqlx/query-1add22e111d5eff8beadbd832b4b8146d95da0a0ce8ce31dc9a2f930a26cc9ce.json b/.sqlx/query-1add22e111d5eff8beadbd832b4b8146d95da0a0ce8ce31dc9a2f930a26cc9ce.json deleted file mode 100644 index 71fd54b..0000000 --- a/.sqlx/query-1add22e111d5eff8beadbd832b4b8146d95da0a0ce8ce31dc9a2f930a26cc9ce.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT takedown_ref FROM users WHERE did = $1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "takedown_ref", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Text" - ] - }, - "nullable": [ - true - ] - }, - "hash": "1add22e111d5eff8beadbd832b4b8146d95da0a0ce8ce31dc9a2f930a26cc9ce" -} diff --git a/.sqlx/query-04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a.json b/.sqlx/query-225c3844ce6962121e5cc0aa544c79d0f93bb3458487d79b64bd40ae9accd522.json similarity index 57% rename from .sqlx/query-04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a.json rename to .sqlx/query-225c3844ce6962121e5cc0aa544c79d0f93bb3458487d79b64bd40ae9accd522.json index a7fb9e8..f2fdaaf 100644 --- a/.sqlx/query-04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a.json +++ b/.sqlx/query-225c3844ce6962121e5cc0aa544c79d0f93bb3458487d79b64bd40ae9accd522.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", + "query": "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", "describe": { "columns": [ { @@ -12,6 +12,11 @@ "ordinal": 1, "name": "takedown_ref", "type_info": "Text" + }, + { + "ordinal": 2, + "name": "is_admin", + "type_info": "Bool" } ], "parameters": { @@ -21,8 +26,9 @@ }, "nullable": [ true, - true + true, + false ] }, - "hash": "04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a" + "hash": "225c3844ce6962121e5cc0aa544c79d0f93bb3458487d79b64bd40ae9accd522" } diff --git a/.sqlx/query-bee4276cbb537512cced16f7017d8f7c068d30f319ef965fa9ec9fb1a3490151.json b/.sqlx/query-49cd5f335121f5eb4f578f6ca3af40e95264ded8021cfc7490b578a96fb8db3c.json similarity index 65% rename from .sqlx/query-bee4276cbb537512cced16f7017d8f7c068d30f319ef965fa9ec9fb1a3490151.json rename to .sqlx/query-49cd5f335121f5eb4f578f6ca3af40e95264ded8021cfc7490b578a96fb8db3c.json index fd22e14..df81408 100644 --- a/.sqlx/query-bee4276cbb537512cced16f7017d8f7c068d30f319ef965fa9ec9fb1a3490151.json +++ b/.sqlx/query-49cd5f335121f5eb4f578f6ca3af40e95264ded8021cfc7490b578a96fb8db3c.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref,\n k.key_bytes as \"key_bytes?\", k.encryption_version as \"encryption_version?\"\n FROM oauth_token t\n JOIN users u ON t.did = u.did\n LEFT JOIN user_keys k ON u.id = k.user_id\n WHERE t.token_id = $1", + "query": "SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,\n k.key_bytes as \"key_bytes?\", k.encryption_version as \"encryption_version?\"\n FROM oauth_token t\n JOIN users u ON t.did = u.did\n LEFT JOIN user_keys k ON u.id = k.user_id\n WHERE t.token_id = $1", "describe": { "columns": [ { @@ -25,11 +25,16 @@ }, { "ordinal": 4, + "name": "is_admin", + "type_info": "Bool" + }, + { + "ordinal": 5, "name": "key_bytes?", "type_info": "Bytea" }, { - "ordinal": 5, + "ordinal": 6, "name": "encryption_version?", "type_info": "Int4" } @@ -45,8 +50,9 @@ true, true, false, + false, true ] }, - "hash": "bee4276cbb537512cced16f7017d8f7c068d30f319ef965fa9ec9fb1a3490151" + "hash": "49cd5f335121f5eb4f578f6ca3af40e95264ded8021cfc7490b578a96fb8db3c" } diff --git a/.sqlx/query-90bcc8fb97f73a0b5f427971aca891936b3f906c2d4cdb4bf203dd6a4c9aa060.json b/.sqlx/query-90bcc8fb97f73a0b5f427971aca891936b3f906c2d4cdb4bf203dd6a4c9aa060.json deleted file mode 100644 index 93b0ff8..0000000 --- a/.sqlx/query-90bcc8fb97f73a0b5f427971aca891936b3f906c2d4cdb4bf203dd6a4c9aa060.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "key_bytes", - "type_info": "Bytea" - }, - { - "ordinal": 1, - "name": "encryption_version", - "type_info": "Int4" - } - ], - "parameters": { - "Left": [ - "Text" - ] - }, - "nullable": [ - false, - true - ] - }, - "hash": "90bcc8fb97f73a0b5f427971aca891936b3f906c2d4cdb4bf203dd6a4c9aa060" -} diff --git a/.sqlx/query-6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761.json b/.sqlx/query-cc68023c320bc4376925c2cd921cd48045a47ca5841eef8c8889894f2c2452f6.json similarity index 66% rename from .sqlx/query-6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761.json rename to .sqlx/query-cc68023c320bc4376925c2cd921cd48045a47ca5841eef8c8889894f2c2452f6.json index 344cdc0..e171db5 100644 --- a/.sqlx/query-6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761.json +++ b/.sqlx/query-cc68023c320bc4376925c2cd921cd48045a47ca5841eef8c8889894f2c2452f6.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1", + "query": "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1", "describe": { "columns": [ { @@ -22,6 +22,11 @@ "ordinal": 3, "name": "takedown_ref", "type_info": "Text" + }, + { + "ordinal": 4, + "name": "is_admin", + "type_info": "Bool" } ], "parameters": { @@ -33,8 +38,9 @@ false, true, true, - true + true, + false ] }, - "hash": "6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761" + "hash": "cc68023c320bc4376925c2cd921cd48045a47ca5841eef8c8889894f2c2452f6" } diff --git a/.sqlx/query-e6077393f797f94d6048f01edd45b27a89ea481427753a860215d6ee85f8dcf8.json b/.sqlx/query-e6077393f797f94d6048f01edd45b27a89ea481427753a860215d6ee85f8dcf8.json new file mode 100644 index 0000000..a01e17b --- /dev/null +++ b/.sqlx/query-e6077393f797f94d6048f01edd45b27a89ea481427753a860215d6ee85f8dcf8.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT u.deactivated_at, u.takedown_ref, u.is_admin,\n k.key_bytes as \"key_bytes?\", k.encryption_version as \"encryption_version?\"\n FROM users u\n LEFT JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "deactivated_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 1, + "name": "takedown_ref", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "is_admin", + "type_info": "Bool" + }, + { + "ordinal": 3, + "name": "key_bytes?", + "type_info": "Bytea" + }, + { + "ordinal": 4, + "name": "encryption_version?", + "type_info": "Int4" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + true, + true, + false, + false, + true + ] + }, + "hash": "e6077393f797f94d6048f01edd45b27a89ea481427753a860215d6ee85f8dcf8" +} diff --git a/.sqlx/query-fd64104d130b93dd5fc9414b8710ad5183b647eaaff90decbce15e10d83c7538.json b/.sqlx/query-fd64104d130b93dd5fc9414b8710ad5183b647eaaff90decbce15e10d83c7538.json new file mode 100644 index 0000000..f193747 --- /dev/null +++ b/.sqlx/query-fd64104d130b93dd5fc9414b8710ad5183b647eaaff90decbce15e10d83c7538.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT COUNT(*) as count FROM users", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "count", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "fd64104d130b93dd5fc9414b8710ad5183b647eaaff90decbce15e10d83c7538" +} diff --git a/migrations/20251218_add_is_admin.sql b/migrations/20251218_add_is_admin.sql new file mode 100644 index 0000000..e3cf1ec --- /dev/null +++ b/migrations/20251218_add_is_admin.sql @@ -0,0 +1 @@ +ALTER TABLE users ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/src/api/admin/account/delete.rs b/src/api/admin/account/delete.rs index eab59f5..43904d5 100644 --- a/src/api/admin/account/delete.rs +++ b/src/api/admin/account/delete.rs @@ -1,3 +1,4 @@ +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, @@ -16,17 +17,9 @@ pub struct DeleteAccountInput { pub async fn delete_account( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let did = input.did.trim(); if did.is_empty() { return ( diff --git a/src/api/admin/account/email.rs b/src/api/admin/account/email.rs index e2a52f5..d61066d 100644 --- a/src/api/admin/account/email.rs +++ b/src/api/admin/account/email.rs @@ -1,3 +1,4 @@ +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, @@ -26,17 +27,9 @@ pub struct SendEmailOutput { pub async fn send_email( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let recipient_did = input.recipient_did.trim(); let content = input.content.trim(); if recipient_did.is_empty() { diff --git a/src/api/admin/account/info.rs b/src/api/admin/account/info.rs index b0c4541..50db839 100644 --- a/src/api/admin/account/info.rs +++ b/src/api/admin/account/info.rs @@ -1,3 +1,4 @@ +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, @@ -35,17 +36,9 @@ pub struct GetAccountInfosOutput { pub async fn get_account_info( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Query(params): Query, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let did = params.did.trim(); if did.is_empty() { return ( @@ -102,17 +95,9 @@ pub struct GetAccountInfosParams { pub async fn get_account_infos( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Query(params): Query, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect(); if dids.is_empty() { return ( diff --git a/src/api/admin/account/profile.rs b/src/api/admin/account/profile.rs index 7f94311..e9bd784 100644 --- a/src/api/admin/account/profile.rs +++ b/src/api/admin/account/profile.rs @@ -1,4 +1,5 @@ use crate::api::repo::record::create_record_internal; +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, @@ -36,18 +37,9 @@ pub struct CreateProfileOutput { pub async fn create_profile( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } - let did = input.did.trim(); if did.is_empty() { return ( @@ -101,18 +93,9 @@ pub async fn create_profile( pub async fn create_record_admin( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } - let did = input.did.trim(); if did.is_empty() { return ( diff --git a/src/api/admin/account/update.rs b/src/api/admin/account/update.rs index 3afb0af..8882632 100644 --- a/src/api/admin/account/update.rs +++ b/src/api/admin/account/update.rs @@ -1,3 +1,4 @@ +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, @@ -17,17 +18,9 @@ pub struct UpdateAccountEmailInput { pub async fn update_account_email( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let account = input.account.trim(); let email = input.email.trim(); if account.is_empty() || email.is_empty() { @@ -70,17 +63,9 @@ pub struct UpdateAccountHandleInput { pub async fn update_account_handle( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let did = input.did.trim(); let handle = input.handle.trim(); if did.is_empty() || handle.is_empty() { @@ -158,17 +143,9 @@ pub struct UpdateAccountPasswordInput { pub async fn update_account_password( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let did = input.did.trim(); let password = input.password.trim(); if did.is_empty() || password.is_empty() { diff --git a/src/api/admin/invite.rs b/src/api/admin/invite.rs index 15f4381..52bb3a7 100644 --- a/src/api/admin/invite.rs +++ b/src/api/admin/invite.rs @@ -1,3 +1,4 @@ +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, @@ -18,17 +19,9 @@ pub struct DisableInviteCodesInput { pub async fn disable_invite_codes( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } if let Some(codes) = &input.codes { for code in codes { let _ = sqlx::query!( @@ -91,17 +84,9 @@ pub struct GetInviteCodesOutput { pub async fn get_invite_codes( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Query(params): Query, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let limit = params.limit.unwrap_or(100).clamp(1, 500); let sort = params.sort.as_deref().unwrap_or("recent"); let order_clause = match sort { @@ -229,17 +214,9 @@ pub struct DisableAccountInvitesInput { pub async fn disable_account_invites( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let account = input.account.trim(); if account.is_empty() { return ( @@ -283,17 +260,9 @@ pub struct EnableAccountInvitesInput { pub async fn enable_account_invites( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let account = input.account.trim(); if account.is_empty() { return ( diff --git a/src/api/admin/server_stats.rs b/src/api/admin/server_stats.rs index 6a24f31..c9da10b 100644 --- a/src/api/admin/server_stats.rs +++ b/src/api/admin/server_stats.rs @@ -1,12 +1,11 @@ +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, extract::State, - http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; use serde::Serialize; -use serde_json::json; #[derive(Serialize)] #[serde(rename_all = "camelCase")] @@ -19,17 +18,8 @@ pub struct ServerStatsResponse { pub async fn get_server_stats( State(state): State, - headers: HeaderMap, + _auth: BearerAuthAdmin, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } - let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users") .fetch_one(&state.db) .await diff --git a/src/api/admin/status.rs b/src/api/admin/status.rs index 5b2a58c..6b1ded0 100644 --- a/src/api/admin/status.rs +++ b/src/api/admin/status.rs @@ -1,3 +1,4 @@ +use crate::auth::BearerAuthAdmin; use crate::state::AppState; use axum::{ Json, @@ -32,17 +33,9 @@ pub struct StatusAttr { pub async fn get_subject_status( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Query(params): Query, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } if params.did.is_none() && params.uri.is_none() && params.blob.is_none() { return ( StatusCode::BAD_REQUEST, @@ -208,17 +201,9 @@ pub struct StatusAttrInput { pub async fn update_subject_status( State(state): State, - headers: axum::http::HeaderMap, + _auth: BearerAuthAdmin, Json(input): Json, ) -> Response { - let auth_header = headers.get("Authorization"); - if auth_header.is_none() { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"error": "AuthenticationRequired"})), - ) - .into_response(); - } let subject_type = input.subject.get("$type").and_then(|t| t.as_str()); match subject_type { Some("com.atproto.admin.defs#repoRef") => { diff --git a/src/api/identity/account.rs b/src/api/identity/account.rs index b9513e1..e2eb37b 100644 --- a/src/api/identity/account.rs +++ b/src/api/identity/account.rs @@ -379,12 +379,18 @@ pub async fn create_account( }; let verification_code = format!("{:06}", rand::random::() % 1_000_000); let code_expires_at = chrono::Utc::now() + chrono::Duration::minutes(30); + let is_first_user = sqlx::query_scalar!("SELECT COUNT(*) as count FROM users") + .fetch_one(&mut *tx) + .await + .map(|c| c.unwrap_or(0) == 0) + .unwrap_or(false); let user_insert: Result<(uuid::Uuid,), _> = sqlx::query_as( r#"INSERT INTO users ( handle, email, did, password_hash, preferred_notification_channel, - discord_id, telegram_username, signal_number - ) VALUES ($1, $2, $3, $4, $5::notification_channel, $6, $7, $8) RETURNING id"#, + discord_id, telegram_username, signal_number, + is_admin + ) VALUES ($1, $2, $3, $4, $5::notification_channel, $6, $7, $8, $9) RETURNING id"#, ) .bind(short_handle) .bind(&email) @@ -412,6 +418,7 @@ pub async fn create_account( .map(|s| s.trim()) .filter(|s| !s.is_empty()), ) + .bind(is_first_user) .fetch_one(&mut *tx) .await; let user_id = match user_insert { diff --git a/src/auth/extractor.rs b/src/auth/extractor.rs index 54c26e0..52a2b2e 100644 --- a/src/auth/extractor.rs +++ b/src/auth/extractor.rs @@ -21,6 +21,7 @@ pub enum AuthError { AuthenticationFailed, AccountDeactivated, AccountTakedown, + AdminRequired, } impl IntoResponse for AuthError { @@ -51,6 +52,11 @@ impl IntoResponse for AuthError { "AccountTakedown", "Account has been taken down", ), + AuthError::AdminRequired => ( + StatusCode::FORBIDDEN, + "AdminRequired", + "This action requires admin privileges", + ), }; (status, Json(json!({ "error": error, "message": message }))).into_response() @@ -182,6 +188,38 @@ impl FromRequestParts for BearerAuthAllowDeactivated { } } +pub struct BearerAuthAdmin(pub AuthenticatedUser); + +impl FromRequestParts for BearerAuthAdmin { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let auth_header = parts + .headers + .get(AUTHORIZATION) + .ok_or(AuthError::MissingToken)? + .to_str() + .map_err(|_| AuthError::InvalidFormat)?; + + let token = extract_bearer_token(auth_header)?; + + match validate_bearer_token_cached(&state.db, &state.cache, token).await { + Ok(user) => { + if !user.is_admin { + return Err(AuthError::AdminRequired); + } + Ok(BearerAuthAdmin(user)) + } + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), + Err(_) => Err(AuthError::AuthenticationFailed), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 287759c..8a555b5 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -11,7 +11,7 @@ pub mod token; pub mod verify; pub use extractor::{ - AuthError, BearerAuth, BearerAuthAllowDeactivated, ExtractedToken, + AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, extract_auth_token_from_header, extract_bearer_token_from_header, }; pub use token::{ @@ -50,6 +50,7 @@ pub struct AuthenticatedUser { pub did: String, pub key_bytes: Option>, pub is_oauth: bool, + pub is_admin: bool, } pub async fn validate_bearer_token( @@ -103,9 +104,9 @@ async fn validate_bearer_token_with_options_internal( } } - let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key { + let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key { let user_status = sqlx::query!( - "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", + "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", did ) .fetch_optional(db) @@ -114,11 +115,11 @@ async fn validate_bearer_token_with_options_internal( .flatten(); match user_status { - Some(status) => (Some(key), status.deactivated_at, status.takedown_ref), - None => (None, None, None), + Some(status) => (Some(key), status.deactivated_at, status.takedown_ref, status.is_admin), + None => (None, None, None, false), } } else if let Some(user) = sqlx::query!( - "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref + "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", @@ -142,9 +143,9 @@ async fn validate_bearer_token_with_options_internal( .await; } - (Some(key), user.deactivated_at, user.takedown_ref) + (Some(key), user.deactivated_at, user.takedown_ref, user.is_admin) } else { - (None, None, None) + (None, None, None, false) }; if let Some(decrypted_key) = decrypted_key { @@ -200,6 +201,7 @@ async fn validate_bearer_token_with_options_internal( did: did.clone(), key_bytes: Some(decrypted_key), is_oauth: false, + is_admin, }); } } @@ -208,7 +210,7 @@ async fn validate_bearer_token_with_options_internal( if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) && let Some(oauth_token) = sqlx::query!( - r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, + r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin, k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" FROM oauth_token t JOIN users u ON t.did = u.did @@ -242,6 +244,7 @@ async fn validate_bearer_token_with_options_internal( did: oauth_token.did, key_bytes, is_oauth: true, + is_admin: oauth_token.is_admin, }); } } @@ -280,43 +283,37 @@ pub async fn validate_token_with_dpop( .await { Ok(result) => { - if !allow_deactivated { - let deactivated = sqlx::query_scalar!( - "SELECT deactivated_at FROM users WHERE did = $1", - result.did - ) - .fetch_optional(db) - .await - .ok() - .flatten() - .flatten(); - if deactivated.is_some() { - return Err(TokenValidationError::AccountDeactivated); - } - } - let takedown = - sqlx::query_scalar!("SELECT takedown_ref FROM users WHERE did = $1", result.did) - .fetch_optional(db) - .await - .ok() - .flatten() - .flatten(); - if takedown.is_some() { - return Err(TokenValidationError::AccountTakedown); - } - let key_bytes = sqlx::query!( - "SELECT k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", + let user_info = sqlx::query!( + r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin, + k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" + FROM users u + LEFT JOIN user_keys k ON u.id = k.user_id + WHERE u.did = $1"#, result.did ) .fetch_optional(db) .await .ok() - .flatten() - .and_then(|row| crate::config::decrypt_key(&row.key_bytes, row.encryption_version).ok()); + .flatten(); + let Some(user_info) = user_info else { + return Err(TokenValidationError::AuthenticationFailed); + }; + if !allow_deactivated && user_info.deactivated_at.is_some() { + return Err(TokenValidationError::AccountDeactivated); + } + if user_info.takedown_ref.is_some() { + return Err(TokenValidationError::AccountTakedown); + } + let key_bytes = if let (Some(kb), Some(ev)) = (&user_info.key_bytes, user_info.encryption_version) { + crate::config::decrypt_key(kb, Some(ev)).ok() + } else { + None + }; Ok(AuthenticatedUser { did: result.did, key_bytes, is_oauth: true, + is_admin: user_info.is_admin, }) } Err(_) => Err(TokenValidationError::AuthenticationFailed), diff --git a/tests/admin_email.rs b/tests/admin_email.rs index bd10379..dfa9e05 100644 --- a/tests/admin_email.rs +++ b/tests/admin_email.rs @@ -18,7 +18,7 @@ async fn test_send_email_success() { let client = common::client(); let base_url = common::base_url().await; let pool = get_pool().await; - let (access_jwt, did) = common::create_account_and_login(&client).await; + let (access_jwt, did) = common::create_admin_account_and_login(&client).await; let res = client .post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url)) .bearer_auth(&access_jwt) @@ -58,7 +58,7 @@ async fn test_send_email_default_subject() { let client = common::client(); let base_url = common::base_url().await; let pool = get_pool().await; - let (access_jwt, did) = common::create_account_and_login(&client).await; + let (access_jwt, did) = common::create_admin_account_and_login(&client).await; let res = client .post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url)) .bearer_auth(&access_jwt) @@ -92,7 +92,7 @@ async fn test_send_email_default_subject() { async fn test_send_email_recipient_not_found() { let client = common::client(); let base_url = common::base_url().await; - let (access_jwt, _) = common::create_account_and_login(&client).await; + let (access_jwt, _) = common::create_admin_account_and_login(&client).await; let res = client .post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url)) .bearer_auth(&access_jwt) @@ -113,7 +113,7 @@ async fn test_send_email_recipient_not_found() { async fn test_send_email_missing_content() { let client = common::client(); let base_url = common::base_url().await; - let (access_jwt, did) = common::create_account_and_login(&client).await; + let (access_jwt, did) = common::create_admin_account_and_login(&client).await; let res = client .post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url)) .bearer_auth(&access_jwt) @@ -134,7 +134,7 @@ async fn test_send_email_missing_content() { async fn test_send_email_missing_recipient() { let client = common::client(); let base_url = common::base_url().await; - let (access_jwt, _) = common::create_account_and_login(&client).await; + let (access_jwt, _) = common::create_admin_account_and_login(&client).await; let res = client .post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url)) .bearer_auth(&access_jwt) diff --git a/tests/admin_invite.rs b/tests/admin_invite.rs index d1a22ca..52d6f43 100644 --- a/tests/admin_invite.rs +++ b/tests/admin_invite.rs @@ -7,7 +7,7 @@ use serde_json::{Value, json}; #[tokio::test] async fn test_admin_get_invite_codes_success() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; let create_payload = json!({ "useCount": 3 }); @@ -38,7 +38,7 @@ async fn test_admin_get_invite_codes_success() { #[tokio::test] async fn test_admin_get_invite_codes_with_limit() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; for _ in 0..5 { let create_payload = json!({ "useCount": 1 @@ -86,7 +86,7 @@ async fn test_admin_get_invite_codes_no_auth() { #[tokio::test] async fn test_disable_account_invites_success() { let client = client(); - let (access_jwt, did) = create_account_and_login(&client).await; + let (access_jwt, did) = create_admin_account_and_login(&client).await; let payload = json!({ "account": did }); @@ -122,7 +122,7 @@ async fn test_disable_account_invites_success() { #[tokio::test] async fn test_enable_account_invites_success() { let client = client(); - let (access_jwt, did) = create_account_and_login(&client).await; + let (access_jwt, did) = create_admin_account_and_login(&client).await; let disable_payload = json!({ "account": did }); @@ -186,7 +186,7 @@ async fn test_disable_account_invites_no_auth() { #[tokio::test] async fn test_disable_account_invites_not_found() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; let payload = json!({ "account": "did:plc:nonexistent" }); @@ -206,7 +206,7 @@ async fn test_disable_account_invites_not_found() { #[tokio::test] async fn test_disable_invite_codes_by_code() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; let create_payload = json!({ "useCount": 5 }); @@ -255,7 +255,7 @@ async fn test_disable_invite_codes_by_code() { #[tokio::test] async fn test_disable_invite_codes_by_account() { let client = client(); - let (access_jwt, did) = create_account_and_login(&client).await; + let (access_jwt, did) = create_admin_account_and_login(&client).await; for _ in 0..3 { let create_payload = json!({ "useCount": 1 @@ -321,7 +321,7 @@ async fn test_disable_invite_codes_no_auth() { #[tokio::test] async fn test_admin_enable_account_invites_not_found() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; let payload = json!({ "account": "did:plc:nonexistent" }); diff --git a/tests/admin_moderation.rs b/tests/admin_moderation.rs index c92e5f8..9d73379 100644 --- a/tests/admin_moderation.rs +++ b/tests/admin_moderation.rs @@ -7,7 +7,7 @@ use serde_json::{Value, json}; #[tokio::test] async fn test_get_subject_status_user_success() { let client = client(); - let (access_jwt, did) = create_account_and_login(&client).await; + let (access_jwt, did) = create_admin_account_and_login(&client).await; let res = client .get(format!( "{}/xrpc/com.atproto.admin.getSubjectStatus", @@ -28,7 +28,7 @@ async fn test_get_subject_status_user_success() { #[tokio::test] async fn test_get_subject_status_not_found() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; let res = client .get(format!( "{}/xrpc/com.atproto.admin.getSubjectStatus", @@ -47,7 +47,7 @@ async fn test_get_subject_status_not_found() { #[tokio::test] async fn test_get_subject_status_no_param() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; let res = client .get(format!( "{}/xrpc/com.atproto.admin.getSubjectStatus", @@ -80,11 +80,12 @@ async fn test_get_subject_status_no_auth() { #[tokio::test] async fn test_update_subject_status_takedown_user() { let client = client(); - let (access_jwt, did) = create_account_and_login(&client).await; + let (admin_jwt, _) = create_admin_account_and_login(&client).await; + let (_, target_did) = create_account_and_login(&client).await; let payload = json!({ "subject": { "$type": "com.atproto.admin.defs#repoRef", - "did": did + "did": target_did }, "takedown": { "apply": true, @@ -96,7 +97,7 @@ async fn test_update_subject_status_takedown_user() { "{}/xrpc/com.atproto.admin.updateSubjectStatus", base_url().await )) - .bearer_auth(&access_jwt) + .bearer_auth(&admin_jwt) .json(&payload) .send() .await @@ -111,8 +112,8 @@ async fn test_update_subject_status_takedown_user() { "{}/xrpc/com.atproto.admin.getSubjectStatus", base_url().await )) - .bearer_auth(&access_jwt) - .query(&[("did", did.as_str())]) + .bearer_auth(&admin_jwt) + .query(&[("did", target_did.as_str())]) .send() .await .expect("Failed to send request"); @@ -125,11 +126,12 @@ async fn test_update_subject_status_takedown_user() { #[tokio::test] async fn test_update_subject_status_remove_takedown() { let client = client(); - let (access_jwt, did) = create_account_and_login(&client).await; + let (admin_jwt, _) = create_admin_account_and_login(&client).await; + let (_, target_did) = create_account_and_login(&client).await; let takedown_payload = json!({ "subject": { "$type": "com.atproto.admin.defs#repoRef", - "did": did + "did": target_did }, "takedown": { "apply": true, @@ -141,14 +143,14 @@ async fn test_update_subject_status_remove_takedown() { "{}/xrpc/com.atproto.admin.updateSubjectStatus", base_url().await )) - .bearer_auth(&access_jwt) + .bearer_auth(&admin_jwt) .json(&takedown_payload) .send() .await; let remove_payload = json!({ "subject": { "$type": "com.atproto.admin.defs#repoRef", - "did": did + "did": target_did }, "takedown": { "apply": false @@ -159,7 +161,7 @@ async fn test_update_subject_status_remove_takedown() { "{}/xrpc/com.atproto.admin.updateSubjectStatus", base_url().await )) - .bearer_auth(&access_jwt) + .bearer_auth(&admin_jwt) .json(&remove_payload) .send() .await @@ -170,8 +172,8 @@ async fn test_update_subject_status_remove_takedown() { "{}/xrpc/com.atproto.admin.getSubjectStatus", base_url().await )) - .bearer_auth(&access_jwt) - .query(&[("did", did.as_str())]) + .bearer_auth(&admin_jwt) + .query(&[("did", target_did.as_str())]) .send() .await .expect("Failed to send request"); @@ -187,11 +189,12 @@ async fn test_update_subject_status_remove_takedown() { #[tokio::test] async fn test_update_subject_status_deactivate_user() { let client = client(); - let (access_jwt, did) = create_account_and_login(&client).await; + let (admin_jwt, _) = create_admin_account_and_login(&client).await; + let (_, target_did) = create_account_and_login(&client).await; let payload = json!({ "subject": { "$type": "com.atproto.admin.defs#repoRef", - "did": did + "did": target_did }, "deactivated": { "apply": true @@ -202,7 +205,7 @@ async fn test_update_subject_status_deactivate_user() { "{}/xrpc/com.atproto.admin.updateSubjectStatus", base_url().await )) - .bearer_auth(&access_jwt) + .bearer_auth(&admin_jwt) .json(&payload) .send() .await @@ -213,8 +216,8 @@ async fn test_update_subject_status_deactivate_user() { "{}/xrpc/com.atproto.admin.getSubjectStatus", base_url().await )) - .bearer_auth(&access_jwt) - .query(&[("did", did.as_str())]) + .bearer_auth(&admin_jwt) + .query(&[("did", target_did.as_str())]) .send() .await .expect("Failed to send request"); @@ -226,7 +229,7 @@ async fn test_update_subject_status_deactivate_user() { #[tokio::test] async fn test_update_subject_status_invalid_type() { let client = client(); - let (access_jwt, _did) = create_account_and_login(&client).await; + let (access_jwt, _did) = create_admin_account_and_login(&client).await; let payload = json!({ "subject": { "$type": "invalid.type", diff --git a/tests/admin_stats.rs b/tests/admin_stats.rs index 4aaa826..8b56637 100644 --- a/tests/admin_stats.rs +++ b/tests/admin_stats.rs @@ -1,14 +1,14 @@ mod common; -use common::{base_url, client, create_account_and_login}; +use common::{base_url, client, create_admin_account_and_login}; use serde_json::Value; #[tokio::test] async fn test_get_server_stats() { let client = client(); let base = base_url().await; - let (token1, _) = create_account_and_login(&client).await; + let (token1, _) = create_admin_account_and_login(&client).await; - let (_, _) = create_account_and_login(&client).await; + let (_, _) = create_admin_account_and_login(&client).await; let resp = client .get(format!("{}/xrpc/com.bspds.admin.getServerStats", base)) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 3367858..8170491 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -511,6 +511,15 @@ pub async fn create_test_post( #[allow(dead_code)] pub async fn create_account_and_login(client: &Client) -> (String, String) { + create_account_and_login_internal(client, false).await +} + +#[allow(dead_code)] +pub async fn create_admin_account_and_login(client: &Client) -> (String, String) { + create_account_and_login_internal(client, true).await +} + +async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) { let mut last_error = String::new(); for attempt in 0..3 { if attempt > 0 { @@ -539,10 +548,6 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) { }; if res.status() == StatusCode::OK { let body: Value = res.json().await.expect("Invalid JSON"); - if let Some(access_jwt) = body["accessJwt"].as_str() { - let did = body["did"].as_str().expect("No did").to_string(); - return (access_jwt.to_string(), did); - } let did = body["did"].as_str().expect("No did").to_string(); let conn_str = get_db_connection_string().await; let pool = sqlx::postgres::PgPoolOptions::new() @@ -550,6 +555,15 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) { .connect(&conn_str) .await .expect("Failed to connect to test database"); + if make_admin { + sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did) + .execute(&pool) + .await + .expect("Failed to mark user as admin"); + } + if let Some(access_jwt) = body["accessJwt"].as_str() { + return (access_jwt.to_string(), did); + } let verification_code: String = sqlx::query_scalar!( "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'", &did