Compare commits
1 Commits
label-serv
...
refactor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31dc4b4f53 |
@@ -150,8 +150,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
middleware.SetGlobalRefresher(refresher)
|
||||
|
||||
// Set global database for pull/push metrics tracking
|
||||
metricsDB := db.NewMetricsDB(uiDatabase)
|
||||
middleware.SetGlobalDatabase(metricsDB)
|
||||
middleware.SetGlobalDatabase(uiDatabase)
|
||||
|
||||
// Create RemoteHoldAuthorizer for hold authorization with caching
|
||||
holdAuthorizer := auth.NewRemoteHoldAuthorizer(uiDatabase, testMode)
|
||||
@@ -191,6 +190,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
HealthChecker: healthChecker,
|
||||
ReadmeFetcher: readmeFetcher,
|
||||
Templates: uiTemplates,
|
||||
DefaultHoldDID: defaultHoldDID,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -212,14 +212,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
// Create ATProto client with session provider (uses DoWithSession for DPoP nonce safety)
|
||||
client := atproto.NewClientWithSessionProvider(pdsEndpoint, did, refresher)
|
||||
|
||||
// Ensure sailor profile exists (creates with default hold if configured)
|
||||
slog.Debug("Ensuring profile exists", "component", "appview/callback", "did", did, "default_hold_did", defaultHoldDID)
|
||||
if err := storage.EnsureProfile(ctx, client, defaultHoldDID); err != nil {
|
||||
slog.Warn("Failed to ensure profile", "component", "appview/callback", "did", did, "error", err)
|
||||
// Continue anyway - profile creation is not critical for avatar fetch
|
||||
} else {
|
||||
slog.Debug("Profile ensured", "component", "appview/callback", "did", did)
|
||||
}
|
||||
// Note: Profile and crew setup now happen automatically via UserContext.EnsureUserSetup()
|
||||
|
||||
// Fetch user's profile record from PDS (contains blob references)
|
||||
profileRecord, err := client.GetProfileRecord(ctx, did)
|
||||
@@ -270,7 +263,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
return nil // Non-fatal
|
||||
}
|
||||
|
||||
var holdDID string
|
||||
// Migrate profile URL→DID if needed (legacy migration, crew registration now handled by UserContext)
|
||||
if profile != nil && profile.DefaultHold != "" {
|
||||
// Check if defaultHold is a URL (needs migration)
|
||||
if strings.HasPrefix(profile.DefaultHold, "http://") || strings.HasPrefix(profile.DefaultHold, "https://") {
|
||||
@@ -286,19 +279,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
} else {
|
||||
slog.Debug("Updated profile with hold DID", "component", "appview/callback", "hold_did", holdDID)
|
||||
}
|
||||
} else {
|
||||
// Already a DID - use it
|
||||
holdDID = profile.DefaultHold
|
||||
}
|
||||
// Register crew regardless of migration (outside the migration block)
|
||||
// Run in background to avoid blocking OAuth callback if hold is offline
|
||||
// Use background context - don't inherit request context which gets canceled on response
|
||||
slog.Debug("Attempting crew registration", "component", "appview/callback", "did", did, "hold_did", holdDID)
|
||||
go func(client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
|
||||
ctx := context.Background()
|
||||
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
||||
}(client, refresher, holdDID)
|
||||
|
||||
}
|
||||
|
||||
return nil // All errors are non-fatal, logged for debugging
|
||||
@@ -320,10 +301,19 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
ctx := context.Background()
|
||||
app := handlers.NewApp(ctx, cfg.Distribution)
|
||||
|
||||
// Wrap registry app with auth method extraction middleware
|
||||
// This extracts the auth method from the JWT and stores it in the request context
|
||||
// Wrap registry app with middleware chain:
|
||||
// 1. ExtractAuthMethod - extracts auth method from JWT and stores in context
|
||||
// 2. UserContextMiddleware - builds UserContext with identity, permissions, service tokens
|
||||
wrappedApp := middleware.ExtractAuthMethod(app)
|
||||
|
||||
// Create dependencies for UserContextMiddleware
|
||||
userContextDeps := &auth.Dependencies{
|
||||
Refresher: refresher,
|
||||
Authorizer: holdAuthorizer,
|
||||
DefaultHoldDID: defaultHoldDID,
|
||||
}
|
||||
wrappedApp = middleware.UserContextMiddleware(userContextDeps)(wrappedApp)
|
||||
|
||||
// Mount registry at /v2/
|
||||
mainRouter.Handle("/v2/*", wrappedApp)
|
||||
|
||||
@@ -412,23 +402,11 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
// Prevents the flood of errors when a stale session is discovered during push
|
||||
tokenHandler.SetOAuthSessionValidator(refresher)
|
||||
|
||||
// Register token post-auth callback for profile management
|
||||
// This decouples the token package from AppView-specific dependencies
|
||||
// Register token post-auth callback
|
||||
// Note: Profile and crew setup now happen automatically via UserContext.EnsureUserSetup()
|
||||
tokenHandler.SetPostAuthCallback(func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error {
|
||||
slog.Debug("Token post-auth callback", "component", "appview/callback", "did", did)
|
||||
|
||||
// Create ATProto client with validated token
|
||||
atprotoClient := atproto.NewClient(pdsEndpoint, did, accessToken)
|
||||
|
||||
// Ensure profile exists (will create with default hold if not exists and default is configured)
|
||||
if err := storage.EnsureProfile(ctx, atprotoClient, defaultHoldDID); err != nil {
|
||||
// Log error but don't fail auth - profile management is not critical
|
||||
slog.Warn("Failed to ensure profile", "component", "appview/callback", "did", did, "error", err)
|
||||
} else {
|
||||
slog.Debug("Profile ensured with default hold", "component", "appview/callback", "did", did, "default_hold_did", defaultHoldDID)
|
||||
}
|
||||
|
||||
return nil // All errors are non-fatal
|
||||
return nil
|
||||
})
|
||||
|
||||
mainRouter.Get("/auth/token", tokenHandler.ServeHTTP)
|
||||
|
||||
@@ -1634,31 +1634,6 @@ func parseTimestamp(s string) (time.Time, error) {
|
||||
return time.Time{}, fmt.Errorf("unable to parse timestamp: %s", s)
|
||||
}
|
||||
|
||||
// MetricsDB wraps a sql.DB and implements the metrics interface for middleware
|
||||
type MetricsDB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewMetricsDB creates a new metrics database wrapper
|
||||
func NewMetricsDB(db *sql.DB) *MetricsDB {
|
||||
return &MetricsDB{db: db}
|
||||
}
|
||||
|
||||
// IncrementPullCount increments the pull count for a repository
|
||||
func (m *MetricsDB) IncrementPullCount(did, repository string) error {
|
||||
return IncrementPullCount(m.db, did, repository)
|
||||
}
|
||||
|
||||
// IncrementPushCount increments the push count for a repository
|
||||
func (m *MetricsDB) IncrementPushCount(did, repository string) error {
|
||||
return IncrementPushCount(m.db, did, repository)
|
||||
}
|
||||
|
||||
// GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository
|
||||
func (m *MetricsDB) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
||||
return GetLatestHoldDIDForRepo(m.db, did, repository)
|
||||
}
|
||||
|
||||
// GetFeaturedRepositories fetches top repositories sorted by stars and pulls
|
||||
func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]FeaturedRepository, error) {
|
||||
query := `
|
||||
|
||||
@@ -11,14 +11,32 @@ import (
|
||||
"net/url"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/auth"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const userKey contextKey = "user"
|
||||
|
||||
// WebAuthDeps contains dependencies for web auth middleware
|
||||
type WebAuthDeps struct {
|
||||
SessionStore *db.SessionStore
|
||||
Database *sql.DB
|
||||
Refresher *oauth.Refresher
|
||||
DefaultHoldDID string
|
||||
}
|
||||
|
||||
// RequireAuth is middleware that requires authentication
|
||||
func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
|
||||
return RequireAuthWithDeps(WebAuthDeps{
|
||||
SessionStore: store,
|
||||
Database: database,
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAuthWithDeps is middleware that requires authentication and creates UserContext
|
||||
func RequireAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID, ok := getSessionID(r)
|
||||
@@ -32,7 +50,7 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
return
|
||||
}
|
||||
|
||||
sess, ok := store.Get(sessionID)
|
||||
sess, ok := deps.SessionStore.Get(sessionID)
|
||||
if !ok {
|
||||
// Build return URL with query parameters preserved
|
||||
returnTo := r.URL.Path
|
||||
@@ -44,7 +62,7 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
}
|
||||
|
||||
// Look up full user from database to get avatar
|
||||
user, err := db.GetUserByDID(database, sess.DID)
|
||||
user, err := db.GetUserByDID(deps.Database, sess.DID)
|
||||
if err != nil || user == nil {
|
||||
// Fallback to session data if DB lookup fails
|
||||
user = &db.User{
|
||||
@@ -54,7 +72,20 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userKey, user)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, userKey, user)
|
||||
|
||||
// Create UserContext for authenticated users (enables EnsureUserSetup)
|
||||
if deps.Refresher != nil {
|
||||
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
|
||||
Refresher: deps.Refresher,
|
||||
DefaultHoldDID: deps.DefaultHoldDID,
|
||||
})
|
||||
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
|
||||
userCtx.EnsureUserSetup()
|
||||
ctx = auth.WithUserContext(ctx, userCtx)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
@@ -62,13 +93,21 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
|
||||
// OptionalAuth is middleware that optionally includes user if authenticated
|
||||
func OptionalAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
|
||||
return OptionalAuthWithDeps(WebAuthDeps{
|
||||
SessionStore: store,
|
||||
Database: database,
|
||||
})
|
||||
}
|
||||
|
||||
// OptionalAuthWithDeps is middleware that optionally includes user and UserContext if authenticated
|
||||
func OptionalAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID, ok := getSessionID(r)
|
||||
if ok {
|
||||
if sess, ok := store.Get(sessionID); ok {
|
||||
if sess, ok := deps.SessionStore.Get(sessionID); ok {
|
||||
// Look up full user from database to get avatar
|
||||
user, err := db.GetUserByDID(database, sess.DID)
|
||||
user, err := db.GetUserByDID(deps.Database, sess.DID)
|
||||
if err != nil || user == nil {
|
||||
// Fallback to session data if DB lookup fails
|
||||
user = &db.User{
|
||||
@@ -77,7 +116,21 @@ func OptionalAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) h
|
||||
PDSEndpoint: sess.PDSEndpoint,
|
||||
}
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), userKey, user)
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, userKey, user)
|
||||
|
||||
// Create UserContext for authenticated users (enables EnsureUserSetup)
|
||||
if deps.Refresher != nil {
|
||||
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
|
||||
Refresher: deps.Refresher,
|
||||
DefaultHoldDID: deps.DefaultHoldDID,
|
||||
})
|
||||
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
|
||||
userCtx.EnsureUserSetup()
|
||||
ctx = auth.WithUserContext(ctx, userCtx)
|
||||
}
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,20 +2,17 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/distribution/distribution/v3/registry/api/errcode"
|
||||
registrymw "github.com/distribution/distribution/v3/registry/middleware/registry"
|
||||
"github.com/distribution/distribution/v3/registry/storage/driver"
|
||||
"github.com/distribution/reference"
|
||||
|
||||
"atcr.io/pkg/appview/readme"
|
||||
"atcr.io/pkg/appview/storage"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
@@ -32,149 +29,12 @@ const authMethodKey contextKey = "auth.method"
|
||||
// pullerDIDKey is the context key for storing the authenticated user's DID from JWT
|
||||
const pullerDIDKey contextKey = "puller.did"
|
||||
|
||||
// validationCacheEntry stores a validated service token with expiration
|
||||
type validationCacheEntry struct {
|
||||
serviceToken string
|
||||
validUntil time.Time
|
||||
err error // Cached error for fast-fail
|
||||
mu sync.Mutex // Per-entry lock to serialize cache population
|
||||
inFlight bool // True if another goroutine is fetching the token
|
||||
done chan struct{} // Closed when fetch completes
|
||||
}
|
||||
|
||||
// validationCache provides request-level caching for service tokens
|
||||
// This prevents concurrent layer uploads from racing on OAuth/DPoP requests
|
||||
type validationCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*validationCacheEntry // key: "did:holdDID"
|
||||
}
|
||||
|
||||
// newValidationCache creates a new validation cache
|
||||
func newValidationCache() *validationCache {
|
||||
return &validationCache{
|
||||
entries: make(map[string]*validationCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// getOrFetch retrieves a service token from cache or fetches it
|
||||
// Multiple concurrent requests for the same DID:holdDID will share the fetch operation
|
||||
func (vc *validationCache) getOrFetch(ctx context.Context, cacheKey string, fetchFunc func() (string, error)) (string, error) {
|
||||
// Fast path: check cache with read lock
|
||||
vc.mu.RLock()
|
||||
entry, exists := vc.entries[cacheKey]
|
||||
vc.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
// Entry exists, check if it's still valid
|
||||
entry.mu.Lock()
|
||||
|
||||
// If another goroutine is fetching, wait for it
|
||||
if entry.inFlight {
|
||||
done := entry.done
|
||||
entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Fetch completed, check result
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
if entry.err != nil {
|
||||
return "", entry.err
|
||||
}
|
||||
if time.Now().Before(entry.validUntil) {
|
||||
return entry.serviceToken, nil
|
||||
}
|
||||
// Fall through to refetch
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
} else {
|
||||
// Check if cached token is still valid
|
||||
if entry.err != nil && time.Now().Before(entry.validUntil) {
|
||||
// Return cached error (fast-fail)
|
||||
entry.mu.Unlock()
|
||||
return "", entry.err
|
||||
}
|
||||
if entry.err == nil && time.Now().Before(entry.validUntil) {
|
||||
// Return cached token
|
||||
token := entry.serviceToken
|
||||
entry.mu.Unlock()
|
||||
return token, nil
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: need to fetch token
|
||||
vc.mu.Lock()
|
||||
entry, exists = vc.entries[cacheKey]
|
||||
if !exists {
|
||||
// Create new entry
|
||||
entry = &validationCacheEntry{
|
||||
inFlight: true,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
vc.entries[cacheKey] = entry
|
||||
}
|
||||
vc.mu.Unlock()
|
||||
|
||||
// Lock the entry to perform fetch
|
||||
entry.mu.Lock()
|
||||
|
||||
// Double-check: another goroutine may have fetched while we waited
|
||||
if !entry.inFlight {
|
||||
if entry.err != nil && time.Now().Before(entry.validUntil) {
|
||||
err := entry.err
|
||||
entry.mu.Unlock()
|
||||
return "", err
|
||||
}
|
||||
if entry.err == nil && time.Now().Before(entry.validUntil) {
|
||||
token := entry.serviceToken
|
||||
entry.mu.Unlock()
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as in-flight and create fresh done channel for this fetch
|
||||
// IMPORTANT: Always create a new channel - a closed channel is not nil
|
||||
entry.done = make(chan struct{})
|
||||
entry.inFlight = true
|
||||
done := entry.done
|
||||
entry.mu.Unlock()
|
||||
|
||||
// Perform the fetch (outside the lock to allow other operations)
|
||||
serviceToken, err := fetchFunc()
|
||||
|
||||
// Update the entry with result
|
||||
entry.mu.Lock()
|
||||
entry.inFlight = false
|
||||
|
||||
if err != nil {
|
||||
// Cache errors for 5 seconds (fast-fail for subsequent requests)
|
||||
entry.err = err
|
||||
entry.validUntil = time.Now().Add(5 * time.Second)
|
||||
entry.serviceToken = ""
|
||||
} else {
|
||||
// Cache token for 45 seconds (covers typical Docker push operation)
|
||||
entry.err = nil
|
||||
entry.serviceToken = serviceToken
|
||||
entry.validUntil = time.Now().Add(45 * time.Second)
|
||||
}
|
||||
|
||||
// Signal completion to waiting goroutines
|
||||
close(done)
|
||||
entry.mu.Unlock()
|
||||
|
||||
return serviceToken, err
|
||||
}
|
||||
|
||||
// Global variables for initialization only
|
||||
// These are set by main.go during startup and copied into NamespaceResolver instances.
|
||||
// After initialization, request handling uses the NamespaceResolver's instance fields.
|
||||
var (
|
||||
globalRefresher *oauth.Refresher
|
||||
globalDatabase storage.DatabaseMetrics
|
||||
globalDatabase *sql.DB
|
||||
globalAuthorizer auth.HoldAuthorizer
|
||||
)
|
||||
|
||||
@@ -186,7 +46,7 @@ func SetGlobalRefresher(refresher *oauth.Refresher) {
|
||||
|
||||
// SetGlobalDatabase sets the database instance during initialization
|
||||
// Must be called before the registry starts serving requests
|
||||
func SetGlobalDatabase(database storage.DatabaseMetrics) {
|
||||
func SetGlobalDatabase(database *sql.DB) {
|
||||
globalDatabase = database
|
||||
}
|
||||
|
||||
@@ -204,14 +64,12 @@ func init() {
|
||||
// NamespaceResolver wraps a namespace and resolves names
|
||||
type NamespaceResolver struct {
|
||||
distribution.Namespace
|
||||
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
||||
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
||||
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
||||
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
||||
database storage.DatabaseMetrics // Metrics database (copied from global on init)
|
||||
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
|
||||
validationCache *validationCache // Request-level service token cache
|
||||
readmeFetcher *readme.Fetcher // README fetcher for repo pages
|
||||
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
||||
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
||||
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
||||
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
||||
sqlDB *sql.DB // Database for hold DID lookup and metrics (copied from global on init)
|
||||
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
|
||||
}
|
||||
|
||||
// initATProtoResolver initializes the name resolution middleware
|
||||
@@ -238,25 +96,16 @@ func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ drive
|
||||
// Copy shared services from globals into the instance
|
||||
// This avoids accessing globals during request handling
|
||||
return &NamespaceResolver{
|
||||
Namespace: ns,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
baseURL: baseURL,
|
||||
testMode: testMode,
|
||||
refresher: globalRefresher,
|
||||
database: globalDatabase,
|
||||
authorizer: globalAuthorizer,
|
||||
validationCache: newValidationCache(),
|
||||
readmeFetcher: readme.NewFetcher(),
|
||||
Namespace: ns,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
baseURL: baseURL,
|
||||
testMode: testMode,
|
||||
refresher: globalRefresher,
|
||||
sqlDB: globalDatabase,
|
||||
authorizer: globalAuthorizer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// authErrorMessage creates a user-friendly auth error with login URL
|
||||
func (nr *NamespaceResolver) authErrorMessage(message string) error {
|
||||
loginURL := fmt.Sprintf("%s/auth/oauth/login", nr.baseURL)
|
||||
fullMessage := fmt.Sprintf("%s - please re-authenticate at %s", message, loginURL)
|
||||
return errcode.ErrorCodeUnauthorized.WithMessage(fullMessage)
|
||||
}
|
||||
|
||||
// Repository resolves the repository name and delegates to underlying namespace
|
||||
// Handles names like:
|
||||
// - atcr.io/alice/myimage → resolve alice to DID
|
||||
@@ -290,113 +139,8 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
||||
}
|
||||
ctx = context.WithValue(ctx, holdDIDKey, holdDID)
|
||||
|
||||
// Auto-reconcile crew membership on first push/pull
|
||||
// This ensures users can push immediately after docker login without web sign-in
|
||||
// EnsureCrewMembership is best-effort and logs errors without failing the request
|
||||
// Run in background to avoid blocking registry operations if hold is offline
|
||||
if holdDID != "" && nr.refresher != nil {
|
||||
slog.Debug("Auto-reconciling crew membership", "component", "registry/middleware", "did", did, "hold_did", holdDID)
|
||||
client := atproto.NewClient(pdsEndpoint, did, "")
|
||||
go func(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
|
||||
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
||||
}(ctx, client, nr.refresher, holdDID)
|
||||
}
|
||||
|
||||
// Get service token for hold authentication (only if authenticated)
|
||||
// Use validation cache to prevent concurrent requests from racing on OAuth/DPoP
|
||||
// Route based on auth method from JWT token
|
||||
// IMPORTANT: Use PULLER's DID/PDS for service token, not owner's!
|
||||
// The puller (authenticated user) needs to authenticate to the hold service.
|
||||
var serviceToken string
|
||||
authMethod, _ := ctx.Value(authMethodKey).(string)
|
||||
pullerDID, _ := ctx.Value(pullerDIDKey).(string)
|
||||
var pullerPDSEndpoint string
|
||||
|
||||
// Only fetch service token if user is authenticated
|
||||
// Unauthenticated requests (like /v2/ ping) should not trigger token fetching
|
||||
if authMethod != "" && pullerDID != "" {
|
||||
// Resolve puller's PDS endpoint for service token request
|
||||
_, _, pullerPDSEndpoint, err = atproto.ResolveIdentity(ctx, pullerDID)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to resolve puller's PDS, falling back to anonymous access",
|
||||
"component", "registry/middleware",
|
||||
"pullerDID", pullerDID,
|
||||
"error", err)
|
||||
// Continue without service token - hold will decide if anonymous access is allowed
|
||||
} else {
|
||||
// Create cache key: "pullerDID:holdDID"
|
||||
cacheKey := fmt.Sprintf("%s:%s", pullerDID, holdDID)
|
||||
|
||||
// Fetch service token through validation cache
|
||||
// This ensures only ONE request per pullerDID:holdDID pair fetches the token
|
||||
// Concurrent requests will wait for the first request to complete
|
||||
var fetchErr error
|
||||
serviceToken, fetchErr = nr.validationCache.getOrFetch(ctx, cacheKey, func() (string, error) {
|
||||
if authMethod == token.AuthMethodAppPassword {
|
||||
// App-password flow: use Bearer token authentication
|
||||
slog.Debug("Using app-password flow for service token",
|
||||
"component", "registry/middleware",
|
||||
"pullerDID", pullerDID,
|
||||
"cacheKey", cacheKey)
|
||||
|
||||
token, err := auth.GetOrFetchServiceTokenWithAppPassword(ctx, pullerDID, holdDID, pullerPDSEndpoint)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get service token with app-password",
|
||||
"component", "registry/middleware",
|
||||
"pullerDID", pullerDID,
|
||||
"holdDID", holdDID,
|
||||
"pullerPDSEndpoint", pullerPDSEndpoint,
|
||||
"error", err)
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
} else if nr.refresher != nil {
|
||||
// OAuth flow: use DPoP authentication
|
||||
slog.Debug("Using OAuth flow for service token",
|
||||
"component", "registry/middleware",
|
||||
"pullerDID", pullerDID,
|
||||
"cacheKey", cacheKey)
|
||||
|
||||
token, err := auth.GetOrFetchServiceToken(ctx, nr.refresher, pullerDID, holdDID, pullerPDSEndpoint)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get service token with OAuth",
|
||||
"component", "registry/middleware",
|
||||
"pullerDID", pullerDID,
|
||||
"holdDID", holdDID,
|
||||
"pullerPDSEndpoint", pullerPDSEndpoint,
|
||||
"error", err)
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
return "", fmt.Errorf("no authentication method available")
|
||||
})
|
||||
|
||||
// Handle errors from cached fetch
|
||||
if fetchErr != nil {
|
||||
errMsg := fetchErr.Error()
|
||||
|
||||
// Check for app-password specific errors
|
||||
if authMethod == token.AuthMethodAppPassword {
|
||||
if strings.Contains(errMsg, "expired or invalid") || strings.Contains(errMsg, "no app-password") {
|
||||
return nil, nr.authErrorMessage("App-password authentication failed. Please re-authenticate with: docker login")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for OAuth specific errors
|
||||
if strings.Contains(errMsg, "OAuth session") || strings.Contains(errMsg, "OAuth validation") {
|
||||
return nil, nr.authErrorMessage("OAuth session expired or invalidated by PDS. Your session has been cleared")
|
||||
}
|
||||
|
||||
// Generic service token error
|
||||
return nil, nr.authErrorMessage(fmt.Sprintf("Failed to obtain storage credentials: %v", fetchErr))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slog.Debug("Skipping service token fetch for unauthenticated request",
|
||||
"component", "registry/middleware",
|
||||
"ownerDID", did)
|
||||
}
|
||||
// Note: Profile and crew membership are now ensured in UserContextMiddleware
|
||||
// via EnsureUserSetup() - no need to call here
|
||||
|
||||
// Create a new reference with identity/image format
|
||||
// Use the identity (or DID) as the namespace to ensure canonical format
|
||||
@@ -413,63 +157,30 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create ATProto client for manifest/tag operations
|
||||
// Pulls: ATProto records are public, no auth needed
|
||||
// Pushes: Need auth, but puller must be owner anyway
|
||||
var atprotoClient *atproto.Client
|
||||
|
||||
if pullerDID == did {
|
||||
// Puller is owner - may need auth for pushes
|
||||
if authMethod == token.AuthMethodOAuth && nr.refresher != nil {
|
||||
atprotoClient = atproto.NewClientWithSessionProvider(pdsEndpoint, did, nr.refresher)
|
||||
} else if authMethod == token.AuthMethodAppPassword {
|
||||
accessToken, _ := auth.GetGlobalTokenCache().Get(did)
|
||||
atprotoClient = atproto.NewClient(pdsEndpoint, did, accessToken)
|
||||
} else {
|
||||
atprotoClient = atproto.NewClient(pdsEndpoint, did, "")
|
||||
}
|
||||
} else {
|
||||
// Puller != owner - reads only, no auth needed
|
||||
atprotoClient = atproto.NewClient(pdsEndpoint, did, "")
|
||||
}
|
||||
|
||||
// IMPORTANT: Use only the image name (not identity/image) for ATProto storage
|
||||
// ATProto records are scoped to the user's DID, so we don't need the identity prefix
|
||||
// Example: "evan.jarrett.net/debian" -> store as "debian"
|
||||
repositoryName := imageName
|
||||
|
||||
// Default auth method to OAuth if not already set (backward compatibility with old tokens)
|
||||
if authMethod == "" {
|
||||
authMethod = token.AuthMethodOAuth
|
||||
// Get UserContext from request context (set by UserContextMiddleware)
|
||||
userCtx := auth.FromContext(ctx)
|
||||
if userCtx == nil {
|
||||
return nil, fmt.Errorf("UserContext not set in request context - ensure UserContextMiddleware is configured")
|
||||
}
|
||||
|
||||
// Set target repository info on UserContext
|
||||
// ATProtoClient is cached lazily via userCtx.GetATProtoClient()
|
||||
userCtx.SetTarget(did, handle, pdsEndpoint, repositoryName, holdDID)
|
||||
|
||||
// Create routing repository - routes manifests to ATProto, blobs to hold service
|
||||
// The registry is stateless - no local storage is used
|
||||
// Bundle all context into a single RegistryContext struct
|
||||
//
|
||||
// NOTE: We create a fresh RoutingRepository on every request (no caching) because:
|
||||
// 1. Each layer upload is a separate HTTP request (possibly different process)
|
||||
// 2. OAuth sessions can be refreshed/invalidated between requests
|
||||
// 3. The refresher already caches sessions efficiently (in-memory + DB)
|
||||
// 4. Caching the repository with a stale ATProtoClient causes refresh token errors
|
||||
registryCtx := &storage.RegistryContext{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
HoldDID: holdDID,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Repository: repositoryName,
|
||||
ServiceToken: serviceToken, // Cached service token from puller's PDS
|
||||
ATProtoClient: atprotoClient,
|
||||
AuthMethod: authMethod, // Auth method from JWT token
|
||||
PullerDID: pullerDID, // Authenticated user making the request
|
||||
PullerPDSEndpoint: pullerPDSEndpoint, // Puller's PDS for service token refresh
|
||||
Database: nr.database,
|
||||
Authorizer: nr.authorizer,
|
||||
Refresher: nr.refresher,
|
||||
ReadmeFetcher: nr.readmeFetcher,
|
||||
}
|
||||
|
||||
return storage.NewRoutingRepository(repo, registryCtx), nil
|
||||
// 4. ATProtoClient is now cached in UserContext via GetATProtoClient()
|
||||
return storage.NewRoutingRepository(repo, userCtx, nr.sqlDB), nil
|
||||
}
|
||||
|
||||
// Repositories delegates to underlying namespace
|
||||
@@ -504,8 +215,8 @@ func (nr *NamespaceResolver) findHoldDID(ctx context.Context, did, pdsEndpoint s
|
||||
}
|
||||
|
||||
if profile != nil && profile.DefaultHold != "" {
|
||||
// Profile exists with defaultHold set
|
||||
// In test mode, verify it's reachable before using it
|
||||
// In test mode, verify the hold is reachable (fall back to default if not)
|
||||
// In production, trust the user's profile and return their hold
|
||||
if nr.testMode {
|
||||
if nr.isHoldReachable(ctx, profile.DefaultHold) {
|
||||
return profile.DefaultHold
|
||||
@@ -584,3 +295,49 @@ func ExtractAuthMethod(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// UserContextMiddleware creates a UserContext from the extracted JWT claims
|
||||
// and stores it in the request context for use throughout request processing.
|
||||
// This middleware should be chained AFTER ExtractAuthMethod.
|
||||
func UserContextMiddleware(deps *auth.Dependencies) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get values set by ExtractAuthMethod
|
||||
authMethod, _ := ctx.Value(authMethodKey).(string)
|
||||
pullerDID, _ := ctx.Value(pullerDIDKey).(string)
|
||||
|
||||
// Build UserContext with all dependencies
|
||||
userCtx := auth.NewUserContext(pullerDID, authMethod, r.Method, deps)
|
||||
|
||||
// Eagerly resolve user's PDS for authenticated users
|
||||
// This is a fast path that avoids lazy loading in most cases
|
||||
if userCtx.IsAuthenticated {
|
||||
if err := userCtx.ResolvePDS(ctx); err != nil {
|
||||
slog.Warn("Failed to resolve puller's PDS",
|
||||
"component", "registry/middleware",
|
||||
"did", pullerDID,
|
||||
"error", err)
|
||||
// Continue without PDS - will fail on service token request
|
||||
}
|
||||
|
||||
// Ensure user has profile and crew membership (runs in background, cached)
|
||||
userCtx.EnsureUserSetup()
|
||||
}
|
||||
|
||||
// Store UserContext in request context
|
||||
ctx = auth.WithUserContext(ctx, userCtx)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
slog.Debug("Created UserContext",
|
||||
"component", "registry/middleware",
|
||||
"isAuthenticated", userCtx.IsAuthenticated,
|
||||
"authMethod", userCtx.AuthMethod,
|
||||
"action", userCtx.Action.String(),
|
||||
"pullerDID", pullerDID)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,17 +129,6 @@ func TestInitATProtoResolver(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthErrorMessage tests the error message formatting
|
||||
func TestAuthErrorMessage(t *testing.T) {
|
||||
resolver := &NamespaceResolver{
|
||||
baseURL: "https://atcr.io",
|
||||
}
|
||||
|
||||
err := resolver.authErrorMessage("OAuth session expired")
|
||||
assert.Contains(t, err.Error(), "OAuth session expired")
|
||||
assert.Contains(t, err.Error(), "https://atcr.io/auth/oauth/login")
|
||||
}
|
||||
|
||||
// TestFindHoldDID_DefaultFallback tests default hold DID fallback
|
||||
func TestFindHoldDID_DefaultFallback(t *testing.T) {
|
||||
// Start a mock PDS server that returns 404 for profile and empty list for holds
|
||||
|
||||
@@ -29,6 +29,7 @@ type UIDependencies struct {
|
||||
HealthChecker *holdhealth.Checker
|
||||
ReadmeFetcher *readme.Fetcher
|
||||
Templates *template.Template
|
||||
DefaultHoldDID string // For UserContext creation
|
||||
}
|
||||
|
||||
// RegisterUIRoutes registers all web UI and API routes on the provided router
|
||||
@@ -36,6 +37,14 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
// Extract trimmed registry URL for templates
|
||||
registryURL := trimRegistryURL(deps.BaseURL)
|
||||
|
||||
// Create web auth dependencies for middleware (enables UserContext in web routes)
|
||||
webAuthDeps := middleware.WebAuthDeps{
|
||||
SessionStore: deps.SessionStore,
|
||||
Database: deps.Database,
|
||||
Refresher: deps.Refresher,
|
||||
DefaultHoldDID: deps.DefaultHoldDID,
|
||||
}
|
||||
|
||||
// OAuth login routes (public)
|
||||
router.Get("/auth/oauth/login", (&uihandlers.LoginHandler{
|
||||
Templates: deps.Templates,
|
||||
@@ -45,7 +54,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
|
||||
// Public routes (with optional auth for navbar)
|
||||
// SECURITY: Public pages use read-only DB
|
||||
router.Get("/", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.HomeHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -53,7 +62,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Get("/api/recent-pushes", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/recent-pushes", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.RecentPushesHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -63,7 +72,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// SECURITY: Search uses read-only DB to prevent writes and limit access to sensitive tables
|
||||
router.Get("/search", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/search", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.SearchHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -71,7 +80,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Get("/api/search-results", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/search-results", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.SearchResultsHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -80,7 +89,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// Install page (public)
|
||||
router.Get("/install", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/install", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.InstallHandler{
|
||||
Templates: deps.Templates,
|
||||
RegistryURL: registryURL,
|
||||
@@ -88,7 +97,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// API route for repository stats (public, read-only)
|
||||
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.GetStatsHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -96,7 +105,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// API routes for stars (require authentication)
|
||||
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
||||
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.StarRepositoryHandler{
|
||||
DB: deps.Database, // Needs write access
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -104,7 +113,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
||||
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.UnstarRepositoryHandler{
|
||||
DB: deps.Database, // Needs write access
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -112,7 +121,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.CheckStarHandler{
|
||||
DB: deps.ReadOnlyDB, // Read-only check
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -121,7 +130,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// Manifest detail API endpoint
|
||||
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.ManifestDetailHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -133,7 +142,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
HealthChecker: deps.HealthChecker,
|
||||
}).ServeHTTP)
|
||||
|
||||
router.Get("/u/{handle}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/u/{handle}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.UserPageHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -152,7 +161,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
DB: deps.ReadOnlyDB,
|
||||
}).ServeHTTP)
|
||||
|
||||
router.Get("/r/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/r/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.RepositoryPageHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -166,7 +175,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
|
||||
// Authenticated routes
|
||||
router.Group(func(r chi.Router) {
|
||||
r.Use(middleware.RequireAuth(deps.SessionStore, deps.Database))
|
||||
r.Use(middleware.RequireAuthWithDeps(webAuthDeps))
|
||||
|
||||
r.Get("/settings", (&uihandlers.SettingsHandler{
|
||||
Templates: deps.Templates,
|
||||
@@ -226,7 +235,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
router.Post("/auth/logout", logoutHandler.ServeHTTP)
|
||||
|
||||
// Custom 404 handler
|
||||
router.NotFound(middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.NotFound(middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.NotFoundHandler{
|
||||
Templates: deps.Templates,
|
||||
RegistryURL: registryURL,
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"atcr.io/pkg/appview/readme"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
// DatabaseMetrics interface for tracking pull/push counts and querying hold DIDs
|
||||
type DatabaseMetrics interface {
|
||||
IncrementPullCount(did, repository string) error
|
||||
IncrementPushCount(did, repository string) error
|
||||
GetLatestHoldDIDForRepo(did, repository string) (string, error)
|
||||
}
|
||||
|
||||
// RegistryContext bundles all the context needed for registry operations
|
||||
// This includes both per-request data (DID, hold) and shared services
|
||||
type RegistryContext struct {
|
||||
// Per-request identity and routing information
|
||||
// Owner = the user whose repository is being accessed
|
||||
// Puller = the authenticated user making the request (from JWT Subject)
|
||||
DID string // Owner's DID - whose repo is being accessed (e.g., "did:plc:abc123")
|
||||
Handle string // Owner's handle (e.g., "alice.bsky.social")
|
||||
HoldDID string // Hold service DID (e.g., "did:web:hold01.atcr.io")
|
||||
PDSEndpoint string // Owner's PDS endpoint URL
|
||||
Repository string // Image repository name (e.g., "debian")
|
||||
ServiceToken string // Service token for hold authentication (from puller's PDS)
|
||||
ATProtoClient *atproto.Client // Authenticated ATProto client for the owner
|
||||
AuthMethod string // Auth method used ("oauth" or "app_password")
|
||||
PullerDID string // Puller's DID - who is making the request (from JWT Subject)
|
||||
PullerPDSEndpoint string // Puller's PDS endpoint URL
|
||||
|
||||
// Shared services (same for all requests)
|
||||
Database DatabaseMetrics // Metrics tracking database
|
||||
Authorizer auth.HoldAuthorizer // Hold access authorization
|
||||
Refresher *oauth.Refresher // OAuth session manager
|
||||
ReadmeFetcher *readme.Fetcher // README fetcher for repo pages
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockDatabaseMetrics struct {
|
||||
mu sync.Mutex
|
||||
pullCount int
|
||||
pushCount int
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) IncrementPullCount(did, repository string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.pullCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) IncrementPushCount(did, repository string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.pushCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
||||
// Return empty string for mock - tests can override if needed
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) getPullCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.pullCount
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) getPushCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.pushCount
|
||||
}
|
||||
|
||||
type mockHoldAuthorizer struct{}
|
||||
|
||||
func (m *mockHoldAuthorizer) Authorize(holdDID, userDID, permission string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func TestRegistryContext_Fields(t *testing.T) {
|
||||
// Create a sample RegistryContext
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Handle: "alice.bsky.social",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
PDSEndpoint: "https://bsky.social",
|
||||
Repository: "debian",
|
||||
ServiceToken: "test-token",
|
||||
ATProtoClient: &atproto.Client{
|
||||
// Mock client - would need proper initialization in real tests
|
||||
},
|
||||
Database: &mockDatabaseMetrics{},
|
||||
}
|
||||
|
||||
// Verify fields are accessible
|
||||
if ctx.DID != "did:plc:test123" {
|
||||
t.Errorf("Expected DID %q, got %q", "did:plc:test123", ctx.DID)
|
||||
}
|
||||
if ctx.Handle != "alice.bsky.social" {
|
||||
t.Errorf("Expected Handle %q, got %q", "alice.bsky.social", ctx.Handle)
|
||||
}
|
||||
if ctx.HoldDID != "did:web:hold01.atcr.io" {
|
||||
t.Errorf("Expected HoldDID %q, got %q", "did:web:hold01.atcr.io", ctx.HoldDID)
|
||||
}
|
||||
if ctx.PDSEndpoint != "https://bsky.social" {
|
||||
t.Errorf("Expected PDSEndpoint %q, got %q", "https://bsky.social", ctx.PDSEndpoint)
|
||||
}
|
||||
if ctx.Repository != "debian" {
|
||||
t.Errorf("Expected Repository %q, got %q", "debian", ctx.Repository)
|
||||
}
|
||||
if ctx.ServiceToken != "test-token" {
|
||||
t.Errorf("Expected ServiceToken %q, got %q", "test-token", ctx.ServiceToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryContext_DatabaseInterface(t *testing.T) {
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := &RegistryContext{
|
||||
Database: db,
|
||||
}
|
||||
|
||||
// Test that interface methods are callable
|
||||
err := ctx.Database.IncrementPullCount("did:plc:test", "repo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
err = ctx.Database.IncrementPushCount("did:plc:test", "repo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add more comprehensive tests:
|
||||
// - Test ATProtoClient integration
|
||||
// - Test OAuth Refresher integration
|
||||
// - Test HoldAuthorizer integration
|
||||
// - Test nil handling for optional fields
|
||||
// - Integration tests with real components
|
||||
@@ -1,93 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
// EnsureCrewMembership attempts to register the user as a crew member on their default hold.
|
||||
// The hold's requestCrew endpoint handles all authorization logic (checking allowAllCrew, existing membership, etc).
|
||||
// This is best-effort and does not fail on errors.
|
||||
func EnsureCrewMembership(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, defaultHoldDID string) {
|
||||
if defaultHoldDID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize URL to DID if needed
|
||||
holdDID := atproto.ResolveHoldDIDFromURL(defaultHoldDID)
|
||||
if holdDID == "" {
|
||||
slog.Warn("failed to resolve hold DID", "defaultHold", defaultHoldDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
holdEndpoint := atproto.ResolveHoldURL(holdDID)
|
||||
|
||||
// Get service token for the hold
|
||||
// Only works with OAuth (refresher required) - app passwords can't get service tokens
|
||||
if refresher == nil {
|
||||
slog.Debug("skipping crew registration - no OAuth refresher (app password flow)", "holdDID", holdDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the refresher to match OAuthSessionRefresher interface
|
||||
serviceToken, err := auth.GetOrFetchServiceToken(ctx, refresher, client.DID(), holdDID, client.PDSEndpoint())
|
||||
if err != nil {
|
||||
slog.Warn("failed to get service token", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Call requestCrew endpoint - it handles all the logic:
|
||||
// - Checks allowAllCrew flag
|
||||
// - Checks if already a crew member (returns success if so)
|
||||
// - Creates crew record if authorized
|
||||
if err := requestCrewMembership(ctx, holdEndpoint, serviceToken); err != nil {
|
||||
slog.Warn("failed to request crew membership", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("successfully registered as crew member", "holdDID", holdDID, "userDID", client.DID())
|
||||
}
|
||||
|
||||
// requestCrewMembership calls the hold's requestCrew endpoint
|
||||
// The endpoint handles all authorization and duplicate checking internally
|
||||
func requestCrewMembership(ctx context.Context, holdEndpoint, serviceToken string) error {
|
||||
// Add 5 second timeout to prevent hanging on offline holds
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
url := fmt.Sprintf("%s%s", holdEndpoint, atproto.HoldRequestCrew)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+serviceToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
// Read response body to capture actual error message from hold
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("requestCrew failed with status %d (failed to read error body: %w)", resp.StatusCode, readErr)
|
||||
}
|
||||
return fmt.Errorf("requestCrew failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnsureCrewMembership_EmptyHoldDID(t *testing.T) {
|
||||
// Test that empty hold DID returns early without error (best-effort function)
|
||||
EnsureCrewMembership(context.Background(), nil, nil, "")
|
||||
// If we get here without panic, test passes
|
||||
}
|
||||
|
||||
// TODO: Add comprehensive tests with HTTP client mocking
|
||||
@@ -3,6 +3,7 @@ package storage
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,8 +13,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/appview/readme"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/opencontainers/go-digest"
|
||||
)
|
||||
@@ -21,22 +24,24 @@ import (
|
||||
// ManifestStore implements distribution.ManifestService
|
||||
// It stores manifests in ATProto as records
|
||||
type ManifestStore struct {
|
||||
ctx *RegistryContext // Context with user/hold info
|
||||
blobStore distribution.BlobStore // Blob store for fetching config during push
|
||||
ctx *auth.UserContext // User context with identity, target, permissions
|
||||
blobStore distribution.BlobStore // Blob store for fetching config during push
|
||||
sqlDB *sql.DB // Database for pull/push counts
|
||||
}
|
||||
|
||||
// NewManifestStore creates a new ATProto-backed manifest store
|
||||
func NewManifestStore(ctx *RegistryContext, blobStore distribution.BlobStore) *ManifestStore {
|
||||
func NewManifestStore(userCtx *auth.UserContext, blobStore distribution.BlobStore, sqlDB *sql.DB) *ManifestStore {
|
||||
return &ManifestStore{
|
||||
ctx: ctx,
|
||||
ctx: userCtx,
|
||||
blobStore: blobStore,
|
||||
sqlDB: sqlDB,
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a manifest exists by digest
|
||||
func (s *ManifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
|
||||
rkey := digestToRKey(dgst)
|
||||
_, err := s.ctx.ATProtoClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
_, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
if err != nil {
|
||||
// If not found, return false without error
|
||||
if errors.Is(err, atproto.ErrRecordNotFound) {
|
||||
@@ -50,10 +55,10 @@ func (s *ManifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, e
|
||||
// Get retrieves a manifest by digest
|
||||
func (s *ManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error) {
|
||||
rkey := digestToRKey(dgst)
|
||||
record, err := s.ctx.ATProtoClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
record, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
if err != nil {
|
||||
return nil, distribution.ErrManifestUnknownRevision{
|
||||
Name: s.ctx.Repository,
|
||||
Name: s.ctx.TargetRepo,
|
||||
Revision: dgst,
|
||||
}
|
||||
}
|
||||
@@ -67,7 +72,7 @@ func (s *ManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...
|
||||
|
||||
// New records: Download blob from ATProto blob storage
|
||||
if manifestRecord.ManifestBlob != nil && manifestRecord.ManifestBlob.Ref.Link != "" {
|
||||
ociManifest, err = s.ctx.ATProtoClient.GetBlob(ctx, manifestRecord.ManifestBlob.Ref.Link)
|
||||
ociManifest, err = s.ctx.GetATProtoClient().GetBlob(ctx, manifestRecord.ManifestBlob.Ref.Link)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download manifest blob: %w", err)
|
||||
}
|
||||
@@ -75,12 +80,12 @@ func (s *ManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...
|
||||
|
||||
// Track pull count (increment asynchronously to avoid blocking the response)
|
||||
// Only count GET requests (actual downloads), not HEAD requests (existence checks)
|
||||
if s.ctx.Database != nil {
|
||||
if s.sqlDB != nil {
|
||||
// Check HTTP method from context (distribution library stores it as "http.request.method")
|
||||
if method, ok := ctx.Value("http.request.method").(string); ok && method == "GET" {
|
||||
go func() {
|
||||
if err := s.ctx.Database.IncrementPullCount(s.ctx.DID, s.ctx.Repository); err != nil {
|
||||
slog.Warn("Failed to increment pull count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
|
||||
if err := db.IncrementPullCount(s.sqlDB, s.ctx.TargetOwnerDID, s.ctx.TargetRepo); err != nil {
|
||||
slog.Warn("Failed to increment pull count", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -107,20 +112,20 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
dgst := digest.FromBytes(payload)
|
||||
|
||||
// Upload manifest as blob to PDS
|
||||
blobRef, err := s.ctx.ATProtoClient.UploadBlob(ctx, payload, mediaType)
|
||||
blobRef, err := s.ctx.GetATProtoClient().UploadBlob(ctx, payload, mediaType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload manifest blob: %w", err)
|
||||
}
|
||||
|
||||
// Create manifest record with structured metadata
|
||||
manifestRecord, err := atproto.NewManifestRecord(s.ctx.Repository, dgst.String(), payload)
|
||||
manifestRecord, err := atproto.NewManifestRecord(s.ctx.TargetRepo, dgst.String(), payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create manifest record: %w", err)
|
||||
}
|
||||
|
||||
// Set the blob reference, hold DID, and hold endpoint
|
||||
manifestRecord.ManifestBlob = blobRef
|
||||
manifestRecord.HoldDID = s.ctx.HoldDID // Primary reference (DID)
|
||||
manifestRecord.HoldDID = s.ctx.TargetHoldDID // Primary reference (DID)
|
||||
|
||||
// Extract Dockerfile labels from config blob and add to annotations
|
||||
// Only for image manifests (not manifest lists which don't have config blobs)
|
||||
@@ -150,7 +155,7 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
platform = fmt.Sprintf("%s/%s", ref.Platform.OS, ref.Platform.Architecture)
|
||||
}
|
||||
slog.Warn("Manifest list references non-existent child manifest",
|
||||
"repository", s.ctx.Repository,
|
||||
"repository", s.ctx.TargetRepo,
|
||||
"missingDigest", ref.Digest,
|
||||
"platform", platform)
|
||||
return "", distribution.ErrManifestBlobUnknown{Digest: refDigest}
|
||||
@@ -185,16 +190,16 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
|
||||
// Store manifest record in ATProto
|
||||
rkey := digestToRKey(dgst)
|
||||
_, err = s.ctx.ATProtoClient.PutRecord(ctx, atproto.ManifestCollection, rkey, manifestRecord)
|
||||
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.ManifestCollection, rkey, manifestRecord)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store manifest record in ATProto: %w", err)
|
||||
}
|
||||
|
||||
// Track push count (increment asynchronously to avoid blocking the response)
|
||||
if s.ctx.Database != nil {
|
||||
if s.sqlDB != nil {
|
||||
go func() {
|
||||
if err := s.ctx.Database.IncrementPushCount(s.ctx.DID, s.ctx.Repository); err != nil {
|
||||
slog.Warn("Failed to increment push count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
|
||||
if err := db.IncrementPushCount(s.sqlDB, s.ctx.TargetOwnerDID, s.ctx.TargetRepo); err != nil {
|
||||
slog.Warn("Failed to increment push count", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -204,9 +209,9 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
for _, option := range options {
|
||||
if tagOpt, ok := option.(distribution.WithTagOption); ok {
|
||||
tag = tagOpt.Tag
|
||||
tagRecord := atproto.NewTagRecord(s.ctx.ATProtoClient.DID(), s.ctx.Repository, tag, dgst.String())
|
||||
tagRKey := atproto.RepositoryTagToRKey(s.ctx.Repository, tag)
|
||||
_, err = s.ctx.ATProtoClient.PutRecord(ctx, atproto.TagCollection, tagRKey, tagRecord)
|
||||
tagRecord := atproto.NewTagRecord(s.ctx.GetATProtoClient().DID(), s.ctx.TargetRepo, tag, dgst.String())
|
||||
tagRKey := atproto.RepositoryTagToRKey(s.ctx.TargetRepo, tag)
|
||||
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.TagCollection, tagRKey, tagRecord)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store tag in ATProto: %w", err)
|
||||
}
|
||||
@@ -215,17 +220,19 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
|
||||
// Notify hold about manifest upload (for layer tracking and Bluesky posts)
|
||||
// Do this asynchronously to avoid blocking the push
|
||||
if tag != "" && s.ctx.ServiceToken != "" && s.ctx.Handle != "" {
|
||||
go func() {
|
||||
// Get service token before goroutine (requires context)
|
||||
serviceToken, _ := s.ctx.GetServiceToken(ctx)
|
||||
if tag != "" && serviceToken != "" && s.ctx.TargetOwnerHandle != "" {
|
||||
go func(serviceToken string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("Panic in notifyHoldAboutManifest", "panic", r)
|
||||
}
|
||||
}()
|
||||
if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String()); err != nil {
|
||||
if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String(), serviceToken); err != nil {
|
||||
slog.Warn("Failed to notify hold about manifest", "error", err)
|
||||
}
|
||||
}()
|
||||
}(serviceToken)
|
||||
}
|
||||
|
||||
// Create or update repo page asynchronously if manifest has relevant annotations
|
||||
@@ -245,7 +252,7 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
// Delete removes a manifest
|
||||
func (s *ManifestStore) Delete(ctx context.Context, dgst digest.Digest) error {
|
||||
rkey := digestToRKey(dgst)
|
||||
return s.ctx.ATProtoClient.DeleteRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
return s.ctx.GetATProtoClient().DeleteRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
}
|
||||
|
||||
// digestToRKey converts a digest to an ATProto record key
|
||||
@@ -300,18 +307,17 @@ func (s *ManifestStore) extractConfigLabels(ctx context.Context, configDigestStr
|
||||
|
||||
// notifyHoldAboutManifest notifies the hold service about a manifest upload
|
||||
// This enables the hold to create layer records and Bluesky posts
|
||||
func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.ManifestRecord, tag, manifestDigest string) error {
|
||||
// Skip if no service token configured (e.g., anonymous pulls)
|
||||
if s.ctx.ServiceToken == "" {
|
||||
func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.ManifestRecord, tag, manifestDigest, serviceToken string) error {
|
||||
// Skip if no service token provided
|
||||
if serviceToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
// For did:web, this is straightforward (e.g., did:web:hold01.atcr.io → https://hold01.atcr.io)
|
||||
holdEndpoint := atproto.ResolveHoldURL(s.ctx.HoldDID)
|
||||
holdEndpoint := atproto.ResolveHoldURL(s.ctx.TargetHoldDID)
|
||||
|
||||
// Use service token from middleware (already cached and validated)
|
||||
serviceToken := s.ctx.ServiceToken
|
||||
// Service token is passed in (already cached and validated)
|
||||
|
||||
// Build notification request
|
||||
manifestData := map[string]any{
|
||||
@@ -360,10 +366,10 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
|
||||
}
|
||||
|
||||
notifyReq := map[string]any{
|
||||
"repository": s.ctx.Repository,
|
||||
"repository": s.ctx.TargetRepo,
|
||||
"tag": tag,
|
||||
"userDid": s.ctx.DID,
|
||||
"userHandle": s.ctx.Handle,
|
||||
"userDid": s.ctx.TargetOwnerDID,
|
||||
"userHandle": s.ctx.TargetOwnerHandle,
|
||||
"manifest": manifestData,
|
||||
}
|
||||
|
||||
@@ -401,7 +407,7 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
|
||||
// Parse response (optional logging)
|
||||
var notifyResp map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(¬ifyResp); err == nil {
|
||||
slog.Info("Hold notification successful", "repository", s.ctx.Repository, "tag", tag, "response", notifyResp)
|
||||
slog.Info("Hold notification successful", "repository", s.ctx.TargetRepo, "tag", tag, "response", notifyResp)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -412,17 +418,17 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
|
||||
// Only creates a new record if one doesn't exist (doesn't overwrite user's custom content)
|
||||
func (s *ManifestStore) ensureRepoPage(ctx context.Context, manifestRecord *atproto.ManifestRecord) {
|
||||
// Check if repo page already exists (don't overwrite user's custom content)
|
||||
rkey := s.ctx.Repository
|
||||
_, err := s.ctx.ATProtoClient.GetRecord(ctx, atproto.RepoPageCollection, rkey)
|
||||
rkey := s.ctx.TargetRepo
|
||||
_, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.RepoPageCollection, rkey)
|
||||
if err == nil {
|
||||
// Record already exists - don't overwrite
|
||||
slog.Debug("Repo page already exists, skipping creation", "did", s.ctx.DID, "repository", s.ctx.Repository)
|
||||
slog.Debug("Repo page already exists, skipping creation", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo)
|
||||
return
|
||||
}
|
||||
|
||||
// Only continue if it's a "not found" error - other errors mean we should skip
|
||||
if !errors.Is(err, atproto.ErrRecordNotFound) {
|
||||
slog.Warn("Failed to check for existing repo page", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
|
||||
slog.Warn("Failed to check for existing repo page", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -448,26 +454,23 @@ func (s *ManifestStore) ensureRepoPage(ctx context.Context, manifestRecord *atpr
|
||||
}
|
||||
|
||||
// Create new repo page record with description and optional avatar
|
||||
repoPage := atproto.NewRepoPageRecord(s.ctx.Repository, description, avatarRef)
|
||||
repoPage := atproto.NewRepoPageRecord(s.ctx.TargetRepo, description, avatarRef)
|
||||
|
||||
slog.Info("Creating repo page from manifest annotations", "did", s.ctx.DID, "repository", s.ctx.Repository, "descriptionLength", len(description), "hasAvatar", avatarRef != nil)
|
||||
slog.Info("Creating repo page from manifest annotations", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "descriptionLength", len(description), "hasAvatar", avatarRef != nil)
|
||||
|
||||
_, err = s.ctx.ATProtoClient.PutRecord(ctx, atproto.RepoPageCollection, rkey, repoPage)
|
||||
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.RepoPageCollection, rkey, repoPage)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to create repo page", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
|
||||
slog.Warn("Failed to create repo page", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Repo page created successfully", "did", s.ctx.DID, "repository", s.ctx.Repository)
|
||||
slog.Info("Repo page created successfully", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo)
|
||||
}
|
||||
|
||||
// fetchReadmeContent attempts to fetch README content from external sources
|
||||
// Priority: io.atcr.readme annotation > derived from org.opencontainers.image.source
|
||||
// Returns the raw markdown content, or empty string if not available
|
||||
func (s *ManifestStore) fetchReadmeContent(ctx context.Context, annotations map[string]string) string {
|
||||
if s.ctx.ReadmeFetcher == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Create a context with timeout for README fetching (don't block push too long)
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
@@ -614,7 +617,7 @@ func (s *ManifestStore) fetchAndUploadIcon(ctx context.Context, iconURL string)
|
||||
}
|
||||
|
||||
// Upload the icon as a blob to the user's PDS
|
||||
blobRef, err := s.ctx.ATProtoClient.UploadBlob(ctx, iconData, mimeType)
|
||||
blobRef, err := s.ctx.GetATProtoClient().UploadBlob(ctx, iconData, mimeType)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to upload icon blob", "url", iconURL, "error", err)
|
||||
return nil
|
||||
|
||||
@@ -8,15 +8,13 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/opencontainers/go-digest"
|
||||
)
|
||||
|
||||
// mockDatabaseMetrics removed - using the one from context_test.go
|
||||
|
||||
// mockBlobStore is a minimal mock of distribution.BlobStore for testing
|
||||
type mockBlobStore struct {
|
||||
blobs map[digest.Digest][]byte
|
||||
@@ -72,16 +70,11 @@ func (m *mockBlobStore) Open(ctx context.Context, dgst digest.Digest) (io.ReadSe
|
||||
return nil, nil // Not needed for current tests
|
||||
}
|
||||
|
||||
// mockRegistryContext creates a mock RegistryContext for testing
|
||||
func mockRegistryContext(client *atproto.Client, repository, holdDID, did, handle string, database DatabaseMetrics) *RegistryContext {
|
||||
return &RegistryContext{
|
||||
ATProtoClient: client,
|
||||
Repository: repository,
|
||||
HoldDID: holdDID,
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
Database: database,
|
||||
}
|
||||
// mockUserContextForManifest creates a mock auth.UserContext for manifest store testing
|
||||
func mockUserContextForManifest(pdsEndpoint, repository, holdDID, ownerDID, ownerHandle string) *auth.UserContext {
|
||||
userCtx := auth.NewUserContext(ownerDID, "oauth", "PUT", nil)
|
||||
userCtx.SetTarget(ownerDID, ownerHandle, pdsEndpoint, repository, holdDID)
|
||||
return userCtx
|
||||
}
|
||||
|
||||
// TestDigestToRKey tests digest to record key conversion
|
||||
@@ -115,24 +108,27 @@ func TestDigestToRKey(t *testing.T) {
|
||||
|
||||
// TestNewManifestStore tests creating a new manifest store
|
||||
func TestNewManifestStore(t *testing.T) {
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
blobStore := newMockBlobStore()
|
||||
db := &mockDatabaseMetrics{}
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:alice123",
|
||||
"alice.test",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
|
||||
if store.ctx.Repository != "myapp" {
|
||||
t.Errorf("repository = %v, want myapp", store.ctx.Repository)
|
||||
if store.ctx.TargetRepo != "myapp" {
|
||||
t.Errorf("repository = %v, want myapp", store.ctx.TargetRepo)
|
||||
}
|
||||
if store.ctx.HoldDID != "did:web:hold.example.com" {
|
||||
t.Errorf("holdDID = %v, want did:web:hold.example.com", store.ctx.HoldDID)
|
||||
if store.ctx.TargetHoldDID != "did:web:hold.example.com" {
|
||||
t.Errorf("holdDID = %v, want did:web:hold.example.com", store.ctx.TargetHoldDID)
|
||||
}
|
||||
if store.ctx.DID != "did:plc:alice123" {
|
||||
t.Errorf("did = %v, want did:plc:alice123", store.ctx.DID)
|
||||
if store.ctx.TargetOwnerDID != "did:plc:alice123" {
|
||||
t.Errorf("did = %v, want did:plc:alice123", store.ctx.TargetOwnerDID)
|
||||
}
|
||||
if store.ctx.Handle != "alice.test" {
|
||||
t.Errorf("handle = %v, want alice.test", store.ctx.Handle)
|
||||
if store.ctx.TargetOwnerHandle != "alice.test" {
|
||||
t.Errorf("handle = %v, want alice.test", store.ctx.TargetOwnerHandle)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,9 +183,14 @@ func TestExtractConfigLabels(t *testing.T) {
|
||||
blobStore.blobs[configDigest] = configData
|
||||
|
||||
// Create manifest store
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
// Extract labels
|
||||
labels, err := store.extractConfigLabels(context.Background(), configDigest.String())
|
||||
@@ -227,9 +228,14 @@ func TestExtractConfigLabels_NoLabels(t *testing.T) {
|
||||
configDigest := digest.FromBytes(configData)
|
||||
blobStore.blobs[configDigest] = configData
|
||||
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
labels, err := store.extractConfigLabels(context.Background(), configDigest.String())
|
||||
if err != nil {
|
||||
@@ -245,9 +251,14 @@ func TestExtractConfigLabels_NoLabels(t *testing.T) {
|
||||
// TestExtractConfigLabels_InvalidDigest tests error handling for invalid digest
|
||||
func TestExtractConfigLabels_InvalidDigest(t *testing.T) {
|
||||
blobStore := newMockBlobStore()
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
_, err := store.extractConfigLabels(context.Background(), "invalid-digest")
|
||||
if err == nil {
|
||||
@@ -264,9 +275,14 @@ func TestExtractConfigLabels_InvalidJSON(t *testing.T) {
|
||||
configDigest := digest.FromBytes(configData)
|
||||
blobStore.blobs[configDigest] = configData
|
||||
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
_, err := store.extractConfigLabels(context.Background(), configDigest.String())
|
||||
if err == nil {
|
||||
@@ -274,28 +290,18 @@ func TestExtractConfigLabels_InvalidJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_WithMetrics tests that metrics are tracked
|
||||
func TestManifestStore_WithMetrics(t *testing.T) {
|
||||
db := &mockDatabaseMetrics{}
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
// TestManifestStore_WithoutDatabase tests that nil database is acceptable
|
||||
func TestManifestStore_WithoutDatabase(t *testing.T) {
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:alice123",
|
||||
"alice.test",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
if store.ctx.Database != db {
|
||||
t.Error("ManifestStore should store database reference")
|
||||
}
|
||||
|
||||
// Note: Actual metrics tracking happens in Put() and Get() which require
|
||||
// full mock setup. The important thing is that the database is wired up.
|
||||
}
|
||||
|
||||
// TestManifestStore_WithoutMetrics tests that nil database is acceptable
|
||||
func TestManifestStore_WithoutMetrics(t *testing.T) {
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
|
||||
if store.ctx.Database != nil {
|
||||
if store.sqlDB != nil {
|
||||
t.Error("ManifestStore should accept nil database")
|
||||
}
|
||||
}
|
||||
@@ -345,9 +351,14 @@ func TestManifestStore_Exists(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
exists, err := store.Exists(context.Background(), tt.digest)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -463,10 +474,14 @@ func TestManifestStore_Get(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
manifest, err := store.Get(context.Background(), tt.digest)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -487,82 +502,6 @@ func TestManifestStore_Get(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_Get_OnlyCountsGETRequests verifies that HEAD requests don't increment pull count
|
||||
func TestManifestStore_Get_OnlyCountsGETRequests(t *testing.T) {
|
||||
ociManifest := []byte(`{"schemaVersion":2}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
httpMethod string
|
||||
expectPullIncrement bool
|
||||
}{
|
||||
{
|
||||
name: "GET request increments pull count",
|
||||
httpMethod: "GET",
|
||||
expectPullIncrement: true,
|
||||
},
|
||||
{
|
||||
name: "HEAD request does not increment pull count",
|
||||
httpMethod: "HEAD",
|
||||
expectPullIncrement: false,
|
||||
},
|
||||
{
|
||||
name: "POST request does not increment pull count",
|
||||
httpMethod: "POST",
|
||||
expectPullIncrement: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == atproto.SyncGetBlob {
|
||||
w.Write(ociManifest)
|
||||
return
|
||||
}
|
||||
w.Write([]byte(`{
|
||||
"uri": "at://did:plc:test123/io.atcr.manifest/abc123",
|
||||
"value": {
|
||||
"$type":"io.atcr.manifest",
|
||||
"holdDid":"did:web:hold01.atcr.io",
|
||||
"mediaType":"application/vnd.oci.image.manifest.v1+json",
|
||||
"manifestBlob":{"ref":{"$link":"bafytest"},"size":100}
|
||||
}
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
mockDB := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold01.atcr.io", "did:plc:test123", "test.handle", mockDB)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
|
||||
// Create a context with the HTTP method stored (as distribution library does)
|
||||
testCtx := context.WithValue(context.Background(), "http.request.method", tt.httpMethod)
|
||||
|
||||
_, err := store.Get(testCtx, "sha256:abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
|
||||
// Wait for async goroutine to complete (metrics are incremented asynchronously)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if tt.expectPullIncrement {
|
||||
// Check that IncrementPullCount was called
|
||||
if mockDB.getPullCount() == 0 {
|
||||
t.Error("Expected pull count to be incremented for GET request, but it wasn't")
|
||||
}
|
||||
} else {
|
||||
// Check that IncrementPullCount was NOT called
|
||||
if mockDB.getPullCount() > 0 {
|
||||
t.Errorf("Expected pull count NOT to be incremented for %s request, but it was (count=%d)", tt.httpMethod, mockDB.getPullCount())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_Put tests storing manifests
|
||||
func TestManifestStore_Put(t *testing.T) {
|
||||
ociManifest := []byte(`{
|
||||
@@ -654,10 +593,14 @@ func TestManifestStore_Put(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
dgst, err := store.Put(context.Background(), tt.manifest, tt.options...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -706,8 +649,13 @@ func TestManifestStore_Put_WithConfigLabels(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
|
||||
// Use config digest in manifest
|
||||
ociManifestWithConfig := []byte(`{
|
||||
@@ -722,7 +670,7 @@ func TestManifestStore_Put_WithConfigLabels(t *testing.T) {
|
||||
payload: ociManifestWithConfig,
|
||||
}
|
||||
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
_, err := store.Put(context.Background(), manifest)
|
||||
if err != nil {
|
||||
@@ -782,9 +730,14 @@ func TestManifestStore_Delete(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
err := store.Delete(context.Background(), tt.digest)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -938,10 +891,14 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
manifest := &rawManifest{
|
||||
mediaType: "application/vnd.oci.image.index.v1+json",
|
||||
@@ -1015,9 +972,14 @@ func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
// Create manifest list with both children
|
||||
manifestList := []byte(`{
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/distribution/distribution/v3/registry/api/errcode"
|
||||
"github.com/opencontainers/go-digest"
|
||||
@@ -32,20 +33,20 @@ var (
|
||||
|
||||
// ProxyBlobStore proxies blob requests to an external storage service
|
||||
type ProxyBlobStore struct {
|
||||
ctx *RegistryContext // All context and services
|
||||
holdURL string // Resolved HTTP URL for XRPC requests
|
||||
ctx *auth.UserContext // User context with identity, target, permissions
|
||||
holdURL string // Resolved HTTP URL for XRPC requests
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewProxyBlobStore creates a new proxy blob store
|
||||
func NewProxyBlobStore(ctx *RegistryContext) *ProxyBlobStore {
|
||||
func NewProxyBlobStore(userCtx *auth.UserContext) *ProxyBlobStore {
|
||||
// Resolve DID to URL once at construction time
|
||||
holdURL := atproto.ResolveHoldURL(ctx.HoldDID)
|
||||
holdURL := atproto.ResolveHoldURL(userCtx.TargetHoldDID)
|
||||
|
||||
slog.Debug("NewProxyBlobStore created", "component", "proxy_blob_store", "hold_did", ctx.HoldDID, "hold_url", holdURL, "user_did", ctx.DID, "repo", ctx.Repository)
|
||||
slog.Debug("NewProxyBlobStore created", "component", "proxy_blob_store", "hold_did", userCtx.TargetHoldDID, "hold_url", holdURL, "user_did", userCtx.TargetOwnerDID, "repo", userCtx.TargetRepo)
|
||||
|
||||
return &ProxyBlobStore{
|
||||
ctx: ctx,
|
||||
ctx: userCtx,
|
||||
holdURL: holdURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Minute, // Timeout for presigned URL requests and uploads
|
||||
@@ -61,32 +62,33 @@ func NewProxyBlobStore(ctx *RegistryContext) *ProxyBlobStore {
|
||||
}
|
||||
|
||||
// doAuthenticatedRequest performs an HTTP request with service token authentication
|
||||
// Uses the service token from middleware to authenticate requests to the hold service
|
||||
// Uses the service token from UserContext to authenticate requests to the hold service
|
||||
func (p *ProxyBlobStore) doAuthenticatedRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
// Use service token that middleware already validated and cached
|
||||
// Middleware fails fast with HTTP 401 if OAuth session is invalid
|
||||
if p.ctx.ServiceToken == "" {
|
||||
// Get service token from UserContext (lazy-loaded and cached per holdDID)
|
||||
serviceToken, err := p.ctx.GetServiceToken(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get service token", "component", "proxy_blob_store", "did", p.ctx.DID, "error", err)
|
||||
return nil, fmt.Errorf("failed to get service token: %w", err)
|
||||
}
|
||||
if serviceToken == "" {
|
||||
// Should never happen - middleware validates OAuth before handlers run
|
||||
slog.Error("No service token in context", "component", "proxy_blob_store", "did", p.ctx.DID)
|
||||
return nil, fmt.Errorf("no service token available (middleware should have validated)")
|
||||
}
|
||||
|
||||
// Add Bearer token to Authorization header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.ctx.ServiceToken))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", serviceToken))
|
||||
|
||||
return p.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// checkReadAccess validates that the user has read access to blobs in this hold
|
||||
func (p *ProxyBlobStore) checkReadAccess(ctx context.Context) error {
|
||||
if p.ctx.Authorizer == nil {
|
||||
return nil // No authorization check if authorizer not configured
|
||||
}
|
||||
allowed, err := p.ctx.Authorizer.CheckReadAccess(ctx, p.ctx.HoldDID, p.ctx.DID)
|
||||
canRead, err := p.ctx.CanRead(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authorization check failed: %w", err)
|
||||
}
|
||||
if !allowed {
|
||||
if !canRead {
|
||||
// Return 403 Forbidden instead of masquerading as missing blob
|
||||
return errcode.ErrorCodeDenied.WithMessage("read access denied")
|
||||
}
|
||||
@@ -95,21 +97,17 @@ func (p *ProxyBlobStore) checkReadAccess(ctx context.Context) error {
|
||||
|
||||
// checkWriteAccess validates that the user has write access to blobs in this hold
|
||||
func (p *ProxyBlobStore) checkWriteAccess(ctx context.Context) error {
|
||||
if p.ctx.Authorizer == nil {
|
||||
return nil // No authorization check if authorizer not configured
|
||||
}
|
||||
|
||||
slog.Debug("Checking write access", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
|
||||
allowed, err := p.ctx.Authorizer.CheckWriteAccess(ctx, p.ctx.HoldDID, p.ctx.DID)
|
||||
slog.Debug("Checking write access", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
|
||||
canWrite, err := p.ctx.CanWrite(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Authorization check error", "component", "proxy_blob_store", "error", err)
|
||||
return fmt.Errorf("authorization check failed: %w", err)
|
||||
}
|
||||
if !allowed {
|
||||
slog.Warn("Write access denied", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
|
||||
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("write access denied to hold %s", p.ctx.HoldDID))
|
||||
if !canWrite {
|
||||
slog.Warn("Write access denied", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
|
||||
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("write access denied to hold %s", p.ctx.TargetHoldDID))
|
||||
}
|
||||
slog.Debug("Write access allowed", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
|
||||
slog.Debug("Write access allowed", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -356,10 +354,10 @@ func (p *ProxyBlobStore) Resume(ctx context.Context, id string) (distribution.Bl
|
||||
// getPresignedURL returns the XRPC endpoint URL for blob operations
|
||||
func (p *ProxyBlobStore) getPresignedURL(ctx context.Context, operation string, dgst digest.Digest) (string, error) {
|
||||
// Use XRPC endpoint: /xrpc/com.atproto.sync.getBlob?did={userDID}&cid={digest}
|
||||
// The 'did' parameter is the USER's DID (whose blob we're fetching), not the hold service DID
|
||||
// The 'did' parameter is the TARGET OWNER's DID (whose blob we're fetching), not the hold service DID
|
||||
// Per migration doc: hold accepts OCI digest directly as cid parameter (checks for sha256: prefix)
|
||||
xrpcURL := fmt.Sprintf("%s%s?did=%s&cid=%s&method=%s",
|
||||
p.holdURL, atproto.SyncGetBlob, p.ctx.DID, dgst.String(), operation)
|
||||
p.holdURL, atproto.SyncGetBlob, p.ctx.TargetOwnerDID, dgst.String(), operation)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", xrpcURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/opencontainers/go-digest"
|
||||
)
|
||||
|
||||
// TestGetServiceToken_CachingLogic tests the token caching mechanism
|
||||
// TestGetServiceToken_CachingLogic tests the global service token caching mechanism
|
||||
// These tests use the global auth cache functions directly
|
||||
func TestGetServiceToken_CachingLogic(t *testing.T) {
|
||||
userDID := "did:plc:test"
|
||||
userDID := "did:plc:cache-test"
|
||||
holdDID := "did:web:hold.example.com"
|
||||
|
||||
// Test 1: Empty cache - invalidate any existing token
|
||||
@@ -30,7 +26,6 @@ func TestGetServiceToken_CachingLogic(t *testing.T) {
|
||||
|
||||
// Test 2: Insert token into cache
|
||||
// Create a JWT-like token with exp claim for testing
|
||||
// Format: header.payload.signature where payload has exp claim
|
||||
testPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(50*time.Second).Unix())
|
||||
testToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
|
||||
|
||||
@@ -70,129 +65,33 @@ func base64URLEncode(data string) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString([]byte(data)), "=")
|
||||
}
|
||||
|
||||
// TestServiceToken_EmptyInContext tests that operations fail when service token is missing
|
||||
func TestServiceToken_EmptyInContext(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "", // No service token (middleware didn't set it)
|
||||
Refresher: nil,
|
||||
}
|
||||
// mockUserContextForProxy creates a mock auth.UserContext for proxy blob store testing.
|
||||
// It sets up both the user identity and target info, and configures test helpers
|
||||
// to bypass network calls.
|
||||
func mockUserContextForProxy(did, holdDID, pdsEndpoint, repository string) *auth.UserContext {
|
||||
userCtx := auth.NewUserContext(did, "oauth", "PUT", nil)
|
||||
userCtx.SetTarget(did, "test.handle", pdsEndpoint, repository, holdDID)
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
// Bypass PDS resolution (avoids network calls)
|
||||
userCtx.SetPDSForTest("test.handle", pdsEndpoint)
|
||||
|
||||
// Try a write operation that requires authentication
|
||||
testDigest := digest.FromString("test-content")
|
||||
_, err := store.Stat(context.Background(), testDigest)
|
||||
// Set up mock authorizer that allows access
|
||||
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
|
||||
|
||||
// Should fail because no service token is available
|
||||
if err == nil {
|
||||
t.Error("Expected error when service token is empty")
|
||||
}
|
||||
// Set default hold DID for push resolution
|
||||
userCtx.SetDefaultHoldDIDForTest(holdDID)
|
||||
|
||||
// Error should indicate authentication issue
|
||||
if !strings.Contains(err.Error(), "UNAUTHORIZED") && !strings.Contains(err.Error(), "authentication") {
|
||||
t.Logf("Got error (acceptable): %v", err)
|
||||
}
|
||||
return userCtx
|
||||
}
|
||||
|
||||
// TestDoAuthenticatedRequest_BearerTokenInjection tests that Bearer tokens are added to requests
|
||||
func TestDoAuthenticatedRequest_BearerTokenInjection(t *testing.T) {
|
||||
// This test verifies the Bearer token injection logic
|
||||
|
||||
testToken := "test-bearer-token-xyz"
|
||||
|
||||
// Create a test server to verify the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
// Create ProxyBlobStore with service token in context (set by middleware)
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:bearer-test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: testToken, // Service token from middleware
|
||||
Refresher: nil,
|
||||
}
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Do authenticated request
|
||||
resp, err := store.doAuthenticatedRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("doAuthenticatedRequest failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Verify Bearer token was added
|
||||
expectedHeader := "Bearer " + testToken
|
||||
if receivedAuthHeader != expectedHeader {
|
||||
t.Errorf("Expected Authorization header %s, got %s", expectedHeader, receivedAuthHeader)
|
||||
}
|
||||
// mockUserContextForProxyWithToken creates a mock UserContext with a pre-populated service token.
|
||||
func mockUserContextForProxyWithToken(did, holdDID, pdsEndpoint, repository, serviceToken string) *auth.UserContext {
|
||||
userCtx := mockUserContextForProxy(did, holdDID, pdsEndpoint, repository)
|
||||
userCtx.SetServiceTokenForTest(holdDID, serviceToken)
|
||||
return userCtx
|
||||
}
|
||||
|
||||
// TestDoAuthenticatedRequest_ErrorWhenTokenUnavailable tests that authentication failures return proper errors
|
||||
func TestDoAuthenticatedRequest_ErrorWhenTokenUnavailable(t *testing.T) {
|
||||
// Create test server (should not be called since auth fails first)
|
||||
called := false
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
// Create ProxyBlobStore without service token (middleware didn't set it)
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:fallback",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "", // No service token
|
||||
Refresher: nil,
|
||||
}
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Do authenticated request - should fail when no service token
|
||||
resp, err := store.doAuthenticatedRequest(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Fatal("Expected doAuthenticatedRequest to fail when no service token is available")
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Verify error indicates authentication/authorization issue
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, "service token") && !strings.Contains(errStr, "UNAUTHORIZED") {
|
||||
t.Errorf("Expected service token or unauthorized error, got: %v", err)
|
||||
}
|
||||
|
||||
if called {
|
||||
t.Error("Expected request to NOT be made when authentication fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveHoldURL tests DID to URL conversion
|
||||
// TestResolveHoldURL tests DID to URL conversion (pure function)
|
||||
func TestResolveHoldURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -200,7 +99,7 @@ func TestResolveHoldURL(t *testing.T) {
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "did:web with http (TEST_MODE)",
|
||||
name: "did:web with http (localhost)",
|
||||
holdDID: "did:web:localhost:8080",
|
||||
expected: "http://localhost:8080",
|
||||
},
|
||||
@@ -228,7 +127,7 @@ func TestResolveHoldURL(t *testing.T) {
|
||||
|
||||
// TestServiceTokenCacheExpiry tests that expired cached tokens are not used
|
||||
func TestServiceTokenCacheExpiry(t *testing.T) {
|
||||
userDID := "did:plc:expiry"
|
||||
userDID := "did:plc:expiry-test"
|
||||
holdDID := "did:web:hold.example.com"
|
||||
|
||||
// Insert expired token
|
||||
@@ -272,20 +171,20 @@ func TestServiceTokenCacheKeyFormat(t *testing.T) {
|
||||
|
||||
// TestNewProxyBlobStore tests ProxyBlobStore creation
|
||||
func TestNewProxyBlobStore(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
}
|
||||
userCtx := mockUserContextForProxy(
|
||||
"did:plc:test",
|
||||
"did:web:hold.example.com",
|
||||
"https://pds.example.com",
|
||||
"test-repo",
|
||||
)
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store := NewProxyBlobStore(userCtx)
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("Expected non-nil ProxyBlobStore")
|
||||
}
|
||||
|
||||
if store.ctx != ctx {
|
||||
if store.ctx != userCtx {
|
||||
t.Error("Expected context to be set")
|
||||
}
|
||||
|
||||
@@ -321,296 +220,55 @@ func BenchmarkServiceTokenCacheAccess(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompleteMultipartUpload_JSONFormat verifies the JSON request format sent to hold service
|
||||
// This test would have caught the "partNumber" vs "part_number" bug
|
||||
func TestCompleteMultipartUpload_JSONFormat(t *testing.T) {
|
||||
var capturedBody map[string]any
|
||||
// TestParseJWTExpiry tests JWT expiry parsing
|
||||
func TestParseJWTExpiry(t *testing.T) {
|
||||
// Create a JWT with known expiry
|
||||
futureTime := time.Now().Add(1 * time.Hour).Unix()
|
||||
testPayload := fmt.Sprintf(`{"exp":%d}`, futureTime)
|
||||
testToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
|
||||
|
||||
// Mock hold service that captures the request body
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.URL.Path, atproto.HoldCompleteUpload) {
|
||||
t.Errorf("Wrong endpoint called: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
// Capture request body
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("Failed to decode request body: %v", err)
|
||||
}
|
||||
capturedBody = body
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{}`))
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with mocked hold URL
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call completeMultipartUpload
|
||||
parts := []CompletedPart{
|
||||
{PartNumber: 1, ETag: "etag-1"},
|
||||
{PartNumber: 2, ETag: "etag-2"},
|
||||
}
|
||||
err := store.completeMultipartUpload(context.Background(), "sha256:abc123", "upload-id-xyz", parts)
|
||||
expiry, err := auth.ParseJWTExpiry(testToken)
|
||||
if err != nil {
|
||||
t.Fatalf("completeMultipartUpload failed: %v", err)
|
||||
t.Fatalf("ParseJWTExpiry failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON format
|
||||
if capturedBody == nil {
|
||||
t.Fatal("No request body was captured")
|
||||
}
|
||||
|
||||
// Check top-level fields
|
||||
if uploadID, ok := capturedBody["uploadId"].(string); !ok || uploadID != "upload-id-xyz" {
|
||||
t.Errorf("Expected uploadId='upload-id-xyz', got %v", capturedBody["uploadId"])
|
||||
}
|
||||
if digest, ok := capturedBody["digest"].(string); !ok || digest != "sha256:abc123" {
|
||||
t.Errorf("Expected digest='sha256:abc123', got %v", capturedBody["digest"])
|
||||
}
|
||||
|
||||
// Check parts array
|
||||
partsArray, ok := capturedBody["parts"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected parts to be array, got %T", capturedBody["parts"])
|
||||
}
|
||||
if len(partsArray) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(partsArray))
|
||||
}
|
||||
|
||||
// Verify first part has "part_number" (not "partNumber")
|
||||
part0, ok := partsArray[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected part to be object, got %T", partsArray[0])
|
||||
}
|
||||
|
||||
// THIS IS THE KEY CHECK - would have caught the bug
|
||||
if _, hasPartNumber := part0["partNumber"]; hasPartNumber {
|
||||
t.Error("Found 'partNumber' (camelCase) - should be 'part_number' (snake_case)")
|
||||
}
|
||||
if partNum, ok := part0["part_number"].(float64); !ok || int(partNum) != 1 {
|
||||
t.Errorf("Expected part_number=1, got %v", part0["part_number"])
|
||||
}
|
||||
if etag, ok := part0["etag"].(string); !ok || etag != "etag-1" {
|
||||
t.Errorf("Expected etag='etag-1', got %v", part0["etag"])
|
||||
// Verify expiry is close to what we set (within 1 second tolerance)
|
||||
expectedExpiry := time.Unix(futureTime, 0)
|
||||
diff := expiry.Sub(expectedExpiry)
|
||||
if diff < -time.Second || diff > time.Second {
|
||||
t.Errorf("Expiry mismatch: expected %v, got %v", expectedExpiry, expiry)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGet_UsesPresignedURLDirectly verifies that Get() doesn't add auth headers to presigned URLs
|
||||
// This test would have caught the presigned URL authentication bug
|
||||
func TestGet_UsesPresignedURLDirectly(t *testing.T) {
|
||||
blobData := []byte("test blob content")
|
||||
var s3ReceivedAuthHeader string
|
||||
|
||||
// Mock S3 server that rejects requests with Authorization header
|
||||
s3Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s3ReceivedAuthHeader = r.Header.Get("Authorization")
|
||||
|
||||
// Presigned URLs should NOT have Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 received Authorization header: %s (should be empty for presigned URLs)", s3ReceivedAuthHeader)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`<?xml version="1.0"?><Error><Code>SignatureDoesNotMatch</Code></Error>`))
|
||||
return
|
||||
}
|
||||
|
||||
// Return blob data
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(blobData)
|
||||
}))
|
||||
defer s3Server.Close()
|
||||
|
||||
// Mock hold service that returns presigned S3 URL
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return presigned URL pointing to S3 server
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
resp := map[string]string{
|
||||
"url": s3Server.URL + "/blob?X-Amz-Signature=fake-signature",
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with service token in context
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call Get()
|
||||
dgst := digest.FromBytes(blobData)
|
||||
retrieved, err := store.Get(context.Background(), dgst)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify correct data was retrieved
|
||||
if string(retrieved) != string(blobData) {
|
||||
t.Errorf("Expected data=%s, got %s", string(blobData), string(retrieved))
|
||||
}
|
||||
|
||||
// Verify S3 received NO Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 should not receive Authorization header for presigned URLs, got: %s", s3ReceivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpen_UsesPresignedURLDirectly verifies that Open() doesn't add auth headers to presigned URLs
|
||||
// This test would have caught the presigned URL authentication bug
|
||||
func TestOpen_UsesPresignedURLDirectly(t *testing.T) {
|
||||
blobData := []byte("test blob stream content")
|
||||
var s3ReceivedAuthHeader string
|
||||
|
||||
// Mock S3 server
|
||||
s3Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s3ReceivedAuthHeader = r.Header.Get("Authorization")
|
||||
|
||||
// Presigned URLs should NOT have Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 received Authorization header: %s (should be empty)", s3ReceivedAuthHeader)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(blobData)
|
||||
}))
|
||||
defer s3Server.Close()
|
||||
|
||||
// Mock hold service
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"url": s3Server.URL + "/blob?X-Amz-Signature=fake",
|
||||
})
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with service token in context
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call Open()
|
||||
dgst := digest.FromBytes(blobData)
|
||||
reader, err := store.Open(context.Background(), dgst)
|
||||
if err != nil {
|
||||
t.Fatalf("Open() failed: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Verify S3 received NO Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 should not receive Authorization header for presigned URLs, got: %s", s3ReceivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipartEndpoints_CorrectURLs verifies all multipart XRPC endpoints use correct URLs
|
||||
// This would have caught the old com.atproto.repo.uploadBlob vs new io.atcr.hold.* endpoints
|
||||
func TestMultipartEndpoints_CorrectURLs(t *testing.T) {
|
||||
// TestParseJWTExpiry_InvalidToken tests error handling for invalid tokens
|
||||
func TestParseJWTExpiry_InvalidToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(*ProxyBlobStore) error
|
||||
expectedPath string
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "startMultipartUpload",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
_, err := store.startMultipartUpload(context.Background(), "sha256:test")
|
||||
return err
|
||||
},
|
||||
expectedPath: atproto.HoldInitiateUpload,
|
||||
},
|
||||
{
|
||||
name: "getPartUploadInfo",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
_, err := store.getPartUploadInfo(context.Background(), "sha256:test", "upload-123", 1)
|
||||
return err
|
||||
},
|
||||
expectedPath: atproto.HoldGetPartUploadURL,
|
||||
},
|
||||
{
|
||||
name: "completeMultipartUpload",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
parts := []CompletedPart{{PartNumber: 1, ETag: "etag1"}}
|
||||
return store.completeMultipartUpload(context.Background(), "sha256:test", "upload-123", parts)
|
||||
},
|
||||
expectedPath: atproto.HoldCompleteUpload,
|
||||
},
|
||||
{
|
||||
name: "abortMultipartUpload",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
return store.abortMultipartUpload(context.Background(), "sha256:test", "upload-123")
|
||||
},
|
||||
expectedPath: atproto.HoldAbortUpload,
|
||||
},
|
||||
{"empty token", ""},
|
||||
{"single part", "header"},
|
||||
{"two parts", "header.payload"},
|
||||
{"invalid base64 payload", "header.!!!.signature"},
|
||||
{"missing exp claim", "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(`{"sub":"test"}`) + ".sig"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedPath string
|
||||
|
||||
// Mock hold service that captures request path
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
|
||||
// Return success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
resp := map[string]string{
|
||||
"uploadId": "test-upload-id",
|
||||
"url": "https://s3.example.com/presigned",
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with service token in context
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call the function
|
||||
_ = tt.testFunc(store) // Ignore error, we just care about the URL
|
||||
|
||||
// Verify correct endpoint was called
|
||||
if capturedPath != tt.expectedPath {
|
||||
t.Errorf("Expected endpoint %s, got %s", tt.expectedPath, capturedPath)
|
||||
}
|
||||
|
||||
// Verify it's NOT the old endpoint
|
||||
if strings.Contains(capturedPath, "com.atproto.repo.uploadBlob") {
|
||||
t.Error("Still using old com.atproto.repo.uploadBlob endpoint!")
|
||||
_, err := auth.ParseJWTExpiry(tt.token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid token")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Tests for doAuthenticatedRequest, Get, Open, completeMultipartUpload, etc.
|
||||
// require complex dependency mocking (OAuth refresher, PDS resolution, HoldAuthorizer).
|
||||
// These should be tested at the integration level with proper infrastructure.
|
||||
//
|
||||
// The current unit tests cover:
|
||||
// - Global service token cache (auth.GetServiceToken, auth.SetServiceToken, etc.)
|
||||
// - URL resolution (atproto.ResolveHoldURL)
|
||||
// - JWT parsing (auth.ParseJWTExpiry)
|
||||
// - Store construction (NewProxyBlobStore)
|
||||
|
||||
@@ -6,94 +6,75 @@ package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/distribution/reference"
|
||||
)
|
||||
|
||||
// RoutingRepository routes manifests to ATProto and blobs to external hold service
|
||||
// The registry (AppView) is stateless and NEVER stores blobs locally
|
||||
// NOTE: A fresh instance is created per-request (see middleware/registry.go)
|
||||
// so no mutex is needed - each request has its own instance
|
||||
// RoutingRepository routes manifests to ATProto and blobs to external hold service.
|
||||
// The registry (AppView) is stateless and NEVER stores blobs locally.
|
||||
// A new instance is created per HTTP request - no caching or synchronization needed.
|
||||
type RoutingRepository struct {
|
||||
distribution.Repository
|
||||
Ctx *RegistryContext // All context and services (exported for token updates)
|
||||
manifestStore *ManifestStore // Manifest store instance (lazy-initialized)
|
||||
blobStore *ProxyBlobStore // Blob store instance (lazy-initialized)
|
||||
userCtx *auth.UserContext
|
||||
sqlDB *sql.DB
|
||||
}
|
||||
|
||||
// NewRoutingRepository creates a new routing repository
|
||||
func NewRoutingRepository(baseRepo distribution.Repository, ctx *RegistryContext) *RoutingRepository {
|
||||
func NewRoutingRepository(baseRepo distribution.Repository, userCtx *auth.UserContext, sqlDB *sql.DB) *RoutingRepository {
|
||||
return &RoutingRepository{
|
||||
Repository: baseRepo,
|
||||
Ctx: ctx,
|
||||
userCtx: userCtx,
|
||||
sqlDB: sqlDB,
|
||||
}
|
||||
}
|
||||
|
||||
// Manifests returns the ATProto-backed manifest service
|
||||
func (r *RoutingRepository) Manifests(ctx context.Context, options ...distribution.ManifestServiceOption) (distribution.ManifestService, error) {
|
||||
// Lazy-initialize manifest store (no mutex needed - one instance per request)
|
||||
if r.manifestStore == nil {
|
||||
// Ensure blob store is created first (needed for label extraction during push)
|
||||
blobStore := r.Blobs(ctx)
|
||||
r.manifestStore = NewManifestStore(r.Ctx, blobStore)
|
||||
}
|
||||
return r.manifestStore, nil
|
||||
// blobStore used to fetch labels from th
|
||||
blobStore := r.Blobs(ctx)
|
||||
return NewManifestStore(r.userCtx, blobStore, r.sqlDB), nil
|
||||
}
|
||||
|
||||
// Blobs returns a proxy blob store that routes to external hold service
|
||||
// The registry (AppView) NEVER stores blobs locally - all blobs go through hold service
|
||||
func (r *RoutingRepository) Blobs(ctx context.Context) distribution.BlobStore {
|
||||
// Return cached blob store if available (no mutex needed - one instance per request)
|
||||
if r.blobStore != nil {
|
||||
slog.Debug("Returning cached blob store", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository)
|
||||
return r.blobStore
|
||||
}
|
||||
|
||||
// Determine if this is a pull (GET/HEAD) or push (PUT/POST/etc) operation
|
||||
// Pull operations use the historical hold DID from the database (blobs are where they were pushed)
|
||||
// Push operations use the discovery-based hold DID from user's profile/default
|
||||
// This allows users to change their default hold and have new pushes go there
|
||||
isPull := false
|
||||
if method, ok := ctx.Value("http.request.method").(string); ok {
|
||||
isPull = method == "GET" || method == "HEAD"
|
||||
}
|
||||
|
||||
holdDID := r.Ctx.HoldDID // Default to discovery-based DID
|
||||
holdSource := "discovery"
|
||||
|
||||
// Only query database for pull operations
|
||||
if isPull && r.Ctx.Database != nil {
|
||||
// Query database for the latest manifest's hold DID
|
||||
if dbHoldDID, err := r.Ctx.Database.GetLatestHoldDIDForRepo(r.Ctx.DID, r.Ctx.Repository); err == nil && dbHoldDID != "" {
|
||||
// Use hold DID from database (pull case - use historical reference)
|
||||
holdDID = dbHoldDID
|
||||
holdSource = "database"
|
||||
slog.Debug("Using hold from database manifest (pull)", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", dbHoldDID)
|
||||
} else if err != nil {
|
||||
// Log error but don't fail - fall back to discovery-based DID
|
||||
slog.Warn("Failed to query database for hold DID", "component", "storage/blobs", "error", err)
|
||||
}
|
||||
// If dbHoldDID is empty (no manifests yet), fall through to use discovery-based DID
|
||||
// Resolve hold DID: pull uses DB lookup, push uses profile discovery
|
||||
holdDID, err := r.userCtx.ResolveHoldDID(ctx, r.sqlDB)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to resolve hold DID", "component", "storage/blobs", "error", err)
|
||||
holdDID = r.userCtx.TargetHoldDID
|
||||
}
|
||||
|
||||
if holdDID == "" {
|
||||
// This should never happen if middleware is configured correctly
|
||||
panic("hold DID not set in RegistryContext - ensure default_hold_did is configured in middleware")
|
||||
panic("hold DID not set - ensure default_hold_did is configured in middleware")
|
||||
}
|
||||
|
||||
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID, "source", holdSource)
|
||||
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.userCtx.TargetOwnerDID, "repo", r.userCtx.TargetRepo, "hold", holdDID, "action", r.userCtx.Action.String())
|
||||
|
||||
// Update context with the correct hold DID (may be from database or discovered)
|
||||
r.Ctx.HoldDID = holdDID
|
||||
|
||||
// Create and cache proxy blob store
|
||||
r.blobStore = NewProxyBlobStore(r.Ctx)
|
||||
return r.blobStore
|
||||
return NewProxyBlobStore(r.userCtx)
|
||||
}
|
||||
|
||||
// Tags returns the tag service
|
||||
// Tags are stored in ATProto as io.atcr.tag records
|
||||
func (r *RoutingRepository) Tags(ctx context.Context) distribution.TagService {
|
||||
return NewTagStore(r.Ctx.ATProtoClient, r.Ctx.Repository)
|
||||
return NewTagStore(r.userCtx.GetATProtoClient(), r.userCtx.TargetRepo)
|
||||
}
|
||||
|
||||
// Named returns a reference to the repository name.
|
||||
// If the base repository is set, it delegates to the base.
|
||||
// Otherwise, it constructs a name from the user context.
|
||||
func (r *RoutingRepository) Named() reference.Named {
|
||||
if r.Repository != nil {
|
||||
return r.Repository.Named()
|
||||
}
|
||||
// Construct from user context
|
||||
name, err := reference.WithName(r.userCtx.TargetRepo)
|
||||
if err != nil {
|
||||
// Fallback: return a simple reference
|
||||
name, _ = reference.WithName("unknown")
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
@@ -2,275 +2,117 @@ package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
)
|
||||
|
||||
// mockDatabase is a simple mock for testing
|
||||
type mockDatabase struct {
|
||||
holdDID string
|
||||
err error
|
||||
// mockUserContext creates a mock auth.UserContext for testing.
|
||||
// It sets up both the user identity and target info, and configures
|
||||
// test helpers to bypass network calls.
|
||||
func mockUserContext(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID string) *auth.UserContext {
|
||||
userCtx := auth.NewUserContext(did, authMethod, httpMethod, nil)
|
||||
userCtx.SetTarget(targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID)
|
||||
|
||||
// Bypass PDS resolution (avoids network calls)
|
||||
userCtx.SetPDSForTest(targetOwnerHandle, targetOwnerPDS)
|
||||
|
||||
// Set up mock authorizer that allows access
|
||||
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
|
||||
|
||||
// Set default hold DID for push resolution
|
||||
userCtx.SetDefaultHoldDIDForTest(targetHoldDID)
|
||||
|
||||
return userCtx
|
||||
}
|
||||
|
||||
func (m *mockDatabase) IncrementPullCount(did, repository string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabase) IncrementPushCount(did, repository string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabase) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.holdDID, nil
|
||||
// mockUserContextWithToken creates a mock UserContext with a pre-populated service token.
|
||||
func mockUserContextWithToken(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID, serviceToken string) *auth.UserContext {
|
||||
userCtx := mockUserContext(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID)
|
||||
userCtx.SetServiceTokenForTest(targetHoldDID, serviceToken)
|
||||
return userCtx
|
||||
}
|
||||
|
||||
func TestNewRoutingRepository(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "debian",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: &atproto.Client{},
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123", // authenticated user
|
||||
"oauth", // auth method
|
||||
"GET", // HTTP method
|
||||
"did:plc:test123", // target owner
|
||||
"test.handle", // target owner handle
|
||||
"https://pds.example.com", // target owner PDS
|
||||
"debian", // repository
|
||||
"did:web:hold01.atcr.io", // hold DID
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
if repo.userCtx.TargetOwnerDID != "did:plc:test123" {
|
||||
t.Errorf("Expected TargetOwnerDID %q, got %q", "did:plc:test123", repo.userCtx.TargetOwnerDID)
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
if repo.Ctx.DID != "did:plc:test123" {
|
||||
t.Errorf("Expected DID %q, got %q", "did:plc:test123", repo.Ctx.DID)
|
||||
if repo.userCtx.TargetRepo != "debian" {
|
||||
t.Errorf("Expected TargetRepo %q, got %q", "debian", repo.userCtx.TargetRepo)
|
||||
}
|
||||
|
||||
if repo.Ctx.Repository != "debian" {
|
||||
t.Errorf("Expected repository %q, got %q", "debian", repo.Ctx.Repository)
|
||||
}
|
||||
|
||||
if repo.manifestStore != nil {
|
||||
t.Error("Expected manifestStore to be nil initially")
|
||||
}
|
||||
|
||||
if repo.blobStore != nil {
|
||||
t.Error("Expected blobStore to be nil initially")
|
||||
if repo.userCtx.TargetHoldDID != "did:web:hold01.atcr.io" {
|
||||
t.Errorf("Expected TargetHoldDID %q, got %q", "did:web:hold01.atcr.io", repo.userCtx.TargetHoldDID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Manifests tests the Manifests() method
|
||||
func TestRoutingRepository_Manifests(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
manifestService, err := repo.Manifests(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, manifestService)
|
||||
|
||||
// Verify the manifest store is cached
|
||||
assert.NotNil(t, repo.manifestStore, "manifest store should be cached")
|
||||
|
||||
// Call again and verify we get the same instance
|
||||
manifestService2, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, manifestService, manifestService2, "should return cached manifest store")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_ManifestStoreCaching tests that manifest store is cached
|
||||
func TestRoutingRepository_ManifestStoreCaching(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
// TestRoutingRepository_Blobs tests the Blobs() method
|
||||
func TestRoutingRepository_Blobs(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// First call creates the store
|
||||
store1, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, store1)
|
||||
|
||||
// Second call returns cached store
|
||||
store2, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, store1, store2, "should return cached manifest store instance")
|
||||
|
||||
// Verify internal cache
|
||||
assert.NotNil(t, repo.manifestStore)
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PullUsesDatabase tests that GET and HEAD (pull) use database hold DID
|
||||
func TestRoutingRepository_Blobs_PullUsesDatabase(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
// Test both GET and HEAD as pull operations
|
||||
for _, method := range []string{"GET", "HEAD"} {
|
||||
// Reset context for each test
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp-" + method, // Unique repo to avoid caching
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID},
|
||||
}
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
pullCtx := context.WithValue(context.Background(), "http.request.method", method)
|
||||
blobStore := repo.Blobs(pullCtx)
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID was updated to use the database value for pull
|
||||
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "pull (%s) should use database hold DID", method)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PushUsesDiscovery tests that push operations use discovery hold DID
|
||||
func TestRoutingRepository_Blobs_PushUsesDiscovery(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
}{
|
||||
{"PUT", "PUT"},
|
||||
{"POST", "POST"},
|
||||
// HEAD is now treated as pull (like GET) - see TestRoutingRepository_Blobs_Pull
|
||||
{"PATCH", "PATCH"},
|
||||
{"DELETE", "DELETE"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp-" + tc.method, // Unique repo to avoid caching
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// Create context with push method
|
||||
pushCtx := context.WithValue(context.Background(), "http.request.method", tc.method)
|
||||
blobStore := repo.Blobs(pushCtx)
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID remains the discovery-based one for push operations
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "%s should use discovery hold DID, not database", tc.method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_NoMethodUsesDiscovery tests that missing method defaults to discovery
|
||||
func TestRoutingRepository_Blobs_NoMethodUsesDiscovery(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp-nomethod",
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// Context without HTTP method (shouldn't happen in practice, but test defensive behavior)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Without method, should default to discovery (safer for push scenarios)
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "missing method should use discovery hold DID")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_WithoutDatabase tests blob store with discovery-based hold
|
||||
func TestRoutingRepository_Blobs_WithoutDatabase(t *testing.T) {
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:nocache456",
|
||||
Repository: "uncached-app",
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:nocache456", ""),
|
||||
Database: nil, // No database
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID remains the discovery-based one
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should use discovery-based hold DID")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_DatabaseEmptyFallback tests fallback when database returns empty hold DID
|
||||
func TestRoutingRepository_Blobs_DatabaseEmptyFallback(t *testing.T) {
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "newapp",
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: ""}, // Empty string (no manifests yet)
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID falls back to discovery-based
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should fall back to discovery-based hold DID when database returns empty")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_BlobStoreCaching tests that blob store is cached
|
||||
func TestRoutingRepository_BlobStoreCaching(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// First call creates the store
|
||||
store1 := repo.Blobs(context.Background())
|
||||
assert.NotNil(t, store1)
|
||||
|
||||
// Second call returns cached store
|
||||
store2 := repo.Blobs(context.Background())
|
||||
assert.Same(t, store1, store2, "should return cached blob store instance")
|
||||
|
||||
// Verify internal cache
|
||||
assert.NotNil(t, repo.blobStore)
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PanicOnEmptyHoldDID tests panic when hold DID is empty
|
||||
func TestRoutingRepository_Blobs_PanicOnEmptyHoldDID(t *testing.T) {
|
||||
// Use a unique DID/repo to ensure no cache entry exists
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:emptyholdtest999",
|
||||
Repository: "empty-hold-app",
|
||||
HoldDID: "", // Empty hold DID should panic
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:emptyholdtest999", ""),
|
||||
}
|
||||
// Create context without default hold and empty target hold
|
||||
userCtx := auth.NewUserContext("did:plc:emptyholdtest999", "oauth", "GET", nil)
|
||||
userCtx.SetTarget("did:plc:emptyholdtest999", "test.handle", "https://pds.example.com", "empty-hold-app", "")
|
||||
userCtx.SetPDSForTest("test.handle", "https://pds.example.com")
|
||||
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
|
||||
// Intentionally NOT setting default hold DID
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
// Should panic with empty hold DID
|
||||
assert.Panics(t, func() {
|
||||
@@ -280,106 +122,140 @@ func TestRoutingRepository_Blobs_PanicOnEmptyHoldDID(t *testing.T) {
|
||||
|
||||
// TestRoutingRepository_Tags tests the Tags() method
|
||||
func TestRoutingRepository_Tags(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
tagService := repo.Tags(context.Background())
|
||||
|
||||
assert.NotNil(t, tagService)
|
||||
|
||||
// Call again and verify we get a new instance (Tags() doesn't cache)
|
||||
// Call again and verify we get a fresh instance (no caching)
|
||||
tagService2 := repo.Tags(context.Background())
|
||||
assert.NotNil(t, tagService2)
|
||||
// Tags service is not cached, so each call creates a new instance
|
||||
}
|
||||
|
||||
// TestRoutingRepository_ConcurrentAccess tests concurrent access to cached stores
|
||||
func TestRoutingRepository_ConcurrentAccess(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
// TestRoutingRepository_UserContext tests that UserContext fields are properly set
|
||||
func TestRoutingRepository_UserContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
httpMethod string
|
||||
expectedAction auth.RequestAction
|
||||
}{
|
||||
{"GET request is pull", "GET", auth.ActionPull},
|
||||
{"HEAD request is pull", "HEAD", auth.ActionPull},
|
||||
{"PUT request is push", "PUT", auth.ActionPush},
|
||||
{"POST request is push", "POST", auth.ActionPush},
|
||||
{"DELETE request is push", "DELETE", auth.ActionPush},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
tc.httpMethod,
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
// Track all manifest stores returned
|
||||
manifestStores := make([]distribution.ManifestService, numGoroutines)
|
||||
blobStores := make([]distribution.BlobStore, numGoroutines)
|
||||
|
||||
// Concurrent access to Manifests()
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
store, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
manifestStores[index] = store
|
||||
}(i)
|
||||
assert.Equal(t, tc.expectedAction, repo.userCtx.Action, "action should match HTTP method")
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all stores are non-nil (due to race conditions, they may not all be the same instance)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
assert.NotNil(t, manifestStores[i], "manifest store should not be nil")
|
||||
}
|
||||
|
||||
// After concurrent creation, subsequent calls should return the cached instance
|
||||
cachedStore, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cachedStore)
|
||||
|
||||
// Concurrent access to Blobs()
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
blobStores[index] = repo.Blobs(context.Background())
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all stores are non-nil (due to race conditions, they may not all be the same instance)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
assert.NotNil(t, blobStores[i], "blob store should not be nil")
|
||||
}
|
||||
|
||||
// After concurrent creation, subsequent calls should return the cached instance
|
||||
cachedBlobStore := repo.Blobs(context.Background())
|
||||
assert.NotNil(t, cachedBlobStore)
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PullPriority tests that database hold DID takes priority for pull (GET)
|
||||
func TestRoutingRepository_Blobs_PullPriority(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp-priority",
|
||||
HoldDID: discoveryHoldDID, // Discovery-based hold
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID}, // Database has a different hold DID
|
||||
// TestRoutingRepository_DifferentHoldDIDs tests routing with different hold DIDs
|
||||
func TestRoutingRepository_DifferentHoldDIDs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
holdDID string
|
||||
}{
|
||||
{"did:web hold", "did:web:hold01.atcr.io"},
|
||||
{"did:web with port", "did:web:localhost:8080"},
|
||||
{"did:plc hold", "did:plc:xyz123"},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"PUT",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
tc.holdDID,
|
||||
)
|
||||
|
||||
// For pull (GET), database should take priority
|
||||
pullCtx := context.WithValue(context.Background(), "http.request.method", "GET")
|
||||
blobStore := repo.Blobs(pullCtx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Database hold DID should take priority over discovery for pull operations
|
||||
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "database hold DID should take priority over discovery for pull (GET)")
|
||||
assert.NotNil(t, blobStore, "should create blob store for %s", tc.holdDID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Named tests the Named() method
|
||||
func TestRoutingRepository_Named(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
// Named() returns a reference.Named from the base repository
|
||||
// Since baseRepo is nil, this tests our implementation handles that case
|
||||
named := repo.Named()
|
||||
|
||||
// With nil base, Named() should return a name constructed from context
|
||||
assert.NotNil(t, named)
|
||||
assert.Contains(t, named.Name(), "myapp")
|
||||
}
|
||||
|
||||
// TestATProtoResolveHoldURL tests DID to URL resolution
|
||||
func TestATProtoResolveHoldURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
holdDID string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "did:web simple domain",
|
||||
holdDID: "did:web:hold01.atcr.io",
|
||||
expected: "https://hold01.atcr.io",
|
||||
},
|
||||
{
|
||||
name: "did:web with port (localhost)",
|
||||
holdDID: "did:web:localhost:8080",
|
||||
expected: "http://localhost:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := atproto.ResolveHoldURL(tt.holdDID)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -18,6 +14,8 @@ import (
|
||||
type serviceTokenEntry struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
err error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// Global cache for service tokens (DID:HoldDID -> token)
|
||||
@@ -61,7 +59,7 @@ func SetServiceToken(did, holdDID, token string) error {
|
||||
cacheKey := did + ":" + holdDID
|
||||
|
||||
// Parse JWT to extract expiry (don't verify signature - we trust the PDS)
|
||||
expiry, err := parseJWTExpiry(token)
|
||||
expiry, err := ParseJWTExpiry(token)
|
||||
if err != nil {
|
||||
// If parsing fails, use default 50s TTL (conservative fallback)
|
||||
slog.Warn("Failed to parse JWT expiry, using default 50s", "error", err, "cacheKey", cacheKey)
|
||||
@@ -85,37 +83,6 @@ func SetServiceToken(did, holdDID, token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseJWTExpiry extracts the expiry time from a JWT without verifying the signature
|
||||
// We trust tokens from the user's PDS, so signature verification isn't needed here
|
||||
// Manually decodes the JWT payload to avoid algorithm compatibility issues
|
||||
func parseJWTExpiry(tokenString string) (time.Time, error) {
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON payload
|
||||
var claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
if claims.Exp == 0 {
|
||||
return time.Time{}, fmt.Errorf("JWT missing exp claim")
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
||||
// InvalidateServiceToken removes a service token from the cache
|
||||
// Used when we detect that a token is invalid or the user's session has expired
|
||||
func InvalidateServiceToken(did, holdDID string) {
|
||||
|
||||
80
pkg/auth/mock_authorizer.go
Normal file
80
pkg/auth/mock_authorizer.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// MockHoldAuthorizer is a test double for HoldAuthorizer.
|
||||
// It allows tests to control the return values of authorization checks
|
||||
// without making network calls or querying a real PDS.
|
||||
type MockHoldAuthorizer struct {
|
||||
// Direct result control
|
||||
CanReadResult bool
|
||||
CanWriteResult bool
|
||||
CanAdminResult bool
|
||||
Error error
|
||||
|
||||
// Captain record to return (optional, for GetCaptainRecord)
|
||||
CaptainRecord *atproto.CaptainRecord
|
||||
|
||||
// Crew membership (optional, for IsCrewMember)
|
||||
IsCrewResult bool
|
||||
}
|
||||
|
||||
// NewMockHoldAuthorizer creates a MockHoldAuthorizer with sensible defaults.
|
||||
// By default, it allows all access (public hold, user is owner).
|
||||
func NewMockHoldAuthorizer() *MockHoldAuthorizer {
|
||||
return &MockHoldAuthorizer{
|
||||
CanReadResult: true,
|
||||
CanWriteResult: true,
|
||||
CanAdminResult: false,
|
||||
IsCrewResult: false,
|
||||
CaptainRecord: &atproto.CaptainRecord{
|
||||
Type: "io.atcr.hold.captain",
|
||||
Owner: "did:plc:mock-owner",
|
||||
Public: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CheckReadAccess returns the configured CanReadResult.
|
||||
func (m *MockHoldAuthorizer) CheckReadAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
return m.CanReadResult, nil
|
||||
}
|
||||
|
||||
// CheckWriteAccess returns the configured CanWriteResult.
|
||||
func (m *MockHoldAuthorizer) CheckWriteAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
return m.CanWriteResult, nil
|
||||
}
|
||||
|
||||
// GetCaptainRecord returns the configured CaptainRecord or a default.
|
||||
func (m *MockHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
if m.CaptainRecord != nil {
|
||||
return m.CaptainRecord, nil
|
||||
}
|
||||
// Return a default captain record
|
||||
return &atproto.CaptainRecord{
|
||||
Type: "io.atcr.hold.captain",
|
||||
Owner: "did:plc:mock-owner",
|
||||
Public: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsCrewMember returns the configured IsCrewResult.
|
||||
func (m *MockHoldAuthorizer) IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
return m.IsCrewResult, nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
@@ -44,299 +46,60 @@ func getErrorHint(apiErr *atclient.APIError) string {
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrFetchServiceToken gets a service token for hold authentication.
|
||||
// Checks cache first, then fetches from PDS with OAuth/DPoP if needed.
|
||||
// This is the canonical implementation used by both middleware and crew registration.
|
||||
//
|
||||
// IMPORTANT: Uses DoWithSession() to hold a per-DID lock through the entire PDS interaction.
|
||||
// This prevents DPoP nonce race conditions when multiple Docker layers upload concurrently.
|
||||
func GetOrFetchServiceToken(
|
||||
ctx context.Context,
|
||||
refresher *oauth.Refresher,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
if refresher == nil {
|
||||
return "", fmt.Errorf("refresher is nil (OAuth session required for service tokens)")
|
||||
// ParseJWTExpiry extracts the expiry time from a JWT without verifying the signature
|
||||
// We trust tokens from the user's PDS, so signature verification isn't needed here
|
||||
// Manually decodes the JWT payload to avoid algorithm compatibility issues
|
||||
func ParseJWTExpiry(tokenString string) (time.Time, error) {
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Using cached service token",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
return cachedToken, nil
|
||||
}
|
||||
|
||||
// Cache miss or expiring soon - validate OAuth and get new service token
|
||||
if cachedToken == "" {
|
||||
slog.Debug("Service token cache miss, fetching new token", "did", did)
|
||||
} else {
|
||||
slog.Debug("Service token expiring soon, proactively renewing", "did", did)
|
||||
}
|
||||
|
||||
// Use DoWithSession to hold the lock through the entire PDS interaction.
|
||||
// This prevents DPoP nonce races when multiple goroutines try to fetch service tokens.
|
||||
var serviceToken string
|
||||
var fetchErr error
|
||||
|
||||
err := refresher.DoWithSession(ctx, did, func(session *indigo_oauth.ClientSession) error {
|
||||
// Double-check cache after acquiring lock - another goroutine may have
|
||||
// populated it while we were waiting (classic double-checked locking pattern)
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Service token cache hit after lock acquisition",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
serviceToken = cachedToken
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cache still empty/expired - proceed with PDS call
|
||||
// Request 5-minute expiry (PDS may grant less)
|
||||
// exp must be absolute Unix timestamp, not relative duration
|
||||
// Note: OAuth scope includes #atcr_hold fragment, but service auth aud must be bare DID
|
||||
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
|
||||
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
pdsEndpoint,
|
||||
atproto.ServerGetServiceAuth,
|
||||
url.QueryEscape(holdDID),
|
||||
url.QueryEscape("com.atproto.repo.getRecord"),
|
||||
expiryTime,
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
fetchErr = fmt.Errorf("failed to create service auth request: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
// Use OAuth session to authenticate to PDS (with DPoP)
|
||||
// The lock is held, so DPoP nonce negotiation is serialized per-DID
|
||||
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
||||
if err != nil {
|
||||
// Auth error - may indicate expired tokens or corrupted session
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
|
||||
// Inspect the error to extract detailed information from indigo's APIError
|
||||
var apiErr *atclient.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
// Log detailed API error information
|
||||
slog.Error("OAuth authentication failed during service token request",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"url", serviceAuthURL,
|
||||
"error", err,
|
||||
"httpStatus", apiErr.StatusCode,
|
||||
"errorName", apiErr.Name,
|
||||
"errorMessage", apiErr.Message,
|
||||
"hint", getErrorHint(apiErr))
|
||||
} else {
|
||||
// Fallback for non-API errors (network errors, etc.)
|
||||
slog.Error("OAuth authentication failed during service token request",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"url", serviceAuthURL,
|
||||
"error", err,
|
||||
"errorType", fmt.Sprintf("%T", err),
|
||||
"hint", "Network error or unexpected failure during OAuth request")
|
||||
}
|
||||
|
||||
fetchErr = fmt.Errorf("OAuth validation failed: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Service auth failed
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("Service token request returned non-200 status",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseBody", string(bodyBytes),
|
||||
"hint", "PDS rejected the service token request - check PDS logs for details")
|
||||
fetchErr = fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
// Parse response to get service token
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
fetchErr = fmt.Errorf("failed to decode service auth response: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
fetchErr = fmt.Errorf("empty token in service auth response")
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
serviceToken = result.Token
|
||||
return nil
|
||||
})
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
// DoWithSession failed (session load or callback error)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
|
||||
// Try to extract detailed error information
|
||||
var apiErr *atclient.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
slog.Error("Failed to get OAuth session for service token",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err,
|
||||
"httpStatus", apiErr.StatusCode,
|
||||
"errorName", apiErr.Name,
|
||||
"errorMessage", apiErr.Message,
|
||||
"hint", getErrorHint(apiErr))
|
||||
} else if fetchErr == nil {
|
||||
// Session load failed (not a fetch error)
|
||||
slog.Error("Failed to get OAuth session for service token",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err,
|
||||
"errorType", fmt.Sprintf("%T", err),
|
||||
"hint", "OAuth session not found in database or token refresh failed")
|
||||
}
|
||||
|
||||
// Delete the stale OAuth session to force re-authentication
|
||||
// This also invalidates the UI session automatically
|
||||
if delErr := refresher.DeleteSession(ctx, did); delErr != nil {
|
||||
slog.Warn("Failed to delete stale OAuth session",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"error", delErr)
|
||||
}
|
||||
|
||||
if fetchErr != nil {
|
||||
return "", fetchErr
|
||||
}
|
||||
return "", fmt.Errorf("failed to get OAuth session: %w", err)
|
||||
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if err := SetServiceToken(did, holdDID, serviceToken); err != nil {
|
||||
slog.Warn("Failed to cache service token", "error", err, "did", did, "holdDID", holdDID)
|
||||
// Non-fatal - we have the token, just won't be cached
|
||||
// Parse the JSON payload
|
||||
var claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
slog.Debug("OAuth validation succeeded, service token obtained", "did", did)
|
||||
return serviceToken, nil
|
||||
if claims.Exp == 0 {
|
||||
return time.Time{}, fmt.Errorf("JWT missing exp claim")
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
||||
// GetOrFetchServiceTokenWithAppPassword gets a service token using app-password Bearer authentication.
|
||||
// Used when auth method is app_password instead of OAuth.
|
||||
func GetOrFetchServiceTokenWithAppPassword(
|
||||
ctx context.Context,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Using cached service token (app-password)",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
return cachedToken, nil
|
||||
}
|
||||
|
||||
// Cache miss or expiring soon - get app-password token and fetch new service token
|
||||
if cachedToken == "" {
|
||||
slog.Debug("Service token cache miss, fetching new token with app-password", "did", did)
|
||||
} else {
|
||||
slog.Debug("Service token expiring soon, proactively renewing with app-password", "did", did)
|
||||
}
|
||||
|
||||
// Get app-password access token from cache
|
||||
accessToken, ok := GetGlobalTokenCache().Get(did)
|
||||
if !ok {
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("No app-password access token found in cache",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"hint", "User must re-authenticate with docker login")
|
||||
return "", fmt.Errorf("no app-password access token available for DID %s", did)
|
||||
}
|
||||
|
||||
// Call com.atproto.server.getServiceAuth on the user's PDS with Bearer token
|
||||
// buildServiceAuthURL constructs the URL for com.atproto.server.getServiceAuth
|
||||
func buildServiceAuthURL(pdsEndpoint, holdDID string) string {
|
||||
// Request 5-minute expiry (PDS may grant less)
|
||||
// exp must be absolute Unix timestamp, not relative duration
|
||||
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
|
||||
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
return fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
pdsEndpoint,
|
||||
atproto.ServerGetServiceAuth,
|
||||
url.QueryEscape(holdDID),
|
||||
url.QueryEscape("com.atproto.repo.getRecord"),
|
||||
expiryTime,
|
||||
)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create service auth request: %w", err)
|
||||
}
|
||||
|
||||
// Set Bearer token authentication (app-password)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
// Make request with standard HTTP client
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("App-password service token request failed",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err)
|
||||
return "", fmt.Errorf("failed to request service token: %w", err)
|
||||
}
|
||||
// parseServiceTokenResponse extracts the token from a service auth response
|
||||
func parseServiceTokenResponse(resp *http.Response) (string, error) {
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
// App-password token is invalid or expired - clear from cache
|
||||
GetGlobalTokenCache().Delete(did)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("App-password token rejected by PDS",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"hint", "User must re-authenticate with docker login")
|
||||
return "", fmt.Errorf("app-password authentication failed: token expired or invalid")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Service auth failed
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("Service token request returned non-200 status (app-password)",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseBody", string(bodyBytes))
|
||||
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Parse response to get service token
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
@@ -348,14 +111,190 @@ func GetOrFetchServiceTokenWithAppPassword(
|
||||
return "", fmt.Errorf("empty token in service auth response")
|
||||
}
|
||||
|
||||
serviceToken := result.Token
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if err := SetServiceToken(did, holdDID, serviceToken); err != nil {
|
||||
slog.Warn("Failed to cache service token", "error", err, "did", did, "holdDID", holdDID)
|
||||
// Non-fatal - we have the token, just won't be cached
|
||||
// GetOrFetchServiceToken gets a service token for hold authentication.
|
||||
// Handles both OAuth/DPoP and app-password authentication based on authMethod.
|
||||
// Checks cache first, then fetches from PDS if needed.
|
||||
//
|
||||
// For OAuth: Uses DoWithSession() to hold a per-DID lock through the entire PDS interaction.
|
||||
// This prevents DPoP nonce race conditions when multiple Docker layers upload concurrently.
|
||||
//
|
||||
// For app-password: Uses Bearer token authentication without locking (no DPoP complexity).
|
||||
func GetOrFetchServiceToken(
|
||||
ctx context.Context,
|
||||
authMethod string,
|
||||
refresher *oauth.Refresher, // Required for OAuth, nil for app-password
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Using cached service token",
|
||||
"did", did,
|
||||
"authMethod", authMethod,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
return cachedToken, nil
|
||||
}
|
||||
|
||||
slog.Debug("App-password validation succeeded, service token obtained", "did", did)
|
||||
// Cache miss or expiring soon - fetch new service token
|
||||
if cachedToken == "" {
|
||||
slog.Debug("Service token cache miss, fetching new token", "did", did, "authMethod", authMethod)
|
||||
} else {
|
||||
slog.Debug("Service token expiring soon, proactively renewing", "did", did, "authMethod", authMethod)
|
||||
}
|
||||
|
||||
var serviceToken string
|
||||
var err error
|
||||
|
||||
// Branch based on auth method
|
||||
if authMethod == AuthMethodOAuth {
|
||||
serviceToken, err = doOAuthFetch(ctx, refresher, did, holdDID, pdsEndpoint)
|
||||
// OAuth-specific cleanup: delete stale session on error
|
||||
if err != nil && refresher != nil {
|
||||
if delErr := refresher.DeleteSession(ctx, did); delErr != nil {
|
||||
slog.Warn("Failed to delete stale OAuth session",
|
||||
"component", "auth/servicetoken",
|
||||
"did", did,
|
||||
"error", delErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
serviceToken, err = doAppPasswordFetch(ctx, did, holdDID, pdsEndpoint)
|
||||
}
|
||||
|
||||
// Unified error handling
|
||||
if err != nil {
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
|
||||
var apiErr *atclient.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
slog.Error("Service token request failed",
|
||||
"component", "auth/servicetoken",
|
||||
"authMethod", authMethod,
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err,
|
||||
"httpStatus", apiErr.StatusCode,
|
||||
"errorName", apiErr.Name,
|
||||
"errorMessage", apiErr.Message,
|
||||
"hint", getErrorHint(apiErr))
|
||||
} else {
|
||||
slog.Error("Service token request failed",
|
||||
"component", "auth/servicetoken",
|
||||
"authMethod", authMethod,
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if cacheErr := SetServiceToken(did, holdDID, serviceToken); cacheErr != nil {
|
||||
slog.Warn("Failed to cache service token", "error", cacheErr, "did", did, "holdDID", holdDID)
|
||||
}
|
||||
|
||||
slog.Debug("Service token obtained", "did", did, "authMethod", authMethod)
|
||||
return serviceToken, nil
|
||||
}
|
||||
|
||||
// doOAuthFetch fetches a service token using OAuth/DPoP authentication.
|
||||
// Uses DoWithSession() for per-DID locking to prevent DPoP nonce races.
|
||||
// Returns (token, error) without logging - caller handles error logging.
|
||||
func doOAuthFetch(
|
||||
ctx context.Context,
|
||||
refresher *oauth.Refresher,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
if refresher == nil {
|
||||
return "", fmt.Errorf("refresher is nil (OAuth session required)")
|
||||
}
|
||||
|
||||
var serviceToken string
|
||||
var fetchErr error
|
||||
|
||||
err := refresher.DoWithSession(ctx, did, func(session *indigo_oauth.ClientSession) error {
|
||||
// Double-check cache after acquiring lock (double-checked locking pattern)
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Service token cache hit after lock acquisition",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
serviceToken = cachedToken
|
||||
return nil
|
||||
}
|
||||
|
||||
serviceAuthURL := buildServiceAuthURL(pdsEndpoint, holdDID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
fetchErr = fmt.Errorf("failed to create request: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
||||
if err != nil {
|
||||
fetchErr = fmt.Errorf("OAuth request failed: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
token, parseErr := parseServiceTokenResponse(resp)
|
||||
if parseErr != nil {
|
||||
fetchErr = parseErr
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
serviceToken = token
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if fetchErr != nil {
|
||||
return "", fetchErr
|
||||
}
|
||||
return "", fmt.Errorf("failed to get OAuth session: %w", err)
|
||||
}
|
||||
|
||||
return serviceToken, nil
|
||||
}
|
||||
|
||||
// doAppPasswordFetch fetches a service token using Bearer token authentication.
|
||||
// Returns (token, error) without logging - caller handles error logging.
|
||||
func doAppPasswordFetch(
|
||||
ctx context.Context,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
accessToken, ok := GetGlobalTokenCache().Get(did)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no app-password access token available for DID %s", did)
|
||||
}
|
||||
|
||||
serviceAuthURL := buildServiceAuthURL(pdsEndpoint, holdDID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
resp.Body.Close()
|
||||
// Clear stale app-password token
|
||||
GetGlobalTokenCache().Delete(did)
|
||||
return "", fmt.Errorf("app-password authentication failed: token expired or invalid")
|
||||
}
|
||||
|
||||
return parseServiceTokenResponse(resp)
|
||||
}
|
||||
|
||||
@@ -11,15 +11,15 @@ func TestGetOrFetchServiceToken_NilRefresher(t *testing.T) {
|
||||
holdDID := "did:web:hold.example.com"
|
||||
pdsEndpoint := "https://pds.example.com"
|
||||
|
||||
// Test with nil refresher - should return error
|
||||
_, err := GetOrFetchServiceToken(ctx, nil, did, holdDID, pdsEndpoint)
|
||||
// Test with nil refresher and OAuth auth method - should return error
|
||||
_, err := GetOrFetchServiceToken(ctx, AuthMethodOAuth, nil, did, holdDID, pdsEndpoint)
|
||||
if err == nil {
|
||||
t.Error("Expected error when refresher is nil")
|
||||
t.Error("Expected error when refresher is nil for OAuth")
|
||||
}
|
||||
|
||||
expectedErrMsg := "refresher is nil"
|
||||
if err.Error() != "refresher is nil (OAuth session required for service tokens)" {
|
||||
t.Errorf("Expected error message to contain %q, got %q", expectedErrMsg, err.Error())
|
||||
expectedErrMsg := "refresher is nil (OAuth session required)"
|
||||
if err.Error() != expectedErrMsg {
|
||||
t.Errorf("Expected error message %q, got %q", expectedErrMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
784
pkg/auth/usercontext.go
Normal file
784
pkg/auth/usercontext.go
Normal file
@@ -0,0 +1,784 @@
|
||||
// Package auth provides UserContext for managing authenticated user state
|
||||
// throughout request handling in the AppView.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
// Auth method constants (duplicated from token package to avoid import cycle)
|
||||
const (
|
||||
AuthMethodOAuth = "oauth"
|
||||
AuthMethodAppPassword = "app_password"
|
||||
)
|
||||
|
||||
// RequestAction represents the type of registry operation
|
||||
type RequestAction int
|
||||
|
||||
const (
|
||||
ActionUnknown RequestAction = iota
|
||||
ActionPull // GET/HEAD - reading from registry
|
||||
ActionPush // PUT/POST/DELETE - writing to registry
|
||||
ActionInspect // Metadata operations only
|
||||
)
|
||||
|
||||
func (a RequestAction) String() string {
|
||||
switch a {
|
||||
case ActionPull:
|
||||
return "pull"
|
||||
case ActionPush:
|
||||
return "push"
|
||||
case ActionInspect:
|
||||
return "inspect"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// HoldPermissions describes what the user can do on a specific hold
|
||||
type HoldPermissions struct {
|
||||
HoldDID string // Hold being checked
|
||||
IsOwner bool // User is captain of this hold
|
||||
IsCrew bool // User is a crew member
|
||||
IsPublic bool // Hold allows public reads
|
||||
CanRead bool // Computed: can user read blobs?
|
||||
CanWrite bool // Computed: can user write blobs?
|
||||
CanAdmin bool // Computed: can user manage crew?
|
||||
Permissions []string // Raw permissions from crew record
|
||||
}
|
||||
|
||||
// contextKey is unexported to prevent collisions
|
||||
type contextKey struct{}
|
||||
|
||||
// userContextKey is the context key for UserContext
|
||||
var userContextKey = contextKey{}
|
||||
|
||||
// userSetupCache tracks which users have had their profile/crew setup ensured
|
||||
var userSetupCache sync.Map // did -> time.Time
|
||||
|
||||
// userSetupTTL is how long to cache user setup status (1 hour)
|
||||
const userSetupTTL = 1 * time.Hour
|
||||
|
||||
// Dependencies bundles services needed by UserContext
|
||||
type Dependencies struct {
|
||||
Refresher *oauth.Refresher
|
||||
Authorizer HoldAuthorizer
|
||||
DefaultHoldDID string // AppView's default hold DID
|
||||
}
|
||||
|
||||
// UserContext encapsulates authenticated user state for a request.
|
||||
// Built early in the middleware chain and available throughout request processing.
|
||||
//
|
||||
// Two-phase initialization:
|
||||
// 1. Middleware phase: Identity is set (DID, authMethod, action)
|
||||
// 2. Repository() phase: Target is set via SetTarget() (owner, repo, holdDID)
|
||||
type UserContext struct {
|
||||
// === User Identity (set in middleware) ===
|
||||
DID string // User's DID (empty if unauthenticated)
|
||||
Handle string // User's handle (may be empty)
|
||||
PDSEndpoint string // User's PDS endpoint
|
||||
AuthMethod string // "oauth", "app_password", or ""
|
||||
IsAuthenticated bool
|
||||
|
||||
// === Request Info ===
|
||||
Action RequestAction
|
||||
HTTPMethod string
|
||||
|
||||
// === Target Info (set by SetTarget) ===
|
||||
TargetOwnerDID string // whose repo is being accessed
|
||||
TargetOwnerHandle string
|
||||
TargetOwnerPDS string
|
||||
TargetRepo string // image name (e.g., "quickslice")
|
||||
TargetHoldDID string // hold where blobs live/will live
|
||||
|
||||
// === Dependencies (injected) ===
|
||||
refresher *oauth.Refresher
|
||||
authorizer HoldAuthorizer
|
||||
defaultHoldDID string
|
||||
|
||||
// === Cached State (lazy-loaded) ===
|
||||
serviceTokens sync.Map // holdDID -> *serviceTokenEntry
|
||||
permissions sync.Map // holdDID -> *HoldPermissions
|
||||
pdsResolved bool
|
||||
pdsResolveErr error
|
||||
mu sync.Mutex // protects PDS resolution
|
||||
atprotoClient *atproto.Client
|
||||
atprotoClientOnce sync.Once
|
||||
}
|
||||
|
||||
// FromContext retrieves UserContext from context.
|
||||
// Returns nil if not present (unauthenticated or before middleware).
|
||||
func FromContext(ctx context.Context) *UserContext {
|
||||
uc, _ := ctx.Value(userContextKey).(*UserContext)
|
||||
return uc
|
||||
}
|
||||
|
||||
// WithUserContext adds UserContext to context
|
||||
func WithUserContext(ctx context.Context, uc *UserContext) context.Context {
|
||||
return context.WithValue(ctx, userContextKey, uc)
|
||||
}
|
||||
|
||||
// NewUserContext creates a UserContext from extracted JWT claims.
|
||||
// The deps parameter provides access to services needed for lazy operations.
|
||||
func NewUserContext(did, authMethod, httpMethod string, deps *Dependencies) *UserContext {
|
||||
action := ActionUnknown
|
||||
switch httpMethod {
|
||||
case "GET", "HEAD":
|
||||
action = ActionPull
|
||||
case "PUT", "POST", "PATCH", "DELETE":
|
||||
action = ActionPush
|
||||
}
|
||||
|
||||
var refresher *oauth.Refresher
|
||||
var authorizer HoldAuthorizer
|
||||
var defaultHoldDID string
|
||||
|
||||
if deps != nil {
|
||||
refresher = deps.Refresher
|
||||
authorizer = deps.Authorizer
|
||||
defaultHoldDID = deps.DefaultHoldDID
|
||||
}
|
||||
|
||||
return &UserContext{
|
||||
DID: did,
|
||||
AuthMethod: authMethod,
|
||||
IsAuthenticated: did != "",
|
||||
Action: action,
|
||||
HTTPMethod: httpMethod,
|
||||
refresher: refresher,
|
||||
authorizer: authorizer,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPDS sets the user's PDS endpoint directly, bypassing network resolution.
|
||||
// Use when PDS is already known (e.g., from previous resolution or client).
|
||||
func (uc *UserContext) SetPDS(handle, pdsEndpoint string) {
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
uc.Handle = handle
|
||||
uc.PDSEndpoint = pdsEndpoint
|
||||
uc.pdsResolved = true
|
||||
uc.pdsResolveErr = nil
|
||||
}
|
||||
|
||||
// SetTarget sets the target repository information.
|
||||
// Called in Repository() after resolving the owner identity.
|
||||
func (uc *UserContext) SetTarget(ownerDID, ownerHandle, ownerPDS, repo, holdDID string) {
|
||||
uc.TargetOwnerDID = ownerDID
|
||||
uc.TargetOwnerHandle = ownerHandle
|
||||
uc.TargetOwnerPDS = ownerPDS
|
||||
uc.TargetRepo = repo
|
||||
uc.TargetHoldDID = holdDID
|
||||
}
|
||||
|
||||
// ResolvePDS resolves the user's PDS endpoint (lazy, cached).
|
||||
// Safe to call multiple times; resolution happens once.
|
||||
func (uc *UserContext) ResolvePDS(ctx context.Context) error {
|
||||
if !uc.IsAuthenticated {
|
||||
return nil // Nothing to resolve for anonymous users
|
||||
}
|
||||
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
|
||||
if uc.pdsResolved {
|
||||
return uc.pdsResolveErr
|
||||
}
|
||||
|
||||
_, handle, pds, err := atproto.ResolveIdentity(ctx, uc.DID)
|
||||
if err != nil {
|
||||
uc.pdsResolveErr = err
|
||||
uc.pdsResolved = true
|
||||
return err
|
||||
}
|
||||
|
||||
uc.Handle = handle
|
||||
uc.PDSEndpoint = pds
|
||||
uc.pdsResolved = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServiceToken returns a service token for the target hold.
|
||||
// Uses internal caching with sync.Once per holdDID.
|
||||
// Requires target to be set via SetTarget().
|
||||
func (uc *UserContext) GetServiceToken(ctx context.Context) (string, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return "", fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
return uc.GetServiceTokenForHold(ctx, uc.TargetHoldDID)
|
||||
}
|
||||
|
||||
// GetServiceTokenForHold returns a service token for an arbitrary hold.
|
||||
// Uses internal caching with sync.Once per holdDID.
|
||||
func (uc *UserContext) GetServiceTokenForHold(ctx context.Context, holdDID string) (string, error) {
|
||||
if !uc.IsAuthenticated {
|
||||
return "", fmt.Errorf("cannot get service token: user not authenticated")
|
||||
}
|
||||
|
||||
// Ensure PDS is resolved
|
||||
if err := uc.ResolvePDS(ctx); err != nil {
|
||||
return "", fmt.Errorf("failed to resolve PDS: %w", err)
|
||||
}
|
||||
|
||||
// Load or create cache entry
|
||||
entryVal, _ := uc.serviceTokens.LoadOrStore(holdDID, &serviceTokenEntry{})
|
||||
entry := entryVal.(*serviceTokenEntry)
|
||||
|
||||
entry.once.Do(func() {
|
||||
slog.Debug("Fetching service token",
|
||||
"component", "auth/context",
|
||||
"userDID", uc.DID,
|
||||
"holdDID", holdDID,
|
||||
"authMethod", uc.AuthMethod)
|
||||
|
||||
// Use unified service token function (handles both OAuth and app-password)
|
||||
serviceToken, err := GetOrFetchServiceToken(
|
||||
ctx, uc.AuthMethod, uc.refresher, uc.DID, holdDID, uc.PDSEndpoint,
|
||||
)
|
||||
|
||||
entry.token = serviceToken
|
||||
entry.err = err
|
||||
if err == nil {
|
||||
// Parse JWT to get expiry
|
||||
expiry, parseErr := ParseJWTExpiry(serviceToken)
|
||||
if parseErr == nil {
|
||||
entry.expiresAt = expiry.Add(-10 * time.Second) // Safety margin
|
||||
} else {
|
||||
entry.expiresAt = time.Now().Add(45 * time.Second) // Default fallback
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return entry.token, entry.err
|
||||
}
|
||||
|
||||
// CanRead checks if user can read blobs from target hold.
|
||||
// - Public hold: any user (even anonymous)
|
||||
// - Private hold: owner OR crew with blob:read/blob:write
|
||||
func (uc *UserContext) CanRead(ctx context.Context) (bool, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return false, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return false, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
return uc.authorizer.CheckReadAccess(ctx, uc.TargetHoldDID, uc.DID)
|
||||
}
|
||||
|
||||
// CanWrite checks if user can write blobs to target hold.
|
||||
// - Must be authenticated
|
||||
// - Must be owner OR crew with blob:write
|
||||
func (uc *UserContext) CanWrite(ctx context.Context) (bool, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return false, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
|
||||
if !uc.IsAuthenticated {
|
||||
return false, nil // Anonymous writes never allowed
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return false, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
return uc.authorizer.CheckWriteAccess(ctx, uc.TargetHoldDID, uc.DID)
|
||||
}
|
||||
|
||||
// GetPermissions returns detailed permissions for target hold.
|
||||
// Lazy-loaded and cached per holdDID.
|
||||
func (uc *UserContext) GetPermissions(ctx context.Context) (*HoldPermissions, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return nil, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
return uc.GetPermissionsForHold(ctx, uc.TargetHoldDID)
|
||||
}
|
||||
|
||||
// GetPermissionsForHold returns detailed permissions for an arbitrary hold.
|
||||
// Lazy-loaded and cached per holdDID.
|
||||
func (uc *UserContext) GetPermissionsForHold(ctx context.Context, holdDID string) (*HoldPermissions, error) {
|
||||
// Check cache first
|
||||
if cached, ok := uc.permissions.Load(holdDID); ok {
|
||||
return cached.(*HoldPermissions), nil
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return nil, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
// Build permissions by querying authorizer
|
||||
captain, err := uc.authorizer.GetCaptainRecord(ctx, holdDID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get captain record: %w", err)
|
||||
}
|
||||
|
||||
perms := &HoldPermissions{
|
||||
HoldDID: holdDID,
|
||||
IsPublic: captain.Public,
|
||||
IsOwner: uc.DID != "" && uc.DID == captain.Owner,
|
||||
}
|
||||
|
||||
// Check crew membership if authenticated and not owner
|
||||
if uc.IsAuthenticated && !perms.IsOwner {
|
||||
isCrew, crewErr := uc.authorizer.IsCrewMember(ctx, holdDID, uc.DID)
|
||||
if crewErr != nil {
|
||||
slog.Warn("Failed to check crew membership",
|
||||
"component", "auth/context",
|
||||
"holdDID", holdDID,
|
||||
"userDID", uc.DID,
|
||||
"error", crewErr)
|
||||
}
|
||||
perms.IsCrew = isCrew
|
||||
}
|
||||
|
||||
// Compute permissions based on role
|
||||
if perms.IsOwner {
|
||||
perms.CanRead = true
|
||||
perms.CanWrite = true
|
||||
perms.CanAdmin = true
|
||||
} else if perms.IsCrew {
|
||||
// Crew members can read and write (for now, all crew have blob:write)
|
||||
// TODO: Check specific permissions from crew record
|
||||
perms.CanRead = true
|
||||
perms.CanWrite = true
|
||||
perms.CanAdmin = false
|
||||
} else if perms.IsPublic {
|
||||
// Public hold - anyone can read
|
||||
perms.CanRead = true
|
||||
perms.CanWrite = false
|
||||
perms.CanAdmin = false
|
||||
} else if uc.IsAuthenticated {
|
||||
// Private hold, authenticated non-crew
|
||||
// Per permission matrix: cannot read private holds
|
||||
perms.CanRead = false
|
||||
perms.CanWrite = false
|
||||
perms.CanAdmin = false
|
||||
} else {
|
||||
// Anonymous on private hold
|
||||
perms.CanRead = false
|
||||
perms.CanWrite = false
|
||||
perms.CanAdmin = false
|
||||
}
|
||||
|
||||
// Cache and return
|
||||
uc.permissions.Store(holdDID, perms)
|
||||
return perms, nil
|
||||
}
|
||||
|
||||
// IsCrewMember checks if user is crew of target hold.
|
||||
func (uc *UserContext) IsCrewMember(ctx context.Context) (bool, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return false, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
|
||||
if !uc.IsAuthenticated {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return false, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
return uc.authorizer.IsCrewMember(ctx, uc.TargetHoldDID, uc.DID)
|
||||
}
|
||||
|
||||
// EnsureCrewMembership is a standalone function to register as crew on a hold.
|
||||
// Use this when you don't have a UserContext (e.g., OAuth callback).
|
||||
// This is best-effort and logs errors without failing.
|
||||
func EnsureCrewMembership(ctx context.Context, did, pdsEndpoint string, refresher *oauth.Refresher, holdDID string) {
|
||||
if holdDID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Only works with OAuth (refresher required) - app passwords can't get service tokens
|
||||
if refresher == nil {
|
||||
slog.Debug("skipping crew registration - no OAuth refresher (app password flow)", "holdDID", holdDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize URL to DID if needed
|
||||
if !atproto.IsDID(holdDID) {
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
|
||||
if holdDID == "" {
|
||||
slog.Warn("failed to resolve hold DID", "defaultHold", holdDID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get service token for the hold (OAuth only at this point)
|
||||
serviceToken, err := GetOrFetchServiceToken(ctx, AuthMethodOAuth, refresher, did, holdDID, pdsEndpoint)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get service token", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
holdEndpoint := atproto.ResolveHoldURL(holdDID)
|
||||
if holdEndpoint == "" {
|
||||
slog.Warn("failed to resolve hold endpoint", "holdDID", holdDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Call requestCrew endpoint
|
||||
if err := requestCrewMembership(ctx, holdEndpoint, serviceToken); err != nil {
|
||||
slog.Warn("failed to request crew membership", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("successfully registered as crew member", "holdDID", holdDID, "userDID", did)
|
||||
}
|
||||
|
||||
// ensureCrewMembership attempts to register as crew on target hold (UserContext method).
|
||||
// Called automatically during first push; idempotent.
|
||||
// This is a best-effort operation and logs errors without failing.
|
||||
// Requires SetTarget() to be called first.
|
||||
func (uc *UserContext) ensureCrewMembership(ctx context.Context) error {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
return uc.EnsureCrewMembershipForHold(ctx, uc.TargetHoldDID)
|
||||
}
|
||||
|
||||
// EnsureCrewMembershipForHold attempts to register as crew on the specified hold.
|
||||
// This is the core implementation that can be called with any holdDID.
|
||||
// Called automatically during first push; idempotent.
|
||||
// This is a best-effort operation and logs errors without failing.
|
||||
func (uc *UserContext) EnsureCrewMembershipForHold(ctx context.Context, holdDID string) error {
|
||||
if holdDID == "" {
|
||||
return nil // Nothing to do
|
||||
}
|
||||
|
||||
// Normalize URL to DID if needed
|
||||
if !atproto.IsDID(holdDID) {
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
|
||||
if holdDID == "" {
|
||||
return fmt.Errorf("failed to resolve hold DID from URL")
|
||||
}
|
||||
}
|
||||
|
||||
if !uc.IsAuthenticated {
|
||||
return fmt.Errorf("cannot register as crew: user not authenticated")
|
||||
}
|
||||
|
||||
if uc.refresher == nil {
|
||||
return fmt.Errorf("cannot register as crew: OAuth session required")
|
||||
}
|
||||
|
||||
// Get service token for the hold
|
||||
serviceToken, err := uc.GetServiceTokenForHold(ctx, holdDID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service token: %w", err)
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
holdEndpoint := atproto.ResolveHoldURL(holdDID)
|
||||
if holdEndpoint == "" {
|
||||
return fmt.Errorf("failed to resolve hold endpoint for %s", holdDID)
|
||||
}
|
||||
|
||||
// Call requestCrew endpoint
|
||||
return requestCrewMembership(ctx, holdEndpoint, serviceToken)
|
||||
}
|
||||
|
||||
// requestCrewMembership calls the hold's requestCrew endpoint
|
||||
// The endpoint handles all authorization and duplicate checking internally
|
||||
func requestCrewMembership(ctx context.Context, holdEndpoint, serviceToken string) error {
|
||||
// Add 5 second timeout to prevent hanging on offline holds
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
url := fmt.Sprintf("%s%s", holdEndpoint, atproto.HoldRequestCrew)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+serviceToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
// Read response body to capture actual error message from hold
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("requestCrew failed with status %d (failed to read error body: %w)", resp.StatusCode, readErr)
|
||||
}
|
||||
return fmt.Errorf("requestCrew failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserClient returns an authenticated ATProto client for the user's own PDS.
|
||||
// Used for profile operations (reading/writing to user's own repo).
|
||||
// Returns nil if not authenticated or PDS not resolved.
|
||||
func (uc *UserContext) GetUserClient() *atproto.Client {
|
||||
if !uc.IsAuthenticated || uc.PDSEndpoint == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if uc.AuthMethod == AuthMethodOAuth && uc.refresher != nil {
|
||||
return atproto.NewClientWithSessionProvider(uc.PDSEndpoint, uc.DID, uc.refresher)
|
||||
} else if uc.AuthMethod == AuthMethodAppPassword {
|
||||
accessToken, _ := GetGlobalTokenCache().Get(uc.DID)
|
||||
return atproto.NewClient(uc.PDSEndpoint, uc.DID, accessToken)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureUserSetup ensures the user has a profile and crew membership.
|
||||
// Called once per user (cached for userSetupTTL). Runs in background - does not block.
|
||||
// Safe to call on every request.
|
||||
func (uc *UserContext) EnsureUserSetup() {
|
||||
if !uc.IsAuthenticated || uc.DID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Check cache - skip if recently set up
|
||||
if lastSetup, ok := userSetupCache.Load(uc.DID); ok {
|
||||
if time.Since(lastSetup.(time.Time)) < userSetupTTL {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Run in background to avoid blocking requests
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
|
||||
// 1. Ensure profile exists
|
||||
if client := uc.GetUserClient(); client != nil {
|
||||
uc.ensureProfile(bgCtx, client)
|
||||
}
|
||||
|
||||
// 2. Ensure crew membership on default hold
|
||||
if uc.defaultHoldDID != "" {
|
||||
EnsureCrewMembership(bgCtx, uc.DID, uc.PDSEndpoint, uc.refresher, uc.defaultHoldDID)
|
||||
}
|
||||
|
||||
// Mark as set up
|
||||
userSetupCache.Store(uc.DID, time.Now())
|
||||
slog.Debug("User setup complete",
|
||||
"component", "auth/usercontext",
|
||||
"did", uc.DID,
|
||||
"defaultHoldDID", uc.defaultHoldDID)
|
||||
}()
|
||||
}
|
||||
|
||||
// ensureProfile creates sailor profile if it doesn't exist.
|
||||
// Inline implementation to avoid circular import with storage package.
|
||||
func (uc *UserContext) ensureProfile(ctx context.Context, client *atproto.Client) {
|
||||
// Check if profile already exists
|
||||
profile, err := client.GetRecord(ctx, atproto.SailorProfileCollection, "self")
|
||||
if err == nil && profile != nil {
|
||||
return // Already exists
|
||||
}
|
||||
|
||||
// Create profile with default hold
|
||||
normalizedDID := ""
|
||||
if uc.defaultHoldDID != "" {
|
||||
normalizedDID = atproto.ResolveHoldDIDFromURL(uc.defaultHoldDID)
|
||||
}
|
||||
|
||||
newProfile := atproto.NewSailorProfileRecord(normalizedDID)
|
||||
if _, err := client.PutRecord(ctx, atproto.SailorProfileCollection, "self", newProfile); err != nil {
|
||||
slog.Warn("Failed to create sailor profile",
|
||||
"component", "auth/usercontext",
|
||||
"did", uc.DID,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("Created sailor profile",
|
||||
"component", "auth/usercontext",
|
||||
"did", uc.DID,
|
||||
"defaultHold", normalizedDID)
|
||||
}
|
||||
|
||||
// GetATProtoClient returns a cached ATProto client for the target owner's PDS.
|
||||
// Authenticated if user is owner, otherwise anonymous.
|
||||
// Cached per-request (uses sync.Once).
|
||||
func (uc *UserContext) GetATProtoClient() *atproto.Client {
|
||||
uc.atprotoClientOnce.Do(func() {
|
||||
if uc.TargetOwnerPDS == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// If puller is owner and authenticated, use authenticated client
|
||||
if uc.DID == uc.TargetOwnerDID && uc.IsAuthenticated {
|
||||
if uc.AuthMethod == AuthMethodOAuth && uc.refresher != nil {
|
||||
uc.atprotoClient = atproto.NewClientWithSessionProvider(uc.TargetOwnerPDS, uc.TargetOwnerDID, uc.refresher)
|
||||
return
|
||||
} else if uc.AuthMethod == AuthMethodAppPassword {
|
||||
accessToken, _ := GetGlobalTokenCache().Get(uc.TargetOwnerDID)
|
||||
uc.atprotoClient = atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, accessToken)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymous client for reads
|
||||
uc.atprotoClient = atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, "")
|
||||
})
|
||||
return uc.atprotoClient
|
||||
}
|
||||
|
||||
// ResolveHoldDID finds the hold for the target repository.
|
||||
// - Pull: uses database lookup (historical from manifest)
|
||||
// - Push: uses discovery (sailor profile → default)
|
||||
//
|
||||
// Must be called after SetTarget() is called with at least TargetOwnerDID and TargetRepo set.
|
||||
// Updates TargetHoldDID on success.
|
||||
func (uc *UserContext) ResolveHoldDID(ctx context.Context, sqlDB *sql.DB) (string, error) {
|
||||
if uc.TargetOwnerDID == "" {
|
||||
return "", fmt.Errorf("target owner not set")
|
||||
}
|
||||
|
||||
var holdDID string
|
||||
var err error
|
||||
|
||||
switch uc.Action {
|
||||
case ActionPull:
|
||||
// For pulls, look up historical hold from database
|
||||
holdDID, err = uc.resolveHoldForPull(ctx, sqlDB)
|
||||
case ActionPush:
|
||||
// For pushes, discover hold from owner's profile
|
||||
holdDID, err = uc.resolveHoldForPush(ctx)
|
||||
default:
|
||||
// Default to push discovery
|
||||
holdDID, err = uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if holdDID == "" {
|
||||
return "", fmt.Errorf("no hold DID found for %s/%s", uc.TargetOwnerDID, uc.TargetRepo)
|
||||
}
|
||||
|
||||
uc.TargetHoldDID = holdDID
|
||||
return holdDID, nil
|
||||
}
|
||||
|
||||
// resolveHoldForPull looks up the hold from the database (historical reference)
|
||||
func (uc *UserContext) resolveHoldForPull(ctx context.Context, sqlDB *sql.DB) (string, error) {
|
||||
// If no database is available, fall back to discovery
|
||||
if sqlDB == nil {
|
||||
return uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
// Try database lookup first
|
||||
holdDID, err := db.GetLatestHoldDIDForRepo(sqlDB, uc.TargetOwnerDID, uc.TargetRepo)
|
||||
if err != nil {
|
||||
slog.Debug("Database lookup failed, falling back to discovery",
|
||||
"component", "auth/context",
|
||||
"ownerDID", uc.TargetOwnerDID,
|
||||
"repo", uc.TargetRepo,
|
||||
"error", err)
|
||||
return uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
if holdDID != "" {
|
||||
return holdDID, nil
|
||||
}
|
||||
|
||||
// No historical hold found, fall back to discovery
|
||||
return uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
// resolveHoldForPush discovers hold from owner's sailor profile or default
|
||||
func (uc *UserContext) resolveHoldForPush(ctx context.Context) (string, error) {
|
||||
// Create anonymous client to query owner's profile
|
||||
client := atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, "")
|
||||
|
||||
// Try to get owner's sailor profile
|
||||
record, err := client.GetRecord(ctx, atproto.SailorProfileCollection, "self")
|
||||
if err == nil && record != nil {
|
||||
var profile atproto.SailorProfileRecord
|
||||
if jsonErr := json.Unmarshal(record.Value, &profile); jsonErr == nil {
|
||||
if profile.DefaultHold != "" {
|
||||
// Normalize to DID if needed
|
||||
holdDID := profile.DefaultHold
|
||||
if !atproto.IsDID(holdDID) {
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
|
||||
}
|
||||
slog.Debug("Found hold from owner's profile",
|
||||
"component", "auth/context",
|
||||
"ownerDID", uc.TargetOwnerDID,
|
||||
"holdDID", holdDID)
|
||||
return holdDID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default hold
|
||||
if uc.defaultHoldDID != "" {
|
||||
slog.Debug("Using default hold",
|
||||
"component", "auth/context",
|
||||
"ownerDID", uc.TargetOwnerDID,
|
||||
"defaultHoldDID", uc.defaultHoldDID)
|
||||
return uc.defaultHoldDID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no hold configured for %s and no default hold set", uc.TargetOwnerDID)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test Helper Methods
|
||||
// =============================================================================
|
||||
// These methods are designed to make UserContext testable by allowing tests
|
||||
// to bypass network-dependent code paths (PDS resolution, OAuth token fetching).
|
||||
// Only use these in tests - they are not intended for production use.
|
||||
|
||||
// SetPDSForTest sets the PDS endpoint directly, bypassing ResolvePDS network calls.
|
||||
// This allows tests to skip DID resolution which would make network requests.
|
||||
// Deprecated: Use SetPDS instead.
|
||||
func (uc *UserContext) SetPDSForTest(handle, pdsEndpoint string) {
|
||||
uc.SetPDS(handle, pdsEndpoint)
|
||||
}
|
||||
|
||||
// SetServiceTokenForTest pre-populates a service token for the given holdDID,
|
||||
// bypassing the sync.Once and OAuth/app-password fetching logic.
|
||||
// The token will appear as if it was already fetched and cached.
|
||||
func (uc *UserContext) SetServiceTokenForTest(holdDID, token string) {
|
||||
entry := &serviceTokenEntry{
|
||||
token: token,
|
||||
expiresAt: time.Now().Add(5 * time.Minute),
|
||||
err: nil,
|
||||
}
|
||||
// Mark the sync.Once as done so real fetch won't happen
|
||||
entry.once.Do(func() {})
|
||||
uc.serviceTokens.Store(holdDID, entry)
|
||||
}
|
||||
|
||||
// SetAuthorizerForTest sets the authorizer for permission checks.
|
||||
// Use with MockHoldAuthorizer to control CanRead/CanWrite behavior in tests.
|
||||
func (uc *UserContext) SetAuthorizerForTest(authorizer HoldAuthorizer) {
|
||||
uc.authorizer = authorizer
|
||||
}
|
||||
|
||||
// SetDefaultHoldDIDForTest sets the default hold DID for tests.
|
||||
// This is used as fallback when resolving hold for push operations.
|
||||
func (uc *UserContext) SetDefaultHoldDIDForTest(holdDID string) {
|
||||
uc.defaultHoldDID = holdDID
|
||||
}
|
||||
Reference in New Issue
Block a user