Add admin functionality

This commit is contained in:
lewis
2025-12-16 18:28:20 +02:00
parent e2bfcdb74f
commit dea6c09aa0
24 changed files with 257 additions and 288 deletions

View File

@@ -1,22 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT takedown_ref FROM users WHERE did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "takedown_ref",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
true
]
},
"hash": "1add22e111d5eff8beadbd832b4b8146d95da0a0ce8ce31dc9a2f930a26cc9ce"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
"query": "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
"describe": {
"columns": [
{
@@ -12,6 +12,11 @@
"ordinal": 1,
"name": "takedown_ref",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "is_admin",
"type_info": "Bool"
}
],
"parameters": {
@@ -21,8 +26,9 @@
},
"nullable": [
true,
true
true,
false
]
},
"hash": "04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a"
"hash": "225c3844ce6962121e5cc0aa544c79d0f93bb3458487d79b64bd40ae9accd522"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref,\n k.key_bytes as \"key_bytes?\", k.encryption_version as \"encryption_version?\"\n FROM oauth_token t\n JOIN users u ON t.did = u.did\n LEFT JOIN user_keys k ON u.id = k.user_id\n WHERE t.token_id = $1",
"query": "SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,\n k.key_bytes as \"key_bytes?\", k.encryption_version as \"encryption_version?\"\n FROM oauth_token t\n JOIN users u ON t.did = u.did\n LEFT JOIN user_keys k ON u.id = k.user_id\n WHERE t.token_id = $1",
"describe": {
"columns": [
{
@@ -25,11 +25,16 @@
},
{
"ordinal": 4,
"name": "is_admin",
"type_info": "Bool"
},
{
"ordinal": 5,
"name": "key_bytes?",
"type_info": "Bytea"
},
{
"ordinal": 5,
"ordinal": 6,
"name": "encryption_version?",
"type_info": "Int4"
}
@@ -45,8 +50,9 @@
true,
true,
false,
false,
true
]
},
"hash": "bee4276cbb537512cced16f7017d8f7c068d30f319ef965fa9ec9fb1a3490151"
"hash": "49cd5f335121f5eb4f578f6ca3af40e95264ded8021cfc7490b578a96fb8db3c"
}

View File

@@ -1,28 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "SELECT k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "key_bytes",
"type_info": "Bytea"
},
{
"ordinal": 1,
"name": "encryption_version",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true
]
},
"hash": "90bcc8fb97f73a0b5f427971aca891936b3f906c2d4cdb4bf203dd6a4c9aa060"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1",
"query": "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1",
"describe": {
"columns": [
{
@@ -22,6 +22,11 @@
"ordinal": 3,
"name": "takedown_ref",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "is_admin",
"type_info": "Bool"
}
],
"parameters": {
@@ -33,8 +38,9 @@
false,
true,
true,
true
true,
false
]
},
"hash": "6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761"
"hash": "cc68023c320bc4376925c2cd921cd48045a47ca5841eef8c8889894f2c2452f6"
}

View File

@@ -0,0 +1,46 @@
{
"db_name": "PostgreSQL",
"query": "SELECT u.deactivated_at, u.takedown_ref, u.is_admin,\n k.key_bytes as \"key_bytes?\", k.encryption_version as \"encryption_version?\"\n FROM users u\n LEFT JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 1,
"name": "takedown_ref",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "is_admin",
"type_info": "Bool"
},
{
"ordinal": 3,
"name": "key_bytes?",
"type_info": "Bytea"
},
{
"ordinal": 4,
"name": "encryption_version?",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
true,
true,
false,
false,
true
]
},
"hash": "e6077393f797f94d6048f01edd45b27a89ea481427753a860215d6ee85f8dcf8"
}

View File

@@ -0,0 +1,20 @@
{
"db_name": "PostgreSQL",
"query": "SELECT COUNT(*) as count FROM users",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "count",
"type_info": "Int8"
}
],
"parameters": {
"Left": []
},
"nullable": [
null
]
},
"hash": "fd64104d130b93dd5fc9414b8710ad5183b647eaaff90decbce15e10d83c7538"
}

View File

@@ -0,0 +1 @@
ALTER TABLE users ADD COLUMN is_admin BOOLEAN NOT NULL DEFAULT FALSE;

View File

@@ -1,3 +1,4 @@
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
@@ -16,17 +17,9 @@ pub struct DeleteAccountInput {
pub async fn delete_account(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<DeleteAccountInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let did = input.did.trim();
if did.is_empty() {
return (

View File

@@ -1,3 +1,4 @@
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
@@ -26,17 +27,9 @@ pub struct SendEmailOutput {
pub async fn send_email(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<SendEmailInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let recipient_did = input.recipient_did.trim();
let content = input.content.trim();
if recipient_did.is_empty() {

View File

@@ -1,3 +1,4 @@
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
@@ -35,17 +36,9 @@ pub struct GetAccountInfosOutput {
pub async fn get_account_info(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Query(params): Query<GetAccountInfoParams>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let did = params.did.trim();
if did.is_empty() {
return (
@@ -102,17 +95,9 @@ pub struct GetAccountInfosParams {
pub async fn get_account_infos(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Query(params): Query<GetAccountInfosParams>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect();
if dids.is_empty() {
return (

View File

@@ -1,4 +1,5 @@
use crate::api::repo::record::create_record_internal;
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
@@ -36,18 +37,9 @@ pub struct CreateProfileOutput {
pub async fn create_profile(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<CreateProfileInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let did = input.did.trim();
if did.is_empty() {
return (
@@ -101,18 +93,9 @@ pub async fn create_profile(
pub async fn create_record_admin(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<CreateRecordAdminInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let did = input.did.trim();
if did.is_empty() {
return (

View File

@@ -1,3 +1,4 @@
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
@@ -17,17 +18,9 @@ pub struct UpdateAccountEmailInput {
pub async fn update_account_email(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<UpdateAccountEmailInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let account = input.account.trim();
let email = input.email.trim();
if account.is_empty() || email.is_empty() {
@@ -70,17 +63,9 @@ pub struct UpdateAccountHandleInput {
pub async fn update_account_handle(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<UpdateAccountHandleInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let did = input.did.trim();
let handle = input.handle.trim();
if did.is_empty() || handle.is_empty() {
@@ -158,17 +143,9 @@ pub struct UpdateAccountPasswordInput {
pub async fn update_account_password(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<UpdateAccountPasswordInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let did = input.did.trim();
let password = input.password.trim();
if did.is_empty() || password.is_empty() {

View File

@@ -1,3 +1,4 @@
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
@@ -18,17 +19,9 @@ pub struct DisableInviteCodesInput {
pub async fn disable_invite_codes(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<DisableInviteCodesInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
if let Some(codes) = &input.codes {
for code in codes {
let _ = sqlx::query!(
@@ -91,17 +84,9 @@ pub struct GetInviteCodesOutput {
pub async fn get_invite_codes(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Query(params): Query<GetInviteCodesParams>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let limit = params.limit.unwrap_or(100).clamp(1, 500);
let sort = params.sort.as_deref().unwrap_or("recent");
let order_clause = match sort {
@@ -229,17 +214,9 @@ pub struct DisableAccountInvitesInput {
pub async fn disable_account_invites(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<DisableAccountInvitesInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let account = input.account.trim();
if account.is_empty() {
return (
@@ -283,17 +260,9 @@ pub struct EnableAccountInvitesInput {
pub async fn enable_account_invites(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<EnableAccountInvitesInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let account = input.account.trim();
if account.is_empty() {
return (

View File

@@ -1,12 +1,11 @@
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use serde::Serialize;
use serde_json::json;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
@@ -19,17 +18,8 @@ pub struct ServerStatsResponse {
pub async fn get_server_stats(
State(state): State<AppState>,
headers: HeaderMap,
_auth: BearerAuthAdmin,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users")
.fetch_one(&state.db)
.await

View File

@@ -1,3 +1,4 @@
use crate::auth::BearerAuthAdmin;
use crate::state::AppState;
use axum::{
Json,
@@ -32,17 +33,9 @@ pub struct StatusAttr {
pub async fn get_subject_status(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Query(params): Query<GetSubjectStatusParams>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
if params.did.is_none() && params.uri.is_none() && params.blob.is_none() {
return (
StatusCode::BAD_REQUEST,
@@ -208,17 +201,9 @@ pub struct StatusAttrInput {
pub async fn update_subject_status(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
_auth: BearerAuthAdmin,
Json(input): Json<UpdateSubjectStatusInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let subject_type = input.subject.get("$type").and_then(|t| t.as_str());
match subject_type {
Some("com.atproto.admin.defs#repoRef") => {

View File

@@ -379,12 +379,18 @@ pub async fn create_account(
};
let verification_code = format!("{:06}", rand::random::<u32>() % 1_000_000);
let code_expires_at = chrono::Utc::now() + chrono::Duration::minutes(30);
let is_first_user = sqlx::query_scalar!("SELECT COUNT(*) as count FROM users")
.fetch_one(&mut *tx)
.await
.map(|c| c.unwrap_or(0) == 0)
.unwrap_or(false);
let user_insert: Result<(uuid::Uuid,), _> = sqlx::query_as(
r#"INSERT INTO users (
handle, email, did, password_hash,
preferred_notification_channel,
discord_id, telegram_username, signal_number
) VALUES ($1, $2, $3, $4, $5::notification_channel, $6, $7, $8) RETURNING id"#,
discord_id, telegram_username, signal_number,
is_admin
) VALUES ($1, $2, $3, $4, $5::notification_channel, $6, $7, $8, $9) RETURNING id"#,
)
.bind(short_handle)
.bind(&email)
@@ -412,6 +418,7 @@ pub async fn create_account(
.map(|s| s.trim())
.filter(|s| !s.is_empty()),
)
.bind(is_first_user)
.fetch_one(&mut *tx)
.await;
let user_id = match user_insert {

View File

@@ -21,6 +21,7 @@ pub enum AuthError {
AuthenticationFailed,
AccountDeactivated,
AccountTakedown,
AdminRequired,
}
impl IntoResponse for AuthError {
@@ -51,6 +52,11 @@ impl IntoResponse for AuthError {
"AccountTakedown",
"Account has been taken down",
),
AuthError::AdminRequired => (
StatusCode::FORBIDDEN,
"AdminRequired",
"This action requires admin privileges",
),
};
(status, Json(json!({ "error": error, "message": message }))).into_response()
@@ -182,6 +188,38 @@ impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
}
}
pub struct BearerAuthAdmin(pub AuthenticatedUser);
impl FromRequestParts<AppState> for BearerAuthAdmin {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(AUTHORIZATION)
.ok_or(AuthError::MissingToken)?
.to_str()
.map_err(|_| AuthError::InvalidFormat)?;
let token = extract_bearer_token(auth_header)?;
match validate_bearer_token_cached(&state.db, &state.cache, token).await {
Ok(user) => {
if !user.is_admin {
return Err(AuthError::AdminRequired);
}
Ok(BearerAuthAdmin(user))
}
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
Err(_) => Err(AuthError::AuthenticationFailed),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -11,7 +11,7 @@ pub mod token;
pub mod verify;
pub use extractor::{
AuthError, BearerAuth, BearerAuthAllowDeactivated, ExtractedToken,
AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
extract_auth_token_from_header, extract_bearer_token_from_header,
};
pub use token::{
@@ -50,6 +50,7 @@ pub struct AuthenticatedUser {
pub did: String,
pub key_bytes: Option<Vec<u8>>,
pub is_oauth: bool,
pub is_admin: bool,
}
pub async fn validate_bearer_token(
@@ -103,9 +104,9 @@ async fn validate_bearer_token_with_options_internal(
}
}
let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key {
let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key {
let user_status = sqlx::query!(
"SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
"SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
did
)
.fetch_optional(db)
@@ -114,11 +115,11 @@ async fn validate_bearer_token_with_options_internal(
.flatten();
match user_status {
Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
None => (None, None, None),
Some(status) => (Some(key), status.deactivated_at, status.takedown_ref, status.is_admin),
None => (None, None, None, false),
}
} else if let Some(user) = sqlx::query!(
"SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref
"SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin
FROM users u
JOIN user_keys k ON u.id = k.user_id
WHERE u.did = $1",
@@ -142,9 +143,9 @@ async fn validate_bearer_token_with_options_internal(
.await;
}
(Some(key), user.deactivated_at, user.takedown_ref)
(Some(key), user.deactivated_at, user.takedown_ref, user.is_admin)
} else {
(None, None, None)
(None, None, None, false)
};
if let Some(decrypted_key) = decrypted_key {
@@ -200,6 +201,7 @@ async fn validate_bearer_token_with_options_internal(
did: did.clone(),
key_bytes: Some(decrypted_key),
is_oauth: false,
is_admin,
});
}
}
@@ -208,7 +210,7 @@ async fn validate_bearer_token_with_options_internal(
if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
&& let Some(oauth_token) = sqlx::query!(
r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref,
r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,
k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
FROM oauth_token t
JOIN users u ON t.did = u.did
@@ -242,6 +244,7 @@ async fn validate_bearer_token_with_options_internal(
did: oauth_token.did,
key_bytes,
is_oauth: true,
is_admin: oauth_token.is_admin,
});
}
}
@@ -280,43 +283,37 @@ pub async fn validate_token_with_dpop(
.await
{
Ok(result) => {
if !allow_deactivated {
let deactivated = sqlx::query_scalar!(
"SELECT deactivated_at FROM users WHERE did = $1",
result.did
)
.fetch_optional(db)
.await
.ok()
.flatten()
.flatten();
if deactivated.is_some() {
return Err(TokenValidationError::AccountDeactivated);
}
}
let takedown =
sqlx::query_scalar!("SELECT takedown_ref FROM users WHERE did = $1", result.did)
.fetch_optional(db)
.await
.ok()
.flatten()
.flatten();
if takedown.is_some() {
return Err(TokenValidationError::AccountTakedown);
}
let key_bytes = sqlx::query!(
"SELECT k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
let user_info = sqlx::query!(
r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin,
k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
FROM users u
LEFT JOIN user_keys k ON u.id = k.user_id
WHERE u.did = $1"#,
result.did
)
.fetch_optional(db)
.await
.ok()
.flatten()
.and_then(|row| crate::config::decrypt_key(&row.key_bytes, row.encryption_version).ok());
.flatten();
let Some(user_info) = user_info else {
return Err(TokenValidationError::AuthenticationFailed);
};
if !allow_deactivated && user_info.deactivated_at.is_some() {
return Err(TokenValidationError::AccountDeactivated);
}
if user_info.takedown_ref.is_some() {
return Err(TokenValidationError::AccountTakedown);
}
let key_bytes = if let (Some(kb), Some(ev)) = (&user_info.key_bytes, user_info.encryption_version) {
crate::config::decrypt_key(kb, Some(ev)).ok()
} else {
None
};
Ok(AuthenticatedUser {
did: result.did,
key_bytes,
is_oauth: true,
is_admin: user_info.is_admin,
})
}
Err(_) => Err(TokenValidationError::AuthenticationFailed),

View File

@@ -18,7 +18,7 @@ async fn test_send_email_success() {
let client = common::client();
let base_url = common::base_url().await;
let pool = get_pool().await;
let (access_jwt, did) = common::create_account_and_login(&client).await;
let (access_jwt, did) = common::create_admin_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url))
.bearer_auth(&access_jwt)
@@ -58,7 +58,7 @@ async fn test_send_email_default_subject() {
let client = common::client();
let base_url = common::base_url().await;
let pool = get_pool().await;
let (access_jwt, did) = common::create_account_and_login(&client).await;
let (access_jwt, did) = common::create_admin_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url))
.bearer_auth(&access_jwt)
@@ -92,7 +92,7 @@ async fn test_send_email_default_subject() {
async fn test_send_email_recipient_not_found() {
let client = common::client();
let base_url = common::base_url().await;
let (access_jwt, _) = common::create_account_and_login(&client).await;
let (access_jwt, _) = common::create_admin_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url))
.bearer_auth(&access_jwt)
@@ -113,7 +113,7 @@ async fn test_send_email_recipient_not_found() {
async fn test_send_email_missing_content() {
let client = common::client();
let base_url = common::base_url().await;
let (access_jwt, did) = common::create_account_and_login(&client).await;
let (access_jwt, did) = common::create_admin_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url))
.bearer_auth(&access_jwt)
@@ -134,7 +134,7 @@ async fn test_send_email_missing_content() {
async fn test_send_email_missing_recipient() {
let client = common::client();
let base_url = common::base_url().await;
let (access_jwt, _) = common::create_account_and_login(&client).await;
let (access_jwt, _) = common::create_admin_account_and_login(&client).await;
let res = client
.post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url))
.bearer_auth(&access_jwt)

View File

@@ -7,7 +7,7 @@ use serde_json::{Value, json};
#[tokio::test]
async fn test_admin_get_invite_codes_success() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
let create_payload = json!({
"useCount": 3
});
@@ -38,7 +38,7 @@ async fn test_admin_get_invite_codes_success() {
#[tokio::test]
async fn test_admin_get_invite_codes_with_limit() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
for _ in 0..5 {
let create_payload = json!({
"useCount": 1
@@ -86,7 +86,7 @@ async fn test_admin_get_invite_codes_no_auth() {
#[tokio::test]
async fn test_disable_account_invites_success() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let (access_jwt, did) = create_admin_account_and_login(&client).await;
let payload = json!({
"account": did
});
@@ -122,7 +122,7 @@ async fn test_disable_account_invites_success() {
#[tokio::test]
async fn test_enable_account_invites_success() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let (access_jwt, did) = create_admin_account_and_login(&client).await;
let disable_payload = json!({
"account": did
});
@@ -186,7 +186,7 @@ async fn test_disable_account_invites_no_auth() {
#[tokio::test]
async fn test_disable_account_invites_not_found() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
let payload = json!({
"account": "did:plc:nonexistent"
});
@@ -206,7 +206,7 @@ async fn test_disable_account_invites_not_found() {
#[tokio::test]
async fn test_disable_invite_codes_by_code() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
let create_payload = json!({
"useCount": 5
});
@@ -255,7 +255,7 @@ async fn test_disable_invite_codes_by_code() {
#[tokio::test]
async fn test_disable_invite_codes_by_account() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let (access_jwt, did) = create_admin_account_and_login(&client).await;
for _ in 0..3 {
let create_payload = json!({
"useCount": 1
@@ -321,7 +321,7 @@ async fn test_disable_invite_codes_no_auth() {
#[tokio::test]
async fn test_admin_enable_account_invites_not_found() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
let payload = json!({
"account": "did:plc:nonexistent"
});

View File

@@ -7,7 +7,7 @@ use serde_json::{Value, json};
#[tokio::test]
async fn test_get_subject_status_user_success() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let (access_jwt, did) = create_admin_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.admin.getSubjectStatus",
@@ -28,7 +28,7 @@ async fn test_get_subject_status_user_success() {
#[tokio::test]
async fn test_get_subject_status_not_found() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.admin.getSubjectStatus",
@@ -47,7 +47,7 @@ async fn test_get_subject_status_not_found() {
#[tokio::test]
async fn test_get_subject_status_no_param() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
let res = client
.get(format!(
"{}/xrpc/com.atproto.admin.getSubjectStatus",
@@ -80,11 +80,12 @@ async fn test_get_subject_status_no_auth() {
#[tokio::test]
async fn test_update_subject_status_takedown_user() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let (admin_jwt, _) = create_admin_account_and_login(&client).await;
let (_, target_did) = create_account_and_login(&client).await;
let payload = json!({
"subject": {
"$type": "com.atproto.admin.defs#repoRef",
"did": did
"did": target_did
},
"takedown": {
"apply": true,
@@ -96,7 +97,7 @@ async fn test_update_subject_status_takedown_user() {
"{}/xrpc/com.atproto.admin.updateSubjectStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.bearer_auth(&admin_jwt)
.json(&payload)
.send()
.await
@@ -111,8 +112,8 @@ async fn test_update_subject_status_takedown_user() {
"{}/xrpc/com.atproto.admin.getSubjectStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.query(&[("did", did.as_str())])
.bearer_auth(&admin_jwt)
.query(&[("did", target_did.as_str())])
.send()
.await
.expect("Failed to send request");
@@ -125,11 +126,12 @@ async fn test_update_subject_status_takedown_user() {
#[tokio::test]
async fn test_update_subject_status_remove_takedown() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let (admin_jwt, _) = create_admin_account_and_login(&client).await;
let (_, target_did) = create_account_and_login(&client).await;
let takedown_payload = json!({
"subject": {
"$type": "com.atproto.admin.defs#repoRef",
"did": did
"did": target_did
},
"takedown": {
"apply": true,
@@ -141,14 +143,14 @@ async fn test_update_subject_status_remove_takedown() {
"{}/xrpc/com.atproto.admin.updateSubjectStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.bearer_auth(&admin_jwt)
.json(&takedown_payload)
.send()
.await;
let remove_payload = json!({
"subject": {
"$type": "com.atproto.admin.defs#repoRef",
"did": did
"did": target_did
},
"takedown": {
"apply": false
@@ -159,7 +161,7 @@ async fn test_update_subject_status_remove_takedown() {
"{}/xrpc/com.atproto.admin.updateSubjectStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.bearer_auth(&admin_jwt)
.json(&remove_payload)
.send()
.await
@@ -170,8 +172,8 @@ async fn test_update_subject_status_remove_takedown() {
"{}/xrpc/com.atproto.admin.getSubjectStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.query(&[("did", did.as_str())])
.bearer_auth(&admin_jwt)
.query(&[("did", target_did.as_str())])
.send()
.await
.expect("Failed to send request");
@@ -187,11 +189,12 @@ async fn test_update_subject_status_remove_takedown() {
#[tokio::test]
async fn test_update_subject_status_deactivate_user() {
let client = client();
let (access_jwt, did) = create_account_and_login(&client).await;
let (admin_jwt, _) = create_admin_account_and_login(&client).await;
let (_, target_did) = create_account_and_login(&client).await;
let payload = json!({
"subject": {
"$type": "com.atproto.admin.defs#repoRef",
"did": did
"did": target_did
},
"deactivated": {
"apply": true
@@ -202,7 +205,7 @@ async fn test_update_subject_status_deactivate_user() {
"{}/xrpc/com.atproto.admin.updateSubjectStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.bearer_auth(&admin_jwt)
.json(&payload)
.send()
.await
@@ -213,8 +216,8 @@ async fn test_update_subject_status_deactivate_user() {
"{}/xrpc/com.atproto.admin.getSubjectStatus",
base_url().await
))
.bearer_auth(&access_jwt)
.query(&[("did", did.as_str())])
.bearer_auth(&admin_jwt)
.query(&[("did", target_did.as_str())])
.send()
.await
.expect("Failed to send request");
@@ -226,7 +229,7 @@ async fn test_update_subject_status_deactivate_user() {
#[tokio::test]
async fn test_update_subject_status_invalid_type() {
let client = client();
let (access_jwt, _did) = create_account_and_login(&client).await;
let (access_jwt, _did) = create_admin_account_and_login(&client).await;
let payload = json!({
"subject": {
"$type": "invalid.type",

View File

@@ -1,14 +1,14 @@
mod common;
use common::{base_url, client, create_account_and_login};
use common::{base_url, client, create_admin_account_and_login};
use serde_json::Value;
#[tokio::test]
async fn test_get_server_stats() {
let client = client();
let base = base_url().await;
let (token1, _) = create_account_and_login(&client).await;
let (token1, _) = create_admin_account_and_login(&client).await;
let (_, _) = create_account_and_login(&client).await;
let (_, _) = create_admin_account_and_login(&client).await;
let resp = client
.get(format!("{}/xrpc/com.bspds.admin.getServerStats", base))

View File

@@ -511,6 +511,15 @@ pub async fn create_test_post(
#[allow(dead_code)]
pub async fn create_account_and_login(client: &Client) -> (String, String) {
create_account_and_login_internal(client, false).await
}
#[allow(dead_code)]
pub async fn create_admin_account_and_login(client: &Client) -> (String, String) {
create_account_and_login_internal(client, true).await
}
async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) {
let mut last_error = String::new();
for attempt in 0..3 {
if attempt > 0 {
@@ -539,10 +548,6 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) {
};
if res.status() == StatusCode::OK {
let body: Value = res.json().await.expect("Invalid JSON");
if let Some(access_jwt) = body["accessJwt"].as_str() {
let did = body["did"].as_str().expect("No did").to_string();
return (access_jwt.to_string(), did);
}
let did = body["did"].as_str().expect("No did").to_string();
let conn_str = get_db_connection_string().await;
let pool = sqlx::postgres::PgPoolOptions::new()
@@ -550,6 +555,15 @@ pub async fn create_account_and_login(client: &Client) -> (String, String) {
.connect(&conn_str)
.await
.expect("Failed to connect to test database");
if make_admin {
sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
.execute(&pool)
.await
.expect("Failed to mark user as admin");
}
if let Some(access_jwt) = body["accessJwt"].as_str() {
return (access_jwt.to_string(), did);
}
let verification_code: String = sqlx::query_scalar!(
"SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'",
&did