mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-23 18:00:32 +00:00
663 lines
25 KiB
Go
663 lines
25 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/distribution/distribution/v3"
|
|
"github.com/distribution/distribution/v3/registry/api/errcode"
|
|
registrymw "github.com/distribution/distribution/v3/registry/middleware/registry"
|
|
"github.com/distribution/distribution/v3/registry/storage/driver"
|
|
"github.com/distribution/reference"
|
|
|
|
"atcr.io/pkg/appview/readme"
|
|
"atcr.io/pkg/appview/storage"
|
|
"atcr.io/pkg/atproto"
|
|
"atcr.io/pkg/auth"
|
|
"atcr.io/pkg/auth/oauth"
|
|
"atcr.io/pkg/auth/token"
|
|
)
|
|
|
|
// authMethodKey is the context key for storing auth method from JWT
|
|
const authMethodKey contextKey = "auth.method"
|
|
|
|
// pullerDIDKey is the context key for storing the authenticated user's DID from JWT
|
|
const pullerDIDKey contextKey = "puller.did"
|
|
|
|
// hasPushScopeKey is the context key for storing whether the JWT has push scope
|
|
const hasPushScopeKey contextKey = "token.has_push_scope"
|
|
|
|
// validationCacheEntry stores a validated service token with expiration
|
|
type validationCacheEntry struct {
|
|
serviceToken string
|
|
validUntil time.Time
|
|
err error // Cached error for fast-fail
|
|
mu sync.Mutex // Per-entry lock to serialize cache population
|
|
inFlight bool // True if another goroutine is fetching the token
|
|
done chan struct{} // Closed when fetch completes
|
|
}
|
|
|
|
// validationCache provides request-level caching for service tokens
|
|
// This prevents concurrent layer uploads from racing on OAuth/DPoP requests
|
|
type validationCache struct {
|
|
mu sync.RWMutex
|
|
entries map[string]*validationCacheEntry // key: "did:holdDID"
|
|
}
|
|
|
|
// newValidationCache creates a new validation cache
|
|
func newValidationCache() *validationCache {
|
|
return &validationCache{
|
|
entries: make(map[string]*validationCacheEntry),
|
|
}
|
|
}
|
|
|
|
// getOrFetch retrieves a service token from cache or fetches it
|
|
// Multiple concurrent requests for the same DID:holdDID will share the fetch operation
|
|
func (vc *validationCache) getOrFetch(ctx context.Context, cacheKey string, fetchFunc func() (string, error)) (string, error) {
|
|
// Fast path: check cache with read lock
|
|
vc.mu.RLock()
|
|
entry, exists := vc.entries[cacheKey]
|
|
vc.mu.RUnlock()
|
|
|
|
if exists {
|
|
// Entry exists, check if it's still valid
|
|
entry.mu.Lock()
|
|
|
|
// If another goroutine is fetching, wait for it
|
|
if entry.inFlight {
|
|
done := entry.done
|
|
entry.mu.Unlock()
|
|
|
|
select {
|
|
case <-done:
|
|
// Fetch completed, check result
|
|
entry.mu.Lock()
|
|
defer entry.mu.Unlock()
|
|
|
|
if entry.err != nil {
|
|
return "", entry.err
|
|
}
|
|
if time.Now().Before(entry.validUntil) {
|
|
return entry.serviceToken, nil
|
|
}
|
|
// Fall through to refetch
|
|
case <-ctx.Done():
|
|
return "", ctx.Err()
|
|
}
|
|
} else {
|
|
// Check if cached token is still valid
|
|
if entry.err != nil && time.Now().Before(entry.validUntil) {
|
|
// Return cached error (fast-fail)
|
|
entry.mu.Unlock()
|
|
return "", entry.err
|
|
}
|
|
if entry.err == nil && time.Now().Before(entry.validUntil) {
|
|
// Return cached token
|
|
token := entry.serviceToken
|
|
entry.mu.Unlock()
|
|
return token, nil
|
|
}
|
|
entry.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Slow path: need to fetch token
|
|
vc.mu.Lock()
|
|
entry, exists = vc.entries[cacheKey]
|
|
if !exists {
|
|
// Create new entry
|
|
entry = &validationCacheEntry{
|
|
inFlight: true,
|
|
done: make(chan struct{}),
|
|
}
|
|
vc.entries[cacheKey] = entry
|
|
}
|
|
vc.mu.Unlock()
|
|
|
|
// Lock the entry to perform fetch
|
|
entry.mu.Lock()
|
|
|
|
// Double-check: another goroutine may have fetched while we waited
|
|
if !entry.inFlight {
|
|
if entry.err != nil && time.Now().Before(entry.validUntil) {
|
|
err := entry.err
|
|
entry.mu.Unlock()
|
|
return "", err
|
|
}
|
|
if entry.err == nil && time.Now().Before(entry.validUntil) {
|
|
token := entry.serviceToken
|
|
entry.mu.Unlock()
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
// Mark as in-flight and create fresh done channel for this fetch
|
|
// IMPORTANT: Always create a new channel - a closed channel is not nil
|
|
entry.done = make(chan struct{})
|
|
entry.inFlight = true
|
|
done := entry.done
|
|
entry.mu.Unlock()
|
|
|
|
// Perform the fetch (outside the lock to allow other operations)
|
|
serviceToken, err := fetchFunc()
|
|
|
|
// Update the entry with result
|
|
entry.mu.Lock()
|
|
entry.inFlight = false
|
|
|
|
if err != nil {
|
|
// Cache errors for 5 seconds (fast-fail for subsequent requests)
|
|
entry.err = err
|
|
entry.validUntil = time.Now().Add(5 * time.Second)
|
|
entry.serviceToken = ""
|
|
} else {
|
|
// Cache token for 45 seconds (covers typical Docker push operation)
|
|
entry.err = nil
|
|
entry.serviceToken = serviceToken
|
|
entry.validUntil = time.Now().Add(45 * time.Second)
|
|
}
|
|
|
|
// Signal completion to waiting goroutines
|
|
close(done)
|
|
entry.mu.Unlock()
|
|
|
|
return serviceToken, err
|
|
}
|
|
|
|
// Global variables for initialization only
|
|
// These are set by main.go during startup and copied into NamespaceResolver instances.
|
|
// After initialization, request handling uses the NamespaceResolver's instance fields.
|
|
var (
|
|
globalRefresher *oauth.Refresher
|
|
globalDatabase storage.HoldDIDLookup
|
|
globalAuthorizer auth.HoldAuthorizer
|
|
globalWebhookDispatcher storage.PushWebhookDispatcher
|
|
globalManifestRefChecker storage.ManifestReferenceChecker
|
|
)
|
|
|
|
// SetGlobalRefresher sets the OAuth refresher instance during initialization
|
|
// Must be called before the registry starts serving requests
|
|
func SetGlobalRefresher(refresher *oauth.Refresher) {
|
|
globalRefresher = refresher
|
|
}
|
|
|
|
// SetGlobalDatabase sets the database instance during initialization
|
|
// Must be called before the registry starts serving requests
|
|
func SetGlobalDatabase(database storage.HoldDIDLookup) {
|
|
globalDatabase = database
|
|
}
|
|
|
|
// SetGlobalManifestRefChecker sets the manifest reference checker during initialization
|
|
func SetGlobalManifestRefChecker(checker storage.ManifestReferenceChecker) {
|
|
globalManifestRefChecker = checker
|
|
}
|
|
|
|
// SetGlobalAuthorizer sets the authorizer instance during initialization
|
|
// Must be called before the registry starts serving requests
|
|
func SetGlobalAuthorizer(authorizer auth.HoldAuthorizer) {
|
|
globalAuthorizer = authorizer
|
|
}
|
|
|
|
// SetGlobalWebhookDispatcher sets the push webhook dispatcher during initialization
|
|
// Must be called before the registry starts serving requests
|
|
func SetGlobalWebhookDispatcher(dispatcher storage.PushWebhookDispatcher) {
|
|
globalWebhookDispatcher = dispatcher
|
|
}
|
|
|
|
// GetGlobalAuthorizer returns the global authorizer instance
|
|
// Used by components that need to clear denial cache (e.g., EnsureCrewMembership)
|
|
func GetGlobalAuthorizer() auth.HoldAuthorizer {
|
|
return globalAuthorizer
|
|
}
|
|
|
|
func init() {
|
|
// Register the name resolution middleware
|
|
if err := registrymw.Register("atproto-resolver", initATProtoResolver); err != nil {
|
|
panic("failed to register atproto-resolver middleware: " + err.Error())
|
|
}
|
|
}
|
|
|
|
// NamespaceResolver wraps a namespace and resolves names
|
|
type NamespaceResolver struct {
|
|
distribution.Namespace
|
|
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
|
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
|
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
|
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
|
database storage.HoldDIDLookup // Database for hold DID lookups (copied from global on init)
|
|
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
|
|
webhookDispatcher storage.PushWebhookDispatcher // Push webhook dispatcher (copied from global on init)
|
|
manifestRefChecker storage.ManifestReferenceChecker // Manifest reference checker (copied from global on init)
|
|
validationCache *validationCache // Request-level service token cache
|
|
readmeFetcher *readme.Fetcher // README fetcher for repo pages
|
|
}
|
|
|
|
// initATProtoResolver initializes the name resolution middleware
|
|
func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ driver.StorageDriver, options map[string]any) (distribution.Namespace, error) {
|
|
// Get default hold DID from config (required)
|
|
// Expected format: "did:web:hold01.atcr.io"
|
|
defaultHoldDID := ""
|
|
if holdDID, ok := options["default_hold_did"].(string); ok {
|
|
defaultHoldDID = holdDID
|
|
}
|
|
|
|
// Get base URL from config (for error messages)
|
|
baseURL := ""
|
|
if url, ok := options["base_url"].(string); ok {
|
|
baseURL = url
|
|
}
|
|
|
|
// Check test mode from options (passed via env var)
|
|
testMode := false
|
|
if tm, ok := options["test_mode"].(bool); ok {
|
|
testMode = tm
|
|
}
|
|
|
|
// Copy shared services from globals into the instance
|
|
// This avoids accessing globals during request handling
|
|
return &NamespaceResolver{
|
|
Namespace: ns,
|
|
defaultHoldDID: defaultHoldDID,
|
|
baseURL: baseURL,
|
|
testMode: testMode,
|
|
refresher: globalRefresher,
|
|
database: globalDatabase,
|
|
authorizer: globalAuthorizer,
|
|
webhookDispatcher: globalWebhookDispatcher,
|
|
manifestRefChecker: globalManifestRefChecker,
|
|
validationCache: newValidationCache(),
|
|
readmeFetcher: readme.NewFetcher(),
|
|
}, nil
|
|
}
|
|
|
|
// authErrorMessage creates a user-friendly auth error with login URL
|
|
func (nr *NamespaceResolver) authErrorMessage(message string) error {
|
|
loginURL := fmt.Sprintf("%s/auth/oauth/login", nr.baseURL)
|
|
fullMessage := fmt.Sprintf("%s - please re-authenticate at %s", message, loginURL)
|
|
return errcode.ErrorCodeUnauthorized.WithMessage(fullMessage)
|
|
}
|
|
|
|
// Repository resolves the repository name and delegates to underlying namespace
|
|
// Handles names like:
|
|
// - atcr.io/alice/myimage → resolve alice to DID
|
|
// - atcr.io/did:plc:xyz123/myimage → use DID directly
|
|
func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Named) (distribution.Repository, error) {
|
|
// Extract the first part of the name (username or DID)
|
|
repoPath := name.Name()
|
|
parts := strings.SplitN(repoPath, "/", 2)
|
|
|
|
if len(parts) < 2 {
|
|
// No user specified, use default or return error
|
|
return nil, fmt.Errorf("repository name must include user: %s", repoPath)
|
|
}
|
|
|
|
identityStr := parts[0]
|
|
imageName := parts[1]
|
|
|
|
// Support hyphen-encoded DIDs in image paths (e.g., did-plc-abc123/repo:tag)
|
|
// OCI reference grammar doesn't allow colons in path components, so DIDs must
|
|
// be encoded with hyphens instead: did:plc:abc123 → did-plc-abc123
|
|
if decoded, ok := auth.DecodeDIDFromHyphens(identityStr); ok {
|
|
identityStr = decoded
|
|
}
|
|
|
|
// Resolve identity to DID, handle, and PDS endpoint
|
|
did, handle, pdsEndpoint, err := atproto.ResolveIdentity(ctx, identityStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
slog.Debug("Resolved identity", "component", "registry/middleware", "did", did, "pds", pdsEndpoint, "handle", handle)
|
|
|
|
// Query for hold DID - either user's hold or default hold service
|
|
// Also returns the sailor profile so we can read preferences (e.g. AutoRemoveUntagged)
|
|
holdDID, sailorProfile := nr.findHoldDIDAndProfile(ctx, did, pdsEndpoint)
|
|
if holdDID == "" {
|
|
// This is a fatal configuration error - registry cannot function without a hold service
|
|
return nil, fmt.Errorf("no hold DID configured: ensure default_hold_did is set in middleware config")
|
|
}
|
|
|
|
// Single-hop hold migration: check if this hold has declared a successor
|
|
holdDID = nr.resolveSuccessor(ctx, holdDID)
|
|
|
|
// Resolve hold DID to HTTP URL via identity directory (cached 24h)
|
|
holdURL, err := atproto.ResolveHoldURL(ctx, holdDID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve hold URL for %s: %w", holdDID, err)
|
|
}
|
|
|
|
// Auto-reconcile crew membership on first push/pull
|
|
// This ensures users can push immediately after docker login without web sign-in
|
|
// EnsureCrewMembership is best-effort and logs errors without failing the request
|
|
// Run synchronously to ensure crew record exists before write access checks
|
|
// (returns quickly if already a member - hold returns 200/201)
|
|
if holdDID != "" && nr.refresher != nil {
|
|
slog.Debug("Auto-reconciling crew membership", "component", "registry/middleware", "did", did, "hold_did", holdDID)
|
|
client := atproto.NewClient(pdsEndpoint, did, "")
|
|
storage.EnsureCrewMembership(ctx, client, nr.refresher, holdDID, nr.authorizer)
|
|
}
|
|
|
|
// Get service token for hold authentication (only if authenticated)
|
|
// Use validation cache to prevent concurrent requests from racing on OAuth/DPoP
|
|
// Route based on auth method from JWT token
|
|
// IMPORTANT: Use PULLER's DID/PDS for service token, not owner's!
|
|
// The puller (authenticated user) needs to authenticate to the hold service.
|
|
var serviceToken string
|
|
authMethod, _ := ctx.Value(authMethodKey).(string)
|
|
pullerDID, _ := ctx.Value(pullerDIDKey).(string)
|
|
hasPushScope, _ := ctx.Value(hasPushScopeKey).(bool)
|
|
var pullerPDSEndpoint string
|
|
|
|
// Only fetch service token if user is authenticated
|
|
// Unauthenticated requests (like /v2/ ping) should not trigger token fetching
|
|
if authMethod != "" && pullerDID != "" {
|
|
// Resolve puller's PDS endpoint for service token request
|
|
_, _, pullerPDSEndpoint, err = atproto.ResolveIdentity(ctx, pullerDID)
|
|
if err != nil {
|
|
slog.Warn("Failed to resolve puller's PDS, falling back to anonymous access",
|
|
"component", "registry/middleware",
|
|
"pullerDID", pullerDID,
|
|
"error", err)
|
|
// Continue without service token - hold will decide if anonymous access is allowed
|
|
} else {
|
|
// Create cache key: "pullerDID:holdDID"
|
|
cacheKey := fmt.Sprintf("%s:%s", pullerDID, holdDID)
|
|
|
|
// Fetch service token through validation cache
|
|
// This ensures only ONE request per pullerDID:holdDID pair fetches the token
|
|
// Concurrent requests will wait for the first request to complete
|
|
var fetchErr error
|
|
serviceToken, fetchErr = nr.validationCache.getOrFetch(ctx, cacheKey, func() (string, error) {
|
|
if authMethod == token.AuthMethodAppPassword {
|
|
// App-password flow: use Bearer token authentication
|
|
slog.Debug("Using app-password flow for service token",
|
|
"component", "registry/middleware",
|
|
"pullerDID", pullerDID,
|
|
"cacheKey", cacheKey)
|
|
|
|
token, err := auth.GetOrFetchServiceTokenWithAppPassword(ctx, pullerDID, holdDID, pullerPDSEndpoint)
|
|
if err != nil {
|
|
slog.Error("Failed to get service token with app-password",
|
|
"component", "registry/middleware",
|
|
"pullerDID", pullerDID,
|
|
"holdDID", holdDID,
|
|
"pullerPDSEndpoint", pullerPDSEndpoint,
|
|
"denial_reason", "service_token_app_password_failed",
|
|
"error", err)
|
|
return "", err
|
|
}
|
|
return token, nil
|
|
} else if nr.refresher != nil {
|
|
// OAuth flow: use DPoP authentication
|
|
slog.Debug("Using OAuth flow for service token",
|
|
"component", "registry/middleware",
|
|
"pullerDID", pullerDID,
|
|
"cacheKey", cacheKey)
|
|
|
|
token, err := auth.GetOrFetchServiceToken(ctx, nr.refresher, pullerDID, holdDID, pullerPDSEndpoint)
|
|
if err != nil {
|
|
slog.Error("Failed to get service token with OAuth",
|
|
"component", "registry/middleware",
|
|
"pullerDID", pullerDID,
|
|
"holdDID", holdDID,
|
|
"pullerPDSEndpoint", pullerPDSEndpoint,
|
|
"denial_reason", "service_token_oauth_failed",
|
|
"error", err)
|
|
return "", err
|
|
}
|
|
return token, nil
|
|
}
|
|
return "", fmt.Errorf("no authentication method available")
|
|
})
|
|
|
|
// Handle errors from cached fetch
|
|
if fetchErr != nil {
|
|
errMsg := fetchErr.Error()
|
|
|
|
// Check for app-password specific errors
|
|
if authMethod == token.AuthMethodAppPassword {
|
|
if strings.Contains(errMsg, "expired or invalid") || strings.Contains(errMsg, "no app-password") {
|
|
return nil, nr.authErrorMessage("App-password authentication failed. Please re-authenticate with: docker login")
|
|
}
|
|
}
|
|
|
|
// Check for OAuth specific errors
|
|
if strings.Contains(errMsg, "OAuth session") || strings.Contains(errMsg, "OAuth validation") {
|
|
return nil, nr.authErrorMessage("OAuth session expired or invalidated by PDS. Your session has been cleared")
|
|
}
|
|
|
|
// Generic service token error
|
|
return nil, nr.authErrorMessage(fmt.Sprintf("Failed to obtain storage credentials: %v", fetchErr))
|
|
}
|
|
}
|
|
} else {
|
|
slog.Debug("Skipping service token fetch for unauthenticated request",
|
|
"component", "registry/middleware",
|
|
"ownerDID", did)
|
|
}
|
|
|
|
// Create a new reference with identity/image format
|
|
// Use the resolved handle (not raw DID) to ensure the name is valid per OCI reference grammar.
|
|
// DIDs contain colons which are illegal in reference path components.
|
|
// This transforms: did-plc-abc123/myimage -> alice.bsky.social/myimage
|
|
canonicalName := fmt.Sprintf("%s/%s", handle, imageName)
|
|
ref, err := reference.ParseNamed(canonicalName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid image name %s: %w", imageName, err)
|
|
}
|
|
|
|
// Delegate to underlying namespace with modified name
|
|
repo, err := nr.Namespace.Repository(ctx, ref)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create ATProto client for manifest/tag operations
|
|
// Pulls: ATProto records are public, no auth needed
|
|
// Pushes: Need auth, but puller must be owner anyway
|
|
var atprotoClient *atproto.Client
|
|
|
|
if pullerDID == did {
|
|
// Puller is owner - may need auth for pushes
|
|
if authMethod == token.AuthMethodOAuth && nr.refresher != nil {
|
|
atprotoClient = atproto.NewClientWithSessionProvider(pdsEndpoint, did, nr.refresher)
|
|
} else if authMethod == token.AuthMethodAppPassword {
|
|
accessToken, _ := auth.GetGlobalTokenCache().Get(did)
|
|
atprotoClient = atproto.NewClient(pdsEndpoint, did, accessToken)
|
|
} else {
|
|
atprotoClient = atproto.NewClient(pdsEndpoint, did, "")
|
|
}
|
|
} else {
|
|
// Puller != owner - reads only, no auth needed
|
|
atprotoClient = atproto.NewClient(pdsEndpoint, did, "")
|
|
}
|
|
|
|
// IMPORTANT: Use only the image name (not identity/image) for ATProto storage
|
|
// ATProto records are scoped to the user's DID, so we don't need the identity prefix
|
|
// Example: "evan.jarrett.net/debian" -> store as "debian"
|
|
repositoryName := imageName
|
|
|
|
// Default auth method to OAuth if not already set (backward compatibility with old tokens)
|
|
if authMethod == "" {
|
|
authMethod = token.AuthMethodOAuth
|
|
}
|
|
|
|
// Create routing repository - routes manifests to ATProto, blobs to hold service
|
|
// The registry is stateless - no local storage is used
|
|
// Bundle all context into a single RegistryContext struct
|
|
//
|
|
// NOTE: We create a fresh RoutingRepository on every request (no caching) because:
|
|
// 1. Each layer upload is a separate HTTP request (possibly different process)
|
|
// 2. OAuth sessions can be refreshed/invalidated between requests
|
|
// 3. The refresher already caches sessions efficiently (in-memory + DB)
|
|
// 4. Caching the repository with a stale ATProtoClient causes refresh token errors
|
|
registryCtx := &storage.RegistryContext{
|
|
DID: did,
|
|
Handle: handle,
|
|
HoldDID: holdDID,
|
|
HoldURL: holdURL,
|
|
PDSEndpoint: pdsEndpoint,
|
|
Repository: repositoryName,
|
|
ServiceToken: serviceToken, // Cached service token from puller's PDS
|
|
ATProtoClient: atprotoClient,
|
|
AuthMethod: authMethod, // Auth method from JWT token
|
|
PullerDID: pullerDID, // Authenticated user making the request
|
|
PullerPDSEndpoint: pullerPDSEndpoint, // Puller's PDS for service token refresh
|
|
HasPushScope: hasPushScope, // Whether JWT has push scope (for pull stats filtering)
|
|
AutoRemoveUntagged: sailorProfile != nil && sailorProfile.AutoRemoveUntagged,
|
|
Database: nr.database,
|
|
Authorizer: nr.authorizer,
|
|
Refresher: nr.refresher,
|
|
ReadmeFetcher: nr.readmeFetcher,
|
|
WebhookDispatcher: nr.webhookDispatcher,
|
|
ManifestRefChecker: nr.manifestRefChecker,
|
|
}
|
|
|
|
return storage.NewRoutingRepository(repo, registryCtx), nil
|
|
}
|
|
|
|
// Repositories delegates to underlying namespace
|
|
func (nr *NamespaceResolver) Repositories(ctx context.Context, repos []string, last string) (int, error) {
|
|
return nr.Namespace.Repositories(ctx, repos, last)
|
|
}
|
|
|
|
// Blobs delegates to underlying namespace
|
|
func (nr *NamespaceResolver) Blobs() distribution.BlobEnumerator {
|
|
return nr.Namespace.Blobs()
|
|
}
|
|
|
|
// BlobStatter delegates to underlying namespace
|
|
func (nr *NamespaceResolver) BlobStatter() distribution.BlobStatter {
|
|
return nr.Namespace.BlobStatter()
|
|
}
|
|
|
|
// findHoldDIDAndProfile determines which hold DID to use for blob storage and
|
|
// returns the user's sailor profile (if available) for reading preferences like
|
|
// AutoRemoveUntagged without an extra PDS call.
|
|
// Priority order:
|
|
// 1. User's sailor profile defaultHold (if set)
|
|
// 2. AppView's default hold DID
|
|
// Returns a hold DID (e.g., "did:web:hold01.atcr.io"), or empty string if none configured
|
|
func (nr *NamespaceResolver) findHoldDIDAndProfile(ctx context.Context, did, pdsEndpoint string) (string, *atproto.SailorProfileRecord) {
|
|
// Create ATProto client (without auth - reading public records)
|
|
client := atproto.NewClient(pdsEndpoint, did, "")
|
|
|
|
// Check for sailor profile
|
|
profile, err := storage.GetProfile(ctx, client)
|
|
if err != nil {
|
|
// Error reading profile (not a 404) - log and continue
|
|
slog.Warn("Failed to read profile", "did", did, "error", err)
|
|
}
|
|
|
|
if profile != nil && profile.DefaultHold != "" {
|
|
// Profile exists with defaultHold set
|
|
// In test mode, verify it's reachable before using it
|
|
if nr.testMode {
|
|
if nr.isHoldReachable(ctx, profile.DefaultHold) {
|
|
return profile.DefaultHold, profile
|
|
}
|
|
slog.Debug("User's defaultHold unreachable, falling back to default", "component", "registry/middleware/testmode", "default_hold", profile.DefaultHold)
|
|
return nr.defaultHoldDID, profile
|
|
}
|
|
return profile.DefaultHold, profile
|
|
}
|
|
|
|
// No profile defaultHold - use AppView default
|
|
return nr.defaultHoldDID, profile
|
|
}
|
|
|
|
// resolveSuccessor checks if a hold has declared a successor and returns it.
|
|
// Single-hop only — does not follow chains. Returns the original holdDID if
|
|
// no successor is set or if the captain record can't be fetched.
|
|
func (nr *NamespaceResolver) resolveSuccessor(ctx context.Context, holdDID string) string {
|
|
if nr.authorizer == nil {
|
|
return holdDID
|
|
}
|
|
captain, err := nr.authorizer.GetCaptainRecord(ctx, holdDID)
|
|
if err != nil {
|
|
return holdDID
|
|
}
|
|
if captain != nil && captain.Successor != "" {
|
|
slog.Info("Hold successor redirect",
|
|
"component", "registry/middleware",
|
|
"from", holdDID,
|
|
"to", captain.Successor)
|
|
return captain.Successor
|
|
}
|
|
return holdDID
|
|
}
|
|
|
|
// isHoldReachable checks if a hold service is reachable
|
|
// Used in test mode to fallback to default hold when user's hold is unavailable
|
|
func (nr *NamespaceResolver) isHoldReachable(ctx context.Context, holdDID string) bool {
|
|
holdURL, err := atproto.ResolveHoldURL(ctx, holdDID)
|
|
if err != nil {
|
|
slog.Debug("Cannot resolve hold URL for reachability check", "component", "registry/middleware", "holdDID", holdDID, "error", err)
|
|
return false
|
|
}
|
|
|
|
testURL := holdURL + "/.well-known/did.json"
|
|
client := atproto.NewClient("", "", "")
|
|
_, err = client.FetchDIDDocument(ctx, testURL)
|
|
return err == nil
|
|
}
|
|
|
|
// ExtractAuthMethod is an HTTP middleware that extracts the auth method and puller DID from the JWT Authorization header
|
|
// and stores them in the request context for later use by the registry middleware.
|
|
// Also stores the HTTP method for routing decisions (GET/HEAD = pull, PUT/POST = push).
|
|
func ExtractAuthMethod(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
// Store HTTP method in context for routing decisions
|
|
// This is used by routing_repository.go to distinguish pull (GET/HEAD) from push (PUT/POST)
|
|
ctx = context.WithValue(ctx, storage.HTTPRequestMethod, r.Method)
|
|
|
|
// Extract Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader != "" {
|
|
// Parse "Bearer <token>" format
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) == 2 && strings.ToLower(parts[0]) == "bearer" {
|
|
tokenString := parts[1]
|
|
|
|
// Extract auth method from JWT (does not validate - just parses)
|
|
authMethod := token.ExtractAuthMethod(tokenString)
|
|
if authMethod != "" {
|
|
// Store in context for registry middleware
|
|
ctx = context.WithValue(ctx, authMethodKey, authMethod)
|
|
}
|
|
|
|
// Extract puller DID (Subject) from JWT
|
|
// This is the authenticated user's DID, used for service token requests
|
|
pullerDID := token.ExtractSubject(tokenString)
|
|
if pullerDID != "" {
|
|
ctx = context.WithValue(ctx, pullerDIDKey, pullerDID)
|
|
}
|
|
|
|
// Extract access scopes from JWT to detect push-scoped tokens
|
|
// Used to distinguish real pulls from manifest GETs during push/imagetools flows
|
|
access := token.ExtractAccess(tokenString)
|
|
if token.HasPushScope(access) {
|
|
ctx = context.WithValue(ctx, hasPushScopeKey, true)
|
|
}
|
|
|
|
slog.Debug("Extracted auth info from JWT",
|
|
"component", "registry/middleware",
|
|
"authMethod", authMethod,
|
|
"pullerDID", pullerDID,
|
|
"hasPushScope", token.HasPushScope(access),
|
|
"httpMethod", r.Method)
|
|
}
|
|
}
|
|
|
|
r = r.WithContext(ctx)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|