Check did:web on account create

This commit is contained in:
Lewis
2025-12-06 15:18:57 +02:00
parent 9da8823d18
commit 7b90694066
3 changed files with 105 additions and 7 deletions

View File

@@ -20,7 +20,6 @@ Lewis' corrected big boy todofile
- [x] Initialize user repository (Root commit).
- [x] Return access JWT and DID.
- [x] Create DID for new user (did:web).
- [ ] Implement all TODOs regarding did:webs.
- [x] Session Management
- [x] Implement `com.atproto.server.createSession` (Login).
- [x] Implement `com.atproto.server.getSession`.
@@ -138,4 +137,5 @@ Lewis' corrected big boy todofile
- [ ] Implement CAR (Content Addressable Archive) encoding/decoding.
- [ ] Validation
- [ ] DID PLC Operations (Sign rotation keys).
- [ ] Fix any remaining TODOs in the code, everywhere, full stop.

View File

@@ -16,6 +16,7 @@ use std::sync::Arc;
use k256::SecretKey;
use rand::rngs::OsRng;
use base64::Engine;
use reqwest;
#[derive(Deserialize)]
pub struct CreateAccountInput {
@@ -50,10 +51,9 @@ pub async fn create_account(
format!("did:plc:{}", uuid::Uuid::new_v4())
} else {
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
let _expected_prefix = format!("did:web:{}", hostname);
// TODO: should verify we are the authority for it if it matches our hostname.
// TODO: if it's an external did:web, we should technically verify ownership via ServiceAuth, but skipping for now.
if let Err(e) = verify_did_web(d, &hostname, &input.handle).await {
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidDid", "message": e}))).into_response();
}
d.clone()
}
} else {
@@ -352,3 +352,73 @@ pub async fn user_did_doc(
}]
})).into_response()
}
async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
let expected_prefix = if hostname.contains(':') {
format!("did:web:{}", hostname.replace(':', "%3A"))
} else {
format!("did:web:{}", hostname)
};
if did.starts_with(&expected_prefix) {
let suffix = &did[expected_prefix.len()..];
let expected_suffix = format!(":u:{}", handle);
if suffix == expected_suffix {
Ok(())
} else {
Err(format!("Invalid DID path for this PDS. Expected {}", expected_suffix))
}
} else {
let parts: Vec<&str> = did.split(':').collect();
if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
return Err("Invalid did:web format".into());
}
let domain_segment = parts[2];
let domain = domain_segment.replace("%3A", ":");
let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
"http"
} else {
"https"
};
let url = if parts.len() == 3 {
format!("{}://{}/.well-known/did.json", scheme, domain)
} else {
let path = parts[3..].join("/");
format!("{}://{}/{}/did.json", scheme, domain, path)
};
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.map_err(|e| format!("Failed to create client: {}", e))?;
let resp = client.get(&url).send().await
.map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
if !resp.status().is_success() {
return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
}
let doc: serde_json::Value = resp.json().await
.map_err(|e| format!("Failed to parse DID doc: {}", e))?;
let services = doc["service"].as_array()
.ok_or("No services found in DID doc")?;
let pds_endpoint = format!("https://{}", hostname);
let has_valid_service = services.iter().any(|s| {
s["type"] == "AtprotoPersonalDataServer" &&
s["serviceEndpoint"] == pds_endpoint
});
if has_valid_service {
Ok(())
} else {
Err(format!("DID document does not list this PDS ({}) as AtprotoPersonalDataServer", pds_endpoint))
}
}
}

View File

@@ -2,6 +2,8 @@ mod common;
use common::*;
use reqwest::StatusCode;
use serde_json::{json, Value};
use wiremock::{MockServer, Mock, ResponseTemplate};
use wiremock::matchers::{method, path};
// #[tokio::test]
// async fn test_resolve_handle() {
@@ -36,9 +38,31 @@ async fn test_well_known_did() {
async fn test_create_did_web_account_and_resolve() {
let client = client();
let mock_server = MockServer::start().await;
let mock_uri = mock_server.uri();
let mock_addr = mock_uri.trim_start_matches("http://");
let did = format!("did:web:{}", mock_addr.replace(":", "%3A"));
let handle = format!("webuser_{}", uuid::Uuid::new_v4());
let did = format!("did:web:example.com:u:{}", handle);
let pds_endpoint = "https://localhost";
let did_doc = json!({
"@context": ["https://www.w3.org/ns/did/v1"],
"id": did,
"service": [{
"id": "#atproto_pds",
"type": "AtprotoPersonalDataServer",
"serviceEndpoint": pds_endpoint
}]
});
Mock::given(method("GET"))
.and(path("/.well-known/did.json"))
.respond_with(ResponseTemplate::new(200).set_body_json(did_doc))
.mount(&mock_server)
.await;
let payload = json!({
"handle": handle,
@@ -53,7 +77,11 @@ async fn test_create_did_web_account_and_resolve() {
.await
.expect("Failed to send request");
assert_eq!(res.status(), StatusCode::OK);
if res.status() != StatusCode::OK {
let status = res.status();
let body: Value = res.json().await.unwrap_or(json!({"error": "could not parse body"}));
panic!("createAccount failed with status {}: {:?}", status, body);
}
let body: Value = res.json().await.expect("createAccount response was not JSON");
assert_eq!(body["did"], did);