Files
tranquil-pds/src/api/server/invite.rs

503 lines
14 KiB
Rust

use crate::state::AppState;
use axum::{
Json,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::error;
use uuid::Uuid;
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateInviteCodeInput {
pub use_count: i32,
pub for_account: Option<String>,
}
#[derive(Serialize)]
pub struct CreateInviteCodeOutput {
pub code: String,
}
pub async fn create_invite_code(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<CreateInviteCodeInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
if input.use_count < 1 {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
)
.into_response();
}
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query!(
r#"
SELECT s.did, k.key_bytes, u.id as user_id
FROM sessions s
JOIN users u ON s.did = u.did
JOIN user_keys k ON u.id = k.user_id
WHERE s.access_jwt = $1
"#,
token
)
.fetch_optional(&state.db)
.await;
let (did, key_bytes, user_id) = match session {
Ok(Some(row)) => (row.did, row.key_bytes, row.user_id),
Ok(None) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
Err(e) => {
error!("DB error in create_invite_code: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
)
.into_response();
}
let creator_user_id = if let Some(for_account) = &input.for_account {
let target = sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
.fetch_optional(&state.db)
.await;
match target {
Ok(Some(row)) => row.id,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "AccountNotFound", "message": "Target account not found"})),
)
.into_response();
}
Err(e) => {
error!("DB error looking up target account: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
}
} else {
user_id
};
let user_invites_disabled = sqlx::query_scalar!(
"SELECT invites_disabled FROM users WHERE did = $1",
did
)
.fetch_optional(&state.db)
.await
.ok()
.flatten()
.flatten()
.unwrap_or(false);
if user_invites_disabled {
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "InvitesDisabled", "message": "Invites are disabled for this account"})),
)
.into_response();
}
let code = Uuid::new_v4().to_string();
let result = sqlx::query!(
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
code,
input.use_count,
creator_user_id
)
.execute(&state.db)
.await;
match result {
Ok(_) => (StatusCode::OK, Json(CreateInviteCodeOutput { code })).into_response(),
Err(e) => {
error!("DB error creating invite code: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response()
}
}
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateInviteCodesInput {
pub code_count: Option<i32>,
pub use_count: i32,
pub for_accounts: Option<Vec<String>>,
}
#[derive(Serialize)]
pub struct CreateInviteCodesOutput {
pub codes: Vec<AccountCodes>,
}
#[derive(Serialize)]
pub struct AccountCodes {
pub account: String,
pub codes: Vec<String>,
}
pub async fn create_invite_codes(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
Json(input): Json<CreateInviteCodesInput>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
if input.use_count < 1 {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
)
.into_response();
}
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query!(
r#"
SELECT s.did, k.key_bytes, u.id as user_id
FROM sessions s
JOIN users u ON s.did = u.did
JOIN user_keys k ON u.id = k.user_id
WHERE s.access_jwt = $1
"#,
token
)
.fetch_optional(&state.db)
.await;
let (_did, key_bytes, user_id) = match session {
Ok(Some(row)) => (row.did, row.key_bytes, row.user_id),
Ok(None) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
Err(e) => {
error!("DB error in create_invite_codes: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
)
.into_response();
}
let code_count = input.code_count.unwrap_or(1).max(1);
let for_accounts = input.for_accounts.unwrap_or_default();
let mut result_codes = Vec::new();
if for_accounts.is_empty() {
let mut codes = Vec::new();
for _ in 0..code_count {
let code = Uuid::new_v4().to_string();
let insert = sqlx::query!(
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
code,
input.use_count,
user_id
)
.execute(&state.db)
.await;
if let Err(e) = insert {
error!("DB error creating invite code: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
codes.push(code);
}
result_codes.push(AccountCodes {
account: "admin".to_string(),
codes,
});
} else {
for account_did in for_accounts {
let target = sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
.fetch_optional(&state.db)
.await;
let target_user_id = match target {
Ok(Some(row)) => row.id,
Ok(None) => {
continue;
}
Err(e) => {
error!("DB error looking up target account: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let mut codes = Vec::new();
for _ in 0..code_count {
let code = Uuid::new_v4().to_string();
let insert = sqlx::query!(
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
code,
input.use_count,
target_user_id
)
.execute(&state.db)
.await;
if let Err(e) = insert {
error!("DB error creating invite code: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
codes.push(code);
}
result_codes.push(AccountCodes {
account: account_did,
codes,
});
}
}
(StatusCode::OK, Json(CreateInviteCodesOutput { codes: result_codes })).into_response()
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetAccountInviteCodesParams {
pub include_used: Option<bool>,
pub create_available: Option<bool>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct InviteCode {
pub code: String,
pub available: i32,
pub disabled: bool,
pub for_account: String,
pub created_by: String,
pub created_at: String,
pub uses: Vec<InviteCodeUse>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct InviteCodeUse {
pub used_by: String,
pub used_at: String,
}
#[derive(Serialize)]
pub struct GetAccountInviteCodesOutput {
pub codes: Vec<InviteCode>,
}
pub async fn get_account_invite_codes(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
) -> Response {
let auth_header = headers.get("Authorization");
if auth_header.is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationRequired"})),
)
.into_response();
}
let token = auth_header
.unwrap()
.to_str()
.unwrap_or("")
.replace("Bearer ", "");
let session = sqlx::query!(
r#"
SELECT s.did, k.key_bytes, u.id as user_id
FROM sessions s
JOIN users u ON s.did = u.did
JOIN user_keys k ON u.id = k.user_id
WHERE s.access_jwt = $1
"#,
token
)
.fetch_optional(&state.db)
.await;
let (did, key_bytes, user_id) = match session {
Ok(Some(row)) => (row.did, row.key_bytes, row.user_id),
Ok(None) => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed"})),
)
.into_response();
}
Err(e) => {
error!("DB error in get_account_invite_codes: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
)
.into_response();
}
let include_used = params.include_used.unwrap_or(true);
let codes_result = sqlx::query!(
r#"
SELECT code, available_uses, created_at, disabled
FROM invite_codes
WHERE created_by_user = $1
ORDER BY created_at DESC
"#,
user_id
)
.fetch_all(&state.db)
.await;
let codes_rows = match codes_result {
Ok(rows) => {
if include_used {
rows
} else {
rows.into_iter().filter(|r| r.available_uses > 0).collect()
}
}
Err(e) => {
error!("DB error fetching invite codes: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "InternalError"})),
)
.into_response();
}
};
let mut codes = Vec::new();
for row in codes_rows {
let uses_result = sqlx::query!(
r#"
SELECT u.did, icu.used_at
FROM invite_code_uses icu
JOIN users u ON icu.used_by_user = u.id
WHERE icu.code = $1
ORDER BY icu.used_at DESC
"#,
row.code
)
.fetch_all(&state.db)
.await;
let uses = match uses_result {
Ok(use_rows) => use_rows
.iter()
.map(|u| InviteCodeUse {
used_by: u.did.clone(),
used_at: u.used_at.to_rfc3339(),
})
.collect(),
Err(_) => Vec::new(),
};
codes.push(InviteCode {
code: row.code,
available: row.available_uses,
disabled: row.disabled.unwrap_or(false),
for_account: did.clone(),
created_by: did.clone(),
created_at: row.created_at.to_rfc3339(),
uses,
});
}
(StatusCode::OK, Json(GetAccountInviteCodesOutput { codes })).into_response()
}