mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-20 00:20:31 +00:00
224 lines
5.9 KiB
Go
224 lines
5.9 KiB
Go
// Package serviceauth provides service token validation for ATProto service authentication.
|
|
// Service tokens are JWTs issued by a user's PDS via com.atproto.server.getServiceAuth.
|
|
// They allow services to authenticate users on behalf of other services.
|
|
package serviceauth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bluesky-social/indigo/atproto/atcrypto"
|
|
"github.com/bluesky-social/indigo/atproto/syntax"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
"atcr.io/pkg/atproto"
|
|
)
|
|
|
|
// ValidatedUser represents a validated user from a service token
|
|
type ValidatedUser struct {
|
|
DID string // User DID (from iss claim - the user's PDS signed this token for the user)
|
|
}
|
|
|
|
// ServiceTokenClaims represents the claims in an ATProto service token
|
|
type ServiceTokenClaims struct {
|
|
jwt.RegisteredClaims
|
|
Lxm string `json:"lxm,omitempty"` // Lexicon method identifier (e.g., "io.atcr.registry.push")
|
|
}
|
|
|
|
// Validator validates ATProto service tokens
|
|
type Validator struct {
|
|
serviceDID string // This service's DID (expected in aud claim)
|
|
pubKeyCache *publicKeyCache // Cache for public keys
|
|
}
|
|
|
|
// NewValidator creates a new service token validator
|
|
// serviceDID is the DID of this service (e.g., "did:web:atcr.io")
|
|
// Tokens will be validated to ensure they are intended for this service (aud claim)
|
|
func NewValidator(serviceDID string) *Validator {
|
|
return &Validator{
|
|
serviceDID: serviceDID,
|
|
pubKeyCache: newPublicKeyCache(24 * time.Hour),
|
|
}
|
|
}
|
|
|
|
// Validate validates a service token and returns the authenticated user
|
|
// tokenString is the raw JWT token (without "Bearer " prefix)
|
|
// Returns the user DID if validation succeeds
|
|
func (v *Validator) Validate(ctx context.Context, tokenString string) (*ValidatedUser, error) {
|
|
// Parse JWT parts manually (golang-jwt doesn't support ES256K algorithm used by ATProto)
|
|
parts := splitJWT(tokenString)
|
|
if parts == nil {
|
|
return nil, fmt.Errorf("invalid JWT format")
|
|
}
|
|
|
|
// Decode payload to extract claims
|
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
|
}
|
|
|
|
// Parse claims
|
|
var claims ServiceTokenClaims
|
|
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
|
}
|
|
|
|
// Get issuer DID (the user's DID - they own the PDS that issued this token)
|
|
issuerDID := claims.Issuer
|
|
if issuerDID == "" {
|
|
return nil, fmt.Errorf("missing iss claim")
|
|
}
|
|
|
|
// Verify audience matches this service
|
|
audiences, err := claims.GetAudience()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get audience: %w", err)
|
|
}
|
|
if len(audiences) == 0 || audiences[0] != v.serviceDID {
|
|
return nil, fmt.Errorf("audience mismatch: expected %s, got %v", v.serviceDID, audiences)
|
|
}
|
|
|
|
// Verify expiration
|
|
exp, err := claims.GetExpirationTime()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get expiration: %w", err)
|
|
}
|
|
if exp != nil && time.Now().After(exp.Time) {
|
|
return nil, fmt.Errorf("token has expired")
|
|
}
|
|
|
|
// Fetch public key from issuer's DID document (with caching)
|
|
publicKey, err := v.getPublicKey(ctx, issuerDID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch public key for issuer %s: %w", issuerDID, err)
|
|
}
|
|
|
|
// Verify signature using ATProto's secp256k1 crypto
|
|
signedData := []byte(parts[0] + "." + parts[1])
|
|
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode signature: %w", err)
|
|
}
|
|
|
|
if err := publicKey.HashAndVerify(signedData, signature); err != nil {
|
|
return nil, fmt.Errorf("signature verification failed: %w", err)
|
|
}
|
|
|
|
slog.Debug("Successfully validated service token",
|
|
"userDID", issuerDID,
|
|
"serviceDID", v.serviceDID)
|
|
|
|
return &ValidatedUser{
|
|
DID: issuerDID,
|
|
}, nil
|
|
}
|
|
|
|
// splitJWT splits a JWT into its three parts
|
|
// Returns nil if the format is invalid
|
|
func splitJWT(token string) []string {
|
|
parts := make([]string, 0, 3)
|
|
start := 0
|
|
count := 0
|
|
|
|
for i, c := range token {
|
|
if c == '.' {
|
|
parts = append(parts, token[start:i])
|
|
start = i + 1
|
|
count++
|
|
}
|
|
}
|
|
|
|
// Add the final part
|
|
parts = append(parts, token[start:])
|
|
|
|
if len(parts) != 3 {
|
|
return nil
|
|
}
|
|
return parts
|
|
}
|
|
|
|
// getPublicKey fetches and caches a public key for a DID
|
|
func (v *Validator) getPublicKey(ctx context.Context, did string) (atcrypto.PublicKey, error) {
|
|
// Check cache first
|
|
if key := v.pubKeyCache.get(did); key != nil {
|
|
return key, nil
|
|
}
|
|
|
|
// Fetch from DID document
|
|
key, err := fetchPublicKeyFromDID(ctx, did)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Cache the key
|
|
v.pubKeyCache.set(did, key)
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// fetchPublicKeyFromDID fetches the public key from a DID document
|
|
func fetchPublicKeyFromDID(ctx context.Context, did string) (atcrypto.PublicKey, error) {
|
|
directory := atproto.GetDirectory()
|
|
atID, err := syntax.ParseAtIdentifier(did)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid DID format: %w", err)
|
|
}
|
|
|
|
ident, err := directory.Lookup(ctx, *atID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve DID: %w", err)
|
|
}
|
|
|
|
publicKey, err := ident.PublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get public key from DID: %w", err)
|
|
}
|
|
|
|
return publicKey, nil
|
|
}
|
|
|
|
// publicKeyCache caches public keys for DIDs
|
|
type publicKeyCache struct {
|
|
mu sync.RWMutex
|
|
entries map[string]cacheEntry
|
|
ttl time.Duration
|
|
}
|
|
|
|
type cacheEntry struct {
|
|
key atcrypto.PublicKey
|
|
expiresAt time.Time
|
|
}
|
|
|
|
func newPublicKeyCache(ttl time.Duration) *publicKeyCache {
|
|
return &publicKeyCache{
|
|
entries: make(map[string]cacheEntry),
|
|
ttl: ttl,
|
|
}
|
|
}
|
|
|
|
func (c *publicKeyCache) get(did string) atcrypto.PublicKey {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
entry, ok := c.entries[did]
|
|
if !ok || time.Now().After(entry.expiresAt) {
|
|
return nil
|
|
}
|
|
return entry.key
|
|
}
|
|
|
|
func (c *publicKeyCache) set(did string, key atcrypto.PublicKey) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
c.entries[did] = cacheEntry{
|
|
key: key,
|
|
expiresAt: time.Now().Add(c.ttl),
|
|
}
|
|
}
|