extract proxying into a middleware instead of a fallback handler

This commit is contained in:
nelind
2026-01-04 00:46:34 +01:00
committed by Tangled
parent b4e2b5d300
commit e61021feee
6 changed files with 104 additions and 120 deletions

4
Cargo.lock generated
View File

@@ -6315,11 +6315,13 @@ dependencies = [
"dotenvy",
"ed25519-dalek",
"futures",
"futures-util",
"governor",
"hex",
"hickory-resolver",
"hkdf",
"hmac",
"http 1.4.0",
"image",
"ipld-core",
"iroh-car",
@@ -6352,7 +6354,9 @@ dependencies = [
"tokio",
"tokio-tungstenite",
"totp-rs",
"tower",
"tower-http",
"tower-layer",
"tracing",
"tracing-subscriber",
"urlencoding",

View File

@@ -64,6 +64,10 @@ totp-rs = { version = "5", features = ["qr"] }
webauthn-rs = { version = "0.5.4", features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] }
webauthn-rs-proto = "0.5.4"
zip = { version = "7.0.0", default-features = false, features = ["deflate"] }
tower = "0.5.2"
tower-layer = "0.3.3"
futures-util = "0.3.31"
http = "1.4.0"
[features]
external-infra = []
[dev-dependencies]

View File

@@ -2,22 +2,13 @@ use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
use crate::state::AppState;
use axum::{
Json,
body::Bytes,
extract::{Path, RawQuery, State},
http::{HeaderMap, Method, StatusCode},
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use serde_json::json;
pub async fn get_state(
State(state): State<AppState>,
headers: HeaderMap,
RawQuery(query): RawQuery,
) -> Response {
if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() {
return proxy_to_appview(state, headers, "app.bsky.ageassurance.getState", query).await;
}
pub async fn get_state(State(state): State<AppState>, headers: HeaderMap) -> Response {
let created_at = get_account_created_at(&state, &headers).await;
let now = chrono::Utc::now().to_rfc3339();
@@ -37,21 +28,7 @@ pub async fn get_state(
.into_response()
}
pub async fn get_age_assurance_state(
State(state): State<AppState>,
headers: HeaderMap,
RawQuery(query): RawQuery,
) -> Response {
if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() {
return proxy_to_appview(
state,
headers,
"app.bsky.unspecced.getAgeAssuranceState",
query,
)
.await;
}
pub async fn get_age_assurance_state() -> Response {
(StatusCode::OK, Json(json!({"status": "assured"}))).into_response()
}
@@ -89,31 +66,3 @@ async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option
row.map(|r| r.created_at.to_rfc3339())
}
async fn proxy_to_appview(
state: AppState,
headers: HeaderMap,
method: &str,
query: Option<String>,
) -> Response {
if headers.get("atproto-proxy").is_none() {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "InvalidRequest",
"message": "Missing required atproto-proxy header"
})),
)
.into_response();
}
crate::api::proxy::proxy_handler(
State(state),
Path(method.to_string()),
Method::GET,
headers,
RawQuery(query),
Bytes::new(),
)
.await
}

View File

@@ -1,13 +1,18 @@
use std::convert::Infallible;
use crate::api::proxy_client::proxy_client;
use crate::state::AppState;
use axum::{
Json,
body::Bytes,
extract::{Path, RawQuery, State},
extract::{RawQuery, Request, State},
handler::Handler,
http::{HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
};
use futures_util::future::Either;
use serde_json::json;
use tower::{Service, util::BoxCloneSyncService};
use tracing::{error, info, warn};
const PROTECTED_METHODS: &[&str] = &[
@@ -33,14 +38,86 @@ fn is_protected_method(method: &str) -> bool {
PROTECTED_METHODS.contains(&method)
}
pub async fn proxy_handler(
pub struct XrpcProxyLayer {
state: AppState,
}
impl XrpcProxyLayer {
pub fn new(state: AppState) -> Self {
XrpcProxyLayer { state }
}
}
impl<S> tower_layer::Layer<S> for XrpcProxyLayer {
type Service = XrpcProxyingService<S>;
fn layer(&self, inner: S) -> Self::Service {
XrpcProxyingService {
inner,
// TODO(nel): make our own service here instead of boxing a HandlerService
handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())),
}
}
}
#[derive(Clone)]
pub struct XrpcProxyingService<S> {
inner: S,
handler: BoxCloneSyncService<Request, Response, Infallible>,
}
impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request>
for XrpcProxyingService<S>
{
type Response = Response;
type Error = Infallible;
type Future = Either<
<BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future,
S::Future,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
if req
.headers()
.contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy))
{
// If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it.
if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err()
&& (req.uri().path().ends_with("app.bsky.ageassurance.getState")
|| req
.uri()
.path()
.ends_with("app.bsky.unspecced.getAgeAssuranceState"))
{
return Either::Right(self.inner.call(req));
}
Either::Left(self.handler.call(req))
} else {
Either::Right(self.inner.call(req))
}
}
}
async fn proxy_handler(
State(state): State<AppState>,
Path(method): Path<String>,
uri: http::Uri,
method_verb: Method,
headers: HeaderMap,
RawQuery(query): RawQuery,
body: Bytes,
) -> Response {
// This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the /
let method = uri.path().trim_start_matches("/");
if is_protected_method(&method) {
warn!(method = %method, "Attempted to proxy protected method");
return (

View File

@@ -1,4 +1,3 @@
use crate::api::proxy_client::proxy_client;
use crate::state::AppState;
use axum::{
Json,
@@ -14,7 +13,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{Map, Value, json};
use std::collections::HashMap;
use std::str::FromStr;
use tracing::{error, info};
use tracing::error;
fn ipld_to_json(ipld: Ipld) -> Value {
match ipld {
@@ -78,61 +77,6 @@ pub async fn get_record(
let user_id: uuid::Uuid = match user_id_opt {
Ok(Some(id)) => id,
Ok(None) => {
if let Some(proxy_header) = headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) {
let did = proxy_header.split('#').next().unwrap_or(proxy_header);
if let Some(resolved) = state.did_resolver.resolve_did(did).await {
let mut url = format!(
"{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}",
resolved.url.trim_end_matches('/'),
urlencoding::encode(&input.repo),
urlencoding::encode(&input.collection),
urlencoding::encode(&input.rkey)
);
if let Some(cid) = &input.cid {
url.push_str(&format!("&cid={}", urlencoding::encode(cid)));
}
info!("Proxying getRecord to {}: {}", did, url);
match proxy_client().get(&url).send().await {
Ok(resp) => {
let status = resp.status();
let body = match resp.bytes().await {
Ok(b) => b,
Err(e) => {
error!("Error reading proxy response: {:?}", e);
return (
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamFailure", "message": "Error reading upstream response"})),
)
.into_response();
}
};
return Response::builder()
.status(status)
.header("content-type", "application/json")
.body(axum::body::Body::from(body))
.unwrap_or_else(|_| {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error")
.into_response()
});
}
Err(e) => {
error!("Error proxying request: {:?}", e);
return (
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamFailure", "message": "Failed to reach upstream service"})),
)
.into_response();
}
}
} else {
error!("Could not resolve DID from atproto-proxy header: {}", did);
return (
StatusCode::BAD_GATEWAY,
Json(json!({"error": "UpstreamFailure", "message": "Could not resolve proxy DID"})),
)
.into_response();
}
}
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),

View File

@@ -22,14 +22,18 @@ pub mod sync;
pub mod util;
pub mod validation;
use api::proxy::XrpcProxyLayer;
use axum::{
Router,
Json, Router,
extract::DefaultBodyLimit,
http::Method,
middleware,
routing::{any, get, post},
routing::{get, post},
};
use http::StatusCode;
use serde_json::json;
use state::AppState;
use tower::{Layer, ServiceBuilder};
use tower_http::cors::{Any, CorsLayer};
use tower_http::services::{ServeDir, ServeFile};
@@ -494,8 +498,10 @@ pub fn app(state: AppState) -> Router {
.route(
"/app.bsky.unspecced.getAgeAssuranceState",
get(api::age_assurance::get_age_assurance_state),
)
.route("/{*method}", any(api::proxy::proxy_handler));
);
let xrpc_service = ServiceBuilder::new()
.layer(XrpcProxyLayer::new(state.clone()))
.service(xrpc_router.with_state(state.clone()));
let oauth_router = Router::new()
.route("/jwks", get(oauth::endpoints::oauth_jwks))
@@ -559,7 +565,7 @@ pub fn app(state: AppState) -> Router {
);
let router = Router::new()
.nest("/xrpc", xrpc_router)
.nest_service("/xrpc", xrpc_service)
.nest("/oauth", oauth_router)
.route("/metrics", get(metrics::metrics_handler))
.route("/health", get(api::server::health))