Files
at-container-registry/pkg/auth/serviceauth/validator.go
2025-12-18 23:23:38 -06:00

224 lines
5.9 KiB
Go

// Package serviceauth provides service token validation for ATProto service authentication.
// Service tokens are JWTs issued by a user's PDS via com.atproto.server.getServiceAuth.
// They allow services to authenticate users on behalf of other services.
package serviceauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
"github.com/bluesky-social/indigo/atproto/atcrypto"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/golang-jwt/jwt/v5"
"atcr.io/pkg/atproto"
)
// ValidatedUser represents a validated user from a service token
type ValidatedUser struct {
DID string // User DID (from iss claim - the user's PDS signed this token for the user)
}
// ServiceTokenClaims represents the claims in an ATProto service token
type ServiceTokenClaims struct {
jwt.RegisteredClaims
Lxm string `json:"lxm,omitempty"` // Lexicon method identifier (e.g., "io.atcr.registry.push")
}
// Validator validates ATProto service tokens
type Validator struct {
serviceDID string // This service's DID (expected in aud claim)
pubKeyCache *publicKeyCache // Cache for public keys
}
// NewValidator creates a new service token validator
// serviceDID is the DID of this service (e.g., "did:web:atcr.io")
// Tokens will be validated to ensure they are intended for this service (aud claim)
func NewValidator(serviceDID string) *Validator {
return &Validator{
serviceDID: serviceDID,
pubKeyCache: newPublicKeyCache(24 * time.Hour),
}
}
// Validate validates a service token and returns the authenticated user
// tokenString is the raw JWT token (without "Bearer " prefix)
// Returns the user DID if validation succeeds
func (v *Validator) Validate(ctx context.Context, tokenString string) (*ValidatedUser, error) {
// Parse JWT parts manually (golang-jwt doesn't support ES256K algorithm used by ATProto)
parts := splitJWT(tokenString)
if parts == nil {
return nil, fmt.Errorf("invalid JWT format")
}
// Decode payload to extract claims
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}
// Parse claims
var claims ServiceTokenClaims
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
// Get issuer DID (the user's DID - they own the PDS that issued this token)
issuerDID := claims.Issuer
if issuerDID == "" {
return nil, fmt.Errorf("missing iss claim")
}
// Verify audience matches this service
audiences, err := claims.GetAudience()
if err != nil {
return nil, fmt.Errorf("failed to get audience: %w", err)
}
if len(audiences) == 0 || audiences[0] != v.serviceDID {
return nil, fmt.Errorf("audience mismatch: expected %s, got %v", v.serviceDID, audiences)
}
// Verify expiration
exp, err := claims.GetExpirationTime()
if err != nil {
return nil, fmt.Errorf("failed to get expiration: %w", err)
}
if exp != nil && time.Now().After(exp.Time) {
return nil, fmt.Errorf("token has expired")
}
// Fetch public key from issuer's DID document (with caching)
publicKey, err := v.getPublicKey(ctx, issuerDID)
if err != nil {
return nil, fmt.Errorf("failed to fetch public key for issuer %s: %w", issuerDID, err)
}
// Verify signature using ATProto's secp256k1 crypto
signedData := []byte(parts[0] + "." + parts[1])
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("failed to decode signature: %w", err)
}
if err := publicKey.HashAndVerify(signedData, signature); err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
slog.Debug("Successfully validated service token",
"userDID", issuerDID,
"serviceDID", v.serviceDID)
return &ValidatedUser{
DID: issuerDID,
}, nil
}
// splitJWT splits a JWT into its three parts
// Returns nil if the format is invalid
func splitJWT(token string) []string {
parts := make([]string, 0, 3)
start := 0
count := 0
for i, c := range token {
if c == '.' {
parts = append(parts, token[start:i])
start = i + 1
count++
}
}
// Add the final part
parts = append(parts, token[start:])
if len(parts) != 3 {
return nil
}
return parts
}
// getPublicKey fetches and caches a public key for a DID
func (v *Validator) getPublicKey(ctx context.Context, did string) (atcrypto.PublicKey, error) {
// Check cache first
if key := v.pubKeyCache.get(did); key != nil {
return key, nil
}
// Fetch from DID document
key, err := fetchPublicKeyFromDID(ctx, did)
if err != nil {
return nil, err
}
// Cache the key
v.pubKeyCache.set(did, key)
return key, nil
}
// fetchPublicKeyFromDID fetches the public key from a DID document
func fetchPublicKeyFromDID(ctx context.Context, did string) (atcrypto.PublicKey, error) {
directory := atproto.GetDirectory()
atID, err := syntax.ParseAtIdentifier(did)
if err != nil {
return nil, fmt.Errorf("invalid DID format: %w", err)
}
ident, err := directory.Lookup(ctx, *atID)
if err != nil {
return nil, fmt.Errorf("failed to resolve DID: %w", err)
}
publicKey, err := ident.PublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get public key from DID: %w", err)
}
return publicKey, nil
}
// publicKeyCache caches public keys for DIDs
type publicKeyCache struct {
mu sync.RWMutex
entries map[string]cacheEntry
ttl time.Duration
}
type cacheEntry struct {
key atcrypto.PublicKey
expiresAt time.Time
}
func newPublicKeyCache(ttl time.Duration) *publicKeyCache {
return &publicKeyCache{
entries: make(map[string]cacheEntry),
ttl: ttl,
}
}
func (c *publicKeyCache) get(did string) atcrypto.PublicKey {
c.mu.RLock()
defer c.mu.RUnlock()
entry, ok := c.entries[did]
if !ok || time.Now().After(entry.expiresAt) {
return nil
}
return entry.key
}
func (c *publicKeyCache) set(did string, key atcrypto.PublicKey) {
c.mu.Lock()
defer c.mu.Unlock()
c.entries[did] = cacheEntry{
key: key,
expiresAt: time.Now().Add(c.ttl),
}
}