big scary refactor. sync enable_bluesky_posts with captain record. implement oauth logout handler. implement crew assignment to hold. this caused a lot of circular dependencies and needed to move functions around in order to fix
This commit is contained in:
@@ -89,7 +89,7 @@ HOLD_DATABASE_DIR=/var/lib/atcr-hold
|
||||
|
||||
# Enable Bluesky posts when users push container images (default: false)
|
||||
# When enabled, the hold's embedded PDS will create posts announcing image pushes
|
||||
# Can be overridden per-hold via the captain record's enableManifestPosts field
|
||||
# Synced to captain record's enableBlueskyPosts field on startup
|
||||
# HOLD_BLUESKY_POSTS_ENABLED=false
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
@@ -9,15 +9,19 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
"github.com/distribution/distribution/v3/configuration"
|
||||
"github.com/distribution/distribution/v3/registry"
|
||||
"github.com/distribution/distribution/v3/registry/handlers"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"atcr.io/pkg/appview/middleware"
|
||||
"atcr.io/pkg/appview/storage"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
"atcr.io/pkg/auth/token"
|
||||
@@ -206,7 +210,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("README cache initialized for manifest push refresh")
|
||||
|
||||
// Initialize UI routes with OAuth app, refresher, device store, health checker, and readme cache
|
||||
uiTemplates, uiRouter := initializeUIRoutes(uiDatabase, uiReadOnlyDB, uiSessionStore, oauthApp, refresher, baseURL, deviceStore, defaultHoldDID, healthChecker, readmeCache)
|
||||
uiTemplates, uiRouter := initializeUIRoutes(uiDatabase, uiReadOnlyDB, uiSessionStore, oauthApp, oauthStore, refresher, baseURL, deviceStore, defaultHoldDID, healthChecker, readmeCache)
|
||||
|
||||
// Create OAuth server
|
||||
oauthServer := oauth.NewServer(oauthApp)
|
||||
@@ -216,15 +220,120 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
if uiSessionStore != nil {
|
||||
oauthServer.SetUISessionStore(uiSessionStore)
|
||||
}
|
||||
// Connect database for user avatar management
|
||||
oauthServer.SetDatabase(uiDatabase)
|
||||
|
||||
// Set default hold DID on OAuth server (extracted earlier)
|
||||
// This is used to create sailor profiles on first login
|
||||
if defaultHoldDID != "" {
|
||||
oauthServer.SetDefaultHoldDID(defaultHoldDID)
|
||||
fmt.Printf("OAuth server will create profiles with default hold: %s\n", defaultHoldDID)
|
||||
}
|
||||
// Register OAuth post-auth callback for AppView business logic
|
||||
// This decouples the OAuth package from AppView-specific dependencies
|
||||
oauthServer.SetPostAuthCallback(func(ctx context.Context, did, handle, pdsEndpoint, sessionID string) error {
|
||||
fmt.Printf("DEBUG [appview/callback]: OAuth post-auth callback for DID=%s\n", did)
|
||||
|
||||
// Parse DID for session resume
|
||||
didParsed, err := syntax.ParseDID(did)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to parse DID %s: %v\n", did, err)
|
||||
return nil // Non-fatal
|
||||
}
|
||||
|
||||
// Resume OAuth session to get authenticated client
|
||||
session, err := oauthApp.ResumeSession(ctx, didParsed, sessionID)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to resume session for DID=%s: %v\n", did, err)
|
||||
// Fallback: update user without avatar
|
||||
_ = db.UpsertUser(uiDatabase, &db.User{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Avatar: "",
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
return nil // Non-fatal
|
||||
}
|
||||
|
||||
// Create authenticated atproto client using the indigo session's API client
|
||||
client := atproto.NewClientWithIndigoClient(pdsEndpoint, did, session.APIClient())
|
||||
|
||||
// Ensure sailor profile exists (creates with default hold if configured)
|
||||
fmt.Printf("DEBUG [appview/callback]: Ensuring profile exists for %s (defaultHold=%s)\n", did, defaultHoldDID)
|
||||
if err := storage.EnsureProfile(ctx, client, defaultHoldDID); err != nil {
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to ensure profile for %s: %v\n", did, err)
|
||||
// Continue anyway - profile creation is not critical for avatar fetch
|
||||
} else {
|
||||
fmt.Printf("DEBUG [appview/callback]: Profile ensured for %s\n", did)
|
||||
}
|
||||
|
||||
// Fetch user's profile record from PDS (contains blob references)
|
||||
profileRecord, err := client.GetProfileRecord(ctx, did)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to fetch profile record for DID=%s: %v\n", did, err)
|
||||
// Still update user without avatar
|
||||
_ = db.UpsertUser(uiDatabase, &db.User{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Avatar: "",
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
return nil // Non-fatal
|
||||
}
|
||||
|
||||
// Construct avatar URL from blob CID using imgs.blue CDN
|
||||
var avatarURL string
|
||||
if profileRecord.Avatar != nil && profileRecord.Avatar.Ref.Link != "" {
|
||||
avatarURL = atproto.BlobCDNURL(did, profileRecord.Avatar.Ref.Link)
|
||||
fmt.Printf("DEBUG [appview/callback]: Constructed avatar URL: %s\n", avatarURL)
|
||||
}
|
||||
|
||||
// Store user with avatar in database
|
||||
err = db.UpsertUser(uiDatabase, &db.User{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Avatar: avatarURL,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to store user in database: %v\n", err)
|
||||
return nil // Non-fatal
|
||||
}
|
||||
|
||||
fmt.Printf("DEBUG [appview/callback]: Stored user with avatar for DID=%s\n", did)
|
||||
|
||||
// Migrate profile URL→DID if needed
|
||||
profile, err := storage.GetProfile(ctx, client)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to get profile for %s: %v\n", did, err)
|
||||
return nil // Non-fatal
|
||||
}
|
||||
|
||||
var holdDID string
|
||||
if profile != nil && profile.DefaultHold != "" {
|
||||
// Check if defaultHold is a URL (needs migration)
|
||||
if strings.HasPrefix(profile.DefaultHold, "http://") || strings.HasPrefix(profile.DefaultHold, "https://") {
|
||||
fmt.Printf("DEBUG [appview/callback]: Migrating hold URL to DID for %s: %s\n", did, profile.DefaultHold)
|
||||
|
||||
// Resolve URL to DID
|
||||
holdDID := atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
|
||||
// Update profile with DID
|
||||
profile.DefaultHold = holdDID
|
||||
if err := storage.UpdateProfile(ctx, client, profile); err != nil {
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to update profile with hold DID for %s: %v\n", did, err)
|
||||
} else {
|
||||
fmt.Printf("DEBUG [appview/callback]: Updated profile with hold DID: %s\n", holdDID)
|
||||
}
|
||||
fmt.Printf("DEBUG [oauth/server]: Attempting crew registration for %s at hold %s\n", did, holdDID)
|
||||
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
||||
} else {
|
||||
// Already a DID - use it
|
||||
holdDID = profile.DefaultHold
|
||||
}
|
||||
// Register crew regardless of migration (outside the migration block)
|
||||
fmt.Printf("DEBUG [appview/callback]: Attempting crew registration for %s at hold %s\n", did, holdDID)
|
||||
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
||||
|
||||
}
|
||||
|
||||
return nil // All errors are non-fatal, logged for debugging
|
||||
})
|
||||
|
||||
// Initialize auth keys and create token issuer
|
||||
var issuer *token.Issuer
|
||||
@@ -284,8 +393,27 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
// Mount auth endpoints if enabled
|
||||
if issuer != nil {
|
||||
// Basic Auth token endpoint (supports device secrets and app passwords)
|
||||
// Reuse defaultHoldDID extracted earlier
|
||||
tokenHandler := token.NewHandler(issuer, deviceStore, defaultHoldDID)
|
||||
tokenHandler := token.NewHandler(issuer, deviceStore)
|
||||
|
||||
// Register token post-auth callback for profile management
|
||||
// This decouples the token package from AppView-specific dependencies
|
||||
tokenHandler.SetPostAuthCallback(func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error {
|
||||
fmt.Printf("DEBUG [appview/callback]: Token post-auth callback for DID=%s\n", did)
|
||||
|
||||
// Create ATProto client with validated token
|
||||
atprotoClient := atproto.NewClient(pdsEndpoint, did, accessToken)
|
||||
|
||||
// Ensure profile exists (will create with default hold if not exists and default is configured)
|
||||
if err := storage.EnsureProfile(ctx, atprotoClient, defaultHoldDID); err != nil {
|
||||
// Log error but don't fail auth - profile management is not critical
|
||||
fmt.Printf("WARNING [appview/callback]: Failed to ensure profile for %s: %v\n", did, err)
|
||||
} else {
|
||||
fmt.Printf("DEBUG [appview/callback]: Profile ensured for %s with default hold %s\n", did, defaultHoldDID)
|
||||
}
|
||||
|
||||
return nil // All errors are non-fatal
|
||||
})
|
||||
|
||||
tokenHandler.RegisterRoutes(mux)
|
||||
|
||||
// Device authorization endpoints (public)
|
||||
@@ -401,7 +529,7 @@ func createTokenIssuer(config *configuration.Configuration) (*token.Issuer, erro
|
||||
// readOnlyDB: read-only connection for public queries (search, user pages, etc.)
|
||||
// defaultHoldDID: DID of the default hold service (e.g., "did:web:hold01.atcr.io")
|
||||
// healthChecker: hold endpoint health checker
|
||||
func initializeUIRoutes(database *sql.DB, readOnlyDB *sql.DB, sessionStore *db.SessionStore, oauthApp *oauth.App, refresher *oauth.Refresher, baseURL string, deviceStore *db.DeviceStore, defaultHoldDID string, healthChecker *holdhealth.Checker, readmeCache *readme.Cache) (*template.Template, *mux.Router) {
|
||||
func initializeUIRoutes(database *sql.DB, readOnlyDB *sql.DB, sessionStore *db.SessionStore, oauthApp *oauth.App, oauthStore *db.OAuthStore, refresher *oauth.Refresher, baseURL string, deviceStore *db.DeviceStore, defaultHoldDID string, healthChecker *holdhealth.Checker, readmeCache *readme.Cache) (*template.Template, *mux.Router) {
|
||||
// Check if UI is enabled
|
||||
uiEnabled := os.Getenv("ATCR_UI_ENABLED")
|
||||
if uiEnabled == "false" {
|
||||
@@ -582,12 +710,12 @@ func initializeUIRoutes(database *sql.DB, readOnlyDB *sql.DB, sessionStore *db.S
|
||||
}).Methods("DELETE")
|
||||
|
||||
// Logout endpoint (supports both GET and POST)
|
||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
if sessionID, ok := db.GetSessionID(r); ok {
|
||||
sessionStore.Delete(sessionID)
|
||||
}
|
||||
db.ClearCookie(w)
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
// Properly revokes OAuth tokens on PDS side before clearing local session
|
||||
router.Handle("/auth/logout", &uihandlers.LogoutHandler{
|
||||
OAuthApp: oauthApp,
|
||||
Refresher: refresher,
|
||||
SessionStore: sessionStore,
|
||||
OAuthStore: oauthStore,
|
||||
}).Methods("GET", "POST")
|
||||
|
||||
// Start Jetstream worker
|
||||
|
||||
@@ -694,7 +694,7 @@ func TestHandleNotifyManifest(t *testing.T) {
|
||||
```bash
|
||||
# Enable/disable Bluesky manifest posting (default: false)
|
||||
# When enabled, hold will create Bluesky posts when users push images
|
||||
# Can be overridden per-hold via captain record's enableManifestPosts field
|
||||
# Synced to captain record's enableBlueskyPosts field on startup
|
||||
HOLD_BLUESKY_POSTS_ENABLED=false
|
||||
```
|
||||
|
||||
@@ -702,20 +702,21 @@ HOLD_BLUESKY_POSTS_ENABLED=false
|
||||
|
||||
### Feature Flags
|
||||
|
||||
**Captain Record Override:**
|
||||
The hold's captain record includes an `enableManifestPosts` field that overrides the environment variable:
|
||||
**Captain Record Sync:**
|
||||
The hold's captain record includes an `enableBlueskyPosts` field that is synchronized with the environment variable on startup:
|
||||
|
||||
```go
|
||||
type CaptainRecord struct {
|
||||
// ... other fields ...
|
||||
EnableManifestPosts bool `json:"enableManifestPosts" cborgen:"enableManifestPosts"`
|
||||
EnableBlueskyPosts bool `json:"enableBlueskyPosts" cborgen:"enableBlueskyPosts"`
|
||||
}
|
||||
```
|
||||
|
||||
**Precedence (highest to lowest):**
|
||||
1. Captain record `enableManifestPosts` field (if set)
|
||||
2. `HOLD_BLUESKY_POSTS_ENABLED` environment variable
|
||||
3. Default: `false` (opt-in feature)
|
||||
**How it works:**
|
||||
1. On startup, Bootstrap reads `HOLD_BLUESKY_POSTS_ENABLED` environment variable
|
||||
2. Creates or updates the captain record to match the env var setting
|
||||
3. At runtime, the code reads from the captain record (which reflects the env var)
|
||||
4. To change the setting, update the env var and restart the hold
|
||||
|
||||
**Rationale:**
|
||||
- Default off for backward compatibility and privacy
|
||||
|
||||
@@ -38,15 +38,17 @@ func LoadConfigFromEnv() (*configuration.Configuration, error) {
|
||||
// Storage (fake in-memory placeholder - all real storage is proxied)
|
||||
config.Storage = buildStorageConfig()
|
||||
|
||||
// Get base URL for error messages and auth config
|
||||
baseURL := GetBaseURL(httpConfig.Addr)
|
||||
|
||||
// Middleware (ATProto resolver)
|
||||
defaultHoldDID := os.Getenv("ATCR_DEFAULT_HOLD_DID")
|
||||
if defaultHoldDID == "" {
|
||||
return nil, fmt.Errorf("ATCR_DEFAULT_HOLD_DID is required")
|
||||
}
|
||||
config.Middleware = buildMiddlewareConfig(defaultHoldDID)
|
||||
config.Middleware = buildMiddlewareConfig(defaultHoldDID, baseURL)
|
||||
|
||||
// Auth
|
||||
baseURL := GetBaseURL(httpConfig.Addr)
|
||||
authConfig, err := buildAuthConfig(baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build auth config: %w", err)
|
||||
@@ -128,7 +130,7 @@ func buildStorageConfig() configuration.Storage {
|
||||
}
|
||||
|
||||
// buildMiddlewareConfig creates middleware configuration
|
||||
func buildMiddlewareConfig(defaultHoldDID string) map[string][]configuration.Middleware {
|
||||
func buildMiddlewareConfig(defaultHoldDID string, baseURL string) map[string][]configuration.Middleware {
|
||||
// Check test mode
|
||||
testMode := os.Getenv("TEST_MODE") == "true"
|
||||
|
||||
@@ -139,6 +141,7 @@ func buildMiddlewareConfig(defaultHoldDID string) map[string][]configuration.Mid
|
||||
Options: configuration.Parameters{
|
||||
"default_hold_did": defaultHoldDID,
|
||||
"test_mode": testMode,
|
||||
"base_url": baseURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -368,6 +368,7 @@ func TestBuildMiddlewareConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
defaultHoldDID string
|
||||
baseURL string
|
||||
testMode bool
|
||||
setTestMode bool
|
||||
wantTestMode bool
|
||||
@@ -375,12 +376,14 @@ func TestBuildMiddlewareConfig(t *testing.T) {
|
||||
{
|
||||
name: "normal mode",
|
||||
defaultHoldDID: "did:web:hold01.atcr.io",
|
||||
baseURL: "https://atcr.io",
|
||||
setTestMode: false,
|
||||
wantTestMode: false,
|
||||
},
|
||||
{
|
||||
name: "test mode enabled",
|
||||
defaultHoldDID: "did:web:hold01.atcr.io",
|
||||
baseURL: "https://atcr.io",
|
||||
testMode: true,
|
||||
setTestMode: true,
|
||||
wantTestMode: true,
|
||||
@@ -395,7 +398,7 @@ func TestBuildMiddlewareConfig(t *testing.T) {
|
||||
os.Unsetenv("TEST_MODE")
|
||||
}
|
||||
|
||||
got := buildMiddlewareConfig(tt.defaultHoldDID)
|
||||
got := buildMiddlewareConfig(tt.defaultHoldDID, tt.baseURL)
|
||||
|
||||
registryMW, ok := got["registry"]
|
||||
if !ok {
|
||||
@@ -415,6 +418,10 @@ func TestBuildMiddlewareConfig(t *testing.T) {
|
||||
t.Errorf("default_hold_did = %v, want %v", mw.Options["default_hold_did"], tt.defaultHoldDID)
|
||||
}
|
||||
|
||||
if mw.Options["base_url"] != tt.baseURL {
|
||||
t.Errorf("base_url = %v, want %v", mw.Options["base_url"], tt.baseURL)
|
||||
}
|
||||
|
||||
if mw.Options["test_mode"] != tt.wantTestMode {
|
||||
t.Errorf("test_mode = %v, want %v", mw.Options["test_mode"], tt.wantTestMode)
|
||||
}
|
||||
|
||||
67
pkg/appview/handlers/logout.go
Normal file
67
pkg/appview/handlers/logout.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
)
|
||||
|
||||
// LogoutHandler handles user logout with proper OAuth token revocation
|
||||
type LogoutHandler struct {
|
||||
OAuthApp *oauth.App
|
||||
Refresher *oauth.Refresher
|
||||
SessionStore *db.SessionStore
|
||||
OAuthStore *db.OAuthStore
|
||||
}
|
||||
|
||||
func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Get UI session ID from cookie
|
||||
uiSessionID, hasSession := db.GetSessionID(r)
|
||||
if !hasSession {
|
||||
// No session to logout from, just redirect
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Get UI session to extract OAuth session ID and user info
|
||||
uiSession, ok := h.SessionStore.Get(uiSessionID)
|
||||
if ok && uiSession != nil && uiSession.DID != "" {
|
||||
// Parse DID for OAuth logout
|
||||
did, err := syntax.ParseDID(uiSession.DID)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [logout]: Failed to parse DID %s: %v\n", uiSession.DID, err)
|
||||
} else {
|
||||
// Attempt to revoke OAuth tokens on PDS side
|
||||
if uiSession.OAuthSessionID != "" {
|
||||
// Call indigo's Logout to revoke tokens on PDS
|
||||
if err := h.OAuthApp.GetClientApp().Logout(r.Context(), did, uiSession.OAuthSessionID); err != nil {
|
||||
// Log error but don't block logout - best effort revocation
|
||||
fmt.Printf("WARNING [logout]: Failed to revoke OAuth tokens for %s on PDS: %v\n", uiSession.DID, err)
|
||||
} else {
|
||||
fmt.Printf("INFO [logout]: Successfully revoked OAuth tokens for %s on PDS\n", uiSession.DID)
|
||||
}
|
||||
|
||||
// Invalidate refresher cache to clear local access tokens
|
||||
h.Refresher.InvalidateSession(uiSession.DID)
|
||||
fmt.Printf("INFO [logout]: Invalidated local OAuth cache for %s\n", uiSession.DID)
|
||||
|
||||
// Delete OAuth session from database (cleanup, might already be done by Logout)
|
||||
if err := h.OAuthStore.DeleteSession(r.Context(), did, uiSession.OAuthSessionID); err != nil {
|
||||
fmt.Printf("WARNING [logout]: Failed to delete OAuth session from database: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("WARNING [logout]: No OAuth session ID found for user %s\n", uiSession.DID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Always delete UI session and clear cookie, even if OAuth revocation failed
|
||||
h.SessionStore.Delete(uiSessionID)
|
||||
db.ClearCookie(w)
|
||||
|
||||
// Redirect to home page
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/middleware"
|
||||
"atcr.io/pkg/appview/storage"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
@@ -41,7 +42,7 @@ func (h *SettingsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
client := atproto.NewClientWithIndigoClient(user.PDSEndpoint, user.DID, apiClient)
|
||||
|
||||
// Fetch sailor profile
|
||||
profile, err := atproto.GetProfile(r.Context(), client)
|
||||
profile, err := storage.GetProfile(r.Context(), client)
|
||||
if err != nil {
|
||||
// Error fetching profile - log out user
|
||||
fmt.Printf("WARNING [settings]: Failed to fetch profile for %s: %v - logging out\n", user.DID, err)
|
||||
@@ -111,7 +112,7 @@ func (h *UpdateDefaultHoldHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
|
||||
client := atproto.NewClientWithIndigoClient(user.PDSEndpoint, user.DID, apiClient)
|
||||
|
||||
// Fetch existing profile or create new one
|
||||
profile, err := atproto.GetProfile(r.Context(), client)
|
||||
profile, err := storage.GetProfile(r.Context(), client)
|
||||
if err != nil || profile == nil {
|
||||
// Profile doesn't exist, create new one
|
||||
profile = atproto.NewSailorProfileRecord(holdEndpoint)
|
||||
@@ -122,7 +123,7 @@ func (h *UpdateDefaultHoldHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
// Save profile
|
||||
if err := atproto.UpdateProfile(r.Context(), client, profile); err != nil {
|
||||
if err := storage.UpdateProfile(r.Context(), client, profile); err != nil {
|
||||
http.Error(w, "Failed to update profile: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview"
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// HealthStatus represents the health status of a hold endpoint
|
||||
@@ -53,7 +53,7 @@ func (c *Checker) CheckHealth(ctx context.Context, endpoint string) (bool, error
|
||||
// Convert DID to HTTP URL if needed
|
||||
// did:web:hold.example.com → https://hold.example.com
|
||||
// https://hold.example.com → https://hold.example.com (passthrough)
|
||||
httpURL := appview.ResolveHoldURL(endpoint)
|
||||
httpURL := atproto.ResolveHoldURL(endpoint)
|
||||
|
||||
// Build health check URL
|
||||
healthURL := httpURL + "/xrpc/_health"
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
|
||||
"atcr.io/pkg/appview"
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
@@ -327,7 +326,7 @@ func (b *BackfillWorker) queryCaptainRecord(ctx context.Context, holdDID string)
|
||||
}
|
||||
|
||||
// Resolve hold DID to URL
|
||||
holdURL := appview.ResolveHoldURL(holdDID)
|
||||
holdURL := atproto.ResolveHoldURL(holdDID)
|
||||
|
||||
// Create client for hold's PDS
|
||||
holdClient := atproto.NewClient(holdURL, holdDID, "")
|
||||
|
||||
@@ -4,12 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/identity"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
@@ -73,6 +69,7 @@ type NamespaceResolver struct {
|
||||
distribution.Namespace
|
||||
directory identity.Directory
|
||||
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
||||
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
||||
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
||||
repositories sync.Map // Cache of RoutingRepository instances by key (did:reponame)
|
||||
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
||||
@@ -93,6 +90,12 @@ func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ drive
|
||||
defaultHoldDID = holdDID
|
||||
}
|
||||
|
||||
// Get base URL from config (for error messages)
|
||||
baseURL := ""
|
||||
if url, ok := options["base_url"].(string); ok {
|
||||
baseURL = url
|
||||
}
|
||||
|
||||
// Check test mode from options (passed via env var)
|
||||
testMode := false
|
||||
if tm, ok := options["test_mode"].(bool); ok {
|
||||
@@ -105,6 +108,7 @@ func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ drive
|
||||
Namespace: ns,
|
||||
directory: directory,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
baseURL: baseURL,
|
||||
testMode: testMode,
|
||||
refresher: globalRefresher,
|
||||
database: globalDatabase,
|
||||
@@ -113,6 +117,13 @@ func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ drive
|
||||
}, nil
|
||||
}
|
||||
|
||||
// authErrorMessage creates a user-friendly auth error with login URL
|
||||
func (nr *NamespaceResolver) authErrorMessage(message string) error {
|
||||
loginURL := fmt.Sprintf("%s/auth/oauth/login", nr.baseURL)
|
||||
fullMessage := fmt.Sprintf("%s - please re-authenticate at %s", message, loginURL)
|
||||
return errcode.ErrorCodeUnauthorized.WithMessage(fullMessage)
|
||||
}
|
||||
|
||||
// Repository resolves the repository name and delegates to underlying namespace
|
||||
// Handles names like:
|
||||
// - atcr.io/alice/myimage → resolve alice to DID
|
||||
@@ -160,99 +171,14 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
||||
ctx = context.WithValue(ctx, holdDIDKey, holdDID)
|
||||
|
||||
// Get service token for hold authentication
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
var serviceToken string
|
||||
if nr.refresher != nil {
|
||||
cachedToken, expiresAt := token.GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
fmt.Printf("DEBUG [registry/middleware]: Using cached service token for DID=%s (expires in %v)\n",
|
||||
did, time.Until(expiresAt).Round(time.Second))
|
||||
serviceToken = cachedToken
|
||||
} else {
|
||||
// Cache miss or expiring soon - validate OAuth and get new service token
|
||||
if cachedToken == "" {
|
||||
fmt.Printf("DEBUG [registry/middleware]: Cache miss, fetching service token for DID=%s\n", did)
|
||||
} else {
|
||||
fmt.Printf("DEBUG [registry/middleware]: Token expiring soon, proactively renewing for DID=%s\n", did)
|
||||
}
|
||||
|
||||
session, err := nr.refresher.GetSession(ctx, did)
|
||||
if err != nil {
|
||||
// OAuth session unavailable - fail fast with proper auth error
|
||||
nr.refresher.InvalidateSession(did)
|
||||
token.InvalidateServiceToken(did, holdDID)
|
||||
fmt.Printf("ERROR [registry/middleware]: Failed to get OAuth session for DID=%s: %v\n", did, err)
|
||||
fmt.Printf("ERROR [registry/middleware]: User needs to re-authenticate via credential helper\n")
|
||||
return nil, errcode.ErrorCodeUnauthorized.WithDetail("OAuth session expired - please re-authenticate")
|
||||
}
|
||||
|
||||
// Call com.atproto.server.getServiceAuth on the user's PDS
|
||||
// Request 5-minute expiry (PDS may grant less)
|
||||
// exp must be absolute Unix timestamp, not relative duration
|
||||
// Note: OAuth scope includes #atcr_hold fragment, but service auth aud must be bare DID
|
||||
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
|
||||
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
pdsEndpoint,
|
||||
atproto.ServerGetServiceAuth,
|
||||
url.QueryEscape(holdDID),
|
||||
url.QueryEscape("com.atproto.repo.getRecord"),
|
||||
expiryTime,
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("ERROR [registry/middleware]: Failed to create service auth request: %v\n", err)
|
||||
return nil, errcode.ErrorCodeUnauthorized.WithDetail("OAuth session validation failed")
|
||||
}
|
||||
|
||||
// Use OAuth session to authenticate to PDS (with DPoP)
|
||||
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
||||
if err != nil {
|
||||
// Invalidate session on auth errors (may indicate corrupted session or expired tokens)
|
||||
nr.refresher.InvalidateSession(did)
|
||||
token.InvalidateServiceToken(did, holdDID)
|
||||
fmt.Printf("ERROR [registry/middleware]: OAuth validation failed for DID=%s: %v\n", did, err)
|
||||
fmt.Printf("ERROR [registry/middleware]: User needs to re-authenticate via credential helper\n")
|
||||
return nil, errcode.ErrorCodeUnauthorized.WithDetail("OAuth session expired - please re-authenticate")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Invalidate session on auth failures
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
nr.refresher.InvalidateSession(did)
|
||||
token.InvalidateServiceToken(did, holdDID)
|
||||
fmt.Printf("ERROR [registry/middleware]: OAuth validation failed for DID=%s: status %d, body: %s\n",
|
||||
did, resp.StatusCode, string(bodyBytes))
|
||||
fmt.Printf("ERROR [registry/middleware]: User needs to re-authenticate via credential helper\n")
|
||||
return nil, errcode.ErrorCodeUnauthorized.WithDetail("OAuth session expired - please re-authenticate")
|
||||
}
|
||||
|
||||
// Parse response to get service token
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
fmt.Printf("ERROR [registry/middleware]: Failed to decode service auth response: %v\n", err)
|
||||
return nil, errcode.ErrorCodeUnauthorized.WithDetail("OAuth session validation failed")
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
fmt.Printf("ERROR [registry/middleware]: Empty token in service auth response\n")
|
||||
return nil, errcode.ErrorCodeUnauthorized.WithDetail("OAuth session validation failed")
|
||||
}
|
||||
|
||||
serviceToken = result.Token
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if err := token.SetServiceToken(did, holdDID, serviceToken); err != nil {
|
||||
fmt.Printf("WARN [registry/middleware]: Failed to cache service token: %v\n", err)
|
||||
// Non-fatal - we have the token, just won't be cached
|
||||
}
|
||||
|
||||
fmt.Printf("DEBUG [registry/middleware]: OAuth validation succeeded for DID=%s\n", did)
|
||||
var err error
|
||||
serviceToken, err = token.GetOrFetchServiceToken(ctx, nr.refresher, did, holdDID, pdsEndpoint)
|
||||
if err != nil {
|
||||
fmt.Printf("ERROR [registry/middleware]: Failed to get service token for DID=%s: %v\n", did, err)
|
||||
fmt.Printf("ERROR [registry/middleware]: User needs to re-authenticate via credential helper\n")
|
||||
return nil, nr.authErrorMessage("OAuth session expired")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,7 +292,7 @@ func (nr *NamespaceResolver) findHoldDID(ctx context.Context, did, pdsEndpoint s
|
||||
client := atproto.NewClient(pdsEndpoint, did, "")
|
||||
|
||||
// Check for sailor profile
|
||||
profile, err := atproto.GetProfile(ctx, client)
|
||||
profile, err := storage.GetProfile(ctx, client)
|
||||
if err != nil {
|
||||
// Error reading profile (not a 404) - log and continue
|
||||
fmt.Printf("WARNING: failed to read profile for %s: %v\n", did, err)
|
||||
|
||||
82
pkg/appview/storage/crew.go
Normal file
82
pkg/appview/storage/crew.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
"atcr.io/pkg/auth/token"
|
||||
)
|
||||
|
||||
// EnsureCrewMembership attempts to register the user as a crew member on their default hold.
|
||||
// The hold's requestCrew endpoint handles all authorization logic (checking allowAllCrew, existing membership, etc).
|
||||
// This is best-effort and does not fail on errors.
|
||||
func EnsureCrewMembership(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, defaultHoldDID string) {
|
||||
if defaultHoldDID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize URL to DID if needed
|
||||
holdDID := atproto.ResolveHoldDIDFromURL(defaultHoldDID)
|
||||
if holdDID == "" {
|
||||
slog.Warn("failed to resolve hold DID", "defaultHold", defaultHoldDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
holdEndpoint := atproto.ResolveHoldURL(holdDID)
|
||||
|
||||
// Get service token for the hold
|
||||
// Only works with OAuth (refresher required) - app passwords can't get service tokens
|
||||
if refresher == nil {
|
||||
slog.Debug("skipping crew registration - no OAuth refresher (app password flow)", "holdDID", holdDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the refresher to match OAuthSessionRefresher interface
|
||||
serviceToken, err := token.GetOrFetchServiceToken(ctx, refresher, client.DID(), holdDID, client.PDSEndpoint())
|
||||
if err != nil {
|
||||
slog.Warn("failed to get service token", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Call requestCrew endpoint - it handles all the logic:
|
||||
// - Checks allowAllCrew flag
|
||||
// - Checks if already a crew member (returns success if so)
|
||||
// - Creates crew record if authorized
|
||||
if err := requestCrewMembership(ctx, holdEndpoint, serviceToken); err != nil {
|
||||
slog.Warn("failed to request crew membership", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("successfully registered as crew member", "holdDID", holdDID, "userDID", client.DID())
|
||||
}
|
||||
|
||||
// requestCrewMembership calls the hold's requestCrew endpoint
|
||||
// The endpoint handles all authorization and duplicate checking internally
|
||||
func requestCrewMembership(ctx context.Context, holdEndpoint, serviceToken string) error {
|
||||
url := fmt.Sprintf("%s%s", holdEndpoint, atproto.HoldRequestCrew)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+serviceToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return fmt.Errorf("requestCrew failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package atproto
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// ProfileRKey is always "self" per lexicon
|
||||
@@ -21,9 +23,9 @@ var migrationLocks sync.Map
|
||||
// If defaultHoldDID is provided, creates profile with that default (or empty if not provided)
|
||||
// Expected format: "did:web:hold01.atcr.io"
|
||||
// Normalizes URLs to DIDs for consistency (for backward compatibility)
|
||||
func EnsureProfile(ctx context.Context, client *Client, defaultHoldDID string) error {
|
||||
func EnsureProfile(ctx context.Context, client *atproto.Client, defaultHoldDID string) error {
|
||||
// Check if profile already exists
|
||||
profile, err := client.GetRecord(ctx, SailorProfileCollection, ProfileRKey)
|
||||
profile, err := client.GetRecord(ctx, atproto.SailorProfileCollection, ProfileRKey)
|
||||
if err == nil && profile != nil {
|
||||
// Profile exists, nothing to do
|
||||
return nil
|
||||
@@ -33,13 +35,13 @@ func EnsureProfile(ctx context.Context, client *Client, defaultHoldDID string) e
|
||||
// This ensures we store DIDs consistently in new profiles
|
||||
normalizedDID := ""
|
||||
if defaultHoldDID != "" {
|
||||
normalizedDID = ResolveHoldDIDFromURL(defaultHoldDID)
|
||||
normalizedDID = atproto.ResolveHoldDIDFromURL(defaultHoldDID)
|
||||
}
|
||||
|
||||
// Profile doesn't exist - create it
|
||||
newProfile := NewSailorProfileRecord(normalizedDID)
|
||||
newProfile := atproto.NewSailorProfileRecord(normalizedDID)
|
||||
|
||||
_, err = client.PutRecord(ctx, SailorProfileCollection, ProfileRKey, newProfile)
|
||||
_, err = client.PutRecord(ctx, atproto.SailorProfileCollection, ProfileRKey, newProfile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create sailor profile: %w", err)
|
||||
}
|
||||
@@ -51,32 +53,32 @@ func EnsureProfile(ctx context.Context, client *Client, defaultHoldDID string) e
|
||||
// GetProfile retrieves the user's profile from their PDS
|
||||
// Returns nil if profile doesn't exist
|
||||
// Automatically migrates old URL-based defaultHold values to DIDs
|
||||
func GetProfile(ctx context.Context, client *Client) (*SailorProfileRecord, error) {
|
||||
record, err := client.GetRecord(ctx, SailorProfileCollection, ProfileRKey)
|
||||
func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorProfileRecord, error) {
|
||||
record, err := client.GetRecord(ctx, atproto.SailorProfileCollection, ProfileRKey)
|
||||
if err != nil {
|
||||
// Check if it's a 404 (profile doesn't exist)
|
||||
if errors.Is(err, ErrRecordNotFound) {
|
||||
if errors.Is(err, atproto.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get profile: %w", err)
|
||||
}
|
||||
|
||||
// Parse the profile record
|
||||
var profile SailorProfileRecord
|
||||
var profile atproto.SailorProfileRecord
|
||||
if err := json.Unmarshal(record.Value, &profile); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse profile: %w", err)
|
||||
}
|
||||
|
||||
// Migrate old URL-based defaultHold to DID format
|
||||
// This ensures backward compatibility with profiles created before DID migration
|
||||
if profile.DefaultHold != "" && !isDID(profile.DefaultHold) {
|
||||
if profile.DefaultHold != "" && !atproto.IsDID(profile.DefaultHold) {
|
||||
// Convert URL to DID transparently
|
||||
migratedDID := ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
migratedDID := atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
profile.DefaultHold = migratedDID
|
||||
|
||||
// Persist the migration to PDS in a background goroutine
|
||||
// Use a lock to ensure only one goroutine migrates this DID
|
||||
did := client.did
|
||||
did := client.DID()
|
||||
if _, loaded := migrationLocks.LoadOrStore(did, true); !loaded {
|
||||
// We got the lock - launch goroutine to persist the migration
|
||||
go func() {
|
||||
@@ -106,15 +108,15 @@ func GetProfile(ctx context.Context, client *Client) (*SailorProfileRecord, erro
|
||||
|
||||
// UpdateProfile updates the user's profile
|
||||
// Normalizes defaultHold to DID format before saving
|
||||
func UpdateProfile(ctx context.Context, client *Client, profile *SailorProfileRecord) error {
|
||||
func UpdateProfile(ctx context.Context, client *atproto.Client, profile *atproto.SailorProfileRecord) error {
|
||||
// Normalize defaultHold to DID if it's a URL
|
||||
// This ensures we always store DIDs, even if user provides a URL
|
||||
if profile.DefaultHold != "" && !isDID(profile.DefaultHold) {
|
||||
profile.DefaultHold = ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
if profile.DefaultHold != "" && !atproto.IsDID(profile.DefaultHold) {
|
||||
profile.DefaultHold = atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
fmt.Printf("DEBUG [profile]: Normalized defaultHold to DID: %s\n", profile.DefaultHold)
|
||||
}
|
||||
|
||||
_, err := client.PutRecord(ctx, SailorProfileCollection, ProfileRKey, profile)
|
||||
_, err := client.PutRecord(ctx, atproto.SailorProfileCollection, ProfileRKey, profile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update profile: %w", err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package atproto
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// TestEnsureProfile_Create tests creating a new profile when one doesn't exist
|
||||
@@ -37,7 +39,7 @@ func TestEnsureProfile_Create(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var createdProfile *SailorProfileRecord
|
||||
var createdProfile *atproto.SailorProfileRecord
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// First request: GetRecord (should 404)
|
||||
@@ -53,8 +55,8 @@ func TestEnsureProfile_Create(t *testing.T) {
|
||||
|
||||
// Verify profile data
|
||||
recordData := body["record"].(map[string]any)
|
||||
if recordData["$type"] != SailorProfileCollection {
|
||||
t.Errorf("$type = %v, want %v", recordData["$type"], SailorProfileCollection)
|
||||
if recordData["$type"] != atproto.SailorProfileCollection {
|
||||
t.Errorf("$type = %v, want %v", recordData["$type"], atproto.SailorProfileCollection)
|
||||
}
|
||||
|
||||
// Check defaultHold normalization
|
||||
@@ -81,7 +83,7 @@ func TestEnsureProfile_Create(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
err := EnsureProfile(context.Background(), client, tt.defaultHoldDID)
|
||||
|
||||
if err != nil {
|
||||
@@ -93,8 +95,8 @@ func TestEnsureProfile_Create(t *testing.T) {
|
||||
t.Fatal("Profile was not created")
|
||||
}
|
||||
|
||||
if createdProfile.Type != SailorProfileCollection {
|
||||
t.Errorf("Type = %v, want %v", createdProfile.Type, SailorProfileCollection)
|
||||
if createdProfile.Type != atproto.SailorProfileCollection {
|
||||
t.Errorf("Type = %v, want %v", createdProfile.Type, atproto.SailorProfileCollection)
|
||||
}
|
||||
|
||||
if createdProfile.DefaultHold != tt.wantNormalized {
|
||||
@@ -134,7 +136,7 @@ func TestEnsureProfile_Exists(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
err := EnsureProfile(context.Background(), client, "did:web:hold01.atcr.io")
|
||||
|
||||
if err != nil {
|
||||
@@ -152,7 +154,7 @@ func TestGetProfile(t *testing.T) {
|
||||
name string
|
||||
serverResponse string
|
||||
serverStatus int
|
||||
wantProfile *SailorProfileRecord
|
||||
wantProfile *atproto.SailorProfileRecord
|
||||
wantNil bool
|
||||
wantErr bool
|
||||
expectMigration bool // Whether URL-to-DID migration should happen
|
||||
@@ -239,7 +241,7 @@ func TestGetProfile(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
profile, err := GetProfile(context.Background(), client)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -326,7 +328,7 @@ func TestGetProfile_MigrationLocking(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
|
||||
// Make 5 concurrent GetProfile calls
|
||||
var wg sync.WaitGroup
|
||||
@@ -360,14 +362,14 @@ func TestGetProfile_MigrationLocking(t *testing.T) {
|
||||
func TestUpdateProfile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profile *SailorProfileRecord
|
||||
profile *atproto.SailorProfileRecord
|
||||
wantNormalized string // Expected defaultHold after normalization
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "update with DID",
|
||||
profile: &SailorProfileRecord{
|
||||
Type: SailorProfileCollection,
|
||||
profile: &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "did:web:hold02.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
@@ -377,8 +379,8 @@ func TestUpdateProfile(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "update with URL - should normalize",
|
||||
profile: &SailorProfileRecord{
|
||||
Type: SailorProfileCollection,
|
||||
profile: &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "https://hold02.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
@@ -388,8 +390,8 @@ func TestUpdateProfile(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "clear default hold",
|
||||
profile: &SailorProfileRecord{
|
||||
Type: SailorProfileCollection,
|
||||
profile: &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
@@ -422,7 +424,7 @@ func TestUpdateProfile(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
err := UpdateProfile(context.Background(), client, tt.profile)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -477,7 +479,7 @@ func TestEnsureProfile_Error(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
err := EnsureProfile(context.Background(), client, "did:web:hold01.atcr.io")
|
||||
|
||||
if err == nil {
|
||||
@@ -497,7 +499,7 @@ func TestGetProfile_InvalidJSON(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
_, err := GetProfile(context.Background(), client)
|
||||
|
||||
if err == nil {
|
||||
@@ -522,7 +524,7 @@ func TestGetProfile_EmptyDefaultHold(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
profile, err := GetProfile(context.Background(), client)
|
||||
|
||||
if err != nil {
|
||||
@@ -542,9 +544,9 @@ func TestUpdateProfile_ServerError(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
profile := &SailorProfileRecord{
|
||||
Type: SailorProfileCollection,
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
profile := &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "did:web:hold01.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview"
|
||||
"atcr.io/pkg/atproto"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/distribution/distribution/v3/registry/api/errcode"
|
||||
@@ -40,7 +39,7 @@ type ProxyBlobStore struct {
|
||||
// NewProxyBlobStore creates a new proxy blob store
|
||||
func NewProxyBlobStore(ctx *RegistryContext) *ProxyBlobStore {
|
||||
// Resolve DID to URL once at construction time
|
||||
holdURL := appview.ResolveHoldURL(ctx.HoldDID)
|
||||
holdURL := atproto.ResolveHoldURL(ctx.HoldDID)
|
||||
|
||||
fmt.Printf("DEBUG [proxy_blob_store]: NewProxyBlobStore created with holdDID=%s, holdURL=%s, userDID=%s, repo=%s\n",
|
||||
ctx.HoldDID, holdURL, ctx.DID, ctx.Repository)
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/token"
|
||||
"github.com/opencontainers/go-digest"
|
||||
@@ -219,7 +218,7 @@ func TestResolveHoldURL(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := appview.ResolveHoldURL(tt.holdDID)
|
||||
result := atproto.ResolveHoldURL(tt.holdDID)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package appview
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
func TestResolveHoldURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -52,7 +56,7 @@ func TestResolveHoldURL(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ResolveHoldURL(tt.input)
|
||||
result := atproto.ResolveHoldURL(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ResolveHoldURL(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
|
||||
@@ -467,19 +467,19 @@ func (t *CaptainRecord) MarshalCBOR(w io.Writer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// t.EnableManifestPosts (bool) (bool)
|
||||
if len("enableManifestPosts") > 8192 {
|
||||
return xerrors.Errorf("Value in field \"enableManifestPosts\" was too long")
|
||||
// t.EnableBlueskyPosts (bool) (bool)
|
||||
if len("enableBlueskyPosts") > 8192 {
|
||||
return xerrors.Errorf("Value in field \"enableBlueskyPosts\" was too long")
|
||||
}
|
||||
|
||||
if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("enableManifestPosts"))); err != nil {
|
||||
if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("enableBlueskyPosts"))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := cw.WriteString(string("enableManifestPosts")); err != nil {
|
||||
if _, err := cw.WriteString(string("enableBlueskyPosts")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cbg.WriteBool(w, t.EnableManifestPosts); err != nil {
|
||||
if err := cbg.WriteBool(w, t.EnableBlueskyPosts); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -617,8 +617,8 @@ func (t *CaptainRecord) UnmarshalCBOR(r io.Reader) (err error) {
|
||||
default:
|
||||
return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra)
|
||||
}
|
||||
// t.EnableManifestPosts (bool) (bool)
|
||||
case "enableManifestPosts":
|
||||
// t.EnableBlueskyPosts (bool) (bool)
|
||||
case "enableBlueskyPosts":
|
||||
|
||||
maj, extra, err = cr.ReadHeader()
|
||||
if err != nil {
|
||||
@@ -629,9 +629,9 @@ func (t *CaptainRecord) UnmarshalCBOR(r io.Reader) (err error) {
|
||||
}
|
||||
switch extra {
|
||||
case 20:
|
||||
t.EnableManifestPosts = false
|
||||
t.EnableBlueskyPosts = false
|
||||
case 21:
|
||||
t.EnableManifestPosts = true
|
||||
t.EnableBlueskyPosts = true
|
||||
default:
|
||||
return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra)
|
||||
}
|
||||
|
||||
@@ -666,3 +666,7 @@ func (c *Client) FetchDIDDocument(ctx context.Context, didDocURL string) (*DIDDo
|
||||
func (c *Client) DID() string {
|
||||
return c.did
|
||||
}
|
||||
|
||||
func (c *Client) PDSEndpoint() string {
|
||||
return c.pdsEndpoint
|
||||
}
|
||||
|
||||
@@ -434,7 +434,7 @@ func ResolveHoldDIDFromURL(holdURL string) string {
|
||||
}
|
||||
|
||||
// isDID checks if a string is a DID (starts with "did:")
|
||||
func isDID(s string) bool {
|
||||
func IsDID(s string) bool {
|
||||
return len(s) > 4 && s[:4] == "did:"
|
||||
}
|
||||
|
||||
@@ -536,14 +536,14 @@ func (t *TagRecord) GetManifestDigest() (string, error) {
|
||||
// Stored in the hold's embedded PDS to identify the hold owner and settings
|
||||
// Uses CBOR encoding for efficient storage in hold's carstore
|
||||
type CaptainRecord struct {
|
||||
Type string `json:"$type" cborgen:"$type"`
|
||||
Owner string `json:"owner" cborgen:"owner"` // DID of hold owner
|
||||
Public bool `json:"public" cborgen:"public"` // Public read access
|
||||
AllowAllCrew bool `json:"allowAllCrew" cborgen:"allowAllCrew"` // Allow any authenticated user to register as crew
|
||||
EnableManifestPosts bool `json:"enableManifestPosts" cborgen:"enableManifestPosts"` // Enable Bluesky posts when manifests are pushed (overrides env var)
|
||||
DeployedAt string `json:"deployedAt" cborgen:"deployedAt"` // RFC3339 timestamp
|
||||
Region string `json:"region,omitempty" cborgen:"region,omitempty"` // S3 region (optional)
|
||||
Provider string `json:"provider,omitempty" cborgen:"provider,omitempty"` // Deployment provider (optional)
|
||||
Type string `json:"$type" cborgen:"$type"`
|
||||
Owner string `json:"owner" cborgen:"owner"` // DID of hold owner
|
||||
Public bool `json:"public" cborgen:"public"` // Public read access
|
||||
AllowAllCrew bool `json:"allowAllCrew" cborgen:"allowAllCrew"` // Allow any authenticated user to register as crew
|
||||
EnableBlueskyPosts bool `json:"enableBlueskyPosts" cborgen:"enableBlueskyPosts"` // Enable Bluesky posts when manifests are pushed (overrides env var)
|
||||
DeployedAt string `json:"deployedAt" cborgen:"deployedAt"` // RFC3339 timestamp
|
||||
Region string `json:"region,omitempty" cborgen:"region,omitempty"` // S3 region (optional)
|
||||
Provider string `json:"provider,omitempty" cborgen:"provider,omitempty"` // Deployment provider (optional)
|
||||
}
|
||||
|
||||
// CrewRecord represents a crew member in the hold
|
||||
|
||||
@@ -824,9 +824,9 @@ func TestIsDID(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isDID(tt.s)
|
||||
got := IsDID(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("isDID() = %v, want %v", got, tt.want)
|
||||
t.Errorf("IsDID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package appview
|
||||
package atproto
|
||||
|
||||
import "strings"
|
||||
|
||||
@@ -14,8 +14,8 @@ func ResolveHoldURL(holdIdentifier string) string {
|
||||
}
|
||||
|
||||
// If it's a DID, convert to URL
|
||||
if strings.HasPrefix(holdIdentifier, "did:web:") {
|
||||
hostname := strings.TrimPrefix(holdIdentifier, "did:web:")
|
||||
if after, ok := strings.CutPrefix(holdIdentifier, "did:web:"); ok {
|
||||
hostname := after
|
||||
|
||||
// Use HTTP for localhost/IP addresses with ports, HTTPS for domains
|
||||
if strings.Contains(hostname, ":") ||
|
||||
@@ -2,17 +2,11 @@ package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/atproto"
|
||||
indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
)
|
||||
|
||||
// UISessionStore is the interface for UI session management
|
||||
@@ -23,13 +17,18 @@ type UserStore interface {
|
||||
UpsertUser(did, handle, pdsEndpoint, avatar string) error
|
||||
}
|
||||
|
||||
// PostAuthCallback is called after successful OAuth authentication.
|
||||
// Parameters: ctx, did, handle, pdsEndpoint, sessionID
|
||||
// This allows AppView to perform business logic (profile creation, avatar fetch, etc.)
|
||||
// without coupling the OAuth package to AppView-specific dependencies.
|
||||
type PostAuthCallback func(ctx context.Context, did, handle, pdsEndpoint, sessionID string) error
|
||||
|
||||
// Server handles OAuth authorization for the AppView
|
||||
type Server struct {
|
||||
app *App
|
||||
refresher *Refresher
|
||||
uiSessionStore UISessionStore
|
||||
db *sql.DB
|
||||
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
||||
app *App
|
||||
refresher *Refresher
|
||||
uiSessionStore UISessionStore
|
||||
postAuthCallback PostAuthCallback
|
||||
}
|
||||
|
||||
// NewServer creates a new OAuth server
|
||||
@@ -39,13 +38,6 @@ func NewServer(app *App) *Server {
|
||||
}
|
||||
}
|
||||
|
||||
// SetDefaultHoldDID sets the default hold DID for profile creation
|
||||
// Expected format: "did:web:hold01.atcr.io"
|
||||
// To find a hold's DID, visit: https://hold-url/.well-known/did.json
|
||||
func (s *Server) SetDefaultHoldDID(did string) {
|
||||
s.defaultHoldDID = did
|
||||
}
|
||||
|
||||
// SetRefresher sets the refresher for invalidating session cache
|
||||
func (s *Server) SetRefresher(refresher *Refresher) {
|
||||
s.refresher = refresher
|
||||
@@ -56,9 +48,10 @@ func (s *Server) SetUISessionStore(store UISessionStore) {
|
||||
s.uiSessionStore = store
|
||||
}
|
||||
|
||||
// SetDatabase sets the database for user management
|
||||
func (s *Server) SetDatabase(db *sql.DB) {
|
||||
s.db = db
|
||||
// SetPostAuthCallback sets the callback to be invoked after successful OAuth authentication
|
||||
// This allows AppView to inject business logic without coupling the OAuth package
|
||||
func (s *Server) SetPostAuthCallback(callback PostAuthCallback) {
|
||||
s.postAuthCallback = callback
|
||||
}
|
||||
|
||||
// ServeAuthorize handles GET /auth/oauth/authorize
|
||||
@@ -140,9 +133,12 @@ func (s *Server) ServeCallback(w http.ResponseWriter, r *http.Request) {
|
||||
handle = did // Fallback to DID if resolution fails
|
||||
}
|
||||
|
||||
// Fetch user's Bluesky profile (including avatar) and store in database
|
||||
if s.db != nil {
|
||||
s.fetchAndStoreAvatar(r.Context(), did, sessionID, handle, sessionData.HostURL)
|
||||
// Call post-auth callback for AppView business logic (profile, avatar, etc.)
|
||||
if s.postAuthCallback != nil {
|
||||
if err := s.postAuthCallback(r.Context(), did, handle, sessionData.HostURL, sessionID); err != nil {
|
||||
// Log error but don't fail OAuth flow - business logic is non-critical
|
||||
fmt.Printf("WARNING [oauth/server]: Post-auth callback failed for DID=%s: %v\n", did, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a UI login (has oauth_return_to cookie)
|
||||
@@ -241,127 +237,6 @@ func (s *Server) renderError(w http.ResponseWriter, message string) {
|
||||
}
|
||||
}
|
||||
|
||||
// fetchAndStoreAvatar fetches the user's Bluesky profile and stores avatar in database
|
||||
func (s *Server) fetchAndStoreAvatar(ctx context.Context, did, sessionID, handle, pdsEndpoint string) {
|
||||
fmt.Printf("DEBUG [oauth/server]: Fetching avatar for DID=%s from PDS=%s\n", did, pdsEndpoint)
|
||||
|
||||
// Parse DID for session resume
|
||||
didParsed, err := syntax.ParseDID(did)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [oauth/server]: Failed to parse DID %s: %v\n", did, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Resume OAuth session to get authenticated client
|
||||
session, err := s.app.ResumeSession(ctx, didParsed, sessionID)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [oauth/server]: Failed to resume session for DID=%s: %v\n", did, err)
|
||||
// Fallback: update user without avatar
|
||||
_ = db.UpsertUser(s.db, &db.User{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Avatar: "",
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create authenticated atproto client using the indigo session's API client
|
||||
client := atproto.NewClientWithIndigoClient(pdsEndpoint, did, session.APIClient())
|
||||
|
||||
// Ensure sailor profile exists (creates with default hold if configured, or empty profile if not)
|
||||
fmt.Printf("DEBUG [oauth/server]: Ensuring profile exists for %s (defaultHold=%s)\n", did, s.defaultHoldDID)
|
||||
if err := atproto.EnsureProfile(ctx, client, s.defaultHoldDID); err != nil {
|
||||
fmt.Printf("WARNING [oauth/server]: Failed to ensure profile for %s: %v\n", did, err)
|
||||
// Continue anyway - profile creation is not critical for avatar fetch
|
||||
} else {
|
||||
fmt.Printf("DEBUG [oauth/server]: Profile ensured for %s\n", did)
|
||||
}
|
||||
|
||||
// Fetch user's profile record from PDS (contains blob references)
|
||||
profileRecord, err := client.GetProfileRecord(ctx, did)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [oauth/server]: Failed to fetch profile record for DID=%s: %v\n", did, err)
|
||||
// Still update user without avatar
|
||||
_ = db.UpsertUser(s.db, &db.User{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Avatar: "",
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Construct avatar URL from blob CID using imgs.blue CDN
|
||||
var avatarURL string
|
||||
if profileRecord.Avatar != nil && profileRecord.Avatar.Ref.Link != "" {
|
||||
avatarURL = atproto.BlobCDNURL(did, profileRecord.Avatar.Ref.Link)
|
||||
fmt.Printf("DEBUG [oauth/server]: Constructed avatar URL: %s\n", avatarURL)
|
||||
}
|
||||
|
||||
// Store user with avatar in database
|
||||
err = db.UpsertUser(s.db, &db.User{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Avatar: avatarURL,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [oauth/server]: Failed to store user in database: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("DEBUG [oauth/server]: Stored user with avatar for DID=%s\n", did)
|
||||
|
||||
// Handle profile migration and crew registration
|
||||
s.migrateProfileAndRegisterCrew(ctx, client, did, session)
|
||||
}
|
||||
|
||||
// migrateProfileAndRegisterCrew handles URL→DID migration and crew registration
|
||||
func (s *Server) migrateProfileAndRegisterCrew(ctx context.Context, client *atproto.Client, did string, session *indigooauth.ClientSession) {
|
||||
// Get user's sailor profile
|
||||
profile, err := atproto.GetProfile(ctx, client)
|
||||
if err != nil {
|
||||
fmt.Printf("WARNING [oauth/server]: Failed to get profile for %s: %v\n", did, err)
|
||||
return
|
||||
}
|
||||
|
||||
if profile == nil || profile.DefaultHold == "" {
|
||||
// No profile or no default hold configured
|
||||
return
|
||||
}
|
||||
|
||||
// Check if defaultHold is a URL (needs migration)
|
||||
var holdDID string
|
||||
if strings.HasPrefix(profile.DefaultHold, "http://") || strings.HasPrefix(profile.DefaultHold, "https://") {
|
||||
fmt.Printf("DEBUG [oauth/server]: Migrating hold URL to DID for %s: %s\n", did, profile.DefaultHold)
|
||||
|
||||
// Resolve URL to DID
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
|
||||
// Update profile with DID
|
||||
profile.DefaultHold = holdDID
|
||||
if err := atproto.UpdateProfile(ctx, client, profile); err != nil {
|
||||
fmt.Printf("WARNING [oauth/server]: Failed to update profile with hold DID for %s: %v\n", did, err)
|
||||
// Continue anyway - crew registration might still work
|
||||
} else {
|
||||
fmt.Printf("DEBUG [oauth/server]: Updated profile with hold DID: %s\n", holdDID)
|
||||
}
|
||||
} else {
|
||||
// Already a DID
|
||||
holdDID = profile.DefaultHold
|
||||
}
|
||||
|
||||
// TODO: Request crew membership at the hold
|
||||
// This requires understanding how to make authenticated HTTP requests with indigo's ClientSession
|
||||
// For now, crew registration will happen on first push when appview validates access
|
||||
fmt.Printf("DEBUG [oauth/server]: Skipping crew registration for now - will happen on first push. Hold DID: %s\n", holdDID)
|
||||
_ = session // TODO: use session for crew registration
|
||||
}
|
||||
|
||||
// HTML templates
|
||||
|
||||
const redirectToSettingsTemplate = `
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -11,30 +12,38 @@ import (
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
mainAtproto "atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
)
|
||||
|
||||
// PostAuthCallback is called after successful Basic Auth authentication.
|
||||
// Parameters: ctx, did, handle, pdsEndpoint, accessToken
|
||||
// This allows AppView to perform business logic (profile creation, etc.)
|
||||
// without coupling the token package to AppView-specific dependencies.
|
||||
type PostAuthCallback func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error
|
||||
|
||||
// Handler handles /auth/token requests
|
||||
type Handler struct {
|
||||
issuer *Issuer
|
||||
validator *auth.SessionValidator
|
||||
deviceStore *db.DeviceStore // For validating device secrets
|
||||
defaultHoldDID string
|
||||
issuer *Issuer
|
||||
validator *auth.SessionValidator
|
||||
deviceStore *db.DeviceStore // For validating device secrets
|
||||
postAuthCallback PostAuthCallback
|
||||
}
|
||||
|
||||
// NewHandler creates a new token handler
|
||||
// defaultHoldDID should be in format "did:web:hold01.atcr.io"
|
||||
// To find a hold's DID, visit: https://hold-url/.well-known/did.json
|
||||
func NewHandler(issuer *Issuer, deviceStore *db.DeviceStore, defaultHoldDID string) *Handler {
|
||||
func NewHandler(issuer *Issuer, deviceStore *db.DeviceStore) *Handler {
|
||||
return &Handler{
|
||||
issuer: issuer,
|
||||
validator: auth.NewSessionValidator(),
|
||||
deviceStore: deviceStore,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
issuer: issuer,
|
||||
validator: auth.NewSessionValidator(),
|
||||
deviceStore: deviceStore,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPostAuthCallback sets the callback to be invoked after successful Basic Auth authentication
|
||||
// This allows AppView to inject business logic without coupling the token package
|
||||
func (h *Handler) SetPostAuthCallback(callback PostAuthCallback) {
|
||||
h.postAuthCallback = callback
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from /auth/token
|
||||
type TokenResponse struct {
|
||||
Token string `json:"token,omitempty"` // Legacy field
|
||||
@@ -142,25 +151,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
auth.GetGlobalTokenCache().Set(did, accessToken, 2*time.Hour)
|
||||
fmt.Printf("DEBUG [token/handler]: Cached access token for DID=%s\n", did)
|
||||
|
||||
// Ensure user profile exists (creates with default hold if needed)
|
||||
// Resolve PDS endpoint for profile management
|
||||
directory := identity.DefaultDirectory()
|
||||
atID, err := syntax.ParseAtIdentifier(username)
|
||||
if err == nil {
|
||||
ident, err := directory.Lookup(r.Context(), *atID)
|
||||
if err != nil {
|
||||
// Log error but don't fail auth - profile management is not critical
|
||||
fmt.Printf("WARNING: failed to resolve PDS for profile management: %v\n", err)
|
||||
} else {
|
||||
pdsEndpoint := ident.PDSEndpoint()
|
||||
if pdsEndpoint != "" {
|
||||
// Create ATProto client with validated token
|
||||
atprotoClient := mainAtproto.NewClient(pdsEndpoint, did, accessToken)
|
||||
|
||||
// Ensure profile exists (will create with default hold if not exists and default is configured)
|
||||
if err := mainAtproto.EnsureProfile(r.Context(), atprotoClient, h.defaultHoldDID); err != nil {
|
||||
// Log error but don't fail auth - profile management is not critical
|
||||
fmt.Printf("WARNING: failed to ensure profile for %s: %v\n", did, err)
|
||||
// Call post-auth callback for AppView business logic (profile management, etc.)
|
||||
if h.postAuthCallback != nil {
|
||||
// Resolve PDS endpoint for callback
|
||||
directory := identity.DefaultDirectory()
|
||||
atID, err := syntax.ParseAtIdentifier(username)
|
||||
if err == nil {
|
||||
ident, err := directory.Lookup(r.Context(), *atID)
|
||||
if err != nil {
|
||||
// Log error but don't fail auth - profile management is not critical
|
||||
fmt.Printf("WARNING: failed to resolve PDS for callback: %v\n", err)
|
||||
} else {
|
||||
pdsEndpoint := ident.PDSEndpoint()
|
||||
if pdsEndpoint != "" {
|
||||
if err := h.postAuthCallback(r.Context(), did, handle, pdsEndpoint, accessToken); err != nil {
|
||||
// Log error but don't fail auth - business logic is non-critical
|
||||
fmt.Printf("WARNING: post-auth callback failed for DID=%s: %v\n", did, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
111
pkg/auth/token/servicetoken.go
Normal file
111
pkg/auth/token/servicetoken.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
// GetOrFetchServiceToken gets a service token for hold authentication.
|
||||
// Checks cache first, then fetches from PDS with OAuth/DPoP if needed.
|
||||
// This is the canonical implementation used by both middleware and crew registration.
|
||||
func GetOrFetchServiceToken(
|
||||
ctx context.Context,
|
||||
refresher *oauth.Refresher,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
if refresher == nil {
|
||||
return "", fmt.Errorf("refresher is nil (OAuth session required for service tokens)")
|
||||
}
|
||||
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
fmt.Printf("DEBUG [atproto/servicetoken]: Using cached service token for DID=%s (expires in %v)\n",
|
||||
did, time.Until(expiresAt).Round(time.Second))
|
||||
return cachedToken, nil
|
||||
}
|
||||
|
||||
// Cache miss or expiring soon - validate OAuth and get new service token
|
||||
if cachedToken == "" {
|
||||
fmt.Printf("DEBUG [atproto/servicetoken]: Cache miss, fetching service token for DID=%s\n", did)
|
||||
} else {
|
||||
fmt.Printf("DEBUG [atproto/servicetoken]: Token expiring soon, proactively renewing for DID=%s\n", did)
|
||||
}
|
||||
|
||||
session, err := refresher.GetSession(ctx, did)
|
||||
if err != nil {
|
||||
// OAuth session unavailable - invalidate and fail
|
||||
refresher.InvalidateSession(did)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
return "", fmt.Errorf("failed to get OAuth session: %w", err)
|
||||
}
|
||||
|
||||
// Call com.atproto.server.getServiceAuth on the user's PDS
|
||||
// Request 5-minute expiry (PDS may grant less)
|
||||
// exp must be absolute Unix timestamp, not relative duration
|
||||
// Note: OAuth scope includes #atcr_hold fragment, but service auth aud must be bare DID
|
||||
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
|
||||
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
pdsEndpoint,
|
||||
atproto.ServerGetServiceAuth,
|
||||
url.QueryEscape(holdDID),
|
||||
url.QueryEscape("com.atproto.repo.getRecord"),
|
||||
expiryTime,
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create service auth request: %w", err)
|
||||
}
|
||||
|
||||
// Use OAuth session to authenticate to PDS (with DPoP)
|
||||
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
||||
if err != nil {
|
||||
// Invalidate session on auth errors (may indicate corrupted session or expired tokens)
|
||||
refresher.InvalidateSession(did)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
return "", fmt.Errorf("OAuth validation failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Invalidate session on auth failures
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
refresher.InvalidateSession(did)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Parse response to get service token
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to decode service auth response: %w", err)
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
return "", fmt.Errorf("empty token in service auth response")
|
||||
}
|
||||
|
||||
serviceToken := result.Token
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if err := SetServiceToken(did, holdDID, serviceToken); err != nil {
|
||||
fmt.Printf("WARN [atproto/servicetoken]: Failed to cache service token: %v\n", err)
|
||||
// Non-fatal - we have the token, just won't be cached
|
||||
}
|
||||
|
||||
fmt.Printf("DEBUG [atproto/servicetoken]: OAuth validation succeeded for DID=%s\n", did)
|
||||
return serviceToken, nil
|
||||
}
|
||||
@@ -40,7 +40,7 @@ type RegistrationConfig struct {
|
||||
|
||||
// EnableBlueskyPosts controls whether to create Bluesky posts for manifest uploads (from env: HOLD_BLUESKY_POSTS_ENABLED)
|
||||
// If true, creates posts when users push images
|
||||
// Can be overridden per-hold via captain record's enableManifestPosts field
|
||||
// Synced to captain record's enableBlueskyPosts field on startup
|
||||
EnableBlueskyPosts bool `yaml:"enable_bluesky_posts"`
|
||||
}
|
||||
|
||||
|
||||
@@ -244,9 +244,15 @@ func (h *XRPCHandler) HandleNotifyManifest(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
// Check if manifest posts are enabled
|
||||
// Controlled by HOLD_BLUESKY_POSTS_ENABLED environment variable
|
||||
// TODO: Override with captain record enableManifestPosts field if set
|
||||
postsEnabled := h.enableBlueskyPosts
|
||||
// Read from captain record (which is synced with HOLD_BLUESKY_POSTS_ENABLED env var)
|
||||
postsEnabled := false
|
||||
_, captain, err := h.pds.GetCaptainRecord(ctx)
|
||||
if err == nil {
|
||||
postsEnabled = captain.EnableBlueskyPosts
|
||||
} else {
|
||||
// Fallback to env var if captain record doesn't exist (shouldn't happen in normal operation)
|
||||
postsEnabled = h.enableBlueskyPosts
|
||||
}
|
||||
|
||||
// Create layer records for each blob
|
||||
layersCreated := 0
|
||||
|
||||
@@ -742,7 +742,7 @@ func TestValidateBlobReadAccess_PrivateHold(t *testing.T) {
|
||||
pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, false, false)
|
||||
|
||||
// Update captain to be private
|
||||
_, err := pds.UpdateCaptainRecord(ctx, false, false)
|
||||
_, err := pds.UpdateCaptainRecord(ctx, false, false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update captain record: %v", err)
|
||||
}
|
||||
|
||||
@@ -16,13 +16,14 @@ const (
|
||||
|
||||
// CreateCaptainRecord creates the captain record for the hold (first-time only).
|
||||
// This will FAIL if the captain record already exists. Use UpdateCaptainRecord to modify.
|
||||
func (p *HoldPDS) CreateCaptainRecord(ctx context.Context, ownerDID string, public bool, allowAllCrew bool) (cid.Cid, error) {
|
||||
func (p *HoldPDS) CreateCaptainRecord(ctx context.Context, ownerDID string, public bool, allowAllCrew bool, enableBlueskyPosts bool) (cid.Cid, error) {
|
||||
captainRecord := &atproto.CaptainRecord{
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: ownerDID,
|
||||
Public: public,
|
||||
AllowAllCrew: allowAllCrew,
|
||||
DeployedAt: time.Now().Format(time.RFC3339),
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: ownerDID,
|
||||
Public: public,
|
||||
AllowAllCrew: allowAllCrew,
|
||||
EnableBlueskyPosts: enableBlueskyPosts,
|
||||
DeployedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// Use repomgr.PutRecord - creates with explicit rkey, fails if already exists
|
||||
@@ -53,8 +54,8 @@ func (p *HoldPDS) GetCaptainRecord(ctx context.Context) (cid.Cid, *atproto.Capta
|
||||
return recordCID, captainRecord, nil
|
||||
}
|
||||
|
||||
// UpdateCaptainRecord updates the captain record (e.g., to change public/allowAllCrew settings)
|
||||
func (p *HoldPDS) UpdateCaptainRecord(ctx context.Context, public bool, allowAllCrew bool) (cid.Cid, error) {
|
||||
// UpdateCaptainRecord updates the captain record (e.g., to change public/allowAllCrew/enableBlueskyPosts settings)
|
||||
func (p *HoldPDS) UpdateCaptainRecord(ctx context.Context, public bool, allowAllCrew bool, enableBlueskyPosts bool) (cid.Cid, error) {
|
||||
// Get existing record to preserve other fields
|
||||
_, existing, err := p.GetCaptainRecord(ctx)
|
||||
if err != nil {
|
||||
@@ -64,6 +65,7 @@ func (p *HoldPDS) UpdateCaptainRecord(ctx context.Context, public bool, allowAll
|
||||
// Update the fields
|
||||
existing.Public = public
|
||||
existing.AllowAllCrew = allowAllCrew
|
||||
existing.EnableBlueskyPosts = enableBlueskyPosts
|
||||
|
||||
recordCID, err := p.repomgr.UpdateRecord(ctx, p.uid, atproto.CaptainCollection, CaptainRkey, existing)
|
||||
if err != nil {
|
||||
|
||||
@@ -71,34 +71,39 @@ func setupTestPDSWithBootstrap(t *testing.T, ownerDID string, public, allowAllCr
|
||||
// TestCreateCaptainRecord tests creating a captain record with various settings
|
||||
func TestCreateCaptainRecord(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ownerDID string
|
||||
public bool
|
||||
allowAllCrew bool
|
||||
name string
|
||||
ownerDID string
|
||||
public bool
|
||||
allowAllCrew bool
|
||||
enableBlueskyPosts bool
|
||||
}{
|
||||
{
|
||||
name: "Private hold, no all-crew",
|
||||
ownerDID: "did:plc:alice123",
|
||||
public: false,
|
||||
allowAllCrew: false,
|
||||
name: "Private hold, no all-crew",
|
||||
ownerDID: "did:plc:alice123",
|
||||
public: false,
|
||||
allowAllCrew: false,
|
||||
enableBlueskyPosts: false,
|
||||
},
|
||||
{
|
||||
name: "Public hold, no all-crew",
|
||||
ownerDID: "did:plc:bob456",
|
||||
public: true,
|
||||
allowAllCrew: false,
|
||||
name: "Public hold, no all-crew",
|
||||
ownerDID: "did:plc:bob456",
|
||||
public: true,
|
||||
allowAllCrew: false,
|
||||
enableBlueskyPosts: true,
|
||||
},
|
||||
{
|
||||
name: "Public hold, allow all crew",
|
||||
ownerDID: "did:plc:charlie789",
|
||||
public: true,
|
||||
allowAllCrew: true,
|
||||
name: "Public hold, allow all crew",
|
||||
ownerDID: "did:plc:charlie789",
|
||||
public: true,
|
||||
allowAllCrew: true,
|
||||
enableBlueskyPosts: false,
|
||||
},
|
||||
{
|
||||
name: "Private hold, allow all crew",
|
||||
ownerDID: "did:plc:dave012",
|
||||
public: false,
|
||||
allowAllCrew: true,
|
||||
name: "Private hold, allow all crew",
|
||||
ownerDID: "did:plc:dave012",
|
||||
public: false,
|
||||
allowAllCrew: true,
|
||||
enableBlueskyPosts: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -109,7 +114,7 @@ func TestCreateCaptainRecord(t *testing.T) {
|
||||
defer pds.Close()
|
||||
|
||||
// Create captain record
|
||||
recordCID, err := pds.CreateCaptainRecord(ctx, tt.ownerDID, tt.public, tt.allowAllCrew)
|
||||
recordCID, err := pds.CreateCaptainRecord(ctx, tt.ownerDID, tt.public, tt.allowAllCrew, tt.enableBlueskyPosts)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCaptainRecord failed: %v", err)
|
||||
}
|
||||
@@ -138,6 +143,9 @@ func TestCreateCaptainRecord(t *testing.T) {
|
||||
if captain.AllowAllCrew != tt.allowAllCrew {
|
||||
t.Errorf("Expected allowAllCrew=%v, got %v", tt.allowAllCrew, captain.AllowAllCrew)
|
||||
}
|
||||
if captain.EnableBlueskyPosts != tt.enableBlueskyPosts {
|
||||
t.Errorf("Expected enableBlueskyPosts=%v, got %v", tt.enableBlueskyPosts, captain.EnableBlueskyPosts)
|
||||
}
|
||||
if captain.Type != atproto.CaptainCollection {
|
||||
t.Errorf("Expected type %s, got %s", atproto.CaptainCollection, captain.Type)
|
||||
}
|
||||
@@ -156,7 +164,7 @@ func TestGetCaptainRecord(t *testing.T) {
|
||||
ownerDID := "did:plc:alice123"
|
||||
|
||||
// Create captain record
|
||||
createdCID, err := pds.CreateCaptainRecord(ctx, ownerDID, true, false)
|
||||
createdCID, err := pds.CreateCaptainRecord(ctx, ownerDID, true, false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCaptainRecord failed: %v", err)
|
||||
}
|
||||
@@ -212,8 +220,8 @@ func TestUpdateCaptainRecord(t *testing.T) {
|
||||
|
||||
ownerDID := "did:plc:alice123"
|
||||
|
||||
// Create initial captain record (public=false, allowAllCrew=false)
|
||||
_, err := pds.CreateCaptainRecord(ctx, ownerDID, false, false)
|
||||
// Create initial captain record (public=false, allowAllCrew=false, enableBlueskyPosts=false)
|
||||
_, err := pds.CreateCaptainRecord(ctx, ownerDID, false, false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCaptainRecord failed: %v", err)
|
||||
}
|
||||
@@ -231,9 +239,12 @@ func TestUpdateCaptainRecord(t *testing.T) {
|
||||
if captain1.AllowAllCrew {
|
||||
t.Error("Expected initial allowAllCrew=false")
|
||||
}
|
||||
if captain1.EnableBlueskyPosts {
|
||||
t.Error("Expected initial enableBlueskyPosts=false")
|
||||
}
|
||||
|
||||
// Update to public=true, allowAllCrew=true
|
||||
updatedCID, err := pds.UpdateCaptainRecord(ctx, true, true)
|
||||
// Update to public=true, allowAllCrew=true, enableBlueskyPosts=true
|
||||
updatedCID, err := pds.UpdateCaptainRecord(ctx, true, true, true)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateCaptainRecord failed: %v", err)
|
||||
}
|
||||
@@ -260,14 +271,17 @@ func TestUpdateCaptainRecord(t *testing.T) {
|
||||
if !captain2.AllowAllCrew {
|
||||
t.Error("Expected allowAllCrew=true after update")
|
||||
}
|
||||
if !captain2.EnableBlueskyPosts {
|
||||
t.Error("Expected enableBlueskyPosts=true after update")
|
||||
}
|
||||
|
||||
// Verify owner didn't change
|
||||
if captain2.Owner != ownerDID {
|
||||
t.Errorf("Expected owner to remain %s, got %s", ownerDID, captain2.Owner)
|
||||
}
|
||||
|
||||
// Update again to different values (public=true, allowAllCrew=false)
|
||||
_, err = pds.UpdateCaptainRecord(ctx, true, false)
|
||||
// Update again to different values (public=true, allowAllCrew=false, enableBlueskyPosts=false)
|
||||
_, err = pds.UpdateCaptainRecord(ctx, true, false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Second UpdateCaptainRecord failed: %v", err)
|
||||
}
|
||||
@@ -292,7 +306,7 @@ func TestUpdateCaptainRecord_NotFound(t *testing.T) {
|
||||
defer pds.Close()
|
||||
|
||||
// Try to update captain record before creating one
|
||||
_, err := pds.UpdateCaptainRecord(ctx, true, true)
|
||||
_, err := pds.UpdateCaptainRecord(ctx, true, true, true)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when updating non-existent captain record")
|
||||
}
|
||||
|
||||
@@ -155,12 +155,12 @@ func (p *HoldPDS) Bootstrap(ctx context.Context, storageDriver driver.StorageDri
|
||||
}
|
||||
|
||||
// Create captain record (hold ownership and settings)
|
||||
_, err = p.CreateCaptainRecord(ctx, ownerDID, public, allowAllCrew)
|
||||
_, err = p.CreateCaptainRecord(ctx, ownerDID, public, allowAllCrew, p.enableBlueskyPosts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create captain record: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("✅ Created captain record (public=%v, allowAllCrew=%v)\n", public, allowAllCrew)
|
||||
fmt.Printf("✅ Created captain record (public=%v, allowAllCrew=%v, enableBlueskyPosts=%v)\n", public, allowAllCrew, p.enableBlueskyPosts)
|
||||
|
||||
// Add hold owner as first crew member with admin role
|
||||
_, err = p.AddCrewMember(ctx, ownerDID, "admin", []string{"blob:read", "blob:write", "crew:admin"})
|
||||
@@ -169,6 +169,24 @@ func (p *HoldPDS) Bootstrap(ctx context.Context, storageDriver driver.StorageDri
|
||||
}
|
||||
|
||||
fmt.Printf("✅ Added %s as hold admin\n", ownerDID)
|
||||
} else {
|
||||
// Captain record exists, check if we need to sync settings from env vars
|
||||
_, existingCaptain, err := p.GetCaptainRecord(ctx)
|
||||
if err == nil {
|
||||
// Check if any settings need updating
|
||||
needsUpdate := existingCaptain.Public != public ||
|
||||
existingCaptain.AllowAllCrew != allowAllCrew ||
|
||||
existingCaptain.EnableBlueskyPosts != p.enableBlueskyPosts
|
||||
|
||||
if needsUpdate {
|
||||
// Update captain record to match env vars
|
||||
_, err = p.UpdateCaptainRecord(ctx, public, allowAllCrew, p.enableBlueskyPosts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update captain record: %w", err)
|
||||
}
|
||||
fmt.Printf("✅ Synced captain record with env vars (public=%v, allowAllCrew=%v, enableBlueskyPosts=%v)\n", public, allowAllCrew, p.enableBlueskyPosts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create Bluesky profile record (idempotent - check if exists first)
|
||||
|
||||
@@ -560,7 +560,7 @@ func TestBootstrap_CaptainWithoutCrew(t *testing.T) {
|
||||
|
||||
// Create captain record WITHOUT crew (unusual state)
|
||||
ownerDID := "did:plc:alice123"
|
||||
_, err = pds.CreateCaptainRecord(ctx, ownerDID, true, false)
|
||||
_, err = pds.CreateCaptainRecord(ctx, ownerDID, true, false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCaptainRecord failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1199,7 +1199,7 @@ func TestHandleRequestCrew(t *testing.T) {
|
||||
handler, ctx := setupTestXRPCHandler(t)
|
||||
|
||||
// Update captain record to allow all crew
|
||||
_, err := handler.pds.UpdateCaptainRecord(ctx, true, true) // public=true, allowAllCrew=true
|
||||
_, err := handler.pds.UpdateCaptainRecord(ctx, true, true, false) // public=true, allowAllCrew=true, enableBlueskyPosts=false
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update captain record: %v", err)
|
||||
}
|
||||
@@ -1243,7 +1243,7 @@ func TestHandleRequestCrew_AllowAllCrewDisabled(t *testing.T) {
|
||||
|
||||
// Captain record was created with allowAllCrew=false in setupTestXRPCHandler
|
||||
// Update to make sure it's false
|
||||
_, err := handler.pds.UpdateCaptainRecord(ctx, true, false) // public=true, allowAllCrew=false
|
||||
_, err := handler.pds.UpdateCaptainRecord(ctx, true, false, false) // public=true, allowAllCrew=false, enableBlueskyPosts=false
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update captain record: %v", err)
|
||||
}
|
||||
@@ -1715,7 +1715,7 @@ func TestHandleGetBlob_CORSHeaders(t *testing.T) {
|
||||
handler, _, ctx := setupTestXRPCHandlerWithBlobs(t)
|
||||
|
||||
// Make hold public
|
||||
_, err := handler.pds.UpdateCaptainRecord(ctx, true, false)
|
||||
_, err := handler.pds.UpdateCaptainRecord(ctx, true, false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update captain: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,54 +1,187 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Configuration
|
||||
SOURCE_REGISTRY="ghcr.io/evanjarrett/hsm-secrets-operator"
|
||||
TARGET_REGISTRY="atcr.io/evan.jarrett.net/hsm-secrets-operator"
|
||||
TAG="latest"
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 <source-image> [target-image]"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 ghcr.io/evanjarrett/myapp:latest"
|
||||
echo " $0 ghcr.io/evanjarrett/myapp:latest atcr.io/evan.jarrett.net/myapp:latest"
|
||||
echo ""
|
||||
echo "If target-image is not specified, it will use atcr.io/<username>/<repo>:<tag>"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Image digests
|
||||
AMD64_DIGEST="sha256:274284a623810cf07c5b4735628832751926b7d192863681d5af1b4137f44254"
|
||||
ARM64_DIGEST="sha256:b57929fd100033092766aad1c7e747deef9b1e3206756c11d0d7a7af74daedff"
|
||||
# Check arguments
|
||||
if [ $# -lt 1 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
echo "=== Migrating multi-arch image from GHCR to ATCR ==="
|
||||
echo "Source: ${SOURCE_REGISTRY}"
|
||||
echo "Target: ${TARGET_REGISTRY}:${TAG}"
|
||||
SOURCE_IMAGE="$1"
|
||||
TARGET_IMAGE="${2:-}"
|
||||
|
||||
# Parse source image to extract components
|
||||
# Format: [registry/]repository[:tag|@digest]
|
||||
parse_image_ref() {
|
||||
local ref="$1"
|
||||
local registry=""
|
||||
local repository=""
|
||||
local tag="latest"
|
||||
|
||||
# Remove digest if present (we'll fetch the manifest-list)
|
||||
ref="${ref%@*}"
|
||||
|
||||
# Extract tag
|
||||
if [[ "$ref" == *:* ]]; then
|
||||
tag="${ref##*:}"
|
||||
ref="${ref%:*}"
|
||||
fi
|
||||
|
||||
# Extract registry and repository
|
||||
if [[ "$ref" == */*/* ]]; then
|
||||
# Has registry
|
||||
registry="${ref%%/*}"
|
||||
repository="${ref#*/}"
|
||||
else
|
||||
# No registry, assume Docker Hub
|
||||
registry="docker.io"
|
||||
repository="$ref"
|
||||
fi
|
||||
|
||||
echo "$registry" "$repository" "$tag"
|
||||
}
|
||||
|
||||
# Parse source image
|
||||
read -r SOURCE_REGISTRY SOURCE_REPO SOURCE_TAG <<< "$(parse_image_ref "$SOURCE_IMAGE")"
|
||||
|
||||
# If no target specified, auto-generate it
|
||||
if [ -z "$TARGET_IMAGE" ]; then
|
||||
# Extract just the repo name (last component)
|
||||
REPO_NAME="${SOURCE_REPO##*/}"
|
||||
# Try to extract username from source
|
||||
if [[ "$SOURCE_REPO" == */* ]]; then
|
||||
USERNAME="${SOURCE_REPO%/*}"
|
||||
USERNAME="${USERNAME##*/}"
|
||||
else
|
||||
USERNAME="default"
|
||||
fi
|
||||
TARGET_IMAGE="atcr.io/${USERNAME}/${REPO_NAME}:${SOURCE_TAG}"
|
||||
fi
|
||||
|
||||
# Parse target image
|
||||
read -r TARGET_REGISTRY TARGET_REPO TARGET_TAG <<< "$(parse_image_ref "$TARGET_IMAGE")"
|
||||
|
||||
echo "=== Migrating multi-arch image ==="
|
||||
echo "Source: ${SOURCE_REGISTRY}/${SOURCE_REPO}:${SOURCE_TAG}"
|
||||
echo "Target: ${TARGET_REGISTRY}/${TARGET_REPO}:${TARGET_TAG}"
|
||||
echo ""
|
||||
|
||||
# Tag and push amd64 image
|
||||
echo ">>> Tagging and pushing amd64 image..."
|
||||
docker tag "${SOURCE_REGISTRY}@${AMD64_DIGEST}" "${TARGET_REGISTRY}:${TAG}-amd64"
|
||||
docker push "${TARGET_REGISTRY}:${TAG}-amd64"
|
||||
# Full source reference
|
||||
SOURCE_REF="${SOURCE_REGISTRY}/${SOURCE_REPO}:${SOURCE_TAG}"
|
||||
TARGET_REF="${TARGET_REGISTRY}/${TARGET_REPO}:${TARGET_TAG}"
|
||||
|
||||
# Fetch the manifest list
|
||||
echo ">>> Fetching manifest list from source..."
|
||||
MANIFEST_JSON=$(docker manifest inspect "$SOURCE_REF" 2>/dev/null || {
|
||||
echo "Error: Failed to fetch manifest list. This may not be a multi-arch image."
|
||||
echo "Trying as single-arch image..."
|
||||
|
||||
# Try pulling as single image
|
||||
docker pull "$SOURCE_REF"
|
||||
docker tag "$SOURCE_REF" "$TARGET_REF"
|
||||
docker push "$TARGET_REF"
|
||||
echo "=== Migration complete (single-arch) ==="
|
||||
exit 0
|
||||
})
|
||||
|
||||
# Check if this is a manifest list
|
||||
MEDIA_TYPE=$(echo "$MANIFEST_JSON" | jq -r '.mediaType // .schemaVersion')
|
||||
if [[ ! "$MEDIA_TYPE" =~ "manifest.list" ]] && [[ ! "$MEDIA_TYPE" =~ "index" ]]; then
|
||||
echo "Warning: Source appears to be a single-arch image, not a manifest list."
|
||||
docker pull "$SOURCE_REF"
|
||||
docker tag "$SOURCE_REF" "$TARGET_REF"
|
||||
docker push "$TARGET_REF"
|
||||
echo "=== Migration complete (single-arch) ==="
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Found multi-arch manifest list"
|
||||
echo ""
|
||||
|
||||
# Tag and push arm64 image
|
||||
echo ">>> Tagging and pushing arm64 image..."
|
||||
docker tag "${SOURCE_REGISTRY}@${ARM64_DIGEST}" "${TARGET_REGISTRY}:${TAG}-arm64"
|
||||
docker push "${TARGET_REGISTRY}:${TAG}-arm64"
|
||||
echo ""
|
||||
# Extract platform information and digests
|
||||
PLATFORMS=$(echo "$MANIFEST_JSON" | jq -r '.manifests[] | "\(.platform.os)|\(.platform.architecture)|\(.platform.variant // "")|\(.digest)"')
|
||||
|
||||
# Create multi-arch manifest using the pushed tags
|
||||
# Arrays to store pushed images for manifest creation
|
||||
declare -a PUSHED_IMAGES
|
||||
declare -a PLATFORM_INFO
|
||||
|
||||
# Process each platform
|
||||
while IFS='|' read -r os arch variant digest; do
|
||||
# Create platform tag (e.g., "linux-amd64" or "linux-arm-v7")
|
||||
PLATFORM_TAG="${os}-${arch}"
|
||||
if [ -n "$variant" ]; then
|
||||
PLATFORM_TAG="${PLATFORM_TAG}-${variant}"
|
||||
fi
|
||||
|
||||
echo ">>> Processing ${os}/${arch}${variant:+/$variant}..."
|
||||
echo " Digest: $digest"
|
||||
|
||||
# Pull by digest
|
||||
echo " Pulling image..."
|
||||
docker pull "${SOURCE_REGISTRY}/${SOURCE_REPO}@${digest}"
|
||||
|
||||
# Tag for target
|
||||
TARGET_PLATFORM_REF="${TARGET_REGISTRY}/${TARGET_REPO}:${TARGET_TAG}-${PLATFORM_TAG}"
|
||||
echo " Tagging as: ${TARGET_PLATFORM_REF}"
|
||||
docker tag "${SOURCE_REGISTRY}/${SOURCE_REPO}@${digest}" "${TARGET_PLATFORM_REF}"
|
||||
|
||||
# Push platform-specific image
|
||||
echo " Pushing..."
|
||||
docker push "${TARGET_PLATFORM_REF}"
|
||||
|
||||
# Store for manifest creation
|
||||
PUSHED_IMAGES+=("${TARGET_PLATFORM_REF}")
|
||||
PLATFORM_INFO+=("${os}|${arch}|${variant}")
|
||||
|
||||
echo ""
|
||||
done <<< "$PLATFORMS"
|
||||
|
||||
# Create multi-arch manifest
|
||||
echo ">>> Creating multi-arch manifest..."
|
||||
docker manifest create "${TARGET_REGISTRY}:${TAG}" \
|
||||
--amend "${TARGET_REGISTRY}:${TAG}-amd64" \
|
||||
--amend "${TARGET_REGISTRY}:${TAG}-arm64"
|
||||
MANIFEST_CREATE_CMD="docker manifest create ${TARGET_REF}"
|
||||
for img in "${PUSHED_IMAGES[@]}"; do
|
||||
MANIFEST_CREATE_CMD="${MANIFEST_CREATE_CMD} --amend ${img}"
|
||||
done
|
||||
|
||||
eval "$MANIFEST_CREATE_CMD"
|
||||
echo ""
|
||||
|
||||
# Annotate the manifest with platform information
|
||||
# Annotate each platform
|
||||
echo ">>> Annotating manifest with platform information..."
|
||||
docker manifest annotate "${TARGET_REGISTRY}:${TAG}" \
|
||||
"${TARGET_REGISTRY}:${TAG}-amd64" \
|
||||
--os linux --arch amd64
|
||||
for i in "${!PUSHED_IMAGES[@]}"; do
|
||||
IFS='|' read -r os arch variant <<< "${PLATFORM_INFO[$i]}"
|
||||
|
||||
docker manifest annotate "${TARGET_REGISTRY}:${TAG}" \
|
||||
"${TARGET_REGISTRY}:${TAG}-arm64" \
|
||||
--os linux --arch arm64
|
||||
ANNOTATE_CMD="docker manifest annotate ${TARGET_REF} ${PUSHED_IMAGES[$i]} --os ${os} --arch ${arch}"
|
||||
if [ -n "$variant" ]; then
|
||||
ANNOTATE_CMD="${ANNOTATE_CMD} --variant ${variant}"
|
||||
fi
|
||||
|
||||
echo " Annotating ${os}/${arch}${variant:+/$variant}..."
|
||||
eval "$ANNOTATE_CMD"
|
||||
done
|
||||
echo ""
|
||||
|
||||
# Push the manifest list
|
||||
echo ">>> Pushing multi-arch manifest..."
|
||||
docker manifest push "${TARGET_REGISTRY}:${TAG}"
|
||||
docker manifest push "${TARGET_REF}"
|
||||
echo ""
|
||||
|
||||
echo "=== Migration complete! ==="
|
||||
echo "You can now pull: docker pull ${TARGET_REGISTRY}:${TAG}"
|
||||
echo "You can now pull: docker pull ${TARGET_REF}"
|
||||
echo ""
|
||||
echo "Migrated platforms:"
|
||||
for i in "${!PLATFORM_INFO[@]}"; do
|
||||
IFS='|' read -r os arch variant <<< "${PLATFORM_INFO[$i]}"
|
||||
echo " - ${os}/${arch}${variant:+/$variant}"
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user