323 lines
8.8 KiB
Go
323 lines
8.8 KiB
Go
// Package proxy provides proxy assertion creation and validation for trusted proxy authentication.
|
|
// Proxy assertions allow AppView to vouch for users when communicating with Hold services,
|
|
// eliminating the need for per-request service token validation.
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bluesky-social/indigo/atproto/atcrypto"
|
|
"github.com/bluesky-social/indigo/atproto/syntax"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
"atcr.io/pkg/atproto"
|
|
)
|
|
|
|
// ProxyAssertionClaims represents the claims in a proxy assertion JWT
|
|
type ProxyAssertionClaims struct {
|
|
jwt.RegisteredClaims
|
|
UserDID string `json:"user_did"` // User being proxied (for clarity, also in sub)
|
|
AuthMethod string `json:"auth_method"` // Original auth method: "oauth", "app_password", "service_token"
|
|
Proof string `json:"proof"` // Original token (truncated hash for audit, not full token)
|
|
}
|
|
|
|
// Asserter creates proxy assertions signed by AppView
|
|
type Asserter struct {
|
|
proxyDID string // AppView's DID (e.g., "did:web:atcr.io")
|
|
signingKey *atcrypto.PrivateKeyK256 // AppView's K-256 signing key
|
|
}
|
|
|
|
// NewAsserter creates a new proxy assertion creator
|
|
func NewAsserter(proxyDID string, signingKey *atcrypto.PrivateKeyK256) *Asserter {
|
|
return &Asserter{
|
|
proxyDID: proxyDID,
|
|
signingKey: signingKey,
|
|
}
|
|
}
|
|
|
|
// CreateAssertion creates a proxy assertion JWT for a user
|
|
// userDID: the user being proxied
|
|
// holdDID: the target hold service
|
|
// authMethod: how the user authenticated ("oauth", "app_password", "service_token")
|
|
// proofHash: a hash of the original authentication proof (for audit trail)
|
|
func (a *Asserter) CreateAssertion(userDID, holdDID, authMethod, proofHash string) (string, error) {
|
|
now := time.Now()
|
|
|
|
claims := ProxyAssertionClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: a.proxyDID,
|
|
Subject: userDID,
|
|
Audience: jwt.ClaimStrings{holdDID},
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(60 * time.Second)), // Short-lived
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
},
|
|
UserDID: userDID,
|
|
AuthMethod: authMethod,
|
|
Proof: proofHash,
|
|
}
|
|
|
|
// Create JWT header
|
|
header := map[string]string{
|
|
"alg": "ES256K",
|
|
"typ": "JWT",
|
|
}
|
|
|
|
// Encode header
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal header: %w", err)
|
|
}
|
|
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
|
|
// Encode payload
|
|
payloadJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal claims: %w", err)
|
|
}
|
|
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
|
|
|
// Create signing input
|
|
signingInput := headerB64 + "." + payloadB64
|
|
|
|
// Sign using K-256
|
|
signature, err := a.signingKey.HashAndSign([]byte(signingInput))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign assertion: %w", err)
|
|
}
|
|
|
|
// Encode signature
|
|
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
|
|
|
|
// Combine into JWT
|
|
token := signingInput + "." + signatureB64
|
|
|
|
slog.Debug("Created proxy assertion",
|
|
"proxyDID", a.proxyDID,
|
|
"userDID", userDID,
|
|
"holdDID", holdDID,
|
|
"authMethod", authMethod)
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// ValidatedUser represents a validated proxy assertion issuer
|
|
type ValidatedUser struct {
|
|
DID string // User DID from sub claim
|
|
ProxyDID string // Proxy DID from iss claim
|
|
AuthMethod string // Original auth method
|
|
}
|
|
|
|
// Validator validates proxy assertions from trusted proxies
|
|
type Validator struct {
|
|
trustedProxies []string // List of trusted proxy DIDs
|
|
pubKeyCache *publicKeyCache // Cache for proxy public keys
|
|
}
|
|
|
|
// NewValidator creates a new proxy assertion validator
|
|
func NewValidator(trustedProxies []string) *Validator {
|
|
return &Validator{
|
|
trustedProxies: trustedProxies,
|
|
pubKeyCache: newPublicKeyCache(24 * time.Hour), // Cache public keys for 24 hours
|
|
}
|
|
}
|
|
|
|
// ValidateAssertion validates a proxy assertion JWT
|
|
// Returns the validated user info if successful
|
|
func (v *Validator) ValidateAssertion(ctx context.Context, tokenString, holdDID string) (*ValidatedUser, error) {
|
|
// Parse JWT parts
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid JWT format")
|
|
}
|
|
|
|
// Decode payload
|
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
|
}
|
|
|
|
// Parse claims
|
|
var claims ProxyAssertionClaims
|
|
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
|
}
|
|
|
|
// Get issuer (proxy DID)
|
|
proxyDID := claims.Issuer
|
|
if proxyDID == "" {
|
|
return nil, fmt.Errorf("missing iss claim")
|
|
}
|
|
|
|
// Check if issuer is trusted
|
|
if !v.isTrustedProxy(proxyDID) {
|
|
return nil, fmt.Errorf("proxy %s not in trustedProxies", proxyDID)
|
|
}
|
|
|
|
// Verify audience matches this hold
|
|
audiences, err := claims.GetAudience()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get audience: %w", err)
|
|
}
|
|
if len(audiences) == 0 || audiences[0] != holdDID {
|
|
return nil, fmt.Errorf("audience mismatch: expected %s, got %v", holdDID, audiences)
|
|
}
|
|
|
|
// Verify expiration
|
|
exp, err := claims.GetExpirationTime()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get expiration: %w", err)
|
|
}
|
|
if exp != nil && time.Now().After(exp.Time) {
|
|
return nil, fmt.Errorf("assertion has expired")
|
|
}
|
|
|
|
// Fetch proxy's public key (with caching)
|
|
publicKey, err := v.getProxyPublicKey(ctx, proxyDID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch public key for proxy %s: %w", proxyDID, err)
|
|
}
|
|
|
|
// Verify signature
|
|
signedData := []byte(parts[0] + "." + parts[1])
|
|
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode signature: %w", err)
|
|
}
|
|
|
|
if err := publicKey.HashAndVerify(signedData, signature); err != nil {
|
|
return nil, fmt.Errorf("signature verification failed: %w", err)
|
|
}
|
|
|
|
// Get user DID from sub claim
|
|
userDID := claims.Subject
|
|
if userDID == "" {
|
|
userDID = claims.UserDID // Fallback to explicit field
|
|
}
|
|
if userDID == "" {
|
|
return nil, fmt.Errorf("missing user DID in assertion")
|
|
}
|
|
|
|
slog.Debug("Validated proxy assertion",
|
|
"proxyDID", proxyDID,
|
|
"userDID", userDID,
|
|
"authMethod", claims.AuthMethod)
|
|
|
|
return &ValidatedUser{
|
|
DID: userDID,
|
|
ProxyDID: proxyDID,
|
|
AuthMethod: claims.AuthMethod,
|
|
}, nil
|
|
}
|
|
|
|
// isTrustedProxy checks if a proxy DID is in the trusted list
|
|
func (v *Validator) isTrustedProxy(proxyDID string) bool {
|
|
for _, trusted := range v.trustedProxies {
|
|
if trusted == proxyDID {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// getProxyPublicKey fetches and caches a proxy's public key
|
|
func (v *Validator) getProxyPublicKey(ctx context.Context, proxyDID string) (atcrypto.PublicKey, error) {
|
|
// Check cache first
|
|
if key := v.pubKeyCache.get(proxyDID); key != nil {
|
|
return key, nil
|
|
}
|
|
|
|
// Fetch from DID document
|
|
key, err := fetchPublicKeyFromDID(ctx, proxyDID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Cache the key
|
|
v.pubKeyCache.set(proxyDID, key)
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// publicKeyCache caches public keys for proxy DIDs
|
|
type publicKeyCache struct {
|
|
mu sync.RWMutex
|
|
entries map[string]cacheEntry
|
|
ttl time.Duration
|
|
}
|
|
|
|
type cacheEntry struct {
|
|
key atcrypto.PublicKey
|
|
expiresAt time.Time
|
|
}
|
|
|
|
func newPublicKeyCache(ttl time.Duration) *publicKeyCache {
|
|
return &publicKeyCache{
|
|
entries: make(map[string]cacheEntry),
|
|
ttl: ttl,
|
|
}
|
|
}
|
|
|
|
func (c *publicKeyCache) get(did string) atcrypto.PublicKey {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
entry, ok := c.entries[did]
|
|
if !ok || time.Now().After(entry.expiresAt) {
|
|
return nil
|
|
}
|
|
return entry.key
|
|
}
|
|
|
|
func (c *publicKeyCache) set(did string, key atcrypto.PublicKey) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
c.entries[did] = cacheEntry{
|
|
key: key,
|
|
expiresAt: time.Now().Add(c.ttl),
|
|
}
|
|
}
|
|
|
|
// fetchPublicKeyFromDID fetches a public key from a DID document
|
|
func fetchPublicKeyFromDID(ctx context.Context, did string) (atcrypto.PublicKey, error) {
|
|
directory := atproto.GetDirectory()
|
|
atID, err := syntax.ParseAtIdentifier(did)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid DID format: %w", err)
|
|
}
|
|
|
|
ident, err := directory.Lookup(ctx, *atID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve DID: %w", err)
|
|
}
|
|
|
|
publicKey, err := ident.PublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get public key from DID: %w", err)
|
|
}
|
|
|
|
return publicKey, nil
|
|
}
|
|
|
|
// HashProofForAudit creates a truncated hash of a token for audit purposes
|
|
// This allows tracking without storing the full sensitive token
|
|
func HashProofForAudit(token string) string {
|
|
if token == "" {
|
|
return ""
|
|
}
|
|
// Use first 16 chars of a simple hash (not cryptographic, just for tracking)
|
|
// We don't need security here, just a way to correlate requests
|
|
hash := 0
|
|
for _, c := range token {
|
|
hash = hash*31 + int(c)
|
|
}
|
|
return fmt.Sprintf("%016x", uint64(hash))
|
|
}
|