diff --git a/TODO.md b/TODO.md index 708e330..3bbfa3e 100644 --- a/TODO.md +++ b/TODO.md @@ -20,7 +20,6 @@ Lewis' corrected big boy todofile - [x] Initialize user repository (Root commit). - [x] Return access JWT and DID. - [x] Create DID for new user (did:web). - - [ ] Implement all TODOs regarding did:webs. - [x] Session Management - [x] Implement `com.atproto.server.createSession` (Login). - [x] Implement `com.atproto.server.getSession`. @@ -138,4 +137,5 @@ Lewis' corrected big boy todofile - [ ] Implement CAR (Content Addressable Archive) encoding/decoding. - [ ] Validation - [ ] DID PLC Operations (Sign rotation keys). +- [ ] Fix any remaining TODOs in the code, everywhere, full stop. diff --git a/src/api/identity.rs b/src/api/identity.rs index 8931c97..1c990e7 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use k256::SecretKey; use rand::rngs::OsRng; use base64::Engine; +use reqwest; #[derive(Deserialize)] pub struct CreateAccountInput { @@ -50,10 +51,9 @@ pub async fn create_account( format!("did:plc:{}", uuid::Uuid::new_v4()) } else { let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); - let _expected_prefix = format!("did:web:{}", hostname); - - // TODO: should verify we are the authority for it if it matches our hostname. - // TODO: if it's an external did:web, we should technically verify ownership via ServiceAuth, but skipping for now. + if let Err(e) = verify_did_web(d, &hostname, &input.handle).await { + return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidDid", "message": e}))).into_response(); + } d.clone() } } else { @@ -352,3 +352,73 @@ pub async fn user_did_doc( }] })).into_response() } + +async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { + let expected_prefix = if hostname.contains(':') { + format!("did:web:{}", hostname.replace(':', "%3A")) + } else { + format!("did:web:{}", hostname) + }; + + if did.starts_with(&expected_prefix) { + let suffix = &did[expected_prefix.len()..]; + let expected_suffix = format!(":u:{}", handle); + if suffix == expected_suffix { + Ok(()) + } else { + Err(format!("Invalid DID path for this PDS. Expected {}", expected_suffix)) + } + } else { + let parts: Vec<&str> = did.split(':').collect(); + if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { + return Err("Invalid did:web format".into()); + } + + let domain_segment = parts[2]; + let domain = domain_segment.replace("%3A", ":"); + + let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { + "http" + } else { + "https" + }; + + let url = if parts.len() == 3 { + format!("{}://{}/.well-known/did.json", scheme, domain) + } else { + let path = parts[3..].join("/"); + format!("{}://{}/{}/did.json", scheme, domain, path) + }; + + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(5)) + .build() + .map_err(|e| format!("Failed to create client: {}", e))?; + + let resp = client.get(&url).send().await + .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; + + if !resp.status().is_success() { + return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); + } + + let doc: serde_json::Value = resp.json().await + .map_err(|e| format!("Failed to parse DID doc: {}", e))?; + + let services = doc["service"].as_array() + .ok_or("No services found in DID doc")?; + + let pds_endpoint = format!("https://{}", hostname); + + let has_valid_service = services.iter().any(|s| { + s["type"] == "AtprotoPersonalDataServer" && + s["serviceEndpoint"] == pds_endpoint + }); + + if has_valid_service { + Ok(()) + } else { + Err(format!("DID document does not list this PDS ({}) as AtprotoPersonalDataServer", pds_endpoint)) + } + } +} diff --git a/tests/identity.rs b/tests/identity.rs index 52eb47f..6b45c35 100644 --- a/tests/identity.rs +++ b/tests/identity.rs @@ -2,6 +2,8 @@ mod common; use common::*; use reqwest::StatusCode; use serde_json::{json, Value}; +use wiremock::{MockServer, Mock, ResponseTemplate}; +use wiremock::matchers::{method, path}; // #[tokio::test] // async fn test_resolve_handle() { @@ -36,9 +38,31 @@ async fn test_well_known_did() { async fn test_create_did_web_account_and_resolve() { let client = client(); + let mock_server = MockServer::start().await; + let mock_uri = mock_server.uri(); + let mock_addr = mock_uri.trim_start_matches("http://"); + + let did = format!("did:web:{}", mock_addr.replace(":", "%3A")); + let handle = format!("webuser_{}", uuid::Uuid::new_v4()); - let did = format!("did:web:example.com:u:{}", handle); + let pds_endpoint = "https://localhost"; + + let did_doc = json!({ + "@context": ["https://www.w3.org/ns/did/v1"], + "id": did, + "service": [{ + "id": "#atproto_pds", + "type": "AtprotoPersonalDataServer", + "serviceEndpoint": pds_endpoint + }] + }); + + Mock::given(method("GET")) + .and(path("/.well-known/did.json")) + .respond_with(ResponseTemplate::new(200).set_body_json(did_doc)) + .mount(&mock_server) + .await; let payload = json!({ "handle": handle, @@ -53,7 +77,11 @@ async fn test_create_did_web_account_and_resolve() { .await .expect("Failed to send request"); - assert_eq!(res.status(), StatusCode::OK); + if res.status() != StatusCode::OK { + let status = res.status(); + let body: Value = res.json().await.unwrap_or(json!({"error": "could not parse body"})); + panic!("createAccount failed with status {}: {:?}", status, body); + } let body: Value = res.json().await.expect("createAccount response was not JSON"); assert_eq!(body["did"], did);