mirror of
https://tangled.org/tranquil.farm/tranquil-pds
synced 2026-02-08 21:30:08 +00:00
extract proxying into a middleware instead of a fallback handler
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -6315,11 +6315,13 @@ dependencies = [
|
||||
"dotenvy",
|
||||
"ed25519-dalek",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"governor",
|
||||
"hex",
|
||||
"hickory-resolver",
|
||||
"hkdf",
|
||||
"hmac",
|
||||
"http 1.4.0",
|
||||
"image",
|
||||
"ipld-core",
|
||||
"iroh-car",
|
||||
@@ -6352,7 +6354,9 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"totp-rs",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-layer",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"urlencoding",
|
||||
|
||||
@@ -64,6 +64,10 @@ totp-rs = { version = "5", features = ["qr"] }
|
||||
webauthn-rs = { version = "0.5.4", features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] }
|
||||
webauthn-rs-proto = "0.5.4"
|
||||
zip = { version = "7.0.0", default-features = false, features = ["deflate"] }
|
||||
tower = "0.5.2"
|
||||
tower-layer = "0.3.3"
|
||||
futures-util = "0.3.31"
|
||||
http = "1.4.0"
|
||||
[features]
|
||||
external-infra = []
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -2,22 +2,13 @@ use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
body::Bytes,
|
||||
extract::{Path, RawQuery, State},
|
||||
http::{HeaderMap, Method, StatusCode},
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
pub async fn get_state(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Response {
|
||||
if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() {
|
||||
return proxy_to_appview(state, headers, "app.bsky.ageassurance.getState", query).await;
|
||||
}
|
||||
|
||||
pub async fn get_state(State(state): State<AppState>, headers: HeaderMap) -> Response {
|
||||
let created_at = get_account_created_at(&state, &headers).await;
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
@@ -37,21 +28,7 @@ pub async fn get_state(
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub async fn get_age_assurance_state(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Response {
|
||||
if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() {
|
||||
return proxy_to_appview(
|
||||
state,
|
||||
headers,
|
||||
"app.bsky.unspecced.getAgeAssuranceState",
|
||||
query,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn get_age_assurance_state() -> Response {
|
||||
(StatusCode::OK, Json(json!({"status": "assured"}))).into_response()
|
||||
}
|
||||
|
||||
@@ -89,31 +66,3 @@ async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option
|
||||
|
||||
row.map(|r| r.created_at.to_rfc3339())
|
||||
}
|
||||
|
||||
async fn proxy_to_appview(
|
||||
state: AppState,
|
||||
headers: HeaderMap,
|
||||
method: &str,
|
||||
query: Option<String>,
|
||||
) -> Response {
|
||||
if headers.get("atproto-proxy").is_none() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": "InvalidRequest",
|
||||
"message": "Missing required atproto-proxy header"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
crate::api::proxy::proxy_handler(
|
||||
State(state),
|
||||
Path(method.to_string()),
|
||||
Method::GET,
|
||||
headers,
|
||||
RawQuery(query),
|
||||
Bytes::new(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::api::proxy_client::proxy_client;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
body::Bytes,
|
||||
extract::{Path, RawQuery, State},
|
||||
extract::{RawQuery, Request, State},
|
||||
handler::Handler,
|
||||
http::{HeaderMap, Method, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures_util::future::Either;
|
||||
use serde_json::json;
|
||||
use tower::{Service, util::BoxCloneSyncService};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
const PROTECTED_METHODS: &[&str] = &[
|
||||
@@ -33,14 +38,86 @@ fn is_protected_method(method: &str) -> bool {
|
||||
PROTECTED_METHODS.contains(&method)
|
||||
}
|
||||
|
||||
pub async fn proxy_handler(
|
||||
pub struct XrpcProxyLayer {
|
||||
state: AppState,
|
||||
}
|
||||
|
||||
impl XrpcProxyLayer {
|
||||
pub fn new(state: AppState) -> Self {
|
||||
XrpcProxyLayer { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> tower_layer::Layer<S> for XrpcProxyLayer {
|
||||
type Service = XrpcProxyingService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
XrpcProxyingService {
|
||||
inner,
|
||||
// TODO(nel): make our own service here instead of boxing a HandlerService
|
||||
handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct XrpcProxyingService<S> {
|
||||
inner: S,
|
||||
handler: BoxCloneSyncService<Request, Response, Infallible>,
|
||||
}
|
||||
|
||||
impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request>
|
||||
for XrpcProxyingService<S>
|
||||
{
|
||||
type Response = Response;
|
||||
|
||||
type Error = Infallible;
|
||||
|
||||
type Future = Either<
|
||||
<BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future,
|
||||
S::Future,
|
||||
>;
|
||||
|
||||
fn poll_ready(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request) -> Self::Future {
|
||||
if req
|
||||
.headers()
|
||||
.contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy))
|
||||
{
|
||||
// If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it.
|
||||
if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err()
|
||||
&& (req.uri().path().ends_with("app.bsky.ageassurance.getState")
|
||||
|| req
|
||||
.uri()
|
||||
.path()
|
||||
.ends_with("app.bsky.unspecced.getAgeAssuranceState"))
|
||||
{
|
||||
return Either::Right(self.inner.call(req));
|
||||
}
|
||||
|
||||
Either::Left(self.handler.call(req))
|
||||
} else {
|
||||
Either::Right(self.inner.call(req))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_handler(
|
||||
State(state): State<AppState>,
|
||||
Path(method): Path<String>,
|
||||
uri: http::Uri,
|
||||
method_verb: Method,
|
||||
headers: HeaderMap,
|
||||
RawQuery(query): RawQuery,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
// This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the /
|
||||
let method = uri.path().trim_start_matches("/");
|
||||
if is_protected_method(&method) {
|
||||
warn!(method = %method, "Attempted to proxy protected method");
|
||||
return (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::api::proxy_client::proxy_client;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
Json,
|
||||
@@ -14,7 +13,7 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use tracing::{error, info};
|
||||
use tracing::error;
|
||||
|
||||
fn ipld_to_json(ipld: Ipld) -> Value {
|
||||
match ipld {
|
||||
@@ -78,61 +77,6 @@ pub async fn get_record(
|
||||
let user_id: uuid::Uuid = match user_id_opt {
|
||||
Ok(Some(id)) => id,
|
||||
Ok(None) => {
|
||||
if let Some(proxy_header) = headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) {
|
||||
let did = proxy_header.split('#').next().unwrap_or(proxy_header);
|
||||
if let Some(resolved) = state.did_resolver.resolve_did(did).await {
|
||||
let mut url = format!(
|
||||
"{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}",
|
||||
resolved.url.trim_end_matches('/'),
|
||||
urlencoding::encode(&input.repo),
|
||||
urlencoding::encode(&input.collection),
|
||||
urlencoding::encode(&input.rkey)
|
||||
);
|
||||
if let Some(cid) = &input.cid {
|
||||
url.push_str(&format!("&cid={}", urlencoding::encode(cid)));
|
||||
}
|
||||
info!("Proxying getRecord to {}: {}", did, url);
|
||||
match proxy_client().get(&url).send().await {
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let body = match resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
error!("Error reading proxy response: {:?}", e);
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamFailure", "message": "Error reading upstream response"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
return Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(body))
|
||||
.unwrap_or_else(|_| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error")
|
||||
.into_response()
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error proxying request: {:?}", e);
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamFailure", "message": "Failed to reach upstream service"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
error!("Could not resolve DID from atproto-proxy header: {}", did);
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(json!({"error": "UpstreamFailure", "message": "Could not resolve proxy DID"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
|
||||
|
||||
16
src/lib.rs
16
src/lib.rs
@@ -22,14 +22,18 @@ pub mod sync;
|
||||
pub mod util;
|
||||
pub mod validation;
|
||||
|
||||
use api::proxy::XrpcProxyLayer;
|
||||
use axum::{
|
||||
Router,
|
||||
Json, Router,
|
||||
extract::DefaultBodyLimit,
|
||||
http::Method,
|
||||
middleware,
|
||||
routing::{any, get, post},
|
||||
routing::{get, post},
|
||||
};
|
||||
use http::StatusCode;
|
||||
use serde_json::json;
|
||||
use state::AppState;
|
||||
use tower::{Layer, ServiceBuilder};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::services::{ServeDir, ServeFile};
|
||||
|
||||
@@ -494,8 +498,10 @@ pub fn app(state: AppState) -> Router {
|
||||
.route(
|
||||
"/app.bsky.unspecced.getAgeAssuranceState",
|
||||
get(api::age_assurance::get_age_assurance_state),
|
||||
)
|
||||
.route("/{*method}", any(api::proxy::proxy_handler));
|
||||
);
|
||||
let xrpc_service = ServiceBuilder::new()
|
||||
.layer(XrpcProxyLayer::new(state.clone()))
|
||||
.service(xrpc_router.with_state(state.clone()));
|
||||
|
||||
let oauth_router = Router::new()
|
||||
.route("/jwks", get(oauth::endpoints::oauth_jwks))
|
||||
@@ -559,7 +565,7 @@ pub fn app(state: AppState) -> Router {
|
||||
);
|
||||
|
||||
let router = Router::new()
|
||||
.nest("/xrpc", xrpc_router)
|
||||
.nest_service("/xrpc", xrpc_service)
|
||||
.nest("/oauth", oauth_router)
|
||||
.route("/metrics", get(metrics::metrics_handler))
|
||||
.route("/health", get(api::server::health))
|
||||
|
||||
Reference in New Issue
Block a user