Files
at-container-registry/pkg/appview/middleware/registry.go
2026-04-07 22:26:21 -05:00

663 lines
25 KiB
Go

package middleware
import (
"context"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/registry/api/errcode"
registrymw "github.com/distribution/distribution/v3/registry/middleware/registry"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference"
"atcr.io/pkg/appview/readme"
"atcr.io/pkg/appview/storage"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth"
"atcr.io/pkg/auth/oauth"
"atcr.io/pkg/auth/token"
)
// authMethodKey is the context key for storing auth method from JWT
const authMethodKey contextKey = "auth.method"
// pullerDIDKey is the context key for storing the authenticated user's DID from JWT
const pullerDIDKey contextKey = "puller.did"
// hasPushScopeKey is the context key for storing whether the JWT has push scope
const hasPushScopeKey contextKey = "token.has_push_scope"
// validationCacheEntry stores a validated service token with expiration
type validationCacheEntry struct {
serviceToken string
validUntil time.Time
err error // Cached error for fast-fail
mu sync.Mutex // Per-entry lock to serialize cache population
inFlight bool // True if another goroutine is fetching the token
done chan struct{} // Closed when fetch completes
}
// validationCache provides request-level caching for service tokens
// This prevents concurrent layer uploads from racing on OAuth/DPoP requests
type validationCache struct {
mu sync.RWMutex
entries map[string]*validationCacheEntry // key: "did:holdDID"
}
// newValidationCache creates a new validation cache
func newValidationCache() *validationCache {
return &validationCache{
entries: make(map[string]*validationCacheEntry),
}
}
// getOrFetch retrieves a service token from cache or fetches it
// Multiple concurrent requests for the same DID:holdDID will share the fetch operation
func (vc *validationCache) getOrFetch(ctx context.Context, cacheKey string, fetchFunc func() (string, error)) (string, error) {
// Fast path: check cache with read lock
vc.mu.RLock()
entry, exists := vc.entries[cacheKey]
vc.mu.RUnlock()
if exists {
// Entry exists, check if it's still valid
entry.mu.Lock()
// If another goroutine is fetching, wait for it
if entry.inFlight {
done := entry.done
entry.mu.Unlock()
select {
case <-done:
// Fetch completed, check result
entry.mu.Lock()
defer entry.mu.Unlock()
if entry.err != nil {
return "", entry.err
}
if time.Now().Before(entry.validUntil) {
return entry.serviceToken, nil
}
// Fall through to refetch
case <-ctx.Done():
return "", ctx.Err()
}
} else {
// Check if cached token is still valid
if entry.err != nil && time.Now().Before(entry.validUntil) {
// Return cached error (fast-fail)
entry.mu.Unlock()
return "", entry.err
}
if entry.err == nil && time.Now().Before(entry.validUntil) {
// Return cached token
token := entry.serviceToken
entry.mu.Unlock()
return token, nil
}
entry.mu.Unlock()
}
}
// Slow path: need to fetch token
vc.mu.Lock()
entry, exists = vc.entries[cacheKey]
if !exists {
// Create new entry
entry = &validationCacheEntry{
inFlight: true,
done: make(chan struct{}),
}
vc.entries[cacheKey] = entry
}
vc.mu.Unlock()
// Lock the entry to perform fetch
entry.mu.Lock()
// Double-check: another goroutine may have fetched while we waited
if !entry.inFlight {
if entry.err != nil && time.Now().Before(entry.validUntil) {
err := entry.err
entry.mu.Unlock()
return "", err
}
if entry.err == nil && time.Now().Before(entry.validUntil) {
token := entry.serviceToken
entry.mu.Unlock()
return token, nil
}
}
// Mark as in-flight and create fresh done channel for this fetch
// IMPORTANT: Always create a new channel - a closed channel is not nil
entry.done = make(chan struct{})
entry.inFlight = true
done := entry.done
entry.mu.Unlock()
// Perform the fetch (outside the lock to allow other operations)
serviceToken, err := fetchFunc()
// Update the entry with result
entry.mu.Lock()
entry.inFlight = false
if err != nil {
// Cache errors for 5 seconds (fast-fail for subsequent requests)
entry.err = err
entry.validUntil = time.Now().Add(5 * time.Second)
entry.serviceToken = ""
} else {
// Cache token for 45 seconds (covers typical Docker push operation)
entry.err = nil
entry.serviceToken = serviceToken
entry.validUntil = time.Now().Add(45 * time.Second)
}
// Signal completion to waiting goroutines
close(done)
entry.mu.Unlock()
return serviceToken, err
}
// Global variables for initialization only
// These are set by main.go during startup and copied into NamespaceResolver instances.
// After initialization, request handling uses the NamespaceResolver's instance fields.
var (
globalRefresher *oauth.Refresher
globalDatabase storage.HoldDIDLookup
globalAuthorizer auth.HoldAuthorizer
globalWebhookDispatcher storage.PushWebhookDispatcher
globalManifestRefChecker storage.ManifestReferenceChecker
)
// SetGlobalRefresher sets the OAuth refresher instance during initialization
// Must be called before the registry starts serving requests
func SetGlobalRefresher(refresher *oauth.Refresher) {
globalRefresher = refresher
}
// SetGlobalDatabase sets the database instance during initialization
// Must be called before the registry starts serving requests
func SetGlobalDatabase(database storage.HoldDIDLookup) {
globalDatabase = database
}
// SetGlobalManifestRefChecker sets the manifest reference checker during initialization
func SetGlobalManifestRefChecker(checker storage.ManifestReferenceChecker) {
globalManifestRefChecker = checker
}
// SetGlobalAuthorizer sets the authorizer instance during initialization
// Must be called before the registry starts serving requests
func SetGlobalAuthorizer(authorizer auth.HoldAuthorizer) {
globalAuthorizer = authorizer
}
// SetGlobalWebhookDispatcher sets the push webhook dispatcher during initialization
// Must be called before the registry starts serving requests
func SetGlobalWebhookDispatcher(dispatcher storage.PushWebhookDispatcher) {
globalWebhookDispatcher = dispatcher
}
// GetGlobalAuthorizer returns the global authorizer instance
// Used by components that need to clear denial cache (e.g., EnsureCrewMembership)
func GetGlobalAuthorizer() auth.HoldAuthorizer {
return globalAuthorizer
}
func init() {
// Register the name resolution middleware
if err := registrymw.Register("atproto-resolver", initATProtoResolver); err != nil {
panic("failed to register atproto-resolver middleware: " + err.Error())
}
}
// NamespaceResolver wraps a namespace and resolves names
type NamespaceResolver struct {
distribution.Namespace
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
testMode bool // If true, fallback to default hold when user's hold is unreachable
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
database storage.HoldDIDLookup // Database for hold DID lookups (copied from global on init)
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
webhookDispatcher storage.PushWebhookDispatcher // Push webhook dispatcher (copied from global on init)
manifestRefChecker storage.ManifestReferenceChecker // Manifest reference checker (copied from global on init)
validationCache *validationCache // Request-level service token cache
readmeFetcher *readme.Fetcher // README fetcher for repo pages
}
// initATProtoResolver initializes the name resolution middleware
func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ driver.StorageDriver, options map[string]any) (distribution.Namespace, error) {
// Get default hold DID from config (required)
// Expected format: "did:web:hold01.atcr.io"
defaultHoldDID := ""
if holdDID, ok := options["default_hold_did"].(string); ok {
defaultHoldDID = holdDID
}
// Get base URL from config (for error messages)
baseURL := ""
if url, ok := options["base_url"].(string); ok {
baseURL = url
}
// Check test mode from options (passed via env var)
testMode := false
if tm, ok := options["test_mode"].(bool); ok {
testMode = tm
}
// Copy shared services from globals into the instance
// This avoids accessing globals during request handling
return &NamespaceResolver{
Namespace: ns,
defaultHoldDID: defaultHoldDID,
baseURL: baseURL,
testMode: testMode,
refresher: globalRefresher,
database: globalDatabase,
authorizer: globalAuthorizer,
webhookDispatcher: globalWebhookDispatcher,
manifestRefChecker: globalManifestRefChecker,
validationCache: newValidationCache(),
readmeFetcher: readme.NewFetcher(),
}, nil
}
// authErrorMessage creates a user-friendly auth error with login URL
func (nr *NamespaceResolver) authErrorMessage(message string) error {
loginURL := fmt.Sprintf("%s/auth/oauth/login", nr.baseURL)
fullMessage := fmt.Sprintf("%s - please re-authenticate at %s", message, loginURL)
return errcode.ErrorCodeUnauthorized.WithMessage(fullMessage)
}
// Repository resolves the repository name and delegates to underlying namespace
// Handles names like:
// - atcr.io/alice/myimage → resolve alice to DID
// - atcr.io/did:plc:xyz123/myimage → use DID directly
func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Named) (distribution.Repository, error) {
// Extract the first part of the name (username or DID)
repoPath := name.Name()
parts := strings.SplitN(repoPath, "/", 2)
if len(parts) < 2 {
// No user specified, use default or return error
return nil, fmt.Errorf("repository name must include user: %s", repoPath)
}
identityStr := parts[0]
imageName := parts[1]
// Support hyphen-encoded DIDs in image paths (e.g., did-plc-abc123/repo:tag)
// OCI reference grammar doesn't allow colons in path components, so DIDs must
// be encoded with hyphens instead: did:plc:abc123 → did-plc-abc123
if decoded, ok := auth.DecodeDIDFromHyphens(identityStr); ok {
identityStr = decoded
}
// Resolve identity to DID, handle, and PDS endpoint
did, handle, pdsEndpoint, err := atproto.ResolveIdentity(ctx, identityStr)
if err != nil {
return nil, err
}
slog.Debug("Resolved identity", "component", "registry/middleware", "did", did, "pds", pdsEndpoint, "handle", handle)
// Query for hold DID - either user's hold or default hold service
// Also returns the sailor profile so we can read preferences (e.g. AutoRemoveUntagged)
holdDID, sailorProfile := nr.findHoldDIDAndProfile(ctx, did, pdsEndpoint)
if holdDID == "" {
// This is a fatal configuration error - registry cannot function without a hold service
return nil, fmt.Errorf("no hold DID configured: ensure default_hold_did is set in middleware config")
}
// Single-hop hold migration: check if this hold has declared a successor
holdDID = nr.resolveSuccessor(ctx, holdDID)
// Resolve hold DID to HTTP URL via identity directory (cached 24h)
holdURL, err := atproto.ResolveHoldURL(ctx, holdDID)
if err != nil {
return nil, fmt.Errorf("failed to resolve hold URL for %s: %w", holdDID, err)
}
// Auto-reconcile crew membership on first push/pull
// This ensures users can push immediately after docker login without web sign-in
// EnsureCrewMembership is best-effort and logs errors without failing the request
// Run synchronously to ensure crew record exists before write access checks
// (returns quickly if already a member - hold returns 200/201)
if holdDID != "" && nr.refresher != nil {
slog.Debug("Auto-reconciling crew membership", "component", "registry/middleware", "did", did, "hold_did", holdDID)
client := atproto.NewClient(pdsEndpoint, did, "")
storage.EnsureCrewMembership(ctx, client, nr.refresher, holdDID, nr.authorizer)
}
// Get service token for hold authentication (only if authenticated)
// Use validation cache to prevent concurrent requests from racing on OAuth/DPoP
// Route based on auth method from JWT token
// IMPORTANT: Use PULLER's DID/PDS for service token, not owner's!
// The puller (authenticated user) needs to authenticate to the hold service.
var serviceToken string
authMethod, _ := ctx.Value(authMethodKey).(string)
pullerDID, _ := ctx.Value(pullerDIDKey).(string)
hasPushScope, _ := ctx.Value(hasPushScopeKey).(bool)
var pullerPDSEndpoint string
// Only fetch service token if user is authenticated
// Unauthenticated requests (like /v2/ ping) should not trigger token fetching
if authMethod != "" && pullerDID != "" {
// Resolve puller's PDS endpoint for service token request
_, _, pullerPDSEndpoint, err = atproto.ResolveIdentity(ctx, pullerDID)
if err != nil {
slog.Warn("Failed to resolve puller's PDS, falling back to anonymous access",
"component", "registry/middleware",
"pullerDID", pullerDID,
"error", err)
// Continue without service token - hold will decide if anonymous access is allowed
} else {
// Create cache key: "pullerDID:holdDID"
cacheKey := fmt.Sprintf("%s:%s", pullerDID, holdDID)
// Fetch service token through validation cache
// This ensures only ONE request per pullerDID:holdDID pair fetches the token
// Concurrent requests will wait for the first request to complete
var fetchErr error
serviceToken, fetchErr = nr.validationCache.getOrFetch(ctx, cacheKey, func() (string, error) {
if authMethod == token.AuthMethodAppPassword {
// App-password flow: use Bearer token authentication
slog.Debug("Using app-password flow for service token",
"component", "registry/middleware",
"pullerDID", pullerDID,
"cacheKey", cacheKey)
token, err := auth.GetOrFetchServiceTokenWithAppPassword(ctx, pullerDID, holdDID, pullerPDSEndpoint)
if err != nil {
slog.Error("Failed to get service token with app-password",
"component", "registry/middleware",
"pullerDID", pullerDID,
"holdDID", holdDID,
"pullerPDSEndpoint", pullerPDSEndpoint,
"denial_reason", "service_token_app_password_failed",
"error", err)
return "", err
}
return token, nil
} else if nr.refresher != nil {
// OAuth flow: use DPoP authentication
slog.Debug("Using OAuth flow for service token",
"component", "registry/middleware",
"pullerDID", pullerDID,
"cacheKey", cacheKey)
token, err := auth.GetOrFetchServiceToken(ctx, nr.refresher, pullerDID, holdDID, pullerPDSEndpoint)
if err != nil {
slog.Error("Failed to get service token with OAuth",
"component", "registry/middleware",
"pullerDID", pullerDID,
"holdDID", holdDID,
"pullerPDSEndpoint", pullerPDSEndpoint,
"denial_reason", "service_token_oauth_failed",
"error", err)
return "", err
}
return token, nil
}
return "", fmt.Errorf("no authentication method available")
})
// Handle errors from cached fetch
if fetchErr != nil {
errMsg := fetchErr.Error()
// Check for app-password specific errors
if authMethod == token.AuthMethodAppPassword {
if strings.Contains(errMsg, "expired or invalid") || strings.Contains(errMsg, "no app-password") {
return nil, nr.authErrorMessage("App-password authentication failed. Please re-authenticate with: docker login")
}
}
// Check for OAuth specific errors
if strings.Contains(errMsg, "OAuth session") || strings.Contains(errMsg, "OAuth validation") {
return nil, nr.authErrorMessage("OAuth session expired or invalidated by PDS. Your session has been cleared")
}
// Generic service token error
return nil, nr.authErrorMessage(fmt.Sprintf("Failed to obtain storage credentials: %v", fetchErr))
}
}
} else {
slog.Debug("Skipping service token fetch for unauthenticated request",
"component", "registry/middleware",
"ownerDID", did)
}
// Create a new reference with identity/image format
// Use the resolved handle (not raw DID) to ensure the name is valid per OCI reference grammar.
// DIDs contain colons which are illegal in reference path components.
// This transforms: did-plc-abc123/myimage -> alice.bsky.social/myimage
canonicalName := fmt.Sprintf("%s/%s", handle, imageName)
ref, err := reference.ParseNamed(canonicalName)
if err != nil {
return nil, fmt.Errorf("invalid image name %s: %w", imageName, err)
}
// Delegate to underlying namespace with modified name
repo, err := nr.Namespace.Repository(ctx, ref)
if err != nil {
return nil, err
}
// Create ATProto client for manifest/tag operations
// Pulls: ATProto records are public, no auth needed
// Pushes: Need auth, but puller must be owner anyway
var atprotoClient *atproto.Client
if pullerDID == did {
// Puller is owner - may need auth for pushes
if authMethod == token.AuthMethodOAuth && nr.refresher != nil {
atprotoClient = atproto.NewClientWithSessionProvider(pdsEndpoint, did, nr.refresher)
} else if authMethod == token.AuthMethodAppPassword {
accessToken, _ := auth.GetGlobalTokenCache().Get(did)
atprotoClient = atproto.NewClient(pdsEndpoint, did, accessToken)
} else {
atprotoClient = atproto.NewClient(pdsEndpoint, did, "")
}
} else {
// Puller != owner - reads only, no auth needed
atprotoClient = atproto.NewClient(pdsEndpoint, did, "")
}
// IMPORTANT: Use only the image name (not identity/image) for ATProto storage
// ATProto records are scoped to the user's DID, so we don't need the identity prefix
// Example: "evan.jarrett.net/debian" -> store as "debian"
repositoryName := imageName
// Default auth method to OAuth if not already set (backward compatibility with old tokens)
if authMethod == "" {
authMethod = token.AuthMethodOAuth
}
// Create routing repository - routes manifests to ATProto, blobs to hold service
// The registry is stateless - no local storage is used
// Bundle all context into a single RegistryContext struct
//
// NOTE: We create a fresh RoutingRepository on every request (no caching) because:
// 1. Each layer upload is a separate HTTP request (possibly different process)
// 2. OAuth sessions can be refreshed/invalidated between requests
// 3. The refresher already caches sessions efficiently (in-memory + DB)
// 4. Caching the repository with a stale ATProtoClient causes refresh token errors
registryCtx := &storage.RegistryContext{
DID: did,
Handle: handle,
HoldDID: holdDID,
HoldURL: holdURL,
PDSEndpoint: pdsEndpoint,
Repository: repositoryName,
ServiceToken: serviceToken, // Cached service token from puller's PDS
ATProtoClient: atprotoClient,
AuthMethod: authMethod, // Auth method from JWT token
PullerDID: pullerDID, // Authenticated user making the request
PullerPDSEndpoint: pullerPDSEndpoint, // Puller's PDS for service token refresh
HasPushScope: hasPushScope, // Whether JWT has push scope (for pull stats filtering)
AutoRemoveUntagged: sailorProfile != nil && sailorProfile.AutoRemoveUntagged,
Database: nr.database,
Authorizer: nr.authorizer,
Refresher: nr.refresher,
ReadmeFetcher: nr.readmeFetcher,
WebhookDispatcher: nr.webhookDispatcher,
ManifestRefChecker: nr.manifestRefChecker,
}
return storage.NewRoutingRepository(repo, registryCtx), nil
}
// Repositories delegates to underlying namespace
func (nr *NamespaceResolver) Repositories(ctx context.Context, repos []string, last string) (int, error) {
return nr.Namespace.Repositories(ctx, repos, last)
}
// Blobs delegates to underlying namespace
func (nr *NamespaceResolver) Blobs() distribution.BlobEnumerator {
return nr.Namespace.Blobs()
}
// BlobStatter delegates to underlying namespace
func (nr *NamespaceResolver) BlobStatter() distribution.BlobStatter {
return nr.Namespace.BlobStatter()
}
// findHoldDIDAndProfile determines which hold DID to use for blob storage and
// returns the user's sailor profile (if available) for reading preferences like
// AutoRemoveUntagged without an extra PDS call.
// Priority order:
// 1. User's sailor profile defaultHold (if set)
// 2. AppView's default hold DID
// Returns a hold DID (e.g., "did:web:hold01.atcr.io"), or empty string if none configured
func (nr *NamespaceResolver) findHoldDIDAndProfile(ctx context.Context, did, pdsEndpoint string) (string, *atproto.SailorProfileRecord) {
// Create ATProto client (without auth - reading public records)
client := atproto.NewClient(pdsEndpoint, did, "")
// Check for sailor profile
profile, err := storage.GetProfile(ctx, client)
if err != nil {
// Error reading profile (not a 404) - log and continue
slog.Warn("Failed to read profile", "did", did, "error", err)
}
if profile != nil && profile.DefaultHold != "" {
// Profile exists with defaultHold set
// In test mode, verify it's reachable before using it
if nr.testMode {
if nr.isHoldReachable(ctx, profile.DefaultHold) {
return profile.DefaultHold, profile
}
slog.Debug("User's defaultHold unreachable, falling back to default", "component", "registry/middleware/testmode", "default_hold", profile.DefaultHold)
return nr.defaultHoldDID, profile
}
return profile.DefaultHold, profile
}
// No profile defaultHold - use AppView default
return nr.defaultHoldDID, profile
}
// resolveSuccessor checks if a hold has declared a successor and returns it.
// Single-hop only — does not follow chains. Returns the original holdDID if
// no successor is set or if the captain record can't be fetched.
func (nr *NamespaceResolver) resolveSuccessor(ctx context.Context, holdDID string) string {
if nr.authorizer == nil {
return holdDID
}
captain, err := nr.authorizer.GetCaptainRecord(ctx, holdDID)
if err != nil {
return holdDID
}
if captain != nil && captain.Successor != "" {
slog.Info("Hold successor redirect",
"component", "registry/middleware",
"from", holdDID,
"to", captain.Successor)
return captain.Successor
}
return holdDID
}
// isHoldReachable checks if a hold service is reachable
// Used in test mode to fallback to default hold when user's hold is unavailable
func (nr *NamespaceResolver) isHoldReachable(ctx context.Context, holdDID string) bool {
holdURL, err := atproto.ResolveHoldURL(ctx, holdDID)
if err != nil {
slog.Debug("Cannot resolve hold URL for reachability check", "component", "registry/middleware", "holdDID", holdDID, "error", err)
return false
}
testURL := holdURL + "/.well-known/did.json"
client := atproto.NewClient("", "", "")
_, err = client.FetchDIDDocument(ctx, testURL)
return err == nil
}
// ExtractAuthMethod is an HTTP middleware that extracts the auth method and puller DID from the JWT Authorization header
// and stores them in the request context for later use by the registry middleware.
// Also stores the HTTP method for routing decisions (GET/HEAD = pull, PUT/POST = push).
func ExtractAuthMethod(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Store HTTP method in context for routing decisions
// This is used by routing_repository.go to distinguish pull (GET/HEAD) from push (PUT/POST)
ctx = context.WithValue(ctx, storage.HTTPRequestMethod, r.Method)
// Extract Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
// Parse "Bearer <token>" format
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && strings.ToLower(parts[0]) == "bearer" {
tokenString := parts[1]
// Extract auth method from JWT (does not validate - just parses)
authMethod := token.ExtractAuthMethod(tokenString)
if authMethod != "" {
// Store in context for registry middleware
ctx = context.WithValue(ctx, authMethodKey, authMethod)
}
// Extract puller DID (Subject) from JWT
// This is the authenticated user's DID, used for service token requests
pullerDID := token.ExtractSubject(tokenString)
if pullerDID != "" {
ctx = context.WithValue(ctx, pullerDIDKey, pullerDID)
}
// Extract access scopes from JWT to detect push-scoped tokens
// Used to distinguish real pulls from manifest GETs during push/imagetools flows
access := token.ExtractAccess(tokenString)
if token.HasPushScope(access) {
ctx = context.WithValue(ctx, hasPushScopeKey, true)
}
slog.Debug("Extracted auth info from JWT",
"component", "registry/middleware",
"authMethod", authMethod,
"pullerDID", pullerDID,
"hasPushScope", token.HasPushScope(access),
"httpMethod", r.Method)
}
}
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}