Files
at-container-registry/pkg/auth/proxy/assertion.go
2025-12-18 23:23:38 -06:00

323 lines
8.8 KiB
Go

// Package proxy provides proxy assertion creation and validation for trusted proxy authentication.
// Proxy assertions allow AppView to vouch for users when communicating with Hold services,
// eliminating the need for per-request service token validation.
package proxy
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"github.com/bluesky-social/indigo/atproto/atcrypto"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/golang-jwt/jwt/v5"
"atcr.io/pkg/atproto"
)
// ProxyAssertionClaims represents the claims in a proxy assertion JWT
type ProxyAssertionClaims struct {
jwt.RegisteredClaims
UserDID string `json:"user_did"` // User being proxied (for clarity, also in sub)
AuthMethod string `json:"auth_method"` // Original auth method: "oauth", "app_password", "service_token"
Proof string `json:"proof"` // Original token (truncated hash for audit, not full token)
}
// Asserter creates proxy assertions signed by AppView
type Asserter struct {
proxyDID string // AppView's DID (e.g., "did:web:atcr.io")
signingKey *atcrypto.PrivateKeyK256 // AppView's K-256 signing key
}
// NewAsserter creates a new proxy assertion creator
func NewAsserter(proxyDID string, signingKey *atcrypto.PrivateKeyK256) *Asserter {
return &Asserter{
proxyDID: proxyDID,
signingKey: signingKey,
}
}
// CreateAssertion creates a proxy assertion JWT for a user
// userDID: the user being proxied
// holdDID: the target hold service
// authMethod: how the user authenticated ("oauth", "app_password", "service_token")
// proofHash: a hash of the original authentication proof (for audit trail)
func (a *Asserter) CreateAssertion(userDID, holdDID, authMethod, proofHash string) (string, error) {
now := time.Now()
claims := ProxyAssertionClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: a.proxyDID,
Subject: userDID,
Audience: jwt.ClaimStrings{holdDID},
ExpiresAt: jwt.NewNumericDate(now.Add(60 * time.Second)), // Short-lived
IssuedAt: jwt.NewNumericDate(now),
},
UserDID: userDID,
AuthMethod: authMethod,
Proof: proofHash,
}
// Create JWT header
header := map[string]string{
"alg": "ES256K",
"typ": "JWT",
}
// Encode header
headerJSON, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("failed to marshal header: %w", err)
}
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
// Encode payload
payloadJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal claims: %w", err)
}
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
// Create signing input
signingInput := headerB64 + "." + payloadB64
// Sign using K-256
signature, err := a.signingKey.HashAndSign([]byte(signingInput))
if err != nil {
return "", fmt.Errorf("failed to sign assertion: %w", err)
}
// Encode signature
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
// Combine into JWT
token := signingInput + "." + signatureB64
slog.Debug("Created proxy assertion",
"proxyDID", a.proxyDID,
"userDID", userDID,
"holdDID", holdDID,
"authMethod", authMethod)
return token, nil
}
// ValidatedUser represents a validated proxy assertion issuer
type ValidatedUser struct {
DID string // User DID from sub claim
ProxyDID string // Proxy DID from iss claim
AuthMethod string // Original auth method
}
// Validator validates proxy assertions from trusted proxies
type Validator struct {
trustedProxies []string // List of trusted proxy DIDs
pubKeyCache *publicKeyCache // Cache for proxy public keys
}
// NewValidator creates a new proxy assertion validator
func NewValidator(trustedProxies []string) *Validator {
return &Validator{
trustedProxies: trustedProxies,
pubKeyCache: newPublicKeyCache(24 * time.Hour), // Cache public keys for 24 hours
}
}
// ValidateAssertion validates a proxy assertion JWT
// Returns the validated user info if successful
func (v *Validator) ValidateAssertion(ctx context.Context, tokenString, holdDID string) (*ValidatedUser, error) {
// Parse JWT parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
}
// Decode payload
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %w", err)
}
// Parse claims
var claims ProxyAssertionClaims
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
// Get issuer (proxy DID)
proxyDID := claims.Issuer
if proxyDID == "" {
return nil, fmt.Errorf("missing iss claim")
}
// Check if issuer is trusted
if !v.isTrustedProxy(proxyDID) {
return nil, fmt.Errorf("proxy %s not in trustedProxies", proxyDID)
}
// Verify audience matches this hold
audiences, err := claims.GetAudience()
if err != nil {
return nil, fmt.Errorf("failed to get audience: %w", err)
}
if len(audiences) == 0 || audiences[0] != holdDID {
return nil, fmt.Errorf("audience mismatch: expected %s, got %v", holdDID, audiences)
}
// Verify expiration
exp, err := claims.GetExpirationTime()
if err != nil {
return nil, fmt.Errorf("failed to get expiration: %w", err)
}
if exp != nil && time.Now().After(exp.Time) {
return nil, fmt.Errorf("assertion has expired")
}
// Fetch proxy's public key (with caching)
publicKey, err := v.getProxyPublicKey(ctx, proxyDID)
if err != nil {
return nil, fmt.Errorf("failed to fetch public key for proxy %s: %w", proxyDID, err)
}
// Verify signature
signedData := []byte(parts[0] + "." + parts[1])
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("failed to decode signature: %w", err)
}
if err := publicKey.HashAndVerify(signedData, signature); err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
// Get user DID from sub claim
userDID := claims.Subject
if userDID == "" {
userDID = claims.UserDID // Fallback to explicit field
}
if userDID == "" {
return nil, fmt.Errorf("missing user DID in assertion")
}
slog.Debug("Validated proxy assertion",
"proxyDID", proxyDID,
"userDID", userDID,
"authMethod", claims.AuthMethod)
return &ValidatedUser{
DID: userDID,
ProxyDID: proxyDID,
AuthMethod: claims.AuthMethod,
}, nil
}
// isTrustedProxy checks if a proxy DID is in the trusted list
func (v *Validator) isTrustedProxy(proxyDID string) bool {
for _, trusted := range v.trustedProxies {
if trusted == proxyDID {
return true
}
}
return false
}
// getProxyPublicKey fetches and caches a proxy's public key
func (v *Validator) getProxyPublicKey(ctx context.Context, proxyDID string) (atcrypto.PublicKey, error) {
// Check cache first
if key := v.pubKeyCache.get(proxyDID); key != nil {
return key, nil
}
// Fetch from DID document
key, err := fetchPublicKeyFromDID(ctx, proxyDID)
if err != nil {
return nil, err
}
// Cache the key
v.pubKeyCache.set(proxyDID, key)
return key, nil
}
// publicKeyCache caches public keys for proxy DIDs
type publicKeyCache struct {
mu sync.RWMutex
entries map[string]cacheEntry
ttl time.Duration
}
type cacheEntry struct {
key atcrypto.PublicKey
expiresAt time.Time
}
func newPublicKeyCache(ttl time.Duration) *publicKeyCache {
return &publicKeyCache{
entries: make(map[string]cacheEntry),
ttl: ttl,
}
}
func (c *publicKeyCache) get(did string) atcrypto.PublicKey {
c.mu.RLock()
defer c.mu.RUnlock()
entry, ok := c.entries[did]
if !ok || time.Now().After(entry.expiresAt) {
return nil
}
return entry.key
}
func (c *publicKeyCache) set(did string, key atcrypto.PublicKey) {
c.mu.Lock()
defer c.mu.Unlock()
c.entries[did] = cacheEntry{
key: key,
expiresAt: time.Now().Add(c.ttl),
}
}
// fetchPublicKeyFromDID fetches a public key from a DID document
func fetchPublicKeyFromDID(ctx context.Context, did string) (atcrypto.PublicKey, error) {
directory := atproto.GetDirectory()
atID, err := syntax.ParseAtIdentifier(did)
if err != nil {
return nil, fmt.Errorf("invalid DID format: %w", err)
}
ident, err := directory.Lookup(ctx, *atID)
if err != nil {
return nil, fmt.Errorf("failed to resolve DID: %w", err)
}
publicKey, err := ident.PublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get public key from DID: %w", err)
}
return publicKey, nil
}
// HashProofForAudit creates a truncated hash of a token for audit purposes
// This allows tracking without storing the full sensitive token
func HashProofForAudit(token string) string {
if token == "" {
return ""
}
// Use first 16 chars of a simple hash (not cryptographic, just for tracking)
// We don't need security here, just a way to correlate requests
hash := 0
for _, c := range token {
hash = hash*31 + int(c)
}
return fmt.Sprintf("%016x", uint64(hash))
}