// Package proxy provides proxy assertion creation and validation for trusted proxy authentication. // Proxy assertions allow AppView to vouch for users when communicating with Hold services, // eliminating the need for per-request service token validation. package proxy import ( "context" "encoding/base64" "encoding/json" "fmt" "log/slog" "strings" "sync" "time" "github.com/bluesky-social/indigo/atproto/atcrypto" "github.com/bluesky-social/indigo/atproto/syntax" "github.com/golang-jwt/jwt/v5" "atcr.io/pkg/atproto" ) // ProxyAssertionClaims represents the claims in a proxy assertion JWT type ProxyAssertionClaims struct { jwt.RegisteredClaims UserDID string `json:"user_did"` // User being proxied (for clarity, also in sub) AuthMethod string `json:"auth_method"` // Original auth method: "oauth", "app_password", "service_token" Proof string `json:"proof"` // Original token (truncated hash for audit, not full token) } // Asserter creates proxy assertions signed by AppView type Asserter struct { proxyDID string // AppView's DID (e.g., "did:web:atcr.io") signingKey *atcrypto.PrivateKeyK256 // AppView's K-256 signing key } // NewAsserter creates a new proxy assertion creator func NewAsserter(proxyDID string, signingKey *atcrypto.PrivateKeyK256) *Asserter { return &Asserter{ proxyDID: proxyDID, signingKey: signingKey, } } // CreateAssertion creates a proxy assertion JWT for a user // userDID: the user being proxied // holdDID: the target hold service // authMethod: how the user authenticated ("oauth", "app_password", "service_token") // proofHash: a hash of the original authentication proof (for audit trail) func (a *Asserter) CreateAssertion(userDID, holdDID, authMethod, proofHash string) (string, error) { now := time.Now() claims := ProxyAssertionClaims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: a.proxyDID, Subject: userDID, Audience: jwt.ClaimStrings{holdDID}, ExpiresAt: jwt.NewNumericDate(now.Add(60 * time.Second)), // Short-lived IssuedAt: jwt.NewNumericDate(now), }, UserDID: userDID, AuthMethod: authMethod, Proof: proofHash, } // Create JWT header header := map[string]string{ "alg": "ES256K", "typ": "JWT", } // Encode header headerJSON, err := json.Marshal(header) if err != nil { return "", fmt.Errorf("failed to marshal header: %w", err) } headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) // Encode payload payloadJSON, err := json.Marshal(claims) if err != nil { return "", fmt.Errorf("failed to marshal claims: %w", err) } payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON) // Create signing input signingInput := headerB64 + "." + payloadB64 // Sign using K-256 signature, err := a.signingKey.HashAndSign([]byte(signingInput)) if err != nil { return "", fmt.Errorf("failed to sign assertion: %w", err) } // Encode signature signatureB64 := base64.RawURLEncoding.EncodeToString(signature) // Combine into JWT token := signingInput + "." + signatureB64 slog.Debug("Created proxy assertion", "proxyDID", a.proxyDID, "userDID", userDID, "holdDID", holdDID, "authMethod", authMethod) return token, nil } // ValidatedUser represents a validated proxy assertion issuer type ValidatedUser struct { DID string // User DID from sub claim ProxyDID string // Proxy DID from iss claim AuthMethod string // Original auth method } // Validator validates proxy assertions from trusted proxies type Validator struct { trustedProxies []string // List of trusted proxy DIDs pubKeyCache *publicKeyCache // Cache for proxy public keys } // NewValidator creates a new proxy assertion validator func NewValidator(trustedProxies []string) *Validator { return &Validator{ trustedProxies: trustedProxies, pubKeyCache: newPublicKeyCache(24 * time.Hour), // Cache public keys for 24 hours } } // ValidateAssertion validates a proxy assertion JWT // Returns the validated user info if successful func (v *Validator) ValidateAssertion(ctx context.Context, tokenString, holdDID string) (*ValidatedUser, error) { // Parse JWT parts parts := strings.Split(tokenString, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWT format") } // Decode payload payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("failed to decode payload: %w", err) } // Parse claims var claims ProxyAssertionClaims if err := json.Unmarshal(payloadBytes, &claims); err != nil { return nil, fmt.Errorf("failed to unmarshal claims: %w", err) } // Get issuer (proxy DID) proxyDID := claims.Issuer if proxyDID == "" { return nil, fmt.Errorf("missing iss claim") } // Check if issuer is trusted if !v.isTrustedProxy(proxyDID) { return nil, fmt.Errorf("proxy %s not in trustedProxies", proxyDID) } // Verify audience matches this hold audiences, err := claims.GetAudience() if err != nil { return nil, fmt.Errorf("failed to get audience: %w", err) } if len(audiences) == 0 || audiences[0] != holdDID { return nil, fmt.Errorf("audience mismatch: expected %s, got %v", holdDID, audiences) } // Verify expiration exp, err := claims.GetExpirationTime() if err != nil { return nil, fmt.Errorf("failed to get expiration: %w", err) } if exp != nil && time.Now().After(exp.Time) { return nil, fmt.Errorf("assertion has expired") } // Fetch proxy's public key (with caching) publicKey, err := v.getProxyPublicKey(ctx, proxyDID) if err != nil { return nil, fmt.Errorf("failed to fetch public key for proxy %s: %w", proxyDID, err) } // Verify signature signedData := []byte(parts[0] + "." + parts[1]) signature, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return nil, fmt.Errorf("failed to decode signature: %w", err) } if err := publicKey.HashAndVerify(signedData, signature); err != nil { return nil, fmt.Errorf("signature verification failed: %w", err) } // Get user DID from sub claim userDID := claims.Subject if userDID == "" { userDID = claims.UserDID // Fallback to explicit field } if userDID == "" { return nil, fmt.Errorf("missing user DID in assertion") } slog.Debug("Validated proxy assertion", "proxyDID", proxyDID, "userDID", userDID, "authMethod", claims.AuthMethod) return &ValidatedUser{ DID: userDID, ProxyDID: proxyDID, AuthMethod: claims.AuthMethod, }, nil } // isTrustedProxy checks if a proxy DID is in the trusted list func (v *Validator) isTrustedProxy(proxyDID string) bool { for _, trusted := range v.trustedProxies { if trusted == proxyDID { return true } } return false } // getProxyPublicKey fetches and caches a proxy's public key func (v *Validator) getProxyPublicKey(ctx context.Context, proxyDID string) (atcrypto.PublicKey, error) { // Check cache first if key := v.pubKeyCache.get(proxyDID); key != nil { return key, nil } // Fetch from DID document key, err := fetchPublicKeyFromDID(ctx, proxyDID) if err != nil { return nil, err } // Cache the key v.pubKeyCache.set(proxyDID, key) return key, nil } // publicKeyCache caches public keys for proxy DIDs type publicKeyCache struct { mu sync.RWMutex entries map[string]cacheEntry ttl time.Duration } type cacheEntry struct { key atcrypto.PublicKey expiresAt time.Time } func newPublicKeyCache(ttl time.Duration) *publicKeyCache { return &publicKeyCache{ entries: make(map[string]cacheEntry), ttl: ttl, } } func (c *publicKeyCache) get(did string) atcrypto.PublicKey { c.mu.RLock() defer c.mu.RUnlock() entry, ok := c.entries[did] if !ok || time.Now().After(entry.expiresAt) { return nil } return entry.key } func (c *publicKeyCache) set(did string, key atcrypto.PublicKey) { c.mu.Lock() defer c.mu.Unlock() c.entries[did] = cacheEntry{ key: key, expiresAt: time.Now().Add(c.ttl), } } // fetchPublicKeyFromDID fetches a public key from a DID document func fetchPublicKeyFromDID(ctx context.Context, did string) (atcrypto.PublicKey, error) { directory := atproto.GetDirectory() atID, err := syntax.ParseAtIdentifier(did) if err != nil { return nil, fmt.Errorf("invalid DID format: %w", err) } ident, err := directory.Lookup(ctx, *atID) if err != nil { return nil, fmt.Errorf("failed to resolve DID: %w", err) } publicKey, err := ident.PublicKey() if err != nil { return nil, fmt.Errorf("failed to get public key from DID: %w", err) } return publicKey, nil } // HashProofForAudit creates a truncated hash of a token for audit purposes // This allows tracking without storing the full sensitive token func HashProofForAudit(token string) string { if token == "" { return "" } // Use first 16 chars of a simple hash (not cryptographic, just for tracking) // We don't need security here, just a way to correlate requests hash := 0 for _, c := range token { hash = hash*31 + int(c) } return fmt.Sprintf("%016x", uint64(hash)) }