From e61021feee2bf07c51cae4e28e8832cfe12938d9 Mon Sep 17 00:00:00 2001 From: nelind Date: Sun, 4 Jan 2026 00:46:34 +0100 Subject: [PATCH] extract proxying into a middleware instead of a fallback handler --- Cargo.lock | 4 ++ Cargo.toml | 4 ++ src/api/age_assurance.rs | 59 ++------------------------ src/api/proxy.rs | 83 +++++++++++++++++++++++++++++++++++-- src/api/repo/record/read.rs | 58 +------------------------- src/lib.rs | 16 ++++--- 6 files changed, 104 insertions(+), 120 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a4fb7fe..b506abe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6315,11 +6315,13 @@ dependencies = [ "dotenvy", "ed25519-dalek", "futures", + "futures-util", "governor", "hex", "hickory-resolver", "hkdf", "hmac", + "http 1.4.0", "image", "ipld-core", "iroh-car", @@ -6352,7 +6354,9 @@ dependencies = [ "tokio", "tokio-tungstenite", "totp-rs", + "tower", "tower-http", + "tower-layer", "tracing", "tracing-subscriber", "urlencoding", diff --git a/Cargo.toml b/Cargo.toml index c401fbb..69bf2fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,10 @@ totp-rs = { version = "5", features = ["qr"] } webauthn-rs = { version = "0.5.4", features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] } webauthn-rs-proto = "0.5.4" zip = { version = "7.0.0", default-features = false, features = ["deflate"] } +tower = "0.5.2" +tower-layer = "0.3.3" +futures-util = "0.3.31" +http = "1.4.0" [features] external-infra = [] [dev-dependencies] diff --git a/src/api/age_assurance.rs b/src/api/age_assurance.rs index d09f26d..ffa3c1f 100644 --- a/src/api/age_assurance.rs +++ b/src/api/age_assurance.rs @@ -2,22 +2,13 @@ use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; use crate::state::AppState; use axum::{ Json, - body::Bytes, - extract::{Path, RawQuery, State}, - http::{HeaderMap, Method, StatusCode}, + extract::State, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; use serde_json::json; -pub async fn get_state( - State(state): State, - headers: HeaderMap, - RawQuery(query): RawQuery, -) -> Response { - if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() { - return proxy_to_appview(state, headers, "app.bsky.ageassurance.getState", query).await; - } - +pub async fn get_state(State(state): State, headers: HeaderMap) -> Response { let created_at = get_account_created_at(&state, &headers).await; let now = chrono::Utc::now().to_rfc3339(); @@ -37,21 +28,7 @@ pub async fn get_state( .into_response() } -pub async fn get_age_assurance_state( - State(state): State, - headers: HeaderMap, - RawQuery(query): RawQuery, -) -> Response { - if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() { - return proxy_to_appview( - state, - headers, - "app.bsky.unspecced.getAgeAssuranceState", - query, - ) - .await; - } - +pub async fn get_age_assurance_state() -> Response { (StatusCode::OK, Json(json!({"status": "assured"}))).into_response() } @@ -89,31 +66,3 @@ async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option row.map(|r| r.created_at.to_rfc3339()) } - -async fn proxy_to_appview( - state: AppState, - headers: HeaderMap, - method: &str, - query: Option, -) -> Response { - if headers.get("atproto-proxy").is_none() { - return ( - StatusCode::BAD_REQUEST, - Json(json!({ - "error": "InvalidRequest", - "message": "Missing required atproto-proxy header" - })), - ) - .into_response(); - } - - crate::api::proxy::proxy_handler( - State(state), - Path(method.to_string()), - Method::GET, - headers, - RawQuery(query), - Bytes::new(), - ) - .await -} diff --git a/src/api/proxy.rs b/src/api/proxy.rs index ce8cd09..370e623 100644 --- a/src/api/proxy.rs +++ b/src/api/proxy.rs @@ -1,13 +1,18 @@ +use std::convert::Infallible; + use crate::api::proxy_client::proxy_client; use crate::state::AppState; use axum::{ Json, body::Bytes, - extract::{Path, RawQuery, State}, + extract::{RawQuery, Request, State}, + handler::Handler, http::{HeaderMap, Method, StatusCode}, response::{IntoResponse, Response}, }; +use futures_util::future::Either; use serde_json::json; +use tower::{Service, util::BoxCloneSyncService}; use tracing::{error, info, warn}; const PROTECTED_METHODS: &[&str] = &[ @@ -33,14 +38,86 @@ fn is_protected_method(method: &str) -> bool { PROTECTED_METHODS.contains(&method) } -pub async fn proxy_handler( +pub struct XrpcProxyLayer { + state: AppState, +} + +impl XrpcProxyLayer { + pub fn new(state: AppState) -> Self { + XrpcProxyLayer { state } + } +} + +impl tower_layer::Layer for XrpcProxyLayer { + type Service = XrpcProxyingService; + + fn layer(&self, inner: S) -> Self::Service { + XrpcProxyingService { + inner, + // TODO(nel): make our own service here instead of boxing a HandlerService + handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())), + } + } +} + +#[derive(Clone)] +pub struct XrpcProxyingService { + inner: S, + handler: BoxCloneSyncService, +} + +impl> Service + for XrpcProxyingService +{ + type Response = Response; + + type Error = Infallible; + + type Future = Either< + as Service>::Future, + S::Future, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + if req + .headers() + .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy)) + { + // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it. + if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() + && (req.uri().path().ends_with("app.bsky.ageassurance.getState") + || req + .uri() + .path() + .ends_with("app.bsky.unspecced.getAgeAssuranceState")) + { + return Either::Right(self.inner.call(req)); + } + + Either::Left(self.handler.call(req)) + } else { + Either::Right(self.inner.call(req)) + } + } +} + +async fn proxy_handler( State(state): State, - Path(method): Path, + uri: http::Uri, method_verb: Method, headers: HeaderMap, RawQuery(query): RawQuery, body: Bytes, ) -> Response { + // This layer is nested under /xrpc in an axum router so the extracted uri will look like / and thus we can just strip the / + let method = uri.path().trim_start_matches("/"); if is_protected_method(&method) { warn!(method = %method, "Attempted to proxy protected method"); return ( diff --git a/src/api/repo/record/read.rs b/src/api/repo/record/read.rs index 8e74520..d252a28 100644 --- a/src/api/repo/record/read.rs +++ b/src/api/repo/record/read.rs @@ -1,4 +1,3 @@ -use crate::api::proxy_client::proxy_client; use crate::state::AppState; use axum::{ Json, @@ -14,7 +13,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{Map, Value, json}; use std::collections::HashMap; use std::str::FromStr; -use tracing::{error, info}; +use tracing::error; fn ipld_to_json(ipld: Ipld) -> Value { match ipld { @@ -78,61 +77,6 @@ pub async fn get_record( let user_id: uuid::Uuid = match user_id_opt { Ok(Some(id)) => id, Ok(None) => { - if let Some(proxy_header) = headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) { - let did = proxy_header.split('#').next().unwrap_or(proxy_header); - if let Some(resolved) = state.did_resolver.resolve_did(did).await { - let mut url = format!( - "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}", - resolved.url.trim_end_matches('/'), - urlencoding::encode(&input.repo), - urlencoding::encode(&input.collection), - urlencoding::encode(&input.rkey) - ); - if let Some(cid) = &input.cid { - url.push_str(&format!("&cid={}", urlencoding::encode(cid))); - } - info!("Proxying getRecord to {}: {}", did, url); - match proxy_client().get(&url).send().await { - Ok(resp) => { - let status = resp.status(); - let body = match resp.bytes().await { - Ok(b) => b, - Err(e) => { - error!("Error reading proxy response: {:?}", e); - return ( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamFailure", "message": "Error reading upstream response"})), - ) - .into_response(); - } - }; - return Response::builder() - .status(status) - .header("content-type", "application/json") - .body(axum::body::Body::from(body)) - .unwrap_or_else(|_| { - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error") - .into_response() - }); - } - Err(e) => { - error!("Error proxying request: {:?}", e); - return ( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamFailure", "message": "Failed to reach upstream service"})), - ) - .into_response(); - } - } - } else { - error!("Could not resolve DID from atproto-proxy header: {}", did); - return ( - StatusCode::BAD_GATEWAY, - Json(json!({"error": "UpstreamFailure", "message": "Could not resolve proxy DID"})), - ) - .into_response(); - } - } return ( StatusCode::NOT_FOUND, Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), diff --git a/src/lib.rs b/src/lib.rs index 22415c9..b4f8b6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,14 +22,18 @@ pub mod sync; pub mod util; pub mod validation; +use api::proxy::XrpcProxyLayer; use axum::{ - Router, + Json, Router, extract::DefaultBodyLimit, http::Method, middleware, - routing::{any, get, post}, + routing::{get, post}, }; +use http::StatusCode; +use serde_json::json; use state::AppState; +use tower::{Layer, ServiceBuilder}; use tower_http::cors::{Any, CorsLayer}; use tower_http::services::{ServeDir, ServeFile}; @@ -494,8 +498,10 @@ pub fn app(state: AppState) -> Router { .route( "/app.bsky.unspecced.getAgeAssuranceState", get(api::age_assurance::get_age_assurance_state), - ) - .route("/{*method}", any(api::proxy::proxy_handler)); + ); + let xrpc_service = ServiceBuilder::new() + .layer(XrpcProxyLayer::new(state.clone())) + .service(xrpc_router.with_state(state.clone())); let oauth_router = Router::new() .route("/jwks", get(oauth::endpoints::oauth_jwks)) @@ -559,7 +565,7 @@ pub fn app(state: AppState) -> Router { ); let router = Router::new() - .nest("/xrpc", xrpc_router) + .nest_service("/xrpc", xrpc_service) .nest("/oauth", oauth_router) .route("/metrics", get(metrics::metrics_handler)) .route("/health", get(api::server::health))