Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31dc4b4f53 | ||
|
|
af99929aa3 | ||
|
|
7f2d780b0a | ||
|
|
8956568ed2 | ||
|
|
c1f2ae0f7a | ||
|
|
012a14c4ee | ||
|
|
4cda163099 | ||
|
|
41bcee4a59 | ||
|
|
24d6b49481 | ||
|
|
363c12e6bf | ||
|
|
2a60a47fd5 | ||
|
|
34c2b8b17c | ||
|
|
8d0cff63fb | ||
|
|
d11356cd18 | ||
|
|
79d1126726 | ||
|
|
8e31137c62 | ||
|
|
023efb05aa | ||
|
|
b18e4c3996 | ||
|
|
24b265bf12 | ||
|
|
e8e375639d | ||
|
|
5a208de4c9 | ||
|
|
104eb86c04 |
37
CLAUDE.md
37
CLAUDE.md
@@ -475,12 +475,47 @@ Lightweight standalone service for BYOS (Bring Your Own Storage) with embedded P
|
||||
|
||||
Read access:
|
||||
- **Public hold** (`HOLD_PUBLIC=true`): Anonymous + all authenticated users
|
||||
- **Private hold** (`HOLD_PUBLIC=false`): Requires authentication + crew membership with blob:read permission
|
||||
- **Private hold** (`HOLD_PUBLIC=false`): Requires authentication + crew membership with blob:read OR blob:write permission
|
||||
- **Note:** `blob:write` implicitly grants `blob:read` access (can't push without pulling)
|
||||
|
||||
Write access:
|
||||
- Hold owner OR crew members with blob:write permission
|
||||
- Verified via `io.atcr.hold.crew` records in hold's embedded PDS
|
||||
|
||||
**Permission Matrix:**
|
||||
|
||||
| User Type | Public Read | Private Read | Write | Crew Admin |
|
||||
|-----------|-------------|--------------|-------|------------|
|
||||
| Anonymous | Yes | No | No | No |
|
||||
| Owner (captain) | Yes | Yes | Yes | Yes (implied) |
|
||||
| Crew (blob:read only) | Yes | Yes | No | No |
|
||||
| Crew (blob:write only) | Yes | Yes* | Yes | No |
|
||||
| Crew (blob:read + blob:write) | Yes | Yes | Yes | No |
|
||||
| Crew (crew:admin) | Yes | Yes | Yes | Yes |
|
||||
| Authenticated non-crew | Yes | No | No | No |
|
||||
|
||||
*`blob:write` implicitly grants `blob:read` access
|
||||
|
||||
**Authorization Error Format:**
|
||||
|
||||
All authorization failures use consistent structured errors (`pkg/hold/pds/auth.go`):
|
||||
```
|
||||
access denied for [action]: [reason] (required: [permission(s)])
|
||||
```
|
||||
|
||||
Examples:
|
||||
- `access denied for blob:read: user is not a crew member (required: blob:read or blob:write)`
|
||||
- `access denied for blob:write: crew member lacks permission (required: blob:write)`
|
||||
- `access denied for crew:admin: user is not a crew member (required: crew:admin)`
|
||||
|
||||
**Shared Error Constants** (`pkg/hold/pds/auth.go`):
|
||||
- `ErrMissingAuthHeader` - Missing Authorization header
|
||||
- `ErrInvalidAuthFormat` - Invalid Authorization header format
|
||||
- `ErrInvalidAuthScheme` - Invalid scheme (expected Bearer or DPoP)
|
||||
- `ErrInvalidJWTFormat` - Malformed JWT
|
||||
- `ErrMissingISSClaim` / `ErrMissingSubClaim` - Missing JWT claims
|
||||
- `ErrTokenExpired` - Token has expired
|
||||
|
||||
**Embedded PDS Endpoints** (`pkg/hold/pds/xrpc.go`):
|
||||
|
||||
Standard ATProto sync endpoints:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Production build for ATCR AppView
|
||||
# Result: ~30MB scratch image with static binary
|
||||
FROM docker.io/golang:1.25.2-trixie AS builder
|
||||
FROM docker.io/golang:1.25.4-trixie AS builder
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
@@ -34,12 +34,12 @@ EXPOSE 5000
|
||||
LABEL org.opencontainers.image.title="ATCR AppView" \
|
||||
org.opencontainers.image.description="ATProto Container Registry - OCI-compliant registry using AT Protocol for manifest storage" \
|
||||
org.opencontainers.image.authors="ATCR Contributors" \
|
||||
org.opencontainers.image.source="https://tangled.org/@evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.documentation="https://tangled.org/@evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.source="https://tangled.org/evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.documentation="https://tangled.org/evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.licenses="MIT" \
|
||||
org.opencontainers.image.version="0.1.0" \
|
||||
io.atcr.icon="https://imgs.blue/evan.jarrett.net/1TpTNrRelfloN2emuWZDrWmPT0o93bAjEnozjD6UPgoVV9m4" \
|
||||
io.atcr.readme="https://tangled.org/@evan.jarrett.net/at-container-registry/raw/main/docs/appview.md"
|
||||
io.atcr.readme="https://tangled.org/evan.jarrett.net/at-container-registry/raw/main/docs/appview.md"
|
||||
|
||||
ENTRYPOINT ["/atcr-appview"]
|
||||
CMD ["serve"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Development image with Air hot reload
|
||||
# Build: docker build -f Dockerfile.dev -t atcr-appview-dev .
|
||||
# Run: docker run -v $(pwd):/app -p 5000:5000 atcr-appview-dev
|
||||
FROM docker.io/golang:1.25.2-trixie
|
||||
FROM docker.io/golang:1.25.4-trixie
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM docker.io/golang:1.25.2-trixie AS builder
|
||||
FROM docker.io/golang:1.25.4-trixie AS builder
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
@@ -38,11 +38,11 @@ EXPOSE 8080
|
||||
LABEL org.opencontainers.image.title="ATCR Hold Service" \
|
||||
org.opencontainers.image.description="ATCR Hold Service - Bring Your Own Storage component for ATCR" \
|
||||
org.opencontainers.image.authors="ATCR Contributors" \
|
||||
org.opencontainers.image.source="https://tangled.org/@evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.documentation="https://tangled.org/@evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.source="https://tangled.org/evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.documentation="https://tangled.org/evan.jarrett.net/at-container-registry" \
|
||||
org.opencontainers.image.licenses="MIT" \
|
||||
org.opencontainers.image.version="0.1.0" \
|
||||
io.atcr.icon="https://imgs.blue/evan.jarrett.net/1TpTOdtS60GdJWBYEqtK22y688jajbQ9a5kbYRFtwuqrkBAE" \
|
||||
io.atcr.readme="https://tangled.org/@evan.jarrett.net/at-container-registry/raw/main/docs/hold.md"
|
||||
io.atcr.readme="https://tangled.org/evan.jarrett.net/at-container-registry/raw/main/docs/hold.md"
|
||||
|
||||
ENTRYPOINT ["/atcr-hold"]
|
||||
|
||||
@@ -82,9 +82,8 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
slog.Info("Initializing hold health checker", "cache_ttl", cfg.Health.CacheTTL)
|
||||
healthChecker := holdhealth.NewChecker(cfg.Health.CacheTTL)
|
||||
|
||||
// Initialize README cache
|
||||
slog.Info("Initializing README cache", "cache_ttl", cfg.Health.ReadmeCacheTTL)
|
||||
readmeCache := readme.NewCache(uiDatabase, cfg.Health.ReadmeCacheTTL)
|
||||
// Initialize README fetcher for rendering repo page descriptions
|
||||
readmeFetcher := readme.NewFetcher()
|
||||
|
||||
// Start background health check worker
|
||||
startupDelay := 5 * time.Second // Wait for hold services to start (Docker compose)
|
||||
@@ -151,20 +150,15 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
middleware.SetGlobalRefresher(refresher)
|
||||
|
||||
// Set global database for pull/push metrics tracking
|
||||
metricsDB := db.NewMetricsDB(uiDatabase)
|
||||
middleware.SetGlobalDatabase(metricsDB)
|
||||
middleware.SetGlobalDatabase(uiDatabase)
|
||||
|
||||
// Create RemoteHoldAuthorizer for hold authorization with caching
|
||||
holdAuthorizer := auth.NewRemoteHoldAuthorizer(uiDatabase, testMode)
|
||||
middleware.SetGlobalAuthorizer(holdAuthorizer)
|
||||
slog.Info("Hold authorizer initialized with database caching")
|
||||
|
||||
// Set global readme cache for middleware
|
||||
middleware.SetGlobalReadmeCache(readmeCache)
|
||||
slog.Info("README cache initialized for manifest push refresh")
|
||||
|
||||
// Initialize Jetstream workers (background services before HTTP routes)
|
||||
initializeJetstream(uiDatabase, &cfg.Jetstream, defaultHoldDID, testMode)
|
||||
initializeJetstream(uiDatabase, &cfg.Jetstream, defaultHoldDID, testMode, refresher)
|
||||
|
||||
// Create main chi router
|
||||
mainRouter := chi.NewRouter()
|
||||
@@ -194,8 +188,9 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
BaseURL: baseURL,
|
||||
DeviceStore: deviceStore,
|
||||
HealthChecker: healthChecker,
|
||||
ReadmeCache: readmeCache,
|
||||
ReadmeFetcher: readmeFetcher,
|
||||
Templates: uiTemplates,
|
||||
DefaultHoldDID: defaultHoldDID,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -217,14 +212,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
// Create ATProto client with session provider (uses DoWithSession for DPoP nonce safety)
|
||||
client := atproto.NewClientWithSessionProvider(pdsEndpoint, did, refresher)
|
||||
|
||||
// Ensure sailor profile exists (creates with default hold if configured)
|
||||
slog.Debug("Ensuring profile exists", "component", "appview/callback", "did", did, "default_hold_did", defaultHoldDID)
|
||||
if err := storage.EnsureProfile(ctx, client, defaultHoldDID); err != nil {
|
||||
slog.Warn("Failed to ensure profile", "component", "appview/callback", "did", did, "error", err)
|
||||
// Continue anyway - profile creation is not critical for avatar fetch
|
||||
} else {
|
||||
slog.Debug("Profile ensured", "component", "appview/callback", "did", did)
|
||||
}
|
||||
// Note: Profile and crew setup now happen automatically via UserContext.EnsureUserSetup()
|
||||
|
||||
// Fetch user's profile record from PDS (contains blob references)
|
||||
profileRecord, err := client.GetProfileRecord(ctx, did)
|
||||
@@ -275,36 +263,23 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
return nil // Non-fatal
|
||||
}
|
||||
|
||||
var holdDID string
|
||||
if profile != nil && profile.DefaultHold != nil && *profile.DefaultHold != "" {
|
||||
defaultHold := *profile.DefaultHold
|
||||
// Migrate profile URL→DID if needed (legacy migration, crew registration now handled by UserContext)
|
||||
if profile != nil && profile.DefaultHold != "" {
|
||||
// Check if defaultHold is a URL (needs migration)
|
||||
if strings.HasPrefix(defaultHold, "http://") || strings.HasPrefix(defaultHold, "https://") {
|
||||
slog.Debug("Migrating hold URL to DID", "component", "appview/callback", "did", did, "hold_url", defaultHold)
|
||||
if strings.HasPrefix(profile.DefaultHold, "http://") || strings.HasPrefix(profile.DefaultHold, "https://") {
|
||||
slog.Debug("Migrating hold URL to DID", "component", "appview/callback", "did", did, "hold_url", profile.DefaultHold)
|
||||
|
||||
// Resolve URL to DID
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(defaultHold)
|
||||
holdDID := atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
|
||||
// Update profile with DID
|
||||
profile.DefaultHold = &holdDID
|
||||
profile.DefaultHold = holdDID
|
||||
if err := storage.UpdateProfile(ctx, client, profile); err != nil {
|
||||
slog.Warn("Failed to update profile with hold DID", "component", "appview/callback", "did", did, "error", err)
|
||||
} else {
|
||||
slog.Debug("Updated profile with hold DID", "component", "appview/callback", "hold_did", holdDID)
|
||||
}
|
||||
} else {
|
||||
// Already a DID - use it
|
||||
holdDID = defaultHold
|
||||
}
|
||||
// Register crew regardless of migration (outside the migration block)
|
||||
// Run in background to avoid blocking OAuth callback if hold is offline
|
||||
// Use background context - don't inherit request context which gets canceled on response
|
||||
slog.Debug("Attempting crew registration", "component", "appview/callback", "did", did, "hold_did", holdDID)
|
||||
go func(client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
|
||||
ctx := context.Background()
|
||||
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
||||
}(client, refresher, holdDID)
|
||||
|
||||
}
|
||||
|
||||
return nil // All errors are non-fatal, logged for debugging
|
||||
@@ -326,10 +301,19 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
ctx := context.Background()
|
||||
app := handlers.NewApp(ctx, cfg.Distribution)
|
||||
|
||||
// Wrap registry app with auth method extraction middleware
|
||||
// This extracts the auth method from the JWT and stores it in the request context
|
||||
// Wrap registry app with middleware chain:
|
||||
// 1. ExtractAuthMethod - extracts auth method from JWT and stores in context
|
||||
// 2. UserContextMiddleware - builds UserContext with identity, permissions, service tokens
|
||||
wrappedApp := middleware.ExtractAuthMethod(app)
|
||||
|
||||
// Create dependencies for UserContextMiddleware
|
||||
userContextDeps := &auth.Dependencies{
|
||||
Refresher: refresher,
|
||||
Authorizer: holdAuthorizer,
|
||||
DefaultHoldDID: defaultHoldDID,
|
||||
}
|
||||
wrappedApp = middleware.UserContextMiddleware(userContextDeps)(wrappedApp)
|
||||
|
||||
// Mount registry at /v2/
|
||||
mainRouter.Handle("/v2/*", wrappedApp)
|
||||
|
||||
@@ -398,6 +382,9 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
// Limit caching to allow scope changes to propagate quickly
|
||||
// PDS servers cache client metadata, so short max-age helps with updates
|
||||
w.Header().Set("Cache-Control", "public, max-age=300")
|
||||
if err := json.NewEncoder(w).Encode(metadataMap); err != nil {
|
||||
http.Error(w, "Failed to encode metadata", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -415,23 +402,11 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
||||
// Prevents the flood of errors when a stale session is discovered during push
|
||||
tokenHandler.SetOAuthSessionValidator(refresher)
|
||||
|
||||
// Register token post-auth callback for profile management
|
||||
// This decouples the token package from AppView-specific dependencies
|
||||
// Register token post-auth callback
|
||||
// Note: Profile and crew setup now happen automatically via UserContext.EnsureUserSetup()
|
||||
tokenHandler.SetPostAuthCallback(func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error {
|
||||
slog.Debug("Token post-auth callback", "component", "appview/callback", "did", did)
|
||||
|
||||
// Create ATProto client with validated token
|
||||
atprotoClient := atproto.NewClient(pdsEndpoint, did, accessToken)
|
||||
|
||||
// Ensure profile exists (will create with default hold if not exists and default is configured)
|
||||
if err := storage.EnsureProfile(ctx, atprotoClient, defaultHoldDID); err != nil {
|
||||
// Log error but don't fail auth - profile management is not critical
|
||||
slog.Warn("Failed to ensure profile", "component", "appview/callback", "did", did, "error", err)
|
||||
} else {
|
||||
slog.Debug("Profile ensured with default hold", "component", "appview/callback", "did", did, "default_hold_did", defaultHoldDID)
|
||||
}
|
||||
|
||||
return nil // All errors are non-fatal
|
||||
return nil
|
||||
})
|
||||
|
||||
mainRouter.Get("/auth/token", tokenHandler.ServeHTTP)
|
||||
@@ -520,7 +495,7 @@ func createTokenIssuer(cfg *appview.Config) (*token.Issuer, error) {
|
||||
}
|
||||
|
||||
// initializeJetstream initializes the Jetstream workers for real-time events and backfill
|
||||
func initializeJetstream(database *sql.DB, jetstreamCfg *appview.JetstreamConfig, defaultHoldDID string, testMode bool) {
|
||||
func initializeJetstream(database *sql.DB, jetstreamCfg *appview.JetstreamConfig, defaultHoldDID string, testMode bool, refresher *oauth.Refresher) {
|
||||
// Start Jetstream worker
|
||||
jetstreamURL := jetstreamCfg.URL
|
||||
|
||||
@@ -544,7 +519,7 @@ func initializeJetstream(database *sql.DB, jetstreamCfg *appview.JetstreamConfig
|
||||
// Get relay endpoint for sync API (defaults to Bluesky's relay)
|
||||
relayEndpoint := jetstreamCfg.RelayEndpoint
|
||||
|
||||
backfillWorker, err := jetstream.NewBackfillWorker(database, relayEndpoint, defaultHoldDID, testMode)
|
||||
backfillWorker, err := jetstream.NewBackfillWorker(database, relayEndpoint, defaultHoldDID, testMode, refresher)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to create backfill worker", "component", "jetstream/backfill", "error", err)
|
||||
} else {
|
||||
|
||||
84
docs/HOLD_XRPC_ENDPOINTS.md
Normal file
84
docs/HOLD_XRPC_ENDPOINTS.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# Hold Service XRPC Endpoints
|
||||
|
||||
This document lists all XRPC endpoints implemented in the Hold service (`pkg/hold/`).
|
||||
|
||||
## PDS Endpoints (`pkg/hold/pds/xrpc.go`)
|
||||
|
||||
### Public (No Auth Required)
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/xrpc/_health` | GET | Health check |
|
||||
| `/xrpc/com.atproto.server.describeServer` | GET | Server metadata |
|
||||
| `/xrpc/com.atproto.repo.describeRepo` | GET | Repository information |
|
||||
| `/xrpc/com.atproto.repo.getRecord` | GET | Retrieve a single record |
|
||||
| `/xrpc/com.atproto.repo.listRecords` | GET | List records in a collection (paginated) |
|
||||
| `/xrpc/com.atproto.sync.listRepos` | GET | List all repositories |
|
||||
| `/xrpc/com.atproto.sync.getRecord` | GET | Get record as CAR file |
|
||||
| `/xrpc/com.atproto.sync.getRepo` | GET | Full repository as CAR file |
|
||||
| `/xrpc/com.atproto.sync.getRepoStatus` | GET | Repository hosting status |
|
||||
| `/xrpc/com.atproto.sync.subscribeRepos` | GET | WebSocket firehose |
|
||||
| `/xrpc/com.atproto.identity.resolveHandle` | GET | Resolve handle to DID |
|
||||
| `/xrpc/app.bsky.actor.getProfile` | GET | Get actor profile |
|
||||
| `/xrpc/app.bsky.actor.getProfiles` | GET | Get multiple profiles |
|
||||
| `/.well-known/did.json` | GET | DID document |
|
||||
| `/.well-known/atproto-did` | GET | DID for handle resolution |
|
||||
|
||||
### Conditional Auth (based on captain.public)
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/xrpc/com.atproto.sync.getBlob` | GET/HEAD | Get blob (routes OCI vs ATProto) |
|
||||
|
||||
### Owner/Crew Admin Required
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/xrpc/com.atproto.repo.deleteRecord` | POST | Delete a record |
|
||||
| `/xrpc/com.atproto.repo.uploadBlob` | POST | Upload ATProto blob |
|
||||
|
||||
### DPoP Auth Required
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/xrpc/io.atcr.hold.requestCrew` | POST | Request crew membership |
|
||||
|
||||
---
|
||||
|
||||
## OCI Multipart Upload Endpoints (`pkg/hold/oci/xrpc.go`)
|
||||
|
||||
All require `blob:write` permission via service token:
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/xrpc/io.atcr.hold.initiateUpload` | POST | Start multipart upload |
|
||||
| `/xrpc/io.atcr.hold.getPartUploadUrl` | POST | Get presigned URL for part |
|
||||
| `/xrpc/io.atcr.hold.uploadPart` | PUT | Direct buffered part upload |
|
||||
| `/xrpc/io.atcr.hold.completeUpload` | POST | Finalize multipart upload |
|
||||
| `/xrpc/io.atcr.hold.abortUpload` | POST | Cancel multipart upload |
|
||||
| `/xrpc/io.atcr.hold.notifyManifest` | POST | Notify manifest push (creates layer records + optional Bluesky post) |
|
||||
|
||||
---
|
||||
|
||||
## Standard ATProto Endpoints (excluding io.atcr.hold.*)
|
||||
|
||||
| Endpoint |
|
||||
|----------|
|
||||
| /xrpc/_health |
|
||||
| /xrpc/com.atproto.server.describeServer |
|
||||
| /xrpc/com.atproto.repo.describeRepo |
|
||||
| /xrpc/com.atproto.repo.getRecord |
|
||||
| /xrpc/com.atproto.repo.listRecords |
|
||||
| /xrpc/com.atproto.repo.deleteRecord |
|
||||
| /xrpc/com.atproto.repo.uploadBlob |
|
||||
| /xrpc/com.atproto.sync.listRepos |
|
||||
| /xrpc/com.atproto.sync.getRecord |
|
||||
| /xrpc/com.atproto.sync.getRepo |
|
||||
| /xrpc/com.atproto.sync.getRepoStatus |
|
||||
| /xrpc/com.atproto.sync.getBlob |
|
||||
| /xrpc/com.atproto.sync.subscribeRepos |
|
||||
| /xrpc/com.atproto.identity.resolveHandle |
|
||||
| /xrpc/app.bsky.actor.getProfile |
|
||||
| /xrpc/app.bsky.actor.getProfiles |
|
||||
| /.well-known/did.json |
|
||||
| /.well-known/atproto-did |
|
||||
@@ -112,7 +112,6 @@ Several packages show decreased percentages despite improvements. This is due to
|
||||
|
||||
**Remaining gaps:**
|
||||
- `notifyHoldAboutManifest()` - 0% (background notification, less critical)
|
||||
- `refreshReadmeCache()` - 11.8% (UI feature, lower priority)
|
||||
|
||||
## Critical Priority: Core Registry Functionality
|
||||
|
||||
@@ -423,12 +422,12 @@ Embedded PDS implementation. Has good test coverage for critical parts, but supp
|
||||
|
||||
---
|
||||
|
||||
### 🟡 pkg/appview/readme (16.7% coverage)
|
||||
### 🟡 pkg/appview/readme (Partial coverage)
|
||||
|
||||
README fetching and caching. Less critical but still needs work.
|
||||
README rendering for repo page descriptions. The cache.go was removed as README content is now stored in `io.atcr.repo.page` records and synced via Jetstream.
|
||||
|
||||
#### cache.go (0% coverage)
|
||||
#### fetcher.go (📊 Partial coverage)
|
||||
- `RenderMarkdown()` - renders repo page description markdown
|
||||
|
||||
---
|
||||
|
||||
|
||||
399
docs/VALKEY_MIGRATION.md
Normal file
399
docs/VALKEY_MIGRATION.md
Normal file
@@ -0,0 +1,399 @@
|
||||
# Analysis: AppView SQL Database Usage
|
||||
|
||||
## Overview
|
||||
|
||||
The AppView uses SQLite with 19 tables. The key finding: **most data is a cache of ATProto records** that could theoretically be rebuilt from users' PDS instances.
|
||||
|
||||
## Data Categories
|
||||
|
||||
### 1. MUST PERSIST (Local State Only)
|
||||
|
||||
These tables contain data that **cannot be reconstructed** from external sources:
|
||||
|
||||
| Table | Purpose | Why It Must Persist |
|
||||
|-------|---------|---------------------|
|
||||
| `oauth_sessions` | OAuth tokens | Refresh tokens are stateful; losing them = users must re-auth |
|
||||
| `ui_sessions` | Web browser sessions | Session continuity for logged-in users |
|
||||
| `devices` | Approved devices + bcrypt secrets | User authorization decisions; secrets are one-way hashed |
|
||||
| `pending_device_auth` | In-flight auth flows | Short-lived (10min) but critical during auth |
|
||||
| `oauth_auth_requests` | OAuth flow state | Short-lived but required for auth completion |
|
||||
| `repository_stats` | Pull/push counts | **Locally tracked metrics** - not stored in ATProto |
|
||||
|
||||
### 2. CACHED FROM PDS (Rebuildable)
|
||||
|
||||
These tables are essentially a **read-through cache** of ATProto data:
|
||||
|
||||
| Table | Source | ATProto Collection |
|
||||
|-------|--------|-------------------|
|
||||
| `users` | User's PDS profile | `app.bsky.actor.profile` + DID document |
|
||||
| `manifests` | User's PDS | `io.atcr.manifest` records |
|
||||
| `tags` | User's PDS | `io.atcr.tag` records |
|
||||
| `layers` | Derived from manifests | Parsed from manifest content |
|
||||
| `manifest_references` | Derived from manifest lists | Parsed from multi-arch manifests |
|
||||
| `repository_annotations` | Manifest config blob | OCI annotations from config |
|
||||
| `repo_pages` | User's PDS | `io.atcr.repo.page` records |
|
||||
| `stars` | User's PDS | `io.atcr.sailor.star` records (synced via Jetstream) |
|
||||
| `hold_captain_records` | Hold's embedded PDS | `io.atcr.hold.captain` records |
|
||||
| `hold_crew_approvals` | Hold's embedded PDS | `io.atcr.hold.crew` records |
|
||||
| `hold_crew_denials` | Local authorization cache | Could re-check on demand |
|
||||
|
||||
### 3. OPERATIONAL
|
||||
|
||||
| Table | Purpose |
|
||||
|-------|---------|
|
||||
| `schema_migrations` | Migration tracking |
|
||||
| `firehose_cursor` | Jetstream position (can restart from 0) |
|
||||
|
||||
## Key Insights
|
||||
|
||||
### What's Actually Unique to AppView?
|
||||
|
||||
1. **Authentication state** - OAuth sessions, devices, UI sessions
|
||||
2. **Engagement metrics** - Pull/push counts (locally tracked, not in ATProto)
|
||||
|
||||
### What Could Be Eliminated?
|
||||
|
||||
If ATCR fully embraced the ATProto model:
|
||||
|
||||
1. **`users`** - Query PDS on demand (with caching)
|
||||
2. **`manifests`, `tags`, `layers`** - Query PDS on demand (with caching)
|
||||
3. **`repository_annotations`** - Fetch manifest config on demand
|
||||
4. **`repo_pages`** - Query PDS on demand
|
||||
5. **`hold_*` tables** - Query hold's PDS on demand
|
||||
|
||||
### Trade-offs
|
||||
|
||||
**Current approach (heavy caching):**
|
||||
- Fast queries for UI (search, browse, stats)
|
||||
- Offline resilience (PDS down doesn't break UI)
|
||||
- Complex sync logic (Jetstream consumer, backfill)
|
||||
- State can diverge from source of truth
|
||||
|
||||
**Lighter approach (query on demand):**
|
||||
- Always fresh data
|
||||
- Simpler codebase (no sync)
|
||||
- Slower queries (network round-trips)
|
||||
- Depends on PDS availability
|
||||
|
||||
## Current Limitation: No Cache-Miss Queries
|
||||
|
||||
**Finding:** There's no "query PDS on cache miss" logic. Users/manifests only enter the DB via:
|
||||
1. OAuth login (user authenticates)
|
||||
2. Jetstream events (firehose activity)
|
||||
|
||||
**Problem:** If someone visits `atcr.io/alice/myapp` before alice is indexed → 404
|
||||
|
||||
**Where this happens:**
|
||||
- `pkg/appview/handlers/repository.go:50-53`: If `db.GetUserByDID()` returns nil → 404
|
||||
- No fallback to `atproto.Client.ListRecords()` or similar
|
||||
|
||||
**This matters for Valkey migration:** If cache is ephemeral and restarts clear it, you need cache-miss logic to repopulate on demand. Otherwise:
|
||||
- Restart Valkey → all users/manifests gone
|
||||
- Wait for Jetstream to re-index OR implement cache-miss queries
|
||||
|
||||
**Cache-miss implementation design:**
|
||||
|
||||
Existing code to reuse: `pkg/appview/jetstream/processor.go:43-97` (`EnsureUser`)
|
||||
|
||||
```go
|
||||
// New: pkg/appview/cache/loader.go
|
||||
|
||||
type Loader struct {
|
||||
cache Cache // Valkey interface
|
||||
client *atproto.Client
|
||||
}
|
||||
|
||||
// GetUser with cache-miss fallback
|
||||
func (l *Loader) GetUser(ctx context.Context, did string) (*User, error) {
|
||||
// 1. Try cache
|
||||
if user := l.cache.GetUser(did); user != nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// 2. Cache miss - resolve identity (already queries network)
|
||||
_, handle, pdsEndpoint, err := atproto.ResolveIdentity(ctx, did)
|
||||
if err != nil {
|
||||
return nil, err // User doesn't exist in network
|
||||
}
|
||||
|
||||
// 3. Fetch profile for avatar
|
||||
client := atproto.NewClient(pdsEndpoint, "", "")
|
||||
profile, _ := client.GetProfileRecord(ctx, did)
|
||||
avatarURL := ""
|
||||
if profile != nil && profile.Avatar != nil {
|
||||
avatarURL = atproto.BlobCDNURL(did, profile.Avatar.Ref.Link)
|
||||
}
|
||||
|
||||
// 4. Cache and return
|
||||
user := &User{DID: did, Handle: handle, PDSEndpoint: pdsEndpoint, Avatar: avatarURL}
|
||||
l.cache.SetUser(user, 1*time.Hour)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetManifestsForRepo with cache-miss fallback
|
||||
func (l *Loader) GetManifestsForRepo(ctx context.Context, did, repo string) ([]Manifest, error) {
|
||||
cacheKey := fmt.Sprintf("manifests:%s:%s", did, repo)
|
||||
|
||||
// 1. Try cache
|
||||
if cached := l.cache.Get(cacheKey); cached != nil {
|
||||
return cached.([]Manifest), nil
|
||||
}
|
||||
|
||||
// 2. Cache miss - get user's PDS endpoint
|
||||
user, err := l.GetUser(ctx, did)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Query PDS for manifests
|
||||
client := atproto.NewClient(user.PDSEndpoint, "", "")
|
||||
records, _, err := client.ListRecordsForRepo(ctx, did, atproto.ManifestCollection, 100, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Filter by repository and parse
|
||||
var manifests []Manifest
|
||||
for _, rec := range records {
|
||||
var m atproto.ManifestRecord
|
||||
if err := json.Unmarshal(rec.Value, &m); err != nil {
|
||||
continue
|
||||
}
|
||||
if m.Repository == repo {
|
||||
manifests = append(manifests, convertManifest(m))
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Cache and return
|
||||
l.cache.Set(cacheKey, manifests, 10*time.Minute)
|
||||
return manifests, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Handler changes:**
|
||||
```go
|
||||
// Before (repository.go:45-53):
|
||||
owner, err := db.GetUserByDID(h.DB, did)
|
||||
if owner == nil {
|
||||
RenderNotFound(w, r, h.Templates, h.RegistryURL)
|
||||
return
|
||||
}
|
||||
|
||||
// After:
|
||||
owner, err := h.Loader.GetUser(r.Context(), did)
|
||||
if err != nil {
|
||||
RenderNotFound(w, r, h.Templates, h.RegistryURL)
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
**Performance considerations:**
|
||||
- Cache hit: ~1ms (Valkey lookup)
|
||||
- Cache miss: ~200-500ms (PDS round-trip)
|
||||
- First request after restart: slower but correct
|
||||
- Jetstream still useful for proactive warming
|
||||
|
||||
---
|
||||
|
||||
## Proposed Architecture: Valkey + ATProto
|
||||
|
||||
### Goal
|
||||
Replace SQLite with Valkey (Redis-compatible) for ephemeral state, push remaining persistent data to ATProto.
|
||||
|
||||
### What goes to Valkey (ephemeral, TTL-based)
|
||||
|
||||
| Current Table | Valkey Key Pattern | TTL | Notes |
|
||||
|---------------|-------------------|-----|-------|
|
||||
| `oauth_sessions` | `oauth:{did}:{session_id}` | 90 days | Lost on restart = re-auth |
|
||||
| `ui_sessions` | `ui:{session_id}` | Session duration | Lost on restart = re-login |
|
||||
| `oauth_auth_requests` | `authreq:{state}` | 10 min | In-flight flows |
|
||||
| `pending_device_auth` | `pending:{device_code}` | 10 min | In-flight flows |
|
||||
| `firehose_cursor` | `cursor:jetstream` | None | Can restart from 0 |
|
||||
| All PDS cache tables | `cache:{collection}:{did}:{rkey}` | 10-60 min | Query PDS on miss |
|
||||
|
||||
**Benefits:**
|
||||
- Multi-instance ready (shared Valkey)
|
||||
- No schema migrations
|
||||
- Natural TTL expiry
|
||||
- Simpler code (no SQL)
|
||||
|
||||
### What could become ATProto records
|
||||
|
||||
| Current Table | Proposed Collection | Where Stored | Open Questions |
|
||||
|---------------|---------------------|--------------|----------------|
|
||||
| `devices` | `io.atcr.sailor.device` | User's PDS | Privacy: IP, user-agent sensitive? |
|
||||
| `repository_stats` | `io.atcr.repo.stats` | Hold's PDS or User's PDS | Who owns the stats? |
|
||||
|
||||
**Devices → Valkey:**
|
||||
- Move current device table to Valkey
|
||||
- Key: `device:{did}:{device_id}` → `{name, secret_hash, ip, user_agent, created_at, last_used}`
|
||||
- TTL: Long (1 year?) or no expiry
|
||||
- Device list: `devices:{did}` → Set of device IDs
|
||||
- Secret validation works the same, just different backend
|
||||
|
||||
**Service auth exploration (future):**
|
||||
The challenge with pure ATProto service auth is the AppView still needs the user's OAuth session to write manifests to their PDS. The current flow:
|
||||
1. User authenticates via OAuth → AppView gets OAuth tokens
|
||||
2. AppView issues registry JWT to credential helper
|
||||
3. Credential helper presents JWT on each push/pull
|
||||
4. AppView uses OAuth session to write to user's PDS
|
||||
|
||||
Service auth could work for the hold side (AppView → Hold), but not for the user's OAuth session.
|
||||
|
||||
**Repository stats → Hold's PDS:**
|
||||
|
||||
**Challenge discovered:** The hold's `getBlob` endpoint only receives `did` + `cid`, not the repository name.
|
||||
|
||||
Current flow (`proxy_blob_store.go:358-362`):
|
||||
```go
|
||||
xrpcURL := fmt.Sprintf("%s%s?did=%s&cid=%s&method=%s",
|
||||
p.holdURL, atproto.SyncGetBlob, p.ctx.DID, dgst.String(), operation)
|
||||
```
|
||||
|
||||
**Implementation options:**
|
||||
|
||||
**Option A: Add repository parameter to getBlob (recommended)**
|
||||
```go
|
||||
// Modified AppView call:
|
||||
xrpcURL := fmt.Sprintf("%s%s?did=%s&cid=%s&method=%s&repo=%s",
|
||||
p.holdURL, atproto.SyncGetBlob, p.ctx.DID, dgst.String(), operation, p.ctx.Repository)
|
||||
```
|
||||
|
||||
```go
|
||||
// Modified hold handler (xrpc.go:969):
|
||||
func (h *XRPCHandler) HandleGetBlob(w http.ResponseWriter, r *http.Request) {
|
||||
did := r.URL.Query().Get("did")
|
||||
cidOrDigest := r.URL.Query().Get("cid")
|
||||
repo := r.URL.Query().Get("repo") // NEW
|
||||
|
||||
// ... existing blob handling ...
|
||||
|
||||
// Increment stats if repo provided
|
||||
if repo != "" {
|
||||
go h.pds.IncrementPullCount(did, repo) // Async, non-blocking
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Stats record structure:**
|
||||
```
|
||||
Collection: io.atcr.hold.stats
|
||||
Rkey: base64(did:repository) // Deterministic, unique
|
||||
|
||||
{
|
||||
"$type": "io.atcr.hold.stats",
|
||||
"did": "did:plc:alice123",
|
||||
"repository": "myapp",
|
||||
"pullCount": 1542,
|
||||
"pushCount": 47,
|
||||
"lastPull": "2025-01-15T...",
|
||||
"lastPush": "2025-01-10T...",
|
||||
"createdAt": "2025-01-01T..."
|
||||
}
|
||||
```
|
||||
|
||||
**Hold-side implementation:**
|
||||
```go
|
||||
// New file: pkg/hold/pds/stats.go
|
||||
|
||||
func (p *HoldPDS) IncrementPullCount(ctx context.Context, did, repo string) error {
|
||||
rkey := statsRecordKey(did, repo)
|
||||
|
||||
// Get or create stats record
|
||||
stats, err := p.GetStatsRecord(ctx, rkey)
|
||||
if err != nil || stats == nil {
|
||||
stats = &atproto.StatsRecord{
|
||||
Type: atproto.StatsCollection,
|
||||
DID: did,
|
||||
Repository: repo,
|
||||
PullCount: 0,
|
||||
PushCount: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Increment and update
|
||||
stats.PullCount++
|
||||
stats.LastPull = time.Now()
|
||||
|
||||
_, err = p.repomgr.UpdateRecord(ctx, p.uid, atproto.StatsCollection, rkey, stats)
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
**Query endpoint (new XRPC):**
|
||||
```
|
||||
GET /xrpc/io.atcr.hold.getStats?did={userDID}&repo={repository}
|
||||
→ Returns JSON: { pullCount, pushCount, lastPull, lastPush }
|
||||
|
||||
GET /xrpc/io.atcr.hold.listStats?did={userDID}
|
||||
→ Returns all stats for a user across all repos on this hold
|
||||
```
|
||||
|
||||
**AppView aggregation:**
|
||||
```go
|
||||
func (l *Loader) GetAggregatedStats(ctx context.Context, did, repo string) (*Stats, error) {
|
||||
// 1. Get all holds that have served this repo
|
||||
holdDIDs, _ := l.cache.GetHoldDIDsForRepo(did, repo)
|
||||
|
||||
// 2. Query each hold for stats
|
||||
var total Stats
|
||||
for _, holdDID := range holdDIDs {
|
||||
holdURL := resolveHoldDID(holdDID)
|
||||
stats, _ := queryHoldStats(ctx, holdURL, did, repo)
|
||||
total.PullCount += stats.PullCount
|
||||
total.PushCount += stats.PushCount
|
||||
}
|
||||
|
||||
return &total, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Files to modify:**
|
||||
- `pkg/atproto/lexicon.go` - Add `StatsCollection` + `StatsRecord`
|
||||
- `pkg/hold/pds/stats.go` - New file for stats operations
|
||||
- `pkg/hold/pds/xrpc.go` - Add `repo` param to getBlob, add stats endpoints
|
||||
- `pkg/appview/storage/proxy_blob_store.go` - Pass repository to getBlob
|
||||
- `pkg/appview/cache/loader.go` - Aggregation logic
|
||||
|
||||
### Migration Path
|
||||
|
||||
**Phase 1: Add Valkey infrastructure**
|
||||
- Add Valkey client to AppView
|
||||
- Create store interfaces that abstract SQLite vs Valkey
|
||||
- Dual-write OAuth sessions to both
|
||||
|
||||
**Phase 2: Migrate sessions to Valkey**
|
||||
- OAuth sessions, UI sessions, auth requests, pending device auth
|
||||
- Remove SQLite session tables
|
||||
- Test: restart AppView, users get logged out (acceptable)
|
||||
|
||||
**Phase 3: Migrate devices to Valkey**
|
||||
- Move device store to Valkey
|
||||
- Same data structure, different backend
|
||||
- Consider device expiry policy
|
||||
|
||||
**Phase 4: Implement hold-side stats**
|
||||
- Add `io.atcr.hold.stats` collection to hold's embedded PDS
|
||||
- Hold increments stats on blob access
|
||||
- Add XRPC endpoint: `io.atcr.hold.getStats`
|
||||
|
||||
**Phase 5: AppView stats aggregation**
|
||||
- Track holdDids per repo in Valkey cache
|
||||
- Query holds for stats, aggregate
|
||||
- Cache aggregated stats with TTL
|
||||
|
||||
**Phase 6: Remove SQLite (optional)**
|
||||
- Keep SQLite as optional cache layer for UI queries
|
||||
- Or: Query PDS on demand with Valkey caching
|
||||
- Jetstream still useful for real-time updates
|
||||
|
||||
## Summary Table
|
||||
|
||||
| Category | Tables | % of Schema | Truly Persistent? |
|
||||
|----------|--------|-------------|-------------------|
|
||||
| Auth & Sessions + Metrics | 6 | 32% | Yes |
|
||||
| PDS Cache | 11 | 58% | No (rebuildable) |
|
||||
| Operational | 2 | 10% | No |
|
||||
|
||||
**~58% of the database is cached ATProto data that could be rebuilt from PDSes.**
|
||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module atcr.io
|
||||
|
||||
go 1.25.5
|
||||
go 1.25.4
|
||||
|
||||
require (
|
||||
github.com/aws/aws-sdk-go v1.55.5
|
||||
|
||||
21
lexicons/io/atcr/authFullApp.json
Normal file
21
lexicons/io/atcr/authFullApp.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"lexicon": 1,
|
||||
"id": "io.atcr.authFullApp",
|
||||
"defs": {
|
||||
"main": {
|
||||
"type": "permission-set",
|
||||
"title": "AT Container Registry",
|
||||
"title:langs": {},
|
||||
"detail": "Push and pull container images to the ATProto Container Registry. Includes creating and managing image manifests, tags, and repository settings.",
|
||||
"detail:langs": {},
|
||||
"permissions": [
|
||||
{
|
||||
"type": "permission",
|
||||
"resource": "repo",
|
||||
"action": ["create", "update", "delete"],
|
||||
"collection": ["io.atcr.manifest", "io.atcr.tag", "io.atcr.sailor.star", "io.atcr.sailor.profile", "io.atcr.repo.page"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -34,11 +34,13 @@
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "S3 region where blobs are stored"
|
||||
"description": "S3 region where blobs are stored",
|
||||
"maxLength": 64
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": "Deployment provider (e.g., fly.io, aws, etc.)"
|
||||
"description": "Deployment provider (e.g., fly.io, aws, etc.)",
|
||||
"maxLength": 64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,13 +18,15 @@
|
||||
"role": {
|
||||
"type": "string",
|
||||
"description": "Member's role in the hold",
|
||||
"knownValues": ["owner", "admin", "write", "read"]
|
||||
"knownValues": ["owner", "admin", "write", "read"],
|
||||
"maxLength": 32
|
||||
},
|
||||
"permissions": {
|
||||
"type": "array",
|
||||
"description": "Specific permissions granted to this member",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"maxLength": 64
|
||||
}
|
||||
},
|
||||
"addedAt": {
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
"properties": {
|
||||
"digest": {
|
||||
"type": "string",
|
||||
"description": "Layer digest (e.g., sha256:abc123...)"
|
||||
"description": "Layer digest (e.g., sha256:abc123...)",
|
||||
"maxLength": 128
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
@@ -20,11 +21,13 @@
|
||||
},
|
||||
"mediaType": {
|
||||
"type": "string",
|
||||
"description": "Media type (e.g., application/vnd.oci.image.layer.v1.tar+gzip)"
|
||||
"description": "Media type (e.g., application/vnd.oci.image.layer.v1.tar+gzip)",
|
||||
"maxLength": 128
|
||||
},
|
||||
"repository": {
|
||||
"type": "string",
|
||||
"description": "Repository this layer belongs to"
|
||||
"description": "Repository this layer belongs to",
|
||||
"maxLength": 255
|
||||
},
|
||||
"userDid": {
|
||||
"type": "string",
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
},
|
||||
"digest": {
|
||||
"type": "string",
|
||||
"description": "Content digest (e.g., 'sha256:abc123...')"
|
||||
"description": "Content digest (e.g., 'sha256:abc123...')",
|
||||
"maxLength": 128
|
||||
},
|
||||
"holdDid": {
|
||||
"type": "string",
|
||||
@@ -37,7 +38,8 @@
|
||||
"application/vnd.docker.distribution.manifest.v2+json",
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
"application/vnd.docker.distribution.manifest.list.v2+json"
|
||||
]
|
||||
],
|
||||
"maxLength": 128
|
||||
},
|
||||
"schemaVersion": {
|
||||
"type": "integer",
|
||||
@@ -65,8 +67,8 @@
|
||||
"description": "Referenced manifests (for manifest lists/indexes)"
|
||||
},
|
||||
"annotations": {
|
||||
"type": "object",
|
||||
"description": "Optional metadata annotations"
|
||||
"type": "unknown",
|
||||
"description": "Optional OCI annotation metadata. Map of string keys to string values (e.g., org.opencontainers.image.title → 'My App')."
|
||||
},
|
||||
"subject": {
|
||||
"type": "ref",
|
||||
@@ -92,7 +94,8 @@
|
||||
"properties": {
|
||||
"mediaType": {
|
||||
"type": "string",
|
||||
"description": "MIME type of the blob"
|
||||
"description": "MIME type of the blob",
|
||||
"maxLength": 128
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
@@ -100,7 +103,8 @@
|
||||
},
|
||||
"digest": {
|
||||
"type": "string",
|
||||
"description": "Content digest (e.g., 'sha256:...')"
|
||||
"description": "Content digest (e.g., 'sha256:...')",
|
||||
"maxLength": 128
|
||||
},
|
||||
"urls": {
|
||||
"type": "array",
|
||||
@@ -111,8 +115,8 @@
|
||||
"description": "Optional direct URLs to blob (for BYOS)"
|
||||
},
|
||||
"annotations": {
|
||||
"type": "object",
|
||||
"description": "Optional metadata"
|
||||
"type": "unknown",
|
||||
"description": "Optional OCI annotation metadata. Map of string keys to string values."
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -123,7 +127,8 @@
|
||||
"properties": {
|
||||
"mediaType": {
|
||||
"type": "string",
|
||||
"description": "Media type of the referenced manifest"
|
||||
"description": "Media type of the referenced manifest",
|
||||
"maxLength": 128
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
@@ -131,7 +136,8 @@
|
||||
},
|
||||
"digest": {
|
||||
"type": "string",
|
||||
"description": "Content digest (e.g., 'sha256:...')"
|
||||
"description": "Content digest (e.g., 'sha256:...')",
|
||||
"maxLength": 128
|
||||
},
|
||||
"platform": {
|
||||
"type": "ref",
|
||||
@@ -139,8 +145,8 @@
|
||||
"description": "Platform information for this manifest"
|
||||
},
|
||||
"annotations": {
|
||||
"type": "object",
|
||||
"description": "Optional metadata"
|
||||
"type": "unknown",
|
||||
"description": "Optional OCI annotation metadata. Map of string keys to string values."
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -151,26 +157,31 @@
|
||||
"properties": {
|
||||
"architecture": {
|
||||
"type": "string",
|
||||
"description": "CPU architecture (e.g., 'amd64', 'arm64', 'arm')"
|
||||
"description": "CPU architecture (e.g., 'amd64', 'arm64', 'arm')",
|
||||
"maxLength": 32
|
||||
},
|
||||
"os": {
|
||||
"type": "string",
|
||||
"description": "Operating system (e.g., 'linux', 'windows', 'darwin')"
|
||||
"description": "Operating system (e.g., 'linux', 'windows', 'darwin')",
|
||||
"maxLength": 32
|
||||
},
|
||||
"osVersion": {
|
||||
"type": "string",
|
||||
"description": "Optional OS version"
|
||||
"description": "Optional OS version",
|
||||
"maxLength": 64
|
||||
},
|
||||
"osFeatures": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"maxLength": 64
|
||||
},
|
||||
"description": "Optional OS features"
|
||||
},
|
||||
"variant": {
|
||||
"type": "string",
|
||||
"description": "Optional CPU variant (e.g., 'v7' for ARM)"
|
||||
"description": "Optional CPU variant (e.g., 'v7' for ARM)",
|
||||
"maxLength": 32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
43
lexicons/io/atcr/repo/page.json
Normal file
43
lexicons/io/atcr/repo/page.json
Normal file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"lexicon": 1,
|
||||
"id": "io.atcr.repo.page",
|
||||
"defs": {
|
||||
"main": {
|
||||
"type": "record",
|
||||
"description": "Repository page metadata including description and avatar. Users can edit this directly in their PDS to customize their repository page.",
|
||||
"key": "any",
|
||||
"record": {
|
||||
"type": "object",
|
||||
"required": ["repository", "createdAt", "updatedAt"],
|
||||
"properties": {
|
||||
"repository": {
|
||||
"type": "string",
|
||||
"description": "The name of the repository (e.g., 'myapp'). Must match the rkey.",
|
||||
"maxLength": 256
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Markdown README/description content for the repository page.",
|
||||
"maxLength": 100000
|
||||
},
|
||||
"avatar": {
|
||||
"type": "blob",
|
||||
"description": "Repository avatar/icon image.",
|
||||
"accept": ["image/png", "image/jpeg", "image/webp"],
|
||||
"maxSize": 3000000
|
||||
},
|
||||
"createdAt": {
|
||||
"type": "string",
|
||||
"format": "datetime",
|
||||
"description": "Record creation timestamp"
|
||||
},
|
||||
"updatedAt": {
|
||||
"type": "string",
|
||||
"format": "datetime",
|
||||
"description": "Record last updated timestamp"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,7 +27,8 @@
|
||||
},
|
||||
"manifestDigest": {
|
||||
"type": "string",
|
||||
"description": "DEPRECATED: Digest of the manifest (e.g., 'sha256:...'). Kept for backward compatibility with old records. New records should use 'manifest' field instead."
|
||||
"description": "DEPRECATED: Digest of the manifest (e.g., 'sha256:...'). Kept for backward compatibility with old records. New records should use 'manifest' field instead.",
|
||||
"maxLength": 128
|
||||
},
|
||||
"createdAt": {
|
||||
"type": "string",
|
||||
|
||||
@@ -79,9 +79,6 @@ type HealthConfig struct {
|
||||
|
||||
// CheckInterval is the hold health check refresh interval (from env: ATCR_HEALTH_CHECK_INTERVAL, default: 15m)
|
||||
CheckInterval time.Duration `yaml:"check_interval"`
|
||||
|
||||
// ReadmeCacheTTL is the README cache TTL (from env: ATCR_README_CACHE_TTL, default: 1h)
|
||||
ReadmeCacheTTL time.Duration `yaml:"readme_cache_ttl"`
|
||||
}
|
||||
|
||||
// JetstreamConfig defines ATProto Jetstream settings
|
||||
@@ -165,7 +162,6 @@ func LoadConfigFromEnv() (*Config, error) {
|
||||
// Health and cache configuration
|
||||
cfg.Health.CacheTTL = getDurationOrDefault("ATCR_HEALTH_CACHE_TTL", 15*time.Minute)
|
||||
cfg.Health.CheckInterval = getDurationOrDefault("ATCR_HEALTH_CHECK_INTERVAL", 15*time.Minute)
|
||||
cfg.Health.ReadmeCacheTTL = getDurationOrDefault("ATCR_README_CACHE_TTL", 1*time.Hour)
|
||||
|
||||
// Jetstream configuration
|
||||
cfg.Jetstream.URL = getEnvOrDefault("JETSTREAM_URL", "wss://jetstream2.us-west.bsky.network/subscribe")
|
||||
|
||||
18
pkg/appview/db/migrations/0006_add_repo_pages.yaml
Normal file
18
pkg/appview/db/migrations/0006_add_repo_pages.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
description: Add repo_pages table and remove readme_cache
|
||||
query: |
|
||||
-- Create repo_pages table for storing repository page metadata
|
||||
-- This replaces readme_cache with PDS-synced data
|
||||
CREATE TABLE IF NOT EXISTS repo_pages (
|
||||
did TEXT NOT NULL,
|
||||
repository TEXT NOT NULL,
|
||||
description TEXT,
|
||||
avatar_cid TEXT,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL,
|
||||
PRIMARY KEY(did, repository),
|
||||
FOREIGN KEY(did) REFERENCES users(did) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_repo_pages_did ON repo_pages(did);
|
||||
|
||||
-- Drop readme_cache table (no longer needed)
|
||||
DROP TABLE IF EXISTS readme_cache;
|
||||
@@ -148,8 +148,9 @@ type PlatformInfo struct {
|
||||
// TagWithPlatforms extends Tag with platform information
|
||||
type TagWithPlatforms struct {
|
||||
Tag
|
||||
Platforms []PlatformInfo
|
||||
IsMultiArch bool
|
||||
Platforms []PlatformInfo
|
||||
IsMultiArch bool
|
||||
HasAttestations bool // true if manifest list contains attestation references
|
||||
}
|
||||
|
||||
// ManifestWithMetadata extends Manifest with tags and platform information
|
||||
|
||||
@@ -7,6 +7,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// BlobCDNURL returns the CDN URL for an ATProto blob
|
||||
// This is a local copy to avoid importing atproto (prevents circular dependencies)
|
||||
func BlobCDNURL(did, cid string) string {
|
||||
return fmt.Sprintf("https://imgs.blue/%s/%s", did, cid)
|
||||
}
|
||||
|
||||
// escapeLikePattern escapes SQL LIKE wildcards (%, _) and backslash for safe searching.
|
||||
// It also sanitizes the input to prevent injection attacks via special characters.
|
||||
func escapeLikePattern(s string) string {
|
||||
@@ -46,11 +52,13 @@ func GetRecentPushes(db *sql.DB, limit, offset int, userFilter string, currentUs
|
||||
COALESCE((SELECT COUNT(*) FROM stars WHERE owner_did = u.did AND repository = t.repository), 0),
|
||||
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = u.did AND repository = t.repository), 0),
|
||||
t.created_at,
|
||||
m.hold_endpoint
|
||||
m.hold_endpoint,
|
||||
COALESCE(rp.avatar_cid, '')
|
||||
FROM tags t
|
||||
JOIN users u ON t.did = u.did
|
||||
JOIN manifests m ON t.did = m.did AND t.repository = m.repository AND t.digest = m.digest
|
||||
LEFT JOIN repository_stats rs ON t.did = rs.did AND t.repository = rs.repository
|
||||
LEFT JOIN repo_pages rp ON t.did = rp.did AND t.repository = rp.repository
|
||||
`
|
||||
|
||||
args := []any{currentUserDID}
|
||||
@@ -73,10 +81,15 @@ func GetRecentPushes(db *sql.DB, limit, offset int, userFilter string, currentUs
|
||||
for rows.Next() {
|
||||
var p Push
|
||||
var isStarredInt int
|
||||
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint); err != nil {
|
||||
var avatarCID string
|
||||
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint, &avatarCID); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
p.IsStarred = isStarredInt > 0
|
||||
// Prefer repo page avatar over annotation icon
|
||||
if avatarCID != "" {
|
||||
p.IconURL = BlobCDNURL(p.DID, avatarCID)
|
||||
}
|
||||
pushes = append(pushes, p)
|
||||
}
|
||||
|
||||
@@ -119,11 +132,13 @@ func SearchPushes(db *sql.DB, query string, limit, offset int, currentUserDID st
|
||||
COALESCE((SELECT COUNT(*) FROM stars WHERE owner_did = u.did AND repository = t.repository), 0),
|
||||
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = u.did AND repository = t.repository), 0),
|
||||
t.created_at,
|
||||
m.hold_endpoint
|
||||
m.hold_endpoint,
|
||||
COALESCE(rp.avatar_cid, '')
|
||||
FROM tags t
|
||||
JOIN users u ON t.did = u.did
|
||||
JOIN manifests m ON t.did = m.did AND t.repository = m.repository AND t.digest = m.digest
|
||||
LEFT JOIN repository_stats rs ON t.did = rs.did AND t.repository = rs.repository
|
||||
LEFT JOIN repo_pages rp ON t.did = rp.did AND t.repository = rp.repository
|
||||
WHERE u.handle LIKE ? ESCAPE '\'
|
||||
OR u.did = ?
|
||||
OR t.repository LIKE ? ESCAPE '\'
|
||||
@@ -146,10 +161,15 @@ func SearchPushes(db *sql.DB, query string, limit, offset int, currentUserDID st
|
||||
for rows.Next() {
|
||||
var p Push
|
||||
var isStarredInt int
|
||||
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint); err != nil {
|
||||
var avatarCID string
|
||||
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint, &avatarCID); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
p.IsStarred = isStarredInt > 0
|
||||
// Prefer repo page avatar over annotation icon
|
||||
if avatarCID != "" {
|
||||
p.IconURL = BlobCDNURL(p.DID, avatarCID)
|
||||
}
|
||||
pushes = append(pushes, p)
|
||||
}
|
||||
|
||||
@@ -293,6 +313,12 @@ func GetUserRepositories(db *sql.DB, did string) ([]Repository, error) {
|
||||
r.IconURL = annotations["io.atcr.icon"]
|
||||
r.ReadmeURL = annotations["io.atcr.readme"]
|
||||
|
||||
// Check for repo page avatar (overrides annotation icon)
|
||||
repoPage, err := GetRepoPage(db, did, r.Name)
|
||||
if err == nil && repoPage != nil && repoPage.AvatarCID != "" {
|
||||
r.IconURL = BlobCDNURL(did, repoPage.AvatarCID)
|
||||
}
|
||||
|
||||
repos = append(repos, r)
|
||||
}
|
||||
|
||||
@@ -596,6 +622,7 @@ func DeleteTag(db *sql.DB, did, repository, tag string) error {
|
||||
// GetTagsWithPlatforms returns all tags for a repository with platform information
|
||||
// Only multi-arch tags (manifest lists) have platform info in manifest_references
|
||||
// Single-arch tags will have empty Platforms slice (platform is obvious for single-arch)
|
||||
// Attestation references (unknown/unknown platforms) are filtered out but tracked via HasAttestations
|
||||
func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatforms, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
@@ -609,7 +636,8 @@ func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatform
|
||||
COALESCE(mr.platform_os, '') as platform_os,
|
||||
COALESCE(mr.platform_architecture, '') as platform_architecture,
|
||||
COALESCE(mr.platform_variant, '') as platform_variant,
|
||||
COALESCE(mr.platform_os_version, '') as platform_os_version
|
||||
COALESCE(mr.platform_os_version, '') as platform_os_version,
|
||||
COALESCE(mr.is_attestation, 0) as is_attestation
|
||||
FROM tags t
|
||||
JOIN manifests m ON t.digest = m.digest AND t.did = m.did AND t.repository = m.repository
|
||||
LEFT JOIN manifest_references mr ON m.id = mr.manifest_id
|
||||
@@ -629,9 +657,10 @@ func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatform
|
||||
for rows.Next() {
|
||||
var t Tag
|
||||
var mediaType, platformOS, platformArch, platformVariant, platformOSVersion string
|
||||
var isAttestation bool
|
||||
|
||||
if err := rows.Scan(&t.ID, &t.DID, &t.Repository, &t.Tag, &t.Digest, &t.CreatedAt,
|
||||
&mediaType, &platformOS, &platformArch, &platformVariant, &platformOSVersion); err != nil {
|
||||
&mediaType, &platformOS, &platformArch, &platformVariant, &platformOSVersion, &isAttestation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -645,6 +674,13 @@ func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatform
|
||||
tagOrder = append(tagOrder, tagKey)
|
||||
}
|
||||
|
||||
// Track if manifest list has attestations
|
||||
if isAttestation {
|
||||
tagMap[tagKey].HasAttestations = true
|
||||
// Skip attestation references in platform display
|
||||
continue
|
||||
}
|
||||
|
||||
// Add platform info if present (only for multi-arch manifest lists)
|
||||
if platformOS != "" || platformArch != "" {
|
||||
tagMap[tagKey].Platforms = append(tagMap[tagKey].Platforms, PlatformInfo{
|
||||
@@ -1598,31 +1634,6 @@ func parseTimestamp(s string) (time.Time, error) {
|
||||
return time.Time{}, fmt.Errorf("unable to parse timestamp: %s", s)
|
||||
}
|
||||
|
||||
// MetricsDB wraps a sql.DB and implements the metrics interface for middleware
|
||||
type MetricsDB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewMetricsDB creates a new metrics database wrapper
|
||||
func NewMetricsDB(db *sql.DB) *MetricsDB {
|
||||
return &MetricsDB{db: db}
|
||||
}
|
||||
|
||||
// IncrementPullCount increments the pull count for a repository
|
||||
func (m *MetricsDB) IncrementPullCount(did, repository string) error {
|
||||
return IncrementPullCount(m.db, did, repository)
|
||||
}
|
||||
|
||||
// IncrementPushCount increments the push count for a repository
|
||||
func (m *MetricsDB) IncrementPushCount(did, repository string) error {
|
||||
return IncrementPushCount(m.db, did, repository)
|
||||
}
|
||||
|
||||
// GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository
|
||||
func (m *MetricsDB) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
||||
return GetLatestHoldDIDForRepo(m.db, did, repository)
|
||||
}
|
||||
|
||||
// GetFeaturedRepositories fetches top repositories sorted by stars and pulls
|
||||
func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]FeaturedRepository, error) {
|
||||
query := `
|
||||
@@ -1650,11 +1661,13 @@ func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]Fe
|
||||
COALESCE((SELECT value FROM repository_annotations WHERE did = m.did AND repository = m.repository AND key = 'io.atcr.icon'), ''),
|
||||
rs.pull_count,
|
||||
rs.star_count,
|
||||
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = m.did AND repository = m.repository), 0)
|
||||
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = m.did AND repository = m.repository), 0),
|
||||
COALESCE(rp.avatar_cid, '')
|
||||
FROM latest_manifests lm
|
||||
JOIN manifests m ON lm.latest_id = m.id
|
||||
JOIN users u ON m.did = u.did
|
||||
JOIN repo_stats rs ON m.did = rs.did AND m.repository = rs.repository
|
||||
LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
|
||||
ORDER BY rs.score DESC, rs.star_count DESC, rs.pull_count DESC, m.created_at DESC
|
||||
LIMIT ?
|
||||
`
|
||||
@@ -1669,15 +1682,88 @@ func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]Fe
|
||||
for rows.Next() {
|
||||
var f FeaturedRepository
|
||||
var isStarredInt int
|
||||
var avatarCID string
|
||||
|
||||
if err := rows.Scan(&f.OwnerDID, &f.OwnerHandle, &f.Repository,
|
||||
&f.Title, &f.Description, &f.IconURL, &f.PullCount, &f.StarCount, &isStarredInt); err != nil {
|
||||
&f.Title, &f.Description, &f.IconURL, &f.PullCount, &f.StarCount, &isStarredInt, &avatarCID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.IsStarred = isStarredInt > 0
|
||||
// Prefer repo page avatar over annotation icon
|
||||
if avatarCID != "" {
|
||||
f.IconURL = BlobCDNURL(f.OwnerDID, avatarCID)
|
||||
}
|
||||
|
||||
featured = append(featured, f)
|
||||
}
|
||||
|
||||
return featured, nil
|
||||
}
|
||||
|
||||
// RepoPage represents a repository page record cached from PDS
|
||||
type RepoPage struct {
|
||||
DID string
|
||||
Repository string
|
||||
Description string
|
||||
AvatarCID string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// UpsertRepoPage inserts or updates a repo page record
|
||||
func UpsertRepoPage(db *sql.DB, did, repository, description, avatarCID string, createdAt, updatedAt time.Time) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO repo_pages (did, repository, description, avatar_cid, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(did, repository) DO UPDATE SET
|
||||
description = excluded.description,
|
||||
avatar_cid = excluded.avatar_cid,
|
||||
updated_at = excluded.updated_at
|
||||
`, did, repository, description, avatarCID, createdAt, updatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRepoPage retrieves a repo page record
|
||||
func GetRepoPage(db *sql.DB, did, repository string) (*RepoPage, error) {
|
||||
var rp RepoPage
|
||||
err := db.QueryRow(`
|
||||
SELECT did, repository, description, avatar_cid, created_at, updated_at
|
||||
FROM repo_pages
|
||||
WHERE did = ? AND repository = ?
|
||||
`, did, repository).Scan(&rp.DID, &rp.Repository, &rp.Description, &rp.AvatarCID, &rp.CreatedAt, &rp.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rp, nil
|
||||
}
|
||||
|
||||
// DeleteRepoPage deletes a repo page record
|
||||
func DeleteRepoPage(db *sql.DB, did, repository string) error {
|
||||
_, err := db.Exec(`
|
||||
DELETE FROM repo_pages WHERE did = ? AND repository = ?
|
||||
`, did, repository)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRepoPagesByDID returns all repo pages for a DID
|
||||
func GetRepoPagesByDID(db *sql.DB, did string) ([]RepoPage, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT did, repository, description, avatar_cid, created_at, updated_at
|
||||
FROM repo_pages
|
||||
WHERE did = ?
|
||||
`, did)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pages []RepoPage
|
||||
for rows.Next() {
|
||||
var rp RepoPage
|
||||
if err := rows.Scan(&rp.DID, &rp.Repository, &rp.Description, &rp.AvatarCID, &rp.CreatedAt, &rp.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pages = append(pages, rp)
|
||||
}
|
||||
return pages, rows.Err()
|
||||
}
|
||||
|
||||
@@ -205,9 +205,14 @@ CREATE TABLE IF NOT EXISTS hold_crew_denials (
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_crew_denials_retry ON hold_crew_denials(next_retry_at);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS readme_cache (
|
||||
url TEXT PRIMARY KEY,
|
||||
html TEXT NOT NULL,
|
||||
fetched_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
CREATE TABLE IF NOT EXISTS repo_pages (
|
||||
did TEXT NOT NULL,
|
||||
repository TEXT NOT NULL,
|
||||
description TEXT,
|
||||
avatar_cid TEXT,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL,
|
||||
PRIMARY KEY(did, repository),
|
||||
FOREIGN KEY(did) REFERENCES users(did) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_readme_cache_fetched ON readme_cache(fetched_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_repo_pages_did ON repo_pages(did);
|
||||
|
||||
32
pkg/appview/handlers/errors.go
Normal file
32
pkg/appview/handlers/errors.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// NotFoundHandler handles 404 errors
|
||||
type NotFoundHandler struct {
|
||||
Templates *template.Template
|
||||
RegistryURL string
|
||||
}
|
||||
|
||||
func (h *NotFoundHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
RenderNotFound(w, r, h.Templates, h.RegistryURL)
|
||||
}
|
||||
|
||||
// RenderNotFound renders the 404 page template.
|
||||
// Use this from other handlers when a resource is not found.
|
||||
func RenderNotFound(w http.ResponseWriter, r *http.Request, templates *template.Template, registryURL string) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
||||
data := struct {
|
||||
PageData
|
||||
}{
|
||||
PageData: NewPageData(r, registryURL),
|
||||
}
|
||||
|
||||
if err := templates.ExecuteTemplate(w, "404", data); err != nil {
|
||||
http.Error(w, "Page not found", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,12 @@ package handlers
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/appview/middleware"
|
||||
@@ -155,3 +158,114 @@ func (h *DeleteManifestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// UploadAvatarHandler handles uploading/updating a repository avatar
|
||||
type UploadAvatarHandler struct {
|
||||
DB *sql.DB
|
||||
Refresher *oauth.Refresher
|
||||
}
|
||||
|
||||
// validImageTypes are the allowed MIME types for avatars (matches lexicon)
|
||||
var validImageTypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/webp": true,
|
||||
}
|
||||
|
||||
func (h *UploadAvatarHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
user := middleware.GetUser(r)
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
repo := chi.URLParam(r, "repository")
|
||||
|
||||
// Parse multipart form (max 3MB to match lexicon maxSize)
|
||||
if err := r.ParseMultipartForm(3 << 20); err != nil {
|
||||
http.Error(w, "File too large (max 3MB)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("avatar")
|
||||
if err != nil {
|
||||
http.Error(w, "No file provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Validate MIME type
|
||||
contentType := header.Header.Get("Content-Type")
|
||||
if !validImageTypes[contentType] {
|
||||
http.Error(w, "Invalid file type. Must be PNG, JPEG, or WebP", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read file data
|
||||
data, err := io.ReadAll(io.LimitReader(file, 3<<20+1)) // Read up to 3MB + 1 byte
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if len(data) > 3<<20 {
|
||||
http.Error(w, "File too large (max 3MB)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create ATProto client with session provider (uses DoWithSession for DPoP nonce safety)
|
||||
pdsClient := atproto.NewClientWithSessionProvider(user.PDSEndpoint, user.DID, h.Refresher)
|
||||
|
||||
// Upload blob to PDS
|
||||
blobRef, err := pdsClient.UploadBlob(r.Context(), data, contentType)
|
||||
if err != nil {
|
||||
if handleOAuthError(r.Context(), h.Refresher, user.DID, err) {
|
||||
http.Error(w, "Authentication failed, please log in again", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to upload image: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch existing repo page record to preserve description
|
||||
var existingDescription string
|
||||
var existingCreatedAt time.Time
|
||||
record, err := pdsClient.GetRecord(r.Context(), atproto.RepoPageCollection, repo)
|
||||
if err == nil {
|
||||
// Parse existing record to preserve description
|
||||
var existingRecord atproto.RepoPageRecord
|
||||
if jsonErr := json.Unmarshal(record.Value, &existingRecord); jsonErr == nil {
|
||||
existingDescription = existingRecord.Description
|
||||
existingCreatedAt = existingRecord.CreatedAt
|
||||
}
|
||||
} else if !errors.Is(err, atproto.ErrRecordNotFound) {
|
||||
// Some other error - check if OAuth error
|
||||
if handleOAuthError(r.Context(), h.Refresher, user.DID, err) {
|
||||
http.Error(w, "Authentication failed, please log in again", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// Log but continue - we'll create a new record
|
||||
}
|
||||
|
||||
// Create updated repo page record
|
||||
repoPage := atproto.NewRepoPageRecord(repo, existingDescription, blobRef)
|
||||
// Preserve original createdAt if record existed
|
||||
if !existingCreatedAt.IsZero() {
|
||||
repoPage.CreatedAt = existingCreatedAt
|
||||
}
|
||||
|
||||
// Save record to PDS
|
||||
_, err = pdsClient.PutRecord(r.Context(), atproto.RepoPageCollection, repo, repoPage)
|
||||
if err != nil {
|
||||
if handleOAuthError(r.Context(), h.Refresher, user.DID, err) {
|
||||
http.Error(w, "Authentication failed, please log in again", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to update repository page: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return new avatar URL
|
||||
avatarURL := atproto.BlobCDNURL(user.DID, blobRef.Ref.Link)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"avatarURL": avatarURL})
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ type RepositoryPageHandler struct {
|
||||
Directory identity.Directory
|
||||
Refresher *oauth.Refresher
|
||||
HealthChecker *holdhealth.Checker
|
||||
ReadmeCache *readme.Cache
|
||||
ReadmeFetcher *readme.Fetcher // For rendering repo page descriptions
|
||||
}
|
||||
|
||||
func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -37,7 +37,7 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||
// Resolve identifier (handle or DID) to canonical DID and current handle
|
||||
did, resolvedHandle, _, err := atproto.ResolveIdentity(r.Context(), identifier)
|
||||
if err != nil {
|
||||
http.Error(w, "User not found", http.StatusNotFound)
|
||||
RenderNotFound(w, r, h.Templates, h.RegistryURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
if owner == nil {
|
||||
http.Error(w, "User not found", http.StatusNotFound)
|
||||
RenderNotFound(w, r, h.Templates, h.RegistryURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
if len(tagsWithPlatforms) == 0 && len(manifests) == 0 {
|
||||
http.Error(w, "Repository not found", http.StatusNotFound)
|
||||
RenderNotFound(w, r, h.Templates, h.RegistryURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -190,19 +190,44 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||
isOwner = (user.DID == owner.DID)
|
||||
}
|
||||
|
||||
// Fetch README content if available
|
||||
// Fetch README content from repo page record or annotations
|
||||
var readmeHTML template.HTML
|
||||
if repo.ReadmeURL != "" && h.ReadmeCache != nil {
|
||||
// Fetch with timeout
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
html, err := h.ReadmeCache.Get(ctx, repo.ReadmeURL)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to fetch README", "url", repo.ReadmeURL, "error", err)
|
||||
// Continue without README on error
|
||||
} else {
|
||||
readmeHTML = template.HTML(html)
|
||||
// Try repo page record from database (synced from PDS via Jetstream)
|
||||
repoPage, err := db.GetRepoPage(h.DB, owner.DID, repository)
|
||||
if err == nil && repoPage != nil {
|
||||
// Use repo page avatar if present
|
||||
if repoPage.AvatarCID != "" {
|
||||
repo.IconURL = atproto.BlobCDNURL(owner.DID, repoPage.AvatarCID)
|
||||
}
|
||||
// Render description as markdown if present
|
||||
if repoPage.Description != "" && h.ReadmeFetcher != nil {
|
||||
html, err := h.ReadmeFetcher.RenderMarkdown([]byte(repoPage.Description))
|
||||
if err != nil {
|
||||
slog.Warn("Failed to render repo page description", "error", err)
|
||||
} else {
|
||||
readmeHTML = template.HTML(html)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fall back to fetching README from URL annotations if no description in repo page
|
||||
if readmeHTML == "" && h.ReadmeFetcher != nil {
|
||||
// Fall back to fetching from URL annotations
|
||||
readmeURL := repo.ReadmeURL
|
||||
if readmeURL == "" && repo.SourceURL != "" {
|
||||
// Try to derive README URL from source URL
|
||||
readmeURL = readme.DeriveReadmeURL(repo.SourceURL, "main")
|
||||
if readmeURL == "" {
|
||||
readmeURL = readme.DeriveReadmeURL(repo.SourceURL, "master")
|
||||
}
|
||||
}
|
||||
if readmeURL != "" {
|
||||
html, err := h.ReadmeFetcher.FetchAndRender(r.Context(), readmeURL)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to fetch README from URL", "url", readmeURL, "error", err)
|
||||
} else {
|
||||
readmeHTML = template.HTML(html)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -62,9 +62,7 @@ func (h *SettingsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
data.Profile.Handle = user.Handle
|
||||
data.Profile.DID = user.DID
|
||||
data.Profile.PDSEndpoint = user.PDSEndpoint
|
||||
if profile.DefaultHold != nil {
|
||||
data.Profile.DefaultHold = *profile.DefaultHold
|
||||
}
|
||||
data.Profile.DefaultHold = profile.DefaultHold
|
||||
|
||||
if err := h.Templates.ExecuteTemplate(w, "settings", data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
@@ -96,9 +94,8 @@ func (h *UpdateDefaultHoldHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
|
||||
profile = atproto.NewSailorProfileRecord(holdEndpoint)
|
||||
} else {
|
||||
// Update existing profile
|
||||
profile.DefaultHold = &holdEndpoint
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
profile.UpdatedAt = &now
|
||||
profile.DefaultHold = holdEndpoint
|
||||
profile.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// Save profile
|
||||
|
||||
@@ -23,7 +23,7 @@ func (h *UserPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Resolve identifier (handle or DID) to canonical DID and current handle
|
||||
did, resolvedHandle, pdsEndpoint, err := atproto.ResolveIdentity(r.Context(), identifier)
|
||||
if err != nil {
|
||||
http.Error(w, "User not found", http.StatusNotFound)
|
||||
RenderNotFound(w, r, h.Templates, h.RegistryURL)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -5,21 +5,26 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/appview/readme"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
// BackfillWorker uses com.atproto.sync.listReposByCollection to backfill historical data
|
||||
type BackfillWorker struct {
|
||||
db *sql.DB
|
||||
client *atproto.Client
|
||||
processor *Processor // Shared processor for DB operations
|
||||
defaultHoldDID string // Default hold DID from AppView config (e.g., "did:web:hold01.atcr.io")
|
||||
testMode bool // If true, suppress warnings for external holds
|
||||
processor *Processor // Shared processor for DB operations
|
||||
defaultHoldDID string // Default hold DID from AppView config (e.g., "did:web:hold01.atcr.io")
|
||||
testMode bool // If true, suppress warnings for external holds
|
||||
refresher *oauth.Refresher // OAuth refresher for PDS writes (optional, can be nil)
|
||||
}
|
||||
|
||||
// BackfillState tracks backfill progress
|
||||
@@ -36,7 +41,8 @@ type BackfillState struct {
|
||||
// NewBackfillWorker creates a backfill worker using sync API
|
||||
// defaultHoldDID should be in format "did:web:hold01.atcr.io"
|
||||
// To find a hold's DID, visit: https://hold-url/.well-known/did.json
|
||||
func NewBackfillWorker(database *sql.DB, relayEndpoint, defaultHoldDID string, testMode bool) (*BackfillWorker, error) {
|
||||
// refresher is optional - if provided, backfill will try to update PDS records when fetching README content
|
||||
func NewBackfillWorker(database *sql.DB, relayEndpoint, defaultHoldDID string, testMode bool, refresher *oauth.Refresher) (*BackfillWorker, error) {
|
||||
// Create client for relay - used only for listReposByCollection
|
||||
client := atproto.NewClient(relayEndpoint, "", "")
|
||||
|
||||
@@ -46,6 +52,7 @@ func NewBackfillWorker(database *sql.DB, relayEndpoint, defaultHoldDID string, t
|
||||
processor: NewProcessor(database, false), // No cache for batch processing
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
testMode: testMode,
|
||||
refresher: refresher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -67,6 +74,7 @@ func (b *BackfillWorker) Start(ctx context.Context) error {
|
||||
atproto.TagCollection, // io.atcr.tag
|
||||
atproto.StarCollection, // io.atcr.sailor.star
|
||||
atproto.SailorProfileCollection, // io.atcr.sailor.profile
|
||||
atproto.RepoPageCollection, // io.atcr.repo.page
|
||||
}
|
||||
|
||||
for _, collection := range collections {
|
||||
@@ -164,12 +172,12 @@ func (b *BackfillWorker) backfillRepo(ctx context.Context, did, collection strin
|
||||
// Track what we found for deletion reconciliation
|
||||
switch collection {
|
||||
case atproto.ManifestCollection:
|
||||
var manifestRecord atproto.Manifest
|
||||
var manifestRecord atproto.ManifestRecord
|
||||
if err := json.Unmarshal(record.Value, &manifestRecord); err == nil {
|
||||
foundManifestDigests = append(foundManifestDigests, manifestRecord.Digest)
|
||||
}
|
||||
case atproto.TagCollection:
|
||||
var tagRecord atproto.Tag
|
||||
var tagRecord atproto.TagRecord
|
||||
if err := json.Unmarshal(record.Value, &tagRecord); err == nil {
|
||||
foundTags = append(foundTags, struct{ Repository, Tag string }{
|
||||
Repository: tagRecord.Repository,
|
||||
@@ -177,15 +185,10 @@ func (b *BackfillWorker) backfillRepo(ctx context.Context, did, collection strin
|
||||
})
|
||||
}
|
||||
case atproto.StarCollection:
|
||||
var starRecord atproto.SailorStar
|
||||
var starRecord atproto.StarRecord
|
||||
if err := json.Unmarshal(record.Value, &starRecord); err == nil {
|
||||
key := fmt.Sprintf("%s/%s", starRecord.Subject.Did, starRecord.Subject.Repository)
|
||||
// Parse CreatedAt string to time.Time
|
||||
createdAt, parseErr := time.Parse(time.RFC3339, starRecord.CreatedAt)
|
||||
if parseErr != nil {
|
||||
createdAt = time.Now()
|
||||
}
|
||||
foundStars[key] = createdAt
|
||||
key := fmt.Sprintf("%s/%s", starRecord.Subject.DID, starRecord.Subject.Repository)
|
||||
foundStars[key] = starRecord.CreatedAt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,6 +225,13 @@ func (b *BackfillWorker) backfillRepo(ctx context.Context, did, collection strin
|
||||
}
|
||||
}
|
||||
|
||||
// After processing repo pages, fetch descriptions from external sources if empty
|
||||
if collection == atproto.RepoPageCollection {
|
||||
if err := b.reconcileRepoPageDescriptions(ctx, did, pdsEndpoint); err != nil {
|
||||
slog.Warn("Backfill failed to reconcile repo page descriptions", "did", did, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return recordCount, nil
|
||||
}
|
||||
|
||||
@@ -287,6 +297,9 @@ func (b *BackfillWorker) processRecord(ctx context.Context, did, collection stri
|
||||
return b.processor.ProcessStar(context.Background(), did, record.Value)
|
||||
case atproto.SailorProfileCollection:
|
||||
return b.processor.ProcessSailorProfile(ctx, did, record.Value, b.queryCaptainRecordWrapper)
|
||||
case atproto.RepoPageCollection:
|
||||
// rkey is extracted from the record URI, but for repo pages we use Repository field
|
||||
return b.processor.ProcessRepoPage(ctx, did, record.URI, record.Value, false)
|
||||
default:
|
||||
return fmt.Errorf("unsupported collection: %s", collection)
|
||||
}
|
||||
@@ -364,12 +377,240 @@ func (b *BackfillWorker) queryCaptainRecord(ctx context.Context, holdDID string)
|
||||
|
||||
// reconcileAnnotations ensures annotations come from the newest manifest in each repository
|
||||
// This fixes the out-of-order backfill issue where older manifests can overwrite newer annotations
|
||||
// NOTE: Currently disabled because the generated Manifest_Annotations type doesn't support
|
||||
// arbitrary key-value pairs. Would need to update lexicon schema with "unknown" type.
|
||||
func (b *BackfillWorker) reconcileAnnotations(ctx context.Context, did string, pdsClient *atproto.Client) error {
|
||||
// TODO: Re-enable once lexicon supports annotations as map[string]string
|
||||
// For now, skip annotation reconciliation as the generated type is an empty struct
|
||||
_ = did
|
||||
_ = pdsClient
|
||||
// Get all repositories for this DID
|
||||
repositories, err := db.GetRepositoriesForDID(b.db, did)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get repositories: %w", err)
|
||||
}
|
||||
|
||||
for _, repo := range repositories {
|
||||
// Find newest manifest for this repository
|
||||
newestManifest, err := db.GetNewestManifestForRepo(b.db, did, repo)
|
||||
if err != nil {
|
||||
slog.Warn("Backfill failed to get newest manifest for repo", "did", did, "repository", repo, "error", err)
|
||||
continue // Skip on error
|
||||
}
|
||||
|
||||
// Fetch the full manifest record from PDS using the digest as rkey
|
||||
rkey := strings.TrimPrefix(newestManifest.Digest, "sha256:")
|
||||
record, err := pdsClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
if err != nil {
|
||||
slog.Warn("Backfill failed to fetch manifest record for repo", "did", did, "repository", repo, "error", err)
|
||||
continue // Skip on error
|
||||
}
|
||||
|
||||
// Parse manifest record
|
||||
var manifestRecord atproto.ManifestRecord
|
||||
if err := json.Unmarshal(record.Value, &manifestRecord); err != nil {
|
||||
slog.Warn("Backfill failed to parse manifest record for repo", "did", did, "repository", repo, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Update annotations from newest manifest only
|
||||
if len(manifestRecord.Annotations) > 0 {
|
||||
// Filter out empty annotations
|
||||
hasData := false
|
||||
for _, value := range manifestRecord.Annotations {
|
||||
if value != "" {
|
||||
hasData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasData {
|
||||
err = db.UpsertRepositoryAnnotations(b.db, did, repo, manifestRecord.Annotations)
|
||||
if err != nil {
|
||||
slog.Warn("Backfill failed to reconcile annotations for repo", "did", did, "repository", repo, "error", err)
|
||||
} else {
|
||||
slog.Info("Backfill reconciled annotations for repo from newest manifest", "did", did, "repository", repo, "digest", newestManifest.Digest)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileRepoPageDescriptions fetches README content from external sources for repo pages with empty descriptions
|
||||
// If the user has an OAuth session, it updates the PDS record (source of truth)
|
||||
// Otherwise, it just stores the fetched content in the database
|
||||
func (b *BackfillWorker) reconcileRepoPageDescriptions(ctx context.Context, did, pdsEndpoint string) error {
|
||||
// Get all repo pages for this DID
|
||||
repoPages, err := db.GetRepoPagesByDID(b.db, did)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get repo pages: %w", err)
|
||||
}
|
||||
|
||||
for _, page := range repoPages {
|
||||
// Skip pages that already have a description
|
||||
if page.Description != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get annotations from the repository's manifest
|
||||
annotations, err := db.GetRepositoryAnnotations(b.db, did, page.Repository)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to get annotations for repo page", "did", did, "repository", page.Repository, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to fetch README content from external sources
|
||||
description := b.fetchReadmeContent(ctx, annotations)
|
||||
if description == "" {
|
||||
// No README content available, skip
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Fetched README for repo page", "did", did, "repository", page.Repository, "descriptionLength", len(description))
|
||||
|
||||
// Try to update PDS if we have OAuth session
|
||||
pdsUpdated := false
|
||||
if b.refresher != nil {
|
||||
if err := b.updateRepoPageInPDS(ctx, did, pdsEndpoint, page.Repository, description, page.AvatarCID); err != nil {
|
||||
slog.Debug("Could not update repo page in PDS, falling back to DB-only", "did", did, "repository", page.Repository, "error", err)
|
||||
} else {
|
||||
pdsUpdated = true
|
||||
slog.Info("Updated repo page in PDS with fetched description", "did", did, "repository", page.Repository)
|
||||
}
|
||||
}
|
||||
|
||||
// Always update database with the fetched content
|
||||
if err := db.UpsertRepoPage(b.db, did, page.Repository, description, page.AvatarCID, page.CreatedAt, time.Now()); err != nil {
|
||||
slog.Warn("Failed to update repo page in database", "did", did, "repository", page.Repository, "error", err)
|
||||
} else if !pdsUpdated {
|
||||
slog.Info("Updated repo page in database (PDS not updated)", "did", did, "repository", page.Repository)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchReadmeContent attempts to fetch README content from external sources based on annotations
|
||||
// Priority: io.atcr.readme annotation > derived from org.opencontainers.image.source
|
||||
func (b *BackfillWorker) fetchReadmeContent(ctx context.Context, annotations map[string]string) string {
|
||||
// Create a context with timeout for README fetching
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Priority 1: Direct README URL from io.atcr.readme annotation
|
||||
if readmeURL := annotations["io.atcr.readme"]; readmeURL != "" {
|
||||
content, err := b.fetchRawReadme(fetchCtx, readmeURL)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to fetch README from io.atcr.readme annotation", "url", readmeURL, "error", err)
|
||||
} else if content != "" {
|
||||
return content
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Derive README URL from org.opencontainers.image.source
|
||||
if sourceURL := annotations["org.opencontainers.image.source"]; sourceURL != "" {
|
||||
// Try main branch first, then master
|
||||
for _, branch := range []string{"main", "master"} {
|
||||
readmeURL := readme.DeriveReadmeURL(sourceURL, branch)
|
||||
if readmeURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := b.fetchRawReadme(fetchCtx, readmeURL)
|
||||
if err != nil {
|
||||
// Only log non-404 errors (404 is expected when trying main vs master)
|
||||
if !readme.Is404(err) {
|
||||
slog.Debug("Failed to fetch README from source URL", "url", readmeURL, "branch", branch, "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// fetchRawReadme fetches raw markdown content from a URL
|
||||
func (b *BackfillWorker) fetchRawReadme(ctx context.Context, readmeURL string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", readmeURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "ATCR-Backfill-README-Fetcher/1.0")
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Limit content size to 100KB
|
||||
limitedReader := io.LimitReader(resp.Body, 100*1024)
|
||||
content, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
// updateRepoPageInPDS updates the repo page record in the user's PDS using OAuth
|
||||
func (b *BackfillWorker) updateRepoPageInPDS(ctx context.Context, did, pdsEndpoint, repository, description, avatarCID string) error {
|
||||
if b.refresher == nil {
|
||||
return fmt.Errorf("no OAuth refresher available")
|
||||
}
|
||||
|
||||
// Create ATProto client with session provider
|
||||
pdsClient := atproto.NewClientWithSessionProvider(pdsEndpoint, did, b.refresher)
|
||||
|
||||
// Get existing repo page record to preserve other fields
|
||||
existingRecord, err := pdsClient.GetRecord(ctx, atproto.RepoPageCollection, repository)
|
||||
var createdAt time.Time
|
||||
var avatarRef *atproto.ATProtoBlobRef
|
||||
|
||||
if err == nil && existingRecord != nil {
|
||||
// Parse existing record
|
||||
var existingPage atproto.RepoPageRecord
|
||||
if err := json.Unmarshal(existingRecord.Value, &existingPage); err == nil {
|
||||
createdAt = existingPage.CreatedAt
|
||||
avatarRef = existingPage.Avatar
|
||||
}
|
||||
}
|
||||
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now()
|
||||
}
|
||||
|
||||
// Create updated repo page record
|
||||
repoPage := &atproto.RepoPageRecord{
|
||||
Type: atproto.RepoPageCollection,
|
||||
Repository: repository,
|
||||
Description: description,
|
||||
Avatar: avatarRef,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Write to PDS - this will use DoWithSession internally
|
||||
_, err = pdsClient.PutRecord(ctx, atproto.RepoPageCollection, repository, repoPage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to PDS: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func (p *Processor) EnsureUser(ctx context.Context, did string) error {
|
||||
// Returns the manifest ID for further processing (layers/references)
|
||||
func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData []byte) (int64, error) {
|
||||
// Unmarshal manifest record
|
||||
var manifestRecord atproto.Manifest
|
||||
var manifestRecord atproto.ManifestRecord
|
||||
if err := json.Unmarshal(recordData, &manifestRecord); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal manifest: %w", err)
|
||||
}
|
||||
@@ -110,19 +110,10 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
|
||||
// Extract hold DID from manifest (with fallback for legacy manifests)
|
||||
// New manifests use holdDid field (DID format)
|
||||
// Old manifests use holdEndpoint field (URL format) - convert to DID
|
||||
var holdDID string
|
||||
if manifestRecord.HoldDid != nil && *manifestRecord.HoldDid != "" {
|
||||
holdDID = *manifestRecord.HoldDid
|
||||
} else if manifestRecord.HoldEndpoint != nil && *manifestRecord.HoldEndpoint != "" {
|
||||
holdDID := manifestRecord.HoldDID
|
||||
if holdDID == "" && manifestRecord.HoldEndpoint != "" {
|
||||
// Legacy manifest - convert URL to DID
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(*manifestRecord.HoldEndpoint)
|
||||
}
|
||||
|
||||
// Parse CreatedAt string to time.Time
|
||||
createdAt, err := time.Parse(time.RFC3339, manifestRecord.CreatedAt)
|
||||
if err != nil {
|
||||
// Fall back to current time if parsing fails
|
||||
createdAt = time.Now()
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(manifestRecord.HoldEndpoint)
|
||||
}
|
||||
|
||||
// Prepare manifest for insertion (WITHOUT annotation fields)
|
||||
@@ -131,9 +122,9 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
|
||||
Repository: manifestRecord.Repository,
|
||||
Digest: manifestRecord.Digest,
|
||||
MediaType: manifestRecord.MediaType,
|
||||
SchemaVersion: int(manifestRecord.SchemaVersion),
|
||||
SchemaVersion: manifestRecord.SchemaVersion,
|
||||
HoldEndpoint: holdDID,
|
||||
CreatedAt: createdAt,
|
||||
CreatedAt: manifestRecord.CreatedAt,
|
||||
// Annotations removed - stored separately in repository_annotations table
|
||||
}
|
||||
|
||||
@@ -163,11 +154,24 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Repository annotations are currently disabled because the generated
|
||||
// Manifest_Annotations type doesn't support arbitrary key-value pairs.
|
||||
// The lexicon would need to use "unknown" type for annotations to support this.
|
||||
// TODO: Re-enable once lexicon supports annotations as map[string]string
|
||||
_ = manifestRecord.Annotations
|
||||
// Update repository annotations ONLY if manifest has at least one non-empty annotation
|
||||
if manifestRecord.Annotations != nil {
|
||||
hasData := false
|
||||
for _, value := range manifestRecord.Annotations {
|
||||
if value != "" {
|
||||
hasData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasData {
|
||||
// Replace all annotations for this repository
|
||||
err = db.UpsertRepositoryAnnotations(p.db, did, manifestRecord.Repository, manifestRecord.Annotations)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to upsert annotations: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert manifest references or layers
|
||||
if isManifestList {
|
||||
@@ -180,19 +184,18 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
|
||||
|
||||
if ref.Platform != nil {
|
||||
platformArch = ref.Platform.Architecture
|
||||
platformOS = ref.Platform.Os
|
||||
if ref.Platform.Variant != nil {
|
||||
platformVariant = *ref.Platform.Variant
|
||||
}
|
||||
if ref.Platform.OsVersion != nil {
|
||||
platformOSVersion = *ref.Platform.OsVersion
|
||||
}
|
||||
platformOS = ref.Platform.OS
|
||||
platformVariant = ref.Platform.Variant
|
||||
platformOSVersion = ref.Platform.OSVersion
|
||||
}
|
||||
|
||||
// Note: Attestation detection via annotations is currently disabled
|
||||
// because the generated Manifest_ManifestReference_Annotations type
|
||||
// doesn't support arbitrary key-value pairs.
|
||||
// Detect attestation manifests from annotations
|
||||
isAttestation := false
|
||||
if ref.Annotations != nil {
|
||||
if refType, ok := ref.Annotations["vnd.docker.reference.type"]; ok {
|
||||
isAttestation = refType == "attestation-manifest"
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.InsertManifestReference(p.db, &db.ManifestReference{
|
||||
ManifestID: manifestID,
|
||||
@@ -232,7 +235,7 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
|
||||
// ProcessTag processes a tag record and stores it in the database
|
||||
func (p *Processor) ProcessTag(ctx context.Context, did string, recordData []byte) error {
|
||||
// Unmarshal tag record
|
||||
var tagRecord atproto.Tag
|
||||
var tagRecord atproto.TagRecord
|
||||
if err := json.Unmarshal(recordData, &tagRecord); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal tag: %w", err)
|
||||
}
|
||||
@@ -242,27 +245,20 @@ func (p *Processor) ProcessTag(ctx context.Context, did string, recordData []byt
|
||||
return fmt.Errorf("failed to get manifest digest from tag record: %w", err)
|
||||
}
|
||||
|
||||
// Parse CreatedAt string to time.Time
|
||||
tagCreatedAt, err := time.Parse(time.RFC3339, tagRecord.CreatedAt)
|
||||
if err != nil {
|
||||
// Fall back to current time if parsing fails
|
||||
tagCreatedAt = time.Now()
|
||||
}
|
||||
|
||||
// Insert or update tag
|
||||
return db.UpsertTag(p.db, &db.Tag{
|
||||
DID: did,
|
||||
Repository: tagRecord.Repository,
|
||||
Tag: tagRecord.Tag,
|
||||
Digest: manifestDigest,
|
||||
CreatedAt: tagCreatedAt,
|
||||
CreatedAt: tagRecord.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
// ProcessStar processes a star record and stores it in the database
|
||||
func (p *Processor) ProcessStar(ctx context.Context, did string, recordData []byte) error {
|
||||
// Unmarshal star record
|
||||
var starRecord atproto.SailorStar
|
||||
var starRecord atproto.StarRecord
|
||||
if err := json.Unmarshal(recordData, &starRecord); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal star: %w", err)
|
||||
}
|
||||
@@ -270,33 +266,27 @@ func (p *Processor) ProcessStar(ctx context.Context, did string, recordData []by
|
||||
// The DID here is the starrer (user who starred)
|
||||
// The subject contains the owner DID and repository
|
||||
// Star count will be calculated on demand from the stars table
|
||||
// Parse the CreatedAt string to time.Time
|
||||
createdAt, err := time.Parse(time.RFC3339, starRecord.CreatedAt)
|
||||
if err != nil {
|
||||
// Fall back to current time if parsing fails
|
||||
createdAt = time.Now()
|
||||
}
|
||||
return db.UpsertStar(p.db, did, starRecord.Subject.Did, starRecord.Subject.Repository, createdAt)
|
||||
return db.UpsertStar(p.db, did, starRecord.Subject.DID, starRecord.Subject.Repository, starRecord.CreatedAt)
|
||||
}
|
||||
|
||||
// ProcessSailorProfile processes a sailor profile record
|
||||
// This is primarily used by backfill to cache captain records for holds
|
||||
func (p *Processor) ProcessSailorProfile(ctx context.Context, did string, recordData []byte, queryCaptainFn func(context.Context, string) error) error {
|
||||
// Unmarshal sailor profile record
|
||||
var profileRecord atproto.SailorProfile
|
||||
var profileRecord atproto.SailorProfileRecord
|
||||
if err := json.Unmarshal(recordData, &profileRecord); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal sailor profile: %w", err)
|
||||
}
|
||||
|
||||
// Skip if no default hold set
|
||||
if profileRecord.DefaultHold == nil || *profileRecord.DefaultHold == "" {
|
||||
if profileRecord.DefaultHold == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert hold URL/DID to canonical DID
|
||||
holdDID := atproto.ResolveHoldDIDFromURL(*profileRecord.DefaultHold)
|
||||
holdDID := atproto.ResolveHoldDIDFromURL(profileRecord.DefaultHold)
|
||||
if holdDID == "" {
|
||||
slog.Warn("Invalid hold reference in profile", "component", "processor", "did", did, "default_hold", *profileRecord.DefaultHold)
|
||||
slog.Warn("Invalid hold reference in profile", "component", "processor", "did", did, "default_hold", profileRecord.DefaultHold)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -309,6 +299,30 @@ func (p *Processor) ProcessSailorProfile(ctx context.Context, did string, record
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessRepoPage processes a repository page record
|
||||
// This is called when Jetstream receives a repo page create/update event
|
||||
func (p *Processor) ProcessRepoPage(ctx context.Context, did string, rkey string, recordData []byte, isDelete bool) error {
|
||||
if isDelete {
|
||||
// Delete the repo page from our cache
|
||||
return db.DeleteRepoPage(p.db, did, rkey)
|
||||
}
|
||||
|
||||
// Unmarshal repo page record
|
||||
var pageRecord atproto.RepoPageRecord
|
||||
if err := json.Unmarshal(recordData, &pageRecord); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal repo page: %w", err)
|
||||
}
|
||||
|
||||
// Extract avatar CID if present
|
||||
avatarCID := ""
|
||||
if pageRecord.Avatar != nil && pageRecord.Avatar.Ref.Link != "" {
|
||||
avatarCID = pageRecord.Avatar.Ref.Link
|
||||
}
|
||||
|
||||
// Upsert to database
|
||||
return db.UpsertRepoPage(p.db, did, pageRecord.Repository, pageRecord.Description, avatarCID, pageRecord.CreatedAt, pageRecord.UpdatedAt)
|
||||
}
|
||||
|
||||
// ProcessIdentity handles identity change events (handle updates)
|
||||
// This is called when Jetstream receives an identity event indicating a handle change.
|
||||
// The identity cache is invalidated to ensure the next lookup uses the new handle,
|
||||
|
||||
@@ -11,11 +11,6 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// ptrString returns a pointer to the given string
|
||||
func ptrString(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// setupTestDB creates an in-memory SQLite database for testing
|
||||
func setupTestDB(t *testing.T) *sql.DB {
|
||||
database, err := sql.Open("sqlite3", ":memory:")
|
||||
@@ -148,22 +143,28 @@ func TestProcessManifest_ImageManifest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test manifest record
|
||||
manifestRecord := &atproto.Manifest{
|
||||
manifestRecord := &atproto.ManifestRecord{
|
||||
Repository: "test-app",
|
||||
Digest: "sha256:abc123",
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
SchemaVersion: 2,
|
||||
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
Config: &atproto.Manifest_BlobReference{
|
||||
HoldEndpoint: "did:web:hold01.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
Config: &atproto.BlobReference{
|
||||
Digest: "sha256:config123",
|
||||
Size: 1234,
|
||||
},
|
||||
Layers: []atproto.Manifest_BlobReference{
|
||||
Layers: []atproto.BlobReference{
|
||||
{Digest: "sha256:layer1", Size: 5000, MediaType: "application/vnd.oci.image.layer.v1.tar+gzip"},
|
||||
{Digest: "sha256:layer2", Size: 3000, MediaType: "application/vnd.oci.image.layer.v1.tar+gzip"},
|
||||
},
|
||||
// Annotations disabled - generated Manifest_Annotations is empty struct
|
||||
Annotations: map[string]string{
|
||||
"org.opencontainers.image.title": "Test App",
|
||||
"org.opencontainers.image.description": "A test application",
|
||||
"org.opencontainers.image.source": "https://github.com/test/app",
|
||||
"org.opencontainers.image.licenses": "MIT",
|
||||
"io.atcr.icon": "https://example.com/icon.png",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to bytes for ProcessManifest
|
||||
@@ -192,8 +193,25 @@ func TestProcessManifest_ImageManifest(t *testing.T) {
|
||||
t.Errorf("Expected 1 manifest, got %d", count)
|
||||
}
|
||||
|
||||
// Note: Annotations verification disabled - generated Manifest_Annotations is empty struct
|
||||
// TODO: Re-enable when lexicon uses "unknown" type for annotations
|
||||
// Verify annotations were stored in repository_annotations table
|
||||
var title, source string
|
||||
err = database.QueryRow("SELECT value FROM repository_annotations WHERE did = ? AND repository = ? AND key = ?",
|
||||
"did:plc:test123", "test-app", "org.opencontainers.image.title").Scan(&title)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query title annotation: %v", err)
|
||||
}
|
||||
if title != "Test App" {
|
||||
t.Errorf("title = %q, want %q", title, "Test App")
|
||||
}
|
||||
|
||||
err = database.QueryRow("SELECT value FROM repository_annotations WHERE did = ? AND repository = ? AND key = ?",
|
||||
"did:plc:test123", "test-app", "org.opencontainers.image.source").Scan(&source)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query source annotation: %v", err)
|
||||
}
|
||||
if source != "https://github.com/test/app" {
|
||||
t.Errorf("source = %q, want %q", source, "https://github.com/test/app")
|
||||
}
|
||||
|
||||
// Verify layers were inserted
|
||||
var layerCount int
|
||||
@@ -224,31 +242,31 @@ func TestProcessManifest_ManifestList(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test manifest list record
|
||||
manifestRecord := &atproto.Manifest{
|
||||
manifestRecord := &atproto.ManifestRecord{
|
||||
Repository: "test-app",
|
||||
Digest: "sha256:list123",
|
||||
MediaType: "application/vnd.oci.image.index.v1+json",
|
||||
SchemaVersion: 2,
|
||||
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
Manifests: []atproto.Manifest_ManifestReference{
|
||||
HoldEndpoint: "did:web:hold01.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
Manifests: []atproto.ManifestReference{
|
||||
{
|
||||
Digest: "sha256:amd64manifest",
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Size: 1000,
|
||||
Platform: &atproto.Manifest_Platform{
|
||||
Platform: &atproto.Platform{
|
||||
Architecture: "amd64",
|
||||
Os: "linux",
|
||||
OS: "linux",
|
||||
},
|
||||
},
|
||||
{
|
||||
Digest: "sha256:arm64manifest",
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Size: 1100,
|
||||
Platform: &atproto.Manifest_Platform{
|
||||
Platform: &atproto.Platform{
|
||||
Architecture: "arm64",
|
||||
Os: "linux",
|
||||
Variant: ptrString("v8"),
|
||||
OS: "linux",
|
||||
Variant: "v8",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -308,11 +326,11 @@ func TestProcessTag(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test tag record (using ManifestDigest field for simplicity)
|
||||
tagRecord := &atproto.Tag{
|
||||
tagRecord := &atproto.TagRecord{
|
||||
Repository: "test-app",
|
||||
Tag: "latest",
|
||||
ManifestDigest: ptrString("sha256:abc123"),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
ManifestDigest: "sha256:abc123",
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Marshal to bytes for ProcessTag
|
||||
@@ -350,7 +368,7 @@ func TestProcessTag(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test upserting same tag with new digest
|
||||
tagRecord.ManifestDigest = ptrString("sha256:newdigest")
|
||||
tagRecord.ManifestDigest = "sha256:newdigest"
|
||||
recordBytes, err = json.Marshal(tagRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal tag: %v", err)
|
||||
@@ -389,12 +407,12 @@ func TestProcessStar(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test star record
|
||||
starRecord := &atproto.SailorStar{
|
||||
Subject: atproto.SailorStar_Subject{
|
||||
Did: "did:plc:owner123",
|
||||
starRecord := &atproto.StarRecord{
|
||||
Subject: atproto.StarSubject{
|
||||
DID: "did:plc:owner123",
|
||||
Repository: "test-app",
|
||||
},
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Marshal to bytes for ProcessStar
|
||||
@@ -448,13 +466,13 @@ func TestProcessManifest_Duplicate(t *testing.T) {
|
||||
p := NewProcessor(database, false)
|
||||
ctx := context.Background()
|
||||
|
||||
manifestRecord := &atproto.Manifest{
|
||||
manifestRecord := &atproto.ManifestRecord{
|
||||
Repository: "test-app",
|
||||
Digest: "sha256:abc123",
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
SchemaVersion: 2,
|
||||
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
HoldEndpoint: "did:web:hold01.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Marshal to bytes for ProcessManifest
|
||||
@@ -500,13 +518,13 @@ func TestProcessManifest_EmptyAnnotations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Manifest with nil annotations
|
||||
manifestRecord := &atproto.Manifest{
|
||||
manifestRecord := &atproto.ManifestRecord{
|
||||
Repository: "test-app",
|
||||
Digest: "sha256:abc123",
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
SchemaVersion: 2,
|
||||
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
HoldEndpoint: "did:web:hold01.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
Annotations: nil,
|
||||
}
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ func NewWorker(database *sql.DB, jetstreamURL string, startCursor int64) *Worker
|
||||
jetstreamURL: jetstreamURL,
|
||||
startCursor: startCursor,
|
||||
wantedCollections: []string{
|
||||
atproto.ManifestCollection, // io.atcr.manifest
|
||||
atproto.TagCollection, // io.atcr.tag
|
||||
atproto.StarCollection, // io.atcr.sailor.star
|
||||
"io.atcr.*", // Subscribe to all ATCR collections
|
||||
},
|
||||
processor: NewProcessor(database, true), // Use cache for live streaming
|
||||
}
|
||||
@@ -312,6 +310,9 @@ func (w *Worker) processMessage(message []byte) error {
|
||||
case atproto.StarCollection:
|
||||
slog.Info("Jetstream processing star event", "did", commit.DID, "operation", commit.Operation, "rkey", commit.RKey)
|
||||
return w.processStar(commit)
|
||||
case atproto.RepoPageCollection:
|
||||
slog.Info("Jetstream processing repo page event", "did", commit.DID, "operation", commit.Operation, "rkey", commit.RKey)
|
||||
return w.processRepoPage(commit)
|
||||
default:
|
||||
// Ignore other collections
|
||||
return nil
|
||||
@@ -436,6 +437,41 @@ func (w *Worker) processStar(commit *CommitEvent) error {
|
||||
return w.processor.ProcessStar(context.Background(), commit.DID, recordBytes)
|
||||
}
|
||||
|
||||
// processRepoPage processes a repo page commit event
|
||||
func (w *Worker) processRepoPage(commit *CommitEvent) error {
|
||||
// Resolve and upsert user with handle/PDS endpoint
|
||||
if err := w.processor.EnsureUser(context.Background(), commit.DID); err != nil {
|
||||
return fmt.Errorf("failed to ensure user: %w", err)
|
||||
}
|
||||
|
||||
isDelete := commit.Operation == "delete"
|
||||
|
||||
if isDelete {
|
||||
// Delete - rkey is the repository name
|
||||
slog.Info("Jetstream deleting repo page", "did", commit.DID, "repository", commit.RKey)
|
||||
if err := w.processor.ProcessRepoPage(context.Background(), commit.DID, commit.RKey, nil, true); err != nil {
|
||||
slog.Error("Jetstream ERROR deleting repo page", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("Jetstream successfully deleted repo page", "did", commit.DID, "repository", commit.RKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse repo page record
|
||||
if commit.Record == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal map to bytes for processing
|
||||
recordBytes, err := json.Marshal(commit.Record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal record: %w", err)
|
||||
}
|
||||
|
||||
// Use shared processor for DB operations
|
||||
return w.processor.ProcessRepoPage(context.Background(), commit.DID, commit.RKey, recordBytes, false)
|
||||
}
|
||||
|
||||
// processIdentity processes an identity event (handle change)
|
||||
func (w *Worker) processIdentity(event *JetstreamEvent) error {
|
||||
if event.Identity == nil {
|
||||
|
||||
@@ -11,14 +11,32 @@ import (
|
||||
"net/url"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/auth"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const userKey contextKey = "user"
|
||||
|
||||
// WebAuthDeps contains dependencies for web auth middleware
|
||||
type WebAuthDeps struct {
|
||||
SessionStore *db.SessionStore
|
||||
Database *sql.DB
|
||||
Refresher *oauth.Refresher
|
||||
DefaultHoldDID string
|
||||
}
|
||||
|
||||
// RequireAuth is middleware that requires authentication
|
||||
func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
|
||||
return RequireAuthWithDeps(WebAuthDeps{
|
||||
SessionStore: store,
|
||||
Database: database,
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAuthWithDeps is middleware that requires authentication and creates UserContext
|
||||
func RequireAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID, ok := getSessionID(r)
|
||||
@@ -32,7 +50,7 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
return
|
||||
}
|
||||
|
||||
sess, ok := store.Get(sessionID)
|
||||
sess, ok := deps.SessionStore.Get(sessionID)
|
||||
if !ok {
|
||||
// Build return URL with query parameters preserved
|
||||
returnTo := r.URL.Path
|
||||
@@ -44,7 +62,7 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
}
|
||||
|
||||
// Look up full user from database to get avatar
|
||||
user, err := db.GetUserByDID(database, sess.DID)
|
||||
user, err := db.GetUserByDID(deps.Database, sess.DID)
|
||||
if err != nil || user == nil {
|
||||
// Fallback to session data if DB lookup fails
|
||||
user = &db.User{
|
||||
@@ -54,7 +72,20 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userKey, user)
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, userKey, user)
|
||||
|
||||
// Create UserContext for authenticated users (enables EnsureUserSetup)
|
||||
if deps.Refresher != nil {
|
||||
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
|
||||
Refresher: deps.Refresher,
|
||||
DefaultHoldDID: deps.DefaultHoldDID,
|
||||
})
|
||||
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
|
||||
userCtx.EnsureUserSetup()
|
||||
ctx = auth.WithUserContext(ctx, userCtx)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
@@ -62,13 +93,21 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
|
||||
|
||||
// OptionalAuth is middleware that optionally includes user if authenticated
|
||||
func OptionalAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
|
||||
return OptionalAuthWithDeps(WebAuthDeps{
|
||||
SessionStore: store,
|
||||
Database: database,
|
||||
})
|
||||
}
|
||||
|
||||
// OptionalAuthWithDeps is middleware that optionally includes user and UserContext if authenticated
|
||||
func OptionalAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID, ok := getSessionID(r)
|
||||
if ok {
|
||||
if sess, ok := store.Get(sessionID); ok {
|
||||
if sess, ok := deps.SessionStore.Get(sessionID); ok {
|
||||
// Look up full user from database to get avatar
|
||||
user, err := db.GetUserByDID(database, sess.DID)
|
||||
user, err := db.GetUserByDID(deps.Database, sess.DID)
|
||||
if err != nil || user == nil {
|
||||
// Fallback to session data if DB lookup fails
|
||||
user = &db.User{
|
||||
@@ -77,7 +116,21 @@ func OptionalAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) h
|
||||
PDSEndpoint: sess.PDSEndpoint,
|
||||
}
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), userKey, user)
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, userKey, user)
|
||||
|
||||
// Create UserContext for authenticated users (enables EnsureUserSetup)
|
||||
if deps.Refresher != nil {
|
||||
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
|
||||
Refresher: deps.Refresher,
|
||||
DefaultHoldDID: deps.DefaultHoldDID,
|
||||
})
|
||||
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
|
||||
userCtx.EnsureUserSetup()
|
||||
ctx = auth.WithUserContext(ctx, userCtx)
|
||||
}
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,15 +2,13 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/distribution/distribution/v3/registry/api/errcode"
|
||||
registrymw "github.com/distribution/distribution/v3/registry/middleware/registry"
|
||||
"github.com/distribution/distribution/v3/registry/storage/driver"
|
||||
"github.com/distribution/reference"
|
||||
@@ -28,151 +26,16 @@ const holdDIDKey contextKey = "hold.did"
|
||||
// authMethodKey is the context key for storing auth method from JWT
|
||||
const authMethodKey contextKey = "auth.method"
|
||||
|
||||
// validationCacheEntry stores a validated service token with expiration
|
||||
type validationCacheEntry struct {
|
||||
serviceToken string
|
||||
validUntil time.Time
|
||||
err error // Cached error for fast-fail
|
||||
mu sync.Mutex // Per-entry lock to serialize cache population
|
||||
inFlight bool // True if another goroutine is fetching the token
|
||||
done chan struct{} // Closed when fetch completes
|
||||
}
|
||||
|
||||
// validationCache provides request-level caching for service tokens
|
||||
// This prevents concurrent layer uploads from racing on OAuth/DPoP requests
|
||||
type validationCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*validationCacheEntry // key: "did:holdDID"
|
||||
}
|
||||
|
||||
// newValidationCache creates a new validation cache
|
||||
func newValidationCache() *validationCache {
|
||||
return &validationCache{
|
||||
entries: make(map[string]*validationCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// getOrFetch retrieves a service token from cache or fetches it
|
||||
// Multiple concurrent requests for the same DID:holdDID will share the fetch operation
|
||||
func (vc *validationCache) getOrFetch(ctx context.Context, cacheKey string, fetchFunc func() (string, error)) (string, error) {
|
||||
// Fast path: check cache with read lock
|
||||
vc.mu.RLock()
|
||||
entry, exists := vc.entries[cacheKey]
|
||||
vc.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
// Entry exists, check if it's still valid
|
||||
entry.mu.Lock()
|
||||
|
||||
// If another goroutine is fetching, wait for it
|
||||
if entry.inFlight {
|
||||
done := entry.done
|
||||
entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Fetch completed, check result
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
if entry.err != nil {
|
||||
return "", entry.err
|
||||
}
|
||||
if time.Now().Before(entry.validUntil) {
|
||||
return entry.serviceToken, nil
|
||||
}
|
||||
// Fall through to refetch
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
} else {
|
||||
// Check if cached token is still valid
|
||||
if entry.err != nil && time.Now().Before(entry.validUntil) {
|
||||
// Return cached error (fast-fail)
|
||||
entry.mu.Unlock()
|
||||
return "", entry.err
|
||||
}
|
||||
if entry.err == nil && time.Now().Before(entry.validUntil) {
|
||||
// Return cached token
|
||||
token := entry.serviceToken
|
||||
entry.mu.Unlock()
|
||||
return token, nil
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: need to fetch token
|
||||
vc.mu.Lock()
|
||||
entry, exists = vc.entries[cacheKey]
|
||||
if !exists {
|
||||
// Create new entry
|
||||
entry = &validationCacheEntry{
|
||||
inFlight: true,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
vc.entries[cacheKey] = entry
|
||||
}
|
||||
vc.mu.Unlock()
|
||||
|
||||
// Lock the entry to perform fetch
|
||||
entry.mu.Lock()
|
||||
|
||||
// Double-check: another goroutine may have fetched while we waited
|
||||
if !entry.inFlight {
|
||||
if entry.err != nil && time.Now().Before(entry.validUntil) {
|
||||
err := entry.err
|
||||
entry.mu.Unlock()
|
||||
return "", err
|
||||
}
|
||||
if entry.err == nil && time.Now().Before(entry.validUntil) {
|
||||
token := entry.serviceToken
|
||||
entry.mu.Unlock()
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as in-flight and create fresh done channel for this fetch
|
||||
// IMPORTANT: Always create a new channel - a closed channel is not nil
|
||||
entry.done = make(chan struct{})
|
||||
entry.inFlight = true
|
||||
done := entry.done
|
||||
entry.mu.Unlock()
|
||||
|
||||
// Perform the fetch (outside the lock to allow other operations)
|
||||
serviceToken, err := fetchFunc()
|
||||
|
||||
// Update the entry with result
|
||||
entry.mu.Lock()
|
||||
entry.inFlight = false
|
||||
|
||||
if err != nil {
|
||||
// Cache errors for 5 seconds (fast-fail for subsequent requests)
|
||||
entry.err = err
|
||||
entry.validUntil = time.Now().Add(5 * time.Second)
|
||||
entry.serviceToken = ""
|
||||
} else {
|
||||
// Cache token for 45 seconds (covers typical Docker push operation)
|
||||
entry.err = nil
|
||||
entry.serviceToken = serviceToken
|
||||
entry.validUntil = time.Now().Add(45 * time.Second)
|
||||
}
|
||||
|
||||
// Signal completion to waiting goroutines
|
||||
close(done)
|
||||
entry.mu.Unlock()
|
||||
|
||||
return serviceToken, err
|
||||
}
|
||||
// pullerDIDKey is the context key for storing the authenticated user's DID from JWT
|
||||
const pullerDIDKey contextKey = "puller.did"
|
||||
|
||||
// Global variables for initialization only
|
||||
// These are set by main.go during startup and copied into NamespaceResolver instances.
|
||||
// After initialization, request handling uses the NamespaceResolver's instance fields.
|
||||
var (
|
||||
globalRefresher *oauth.Refresher
|
||||
globalDatabase storage.DatabaseMetrics
|
||||
globalAuthorizer auth.HoldAuthorizer
|
||||
globalReadmeCache storage.ReadmeCache
|
||||
globalRefresher *oauth.Refresher
|
||||
globalDatabase *sql.DB
|
||||
globalAuthorizer auth.HoldAuthorizer
|
||||
)
|
||||
|
||||
// SetGlobalRefresher sets the OAuth refresher instance during initialization
|
||||
@@ -183,7 +46,7 @@ func SetGlobalRefresher(refresher *oauth.Refresher) {
|
||||
|
||||
// SetGlobalDatabase sets the database instance during initialization
|
||||
// Must be called before the registry starts serving requests
|
||||
func SetGlobalDatabase(database storage.DatabaseMetrics) {
|
||||
func SetGlobalDatabase(database *sql.DB) {
|
||||
globalDatabase = database
|
||||
}
|
||||
|
||||
@@ -193,12 +56,6 @@ func SetGlobalAuthorizer(authorizer auth.HoldAuthorizer) {
|
||||
globalAuthorizer = authorizer
|
||||
}
|
||||
|
||||
// SetGlobalReadmeCache sets the readme cache instance during initialization
|
||||
// Must be called before the registry starts serving requests
|
||||
func SetGlobalReadmeCache(readmeCache storage.ReadmeCache) {
|
||||
globalReadmeCache = readmeCache
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register the name resolution middleware
|
||||
registrymw.Register("atproto-resolver", initATProtoResolver)
|
||||
@@ -207,14 +64,12 @@ func init() {
|
||||
// NamespaceResolver wraps a namespace and resolves names
|
||||
type NamespaceResolver struct {
|
||||
distribution.Namespace
|
||||
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
||||
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
||||
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
||||
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
||||
database storage.DatabaseMetrics // Metrics database (copied from global on init)
|
||||
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
|
||||
readmeCache storage.ReadmeCache // README cache (copied from global on init)
|
||||
validationCache *validationCache // Request-level service token cache
|
||||
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
||||
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
||||
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
||||
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
||||
sqlDB *sql.DB // Database for hold DID lookup and metrics (copied from global on init)
|
||||
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
|
||||
}
|
||||
|
||||
// initATProtoResolver initializes the name resolution middleware
|
||||
@@ -241,25 +96,16 @@ func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ drive
|
||||
// Copy shared services from globals into the instance
|
||||
// This avoids accessing globals during request handling
|
||||
return &NamespaceResolver{
|
||||
Namespace: ns,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
baseURL: baseURL,
|
||||
testMode: testMode,
|
||||
refresher: globalRefresher,
|
||||
database: globalDatabase,
|
||||
authorizer: globalAuthorizer,
|
||||
readmeCache: globalReadmeCache,
|
||||
validationCache: newValidationCache(),
|
||||
Namespace: ns,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
baseURL: baseURL,
|
||||
testMode: testMode,
|
||||
refresher: globalRefresher,
|
||||
sqlDB: globalDatabase,
|
||||
authorizer: globalAuthorizer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// authErrorMessage creates a user-friendly auth error with login URL
|
||||
func (nr *NamespaceResolver) authErrorMessage(message string) error {
|
||||
loginURL := fmt.Sprintf("%s/auth/oauth/login", nr.baseURL)
|
||||
fullMessage := fmt.Sprintf("%s - please re-authenticate at %s", message, loginURL)
|
||||
return errcode.ErrorCodeUnauthorized.WithMessage(fullMessage)
|
||||
}
|
||||
|
||||
// Repository resolves the repository name and delegates to underlying namespace
|
||||
// Handles names like:
|
||||
// - atcr.io/alice/myimage → resolve alice to DID
|
||||
@@ -293,99 +139,8 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
||||
}
|
||||
ctx = context.WithValue(ctx, holdDIDKey, holdDID)
|
||||
|
||||
// Auto-reconcile crew membership on first push/pull
|
||||
// This ensures users can push immediately after docker login without web sign-in
|
||||
// EnsureCrewMembership is best-effort and logs errors without failing the request
|
||||
// Run in background to avoid blocking registry operations if hold is offline
|
||||
if holdDID != "" && nr.refresher != nil {
|
||||
slog.Debug("Auto-reconciling crew membership", "component", "registry/middleware", "did", did, "hold_did", holdDID)
|
||||
client := atproto.NewClient(pdsEndpoint, did, "")
|
||||
go func(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
|
||||
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
||||
}(ctx, client, nr.refresher, holdDID)
|
||||
}
|
||||
|
||||
// Get service token for hold authentication (only if authenticated)
|
||||
// Use validation cache to prevent concurrent requests from racing on OAuth/DPoP
|
||||
// Route based on auth method from JWT token
|
||||
var serviceToken string
|
||||
authMethod, _ := ctx.Value(authMethodKey).(string)
|
||||
|
||||
// Only fetch service token if user is authenticated
|
||||
// Unauthenticated requests (like /v2/ ping) should not trigger token fetching
|
||||
if authMethod != "" {
|
||||
// Create cache key: "did:holdDID"
|
||||
cacheKey := fmt.Sprintf("%s:%s", did, holdDID)
|
||||
|
||||
// Fetch service token through validation cache
|
||||
// This ensures only ONE request per DID:holdDID pair fetches the token
|
||||
// Concurrent requests will wait for the first request to complete
|
||||
var fetchErr error
|
||||
serviceToken, fetchErr = nr.validationCache.getOrFetch(ctx, cacheKey, func() (string, error) {
|
||||
if authMethod == token.AuthMethodAppPassword {
|
||||
// App-password flow: use Bearer token authentication
|
||||
slog.Debug("Using app-password flow for service token",
|
||||
"component", "registry/middleware",
|
||||
"did", did,
|
||||
"cacheKey", cacheKey)
|
||||
|
||||
token, err := token.GetOrFetchServiceTokenWithAppPassword(ctx, did, holdDID, pdsEndpoint)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get service token with app-password",
|
||||
"component", "registry/middleware",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err)
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
} else if nr.refresher != nil {
|
||||
// OAuth flow: use DPoP authentication
|
||||
slog.Debug("Using OAuth flow for service token",
|
||||
"component", "registry/middleware",
|
||||
"did", did,
|
||||
"cacheKey", cacheKey)
|
||||
|
||||
token, err := token.GetOrFetchServiceToken(ctx, nr.refresher, did, holdDID, pdsEndpoint)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get service token with OAuth",
|
||||
"component", "registry/middleware",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err)
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
return "", fmt.Errorf("no authentication method available")
|
||||
})
|
||||
|
||||
// Handle errors from cached fetch
|
||||
if fetchErr != nil {
|
||||
errMsg := fetchErr.Error()
|
||||
|
||||
// Check for app-password specific errors
|
||||
if authMethod == token.AuthMethodAppPassword {
|
||||
if strings.Contains(errMsg, "expired or invalid") || strings.Contains(errMsg, "no app-password") {
|
||||
return nil, nr.authErrorMessage("App-password authentication failed. Please re-authenticate with: docker login")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for OAuth specific errors
|
||||
if strings.Contains(errMsg, "OAuth session") || strings.Contains(errMsg, "OAuth validation") {
|
||||
return nil, nr.authErrorMessage("OAuth session expired or invalidated by PDS. Your session has been cleared")
|
||||
}
|
||||
|
||||
// Generic service token error
|
||||
return nil, nr.authErrorMessage(fmt.Sprintf("Failed to obtain storage credentials: %v", fetchErr))
|
||||
}
|
||||
} else {
|
||||
slog.Debug("Skipping service token fetch for unauthenticated request",
|
||||
"component", "registry/middleware",
|
||||
"did", did)
|
||||
}
|
||||
// Note: Profile and crew membership are now ensured in UserContextMiddleware
|
||||
// via EnsureUserSetup() - no need to call here
|
||||
|
||||
// Create a new reference with identity/image format
|
||||
// Use the identity (or DID) as the namespace to ensure canonical format
|
||||
@@ -402,74 +157,30 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get access token for PDS operations
|
||||
// Use auth method from JWT to determine client type:
|
||||
// - OAuth users: use session provider (DPoP-enabled)
|
||||
// - App-password users: use Basic Auth token cache
|
||||
var atprotoClient *atproto.Client
|
||||
|
||||
if authMethod == token.AuthMethodOAuth && nr.refresher != nil {
|
||||
// OAuth flow: use session provider for locked OAuth sessions
|
||||
// This prevents DPoP nonce race conditions during concurrent layer uploads
|
||||
slog.Debug("Creating ATProto client with OAuth session provider",
|
||||
"component", "registry/middleware",
|
||||
"did", did,
|
||||
"authMethod", authMethod)
|
||||
atprotoClient = atproto.NewClientWithSessionProvider(pdsEndpoint, did, nr.refresher)
|
||||
} else {
|
||||
// App-password flow (or fallback): use Basic Auth token cache
|
||||
accessToken, ok := auth.GetGlobalTokenCache().Get(did)
|
||||
if !ok {
|
||||
slog.Debug("No cached access token found for app-password auth",
|
||||
"component", "registry/middleware",
|
||||
"did", did,
|
||||
"authMethod", authMethod)
|
||||
accessToken = "" // Will fail on manifest push, but let it try
|
||||
} else {
|
||||
slog.Debug("Creating ATProto client with app-password",
|
||||
"component", "registry/middleware",
|
||||
"did", did,
|
||||
"authMethod", authMethod,
|
||||
"token_length", len(accessToken))
|
||||
}
|
||||
atprotoClient = atproto.NewClient(pdsEndpoint, did, accessToken)
|
||||
}
|
||||
|
||||
// IMPORTANT: Use only the image name (not identity/image) for ATProto storage
|
||||
// ATProto records are scoped to the user's DID, so we don't need the identity prefix
|
||||
// Example: "evan.jarrett.net/debian" -> store as "debian"
|
||||
repositoryName := imageName
|
||||
|
||||
// Default auth method to OAuth if not already set (backward compatibility with old tokens)
|
||||
if authMethod == "" {
|
||||
authMethod = token.AuthMethodOAuth
|
||||
// Get UserContext from request context (set by UserContextMiddleware)
|
||||
userCtx := auth.FromContext(ctx)
|
||||
if userCtx == nil {
|
||||
return nil, fmt.Errorf("UserContext not set in request context - ensure UserContextMiddleware is configured")
|
||||
}
|
||||
|
||||
// Set target repository info on UserContext
|
||||
// ATProtoClient is cached lazily via userCtx.GetATProtoClient()
|
||||
userCtx.SetTarget(did, handle, pdsEndpoint, repositoryName, holdDID)
|
||||
|
||||
// Create routing repository - routes manifests to ATProto, blobs to hold service
|
||||
// The registry is stateless - no local storage is used
|
||||
// Bundle all context into a single RegistryContext struct
|
||||
//
|
||||
// NOTE: We create a fresh RoutingRepository on every request (no caching) because:
|
||||
// 1. Each layer upload is a separate HTTP request (possibly different process)
|
||||
// 2. OAuth sessions can be refreshed/invalidated between requests
|
||||
// 3. The refresher already caches sessions efficiently (in-memory + DB)
|
||||
// 4. Caching the repository with a stale ATProtoClient causes refresh token errors
|
||||
registryCtx := &storage.RegistryContext{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
HoldDID: holdDID,
|
||||
PDSEndpoint: pdsEndpoint,
|
||||
Repository: repositoryName,
|
||||
ServiceToken: serviceToken, // Cached service token from middleware validation
|
||||
ATProtoClient: atprotoClient,
|
||||
AuthMethod: authMethod, // Auth method from JWT token
|
||||
Database: nr.database,
|
||||
Authorizer: nr.authorizer,
|
||||
Refresher: nr.refresher,
|
||||
ReadmeCache: nr.readmeCache,
|
||||
}
|
||||
|
||||
return storage.NewRoutingRepository(repo, registryCtx), nil
|
||||
// 4. ATProtoClient is now cached in UserContext via GetATProtoClient()
|
||||
return storage.NewRoutingRepository(repo, userCtx, nr.sqlDB), nil
|
||||
}
|
||||
|
||||
// Repositories delegates to underlying namespace
|
||||
@@ -490,8 +201,7 @@ func (nr *NamespaceResolver) BlobStatter() distribution.BlobStatter {
|
||||
// findHoldDID determines which hold DID to use for blob storage
|
||||
// Priority order:
|
||||
// 1. User's sailor profile defaultHold (if set)
|
||||
// 2. User's own hold record (io.atcr.hold)
|
||||
// 3. AppView's default hold DID
|
||||
// 2. AppView's default hold DID
|
||||
// Returns a hold DID (e.g., "did:web:hold01.atcr.io"), or empty string if none configured
|
||||
func (nr *NamespaceResolver) findHoldDID(ctx context.Context, did, pdsEndpoint string) string {
|
||||
// Create ATProto client (without auth - reading public records)
|
||||
@@ -504,22 +214,20 @@ func (nr *NamespaceResolver) findHoldDID(ctx context.Context, did, pdsEndpoint s
|
||||
slog.Warn("Failed to read profile", "did", did, "error", err)
|
||||
}
|
||||
|
||||
if profile != nil && profile.DefaultHold != nil && *profile.DefaultHold != "" {
|
||||
defaultHold := *profile.DefaultHold
|
||||
// Profile exists with defaultHold set
|
||||
// In test mode, verify it's reachable before using it
|
||||
if profile != nil && profile.DefaultHold != "" {
|
||||
// In test mode, verify the hold is reachable (fall back to default if not)
|
||||
// In production, trust the user's profile and return their hold
|
||||
if nr.testMode {
|
||||
if nr.isHoldReachable(ctx, defaultHold) {
|
||||
return defaultHold
|
||||
if nr.isHoldReachable(ctx, profile.DefaultHold) {
|
||||
return profile.DefaultHold
|
||||
}
|
||||
slog.Debug("User's defaultHold unreachable, falling back to default", "component", "registry/middleware/testmode", "default_hold", defaultHold)
|
||||
slog.Debug("User's defaultHold unreachable, falling back to default", "component", "registry/middleware/testmode", "default_hold", profile.DefaultHold)
|
||||
return nr.defaultHoldDID
|
||||
}
|
||||
return defaultHold
|
||||
return profile.DefaultHold
|
||||
}
|
||||
|
||||
// Profile doesn't exist or defaultHold is null/empty
|
||||
// Legacy io.atcr.hold records are no longer supported - use AppView default
|
||||
// No profile defaultHold - use AppView default
|
||||
return nr.defaultHoldDID
|
||||
}
|
||||
|
||||
@@ -542,10 +250,17 @@ func (nr *NamespaceResolver) isHoldReachable(ctx context.Context, holdDID string
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractAuthMethod is an HTTP middleware that extracts the auth method from the JWT Authorization header
|
||||
// and stores it in the request context for later use by the registry middleware
|
||||
// ExtractAuthMethod is an HTTP middleware that extracts the auth method and puller DID from the JWT Authorization header
|
||||
// and stores them in the request context for later use by the registry middleware.
|
||||
// Also stores the HTTP method for routing decisions (GET/HEAD = pull, PUT/POST = push).
|
||||
func ExtractAuthMethod(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Store HTTP method in context for routing decisions
|
||||
// This is used by routing_repository.go to distinguish pull (GET/HEAD) from push (PUT/POST)
|
||||
ctx = context.WithValue(ctx, "http.request.method", r.Method)
|
||||
|
||||
// Extract Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
@@ -558,15 +273,71 @@ func ExtractAuthMethod(next http.Handler) http.Handler {
|
||||
authMethod := token.ExtractAuthMethod(tokenString)
|
||||
if authMethod != "" {
|
||||
// Store in context for registry middleware
|
||||
ctx := context.WithValue(r.Context(), authMethodKey, authMethod)
|
||||
r = r.WithContext(ctx)
|
||||
slog.Debug("Extracted auth method from JWT",
|
||||
"component", "registry/middleware",
|
||||
"authMethod", authMethod)
|
||||
ctx = context.WithValue(ctx, authMethodKey, authMethod)
|
||||
}
|
||||
|
||||
// Extract puller DID (Subject) from JWT
|
||||
// This is the authenticated user's DID, used for service token requests
|
||||
pullerDID := token.ExtractSubject(tokenString)
|
||||
if pullerDID != "" {
|
||||
ctx = context.WithValue(ctx, pullerDIDKey, pullerDID)
|
||||
}
|
||||
|
||||
slog.Debug("Extracted auth info from JWT",
|
||||
"component", "registry/middleware",
|
||||
"authMethod", authMethod,
|
||||
"pullerDID", pullerDID,
|
||||
"httpMethod", r.Method)
|
||||
}
|
||||
}
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// UserContextMiddleware creates a UserContext from the extracted JWT claims
|
||||
// and stores it in the request context for use throughout request processing.
|
||||
// This middleware should be chained AFTER ExtractAuthMethod.
|
||||
func UserContextMiddleware(deps *auth.Dependencies) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get values set by ExtractAuthMethod
|
||||
authMethod, _ := ctx.Value(authMethodKey).(string)
|
||||
pullerDID, _ := ctx.Value(pullerDIDKey).(string)
|
||||
|
||||
// Build UserContext with all dependencies
|
||||
userCtx := auth.NewUserContext(pullerDID, authMethod, r.Method, deps)
|
||||
|
||||
// Eagerly resolve user's PDS for authenticated users
|
||||
// This is a fast path that avoids lazy loading in most cases
|
||||
if userCtx.IsAuthenticated {
|
||||
if err := userCtx.ResolvePDS(ctx); err != nil {
|
||||
slog.Warn("Failed to resolve puller's PDS",
|
||||
"component", "registry/middleware",
|
||||
"did", pullerDID,
|
||||
"error", err)
|
||||
// Continue without PDS - will fail on service token request
|
||||
}
|
||||
|
||||
// Ensure user has profile and crew membership (runs in background, cached)
|
||||
userCtx.EnsureUserSetup()
|
||||
}
|
||||
|
||||
// Store UserContext in request context
|
||||
ctx = auth.WithUserContext(ctx, userCtx)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
slog.Debug("Created UserContext",
|
||||
"component", "registry/middleware",
|
||||
"isAuthenticated", userCtx.IsAuthenticated,
|
||||
"authMethod", userCtx.AuthMethod,
|
||||
"action", userCtx.Action.String(),
|
||||
"pullerDID", pullerDID)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,11 +67,6 @@ func TestSetGlobalAuthorizer(t *testing.T) {
|
||||
// If we get here without panic, test passes
|
||||
}
|
||||
|
||||
func TestSetGlobalReadmeCache(t *testing.T) {
|
||||
SetGlobalReadmeCache(nil)
|
||||
// If we get here without panic, test passes
|
||||
}
|
||||
|
||||
// TestInitATProtoResolver tests the initialization function
|
||||
func TestInitATProtoResolver(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -134,17 +129,6 @@ func TestInitATProtoResolver(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthErrorMessage tests the error message formatting
|
||||
func TestAuthErrorMessage(t *testing.T) {
|
||||
resolver := &NamespaceResolver{
|
||||
baseURL: "https://atcr.io",
|
||||
}
|
||||
|
||||
err := resolver.authErrorMessage("OAuth session expired")
|
||||
assert.Contains(t, err.Error(), "OAuth session expired")
|
||||
assert.Contains(t, err.Error(), "https://atcr.io/auth/oauth/login")
|
||||
}
|
||||
|
||||
// TestFindHoldDID_DefaultFallback tests default hold DID fallback
|
||||
func TestFindHoldDID_DefaultFallback(t *testing.T) {
|
||||
// Start a mock PDS server that returns 404 for profile and empty list for holds
|
||||
@@ -204,34 +188,9 @@ func TestFindHoldDID_SailorProfile(t *testing.T) {
|
||||
assert.Equal(t, "did:web:user.hold.io", holdDID, "should use sailor profile's defaultHold")
|
||||
}
|
||||
|
||||
// TestFindHoldDID_NoProfile tests fallback to default hold when no profile exists
|
||||
func TestFindHoldDID_NoProfile(t *testing.T) {
|
||||
// Start a mock PDS server that returns 404 for profile
|
||||
mockPDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/xrpc/com.atproto.repo.getRecord" {
|
||||
// Profile not found
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer mockPDS.Close()
|
||||
|
||||
resolver := &NamespaceResolver{
|
||||
defaultHoldDID: "did:web:default.atcr.io",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
holdDID := resolver.findHoldDID(ctx, "did:plc:test123", mockPDS.URL)
|
||||
|
||||
// Should fall back to default hold DID when no profile exists
|
||||
// Note: Legacy io.atcr.hold records are no longer supported
|
||||
assert.Equal(t, "did:web:default.atcr.io", holdDID, "should fall back to default hold DID")
|
||||
}
|
||||
|
||||
// TestFindHoldDID_Priority tests that profile takes priority over default
|
||||
// TestFindHoldDID_Priority tests the priority order
|
||||
func TestFindHoldDID_Priority(t *testing.T) {
|
||||
// Start a mock PDS server that returns profile
|
||||
// Start a mock PDS server that returns both profile and hold records
|
||||
mockPDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/xrpc/com.atproto.repo.getRecord" {
|
||||
// Return sailor profile with defaultHold (highest priority)
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
// Package readme provides README fetching, rendering, and caching functionality
|
||||
// for container repositories. It fetches markdown content from URLs, renders it
|
||||
// to sanitized HTML using GitHub-flavored markdown, and caches the results in
|
||||
// a database with configurable TTL.
|
||||
package readme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache stores rendered README HTML in the database
|
||||
type Cache struct {
|
||||
db *sql.DB
|
||||
fetcher *Fetcher
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewCache creates a new README cache
|
||||
func NewCache(db *sql.DB, ttl time.Duration) *Cache {
|
||||
if ttl == 0 {
|
||||
ttl = 1 * time.Hour // Default TTL
|
||||
}
|
||||
return &Cache{
|
||||
db: db,
|
||||
fetcher: NewFetcher(),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a README from cache or fetches it
|
||||
func (c *Cache) Get(ctx context.Context, readmeURL string) (string, error) {
|
||||
// Try to get from cache
|
||||
html, fetchedAt, err := c.getFromDB(readmeURL)
|
||||
if err == nil {
|
||||
// Check if cache is still valid
|
||||
if time.Since(fetchedAt) < c.ttl {
|
||||
return html, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss or expired, fetch fresh content
|
||||
html, err = c.fetcher.FetchAndRender(ctx, readmeURL)
|
||||
if err != nil {
|
||||
// If fetch fails but we have stale cache, return it
|
||||
if html != "" {
|
||||
return html, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
if err := c.storeInDB(readmeURL, html); err != nil {
|
||||
// Log error but don't fail - we have the content
|
||||
slog.Warn("Failed to cache README", "error", err)
|
||||
}
|
||||
|
||||
return html, nil
|
||||
}
|
||||
|
||||
// getFromDB retrieves cached README from database
|
||||
func (c *Cache) getFromDB(readmeURL string) (string, time.Time, error) {
|
||||
var html string
|
||||
var fetchedAt time.Time
|
||||
|
||||
err := c.db.QueryRow(`
|
||||
SELECT html, fetched_at
|
||||
FROM readme_cache
|
||||
WHERE url = ?
|
||||
`, readmeURL).Scan(&html, &fetchedAt)
|
||||
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
return html, fetchedAt, nil
|
||||
}
|
||||
|
||||
// storeInDB stores rendered README in database
|
||||
func (c *Cache) storeInDB(readmeURL, html string) error {
|
||||
_, err := c.db.Exec(`
|
||||
INSERT INTO readme_cache (url, html, fetched_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(url) DO UPDATE SET
|
||||
html = excluded.html,
|
||||
fetched_at = excluded.fetched_at
|
||||
`, readmeURL, html, time.Now())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate removes a README from the cache
|
||||
func (c *Cache) Invalidate(readmeURL string) error {
|
||||
_, err := c.db.Exec(`
|
||||
DELETE FROM readme_cache
|
||||
WHERE url = ?
|
||||
`, readmeURL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Cleanup removes expired entries from the cache
|
||||
func (c *Cache) Cleanup() error {
|
||||
cutoff := time.Now().Add(-c.ttl * 2) // Keep for 2x TTL
|
||||
_, err := c.db.Exec(`
|
||||
DELETE FROM readme_cache
|
||||
WHERE fetched_at < ?
|
||||
`, cutoff)
|
||||
return err
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package readme
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCache_Struct(t *testing.T) {
|
||||
// Simple struct test
|
||||
cache := &Cache{}
|
||||
if cache == nil {
|
||||
t.Error("Expected non-nil cache")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add cache operation tests
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -180,6 +181,27 @@ func getBaseURL(u *url.URL) string {
|
||||
return fmt.Sprintf("%s://%s%s", u.Scheme, u.Host, path)
|
||||
}
|
||||
|
||||
// Is404 returns true if the error indicates a 404 Not Found response
|
||||
func Is404(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "unexpected status code: 404")
|
||||
}
|
||||
|
||||
// RenderMarkdown renders a markdown string to sanitized HTML
|
||||
// This is used for rendering repo page descriptions stored in the database
|
||||
func (f *Fetcher) RenderMarkdown(content []byte) (string, error) {
|
||||
// Render markdown to HTML (no base URL for repo page descriptions)
|
||||
return f.renderMarkdown(content, "")
|
||||
}
|
||||
|
||||
// Regex patterns for matching relative URLs that need rewriting
|
||||
// These match src="..." or href="..." where the URL is relative (not absolute, not data:, not #anchor)
|
||||
var (
|
||||
// Match src="filename" where filename doesn't start with http://, https://, //, /, #, data:, or mailto:
|
||||
relativeSrcPattern = regexp.MustCompile(`src="([^"/:][^"]*)"`)
|
||||
// Match href="filename" where filename doesn't start with http://, https://, //, /, #, data:, or mailto:
|
||||
relativeHrefPattern = regexp.MustCompile(`href="([^"/:][^"]*)"`)
|
||||
)
|
||||
|
||||
// rewriteRelativeURLs converts relative URLs to absolute URLs
|
||||
func rewriteRelativeURLs(html, baseURL string) string {
|
||||
if baseURL == "" {
|
||||
@@ -191,20 +213,51 @@ func rewriteRelativeURLs(html, baseURL string) string {
|
||||
return html
|
||||
}
|
||||
|
||||
// Simple string replacement for common patterns
|
||||
// This is a basic implementation - for production, consider using an HTML parser
|
||||
// Handle root-relative URLs (starting with /) first
|
||||
// Must be done before bare relative URLs to avoid double-processing
|
||||
if base.Scheme != "" && base.Host != "" {
|
||||
root := fmt.Sprintf("%s://%s/", base.Scheme, base.Host)
|
||||
// Replace src="/" and href="/" but not src="//" (protocol-relative URLs)
|
||||
html = strings.ReplaceAll(html, `src="/`, fmt.Sprintf(`src="%s`, root))
|
||||
html = strings.ReplaceAll(html, `href="/`, fmt.Sprintf(`href="%s`, root))
|
||||
}
|
||||
|
||||
// Handle explicit relative paths (./something and ../something)
|
||||
html = strings.ReplaceAll(html, `src="./`, fmt.Sprintf(`src="%s`, baseURL))
|
||||
html = strings.ReplaceAll(html, `href="./`, fmt.Sprintf(`href="%s`, baseURL))
|
||||
html = strings.ReplaceAll(html, `src="../`, fmt.Sprintf(`src="%s../`, baseURL))
|
||||
html = strings.ReplaceAll(html, `href="../`, fmt.Sprintf(`href="%s../`, baseURL))
|
||||
|
||||
// Handle root-relative URLs (starting with /)
|
||||
if base.Scheme != "" && base.Host != "" {
|
||||
root := fmt.Sprintf("%s://%s/", base.Scheme, base.Host)
|
||||
// Replace src="/" and href="/" but not src="//" (absolute URLs)
|
||||
html = strings.ReplaceAll(html, `src="/`, fmt.Sprintf(`src="%s`, root))
|
||||
html = strings.ReplaceAll(html, `href="/`, fmt.Sprintf(`href="%s`, root))
|
||||
}
|
||||
// Handle bare relative URLs (e.g., src="image.png" without ./ prefix)
|
||||
// Skip URLs that are already absolute (start with http://, https://, or //)
|
||||
// Skip anchors (#), data URLs (data:), and mailto links
|
||||
html = relativeSrcPattern.ReplaceAllStringFunc(html, func(match string) string {
|
||||
// Extract the URL from src="..."
|
||||
url := match[5 : len(match)-1] // Remove 'src="' and '"'
|
||||
|
||||
// Skip if already processed or is a special URL type
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") ||
|
||||
strings.HasPrefix(url, "//") || strings.HasPrefix(url, "#") ||
|
||||
strings.HasPrefix(url, "data:") || strings.HasPrefix(url, "mailto:") {
|
||||
return match
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`src="%s%s"`, baseURL, url)
|
||||
})
|
||||
|
||||
html = relativeHrefPattern.ReplaceAllStringFunc(html, func(match string) string {
|
||||
// Extract the URL from href="..."
|
||||
url := match[6 : len(match)-1] // Remove 'href="' and '"'
|
||||
|
||||
// Skip if already processed or is a special URL type
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") ||
|
||||
strings.HasPrefix(url, "//") || strings.HasPrefix(url, "#") ||
|
||||
strings.HasPrefix(url, "data:") || strings.HasPrefix(url, "mailto:") {
|
||||
return match
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`href="%s%s"`, baseURL, url)
|
||||
})
|
||||
|
||||
return html
|
||||
}
|
||||
|
||||
@@ -145,6 +145,48 @@ func TestRewriteRelativeURLs(t *testing.T) {
|
||||
baseURL: "https://example.com/docs/",
|
||||
expected: `<img src="https://example.com//cdn.example.com/image.png">`,
|
||||
},
|
||||
{
|
||||
name: "bare relative src (no ./ prefix)",
|
||||
html: `<img src="image.png">`,
|
||||
baseURL: "https://example.com/docs/",
|
||||
expected: `<img src="https://example.com/docs/image.png">`,
|
||||
},
|
||||
{
|
||||
name: "bare relative href (no ./ prefix)",
|
||||
html: `<a href="page.html">link</a>`,
|
||||
baseURL: "https://example.com/docs/",
|
||||
expected: `<a href="https://example.com/docs/page.html">link</a>`,
|
||||
},
|
||||
{
|
||||
name: "bare relative with path",
|
||||
html: `<img src="images/logo.png">`,
|
||||
baseURL: "https://example.com/docs/",
|
||||
expected: `<img src="https://example.com/docs/images/logo.png">`,
|
||||
},
|
||||
{
|
||||
name: "anchor links unchanged",
|
||||
html: `<a href="#section">link</a>`,
|
||||
baseURL: "https://example.com/docs/",
|
||||
expected: `<a href="#section">link</a>`,
|
||||
},
|
||||
{
|
||||
name: "data URLs unchanged",
|
||||
html: `<img src="data:image/png;base64,abc123">`,
|
||||
baseURL: "https://example.com/docs/",
|
||||
expected: `<img src="data:image/png;base64,abc123">`,
|
||||
},
|
||||
{
|
||||
name: "mailto links unchanged",
|
||||
html: `<a href="mailto:test@example.com">email</a>`,
|
||||
baseURL: "https://example.com/docs/",
|
||||
expected: `<a href="mailto:test@example.com">email</a>`,
|
||||
},
|
||||
{
|
||||
name: "mixed bare and prefixed relative URLs",
|
||||
html: `<img src="slices_and_lucy.png"><a href="./other.md">link</a>`,
|
||||
baseURL: "https://github.com/user/repo/blob/main/",
|
||||
expected: `<img src="https://github.com/user/repo/blob/main/slices_and_lucy.png"><a href="https://github.com/user/repo/blob/main/other.md">link</a>`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -157,4 +199,110 @@ func TestRewriteRelativeURLs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetcher_RenderMarkdown(t *testing.T) {
|
||||
fetcher := NewFetcher()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantContain string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple paragraph",
|
||||
content: "Hello, world!",
|
||||
wantContain: "<p>Hello, world!</p>",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "heading",
|
||||
content: "# My App",
|
||||
wantContain: "<h1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bold text",
|
||||
content: "This is **bold** text.",
|
||||
wantContain: "<strong>bold</strong>",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "italic text",
|
||||
content: "This is *italic* text.",
|
||||
wantContain: "<em>italic</em>",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "code block",
|
||||
content: "```\ncode here\n```",
|
||||
wantContain: "<pre>",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "link",
|
||||
content: "[Link text](https://example.com)",
|
||||
wantContain: `href="https://example.com"`,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "image",
|
||||
content: "",
|
||||
wantContain: `src="https://example.com/image.png"`,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unordered list",
|
||||
content: "- Item 1\n- Item 2",
|
||||
wantContain: "<ul>",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ordered list",
|
||||
content: "1. Item 1\n2. Item 2",
|
||||
wantContain: "<ol>",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
wantContain: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "complex markdown",
|
||||
content: "# Title\n\nA paragraph with **bold** and *italic* text.\n\n- List item 1\n- List item 2\n\n```go\nfunc main() {}\n```",
|
||||
wantContain: "<h1",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
html, err := fetcher.RenderMarkdown([]byte(tt.content))
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("RenderMarkdown() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && tt.wantContain != "" {
|
||||
if !containsSubstring(html, tt.wantContain) {
|
||||
t.Errorf("RenderMarkdown() = %q, want to contain %q", html, tt.wantContain)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
return len(substr) == 0 || (len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstringHelper(s, substr)))
|
||||
}
|
||||
|
||||
func containsSubstringHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: Add README fetching and caching tests
|
||||
|
||||
103
pkg/appview/readme/source.go
Normal file
103
pkg/appview/readme/source.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package readme
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Platform represents a supported Git hosting platform
|
||||
type Platform string
|
||||
|
||||
const (
|
||||
PlatformGitHub Platform = "github"
|
||||
PlatformGitLab Platform = "gitlab"
|
||||
PlatformTangled Platform = "tangled"
|
||||
)
|
||||
|
||||
// ParseSourceURL extracts platform, user, and repo from a source repository URL.
|
||||
// Returns ok=false if the URL is not a recognized pattern.
|
||||
func ParseSourceURL(sourceURL string) (platform Platform, user, repo string, ok bool) {
|
||||
if sourceURL == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(sourceURL)
|
||||
if err != nil {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
// Normalize: remove trailing slash and .git suffix
|
||||
path := strings.TrimSuffix(parsed.Path, "/")
|
||||
path = strings.TrimSuffix(path, ".git")
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
|
||||
if path == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
host := strings.ToLower(parsed.Host)
|
||||
|
||||
switch {
|
||||
case host == "github.com":
|
||||
// GitHub: github.com/{user}/{repo}
|
||||
parts := strings.SplitN(path, "/", 3)
|
||||
if len(parts) < 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
return PlatformGitHub, parts[0], parts[1], true
|
||||
|
||||
case host == "gitlab.com":
|
||||
// GitLab: gitlab.com/{user}/{repo} or gitlab.com/{group}/{subgroup}/{repo}
|
||||
// For nested groups, user = everything except last part, repo = last part
|
||||
lastSlash := strings.LastIndex(path, "/")
|
||||
if lastSlash == -1 || lastSlash == 0 {
|
||||
return "", "", "", false
|
||||
}
|
||||
user = path[:lastSlash]
|
||||
repo = path[lastSlash+1:]
|
||||
if user == "" || repo == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
return PlatformGitLab, user, repo, true
|
||||
|
||||
case host == "tangled.org" || host == "tangled.sh":
|
||||
// Tangled: tangled.org/{user}/{repo} or tangled.sh/@{user}/{repo} (legacy)
|
||||
// Strip leading @ from user if present
|
||||
path = strings.TrimPrefix(path, "@")
|
||||
parts := strings.SplitN(path, "/", 3)
|
||||
if len(parts) < 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
return PlatformTangled, parts[0], parts[1], true
|
||||
|
||||
default:
|
||||
return "", "", "", false
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveReadmeURL converts a source repository URL to a raw README URL.
|
||||
// Returns empty string if platform is not supported.
|
||||
func DeriveReadmeURL(sourceURL, branch string) string {
|
||||
platform, user, repo, ok := ParseSourceURL(sourceURL)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch platform {
|
||||
case PlatformGitHub:
|
||||
// https://raw.githubusercontent.com/{user}/{repo}/refs/heads/{branch}/README.md
|
||||
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/refs/heads/%s/README.md", user, repo, branch)
|
||||
|
||||
case PlatformGitLab:
|
||||
// https://gitlab.com/{user}/{repo}/-/raw/{branch}/README.md
|
||||
return fmt.Sprintf("https://gitlab.com/%s/%s/-/raw/%s/README.md", user, repo, branch)
|
||||
|
||||
case PlatformTangled:
|
||||
// https://tangled.org/{user}/{repo}/raw/{branch}/README.md
|
||||
return fmt.Sprintf("https://tangled.org/%s/%s/raw/%s/README.md", user, repo, branch)
|
||||
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
241
pkg/appview/readme/source_test.go
Normal file
241
pkg/appview/readme/source_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package readme
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseSourceURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sourceURL string
|
||||
wantPlatform Platform
|
||||
wantUser string
|
||||
wantRepo string
|
||||
wantOK bool
|
||||
}{
|
||||
// GitHub
|
||||
{
|
||||
name: "github standard",
|
||||
sourceURL: "https://github.com/bigmoves/quickslice",
|
||||
wantPlatform: PlatformGitHub,
|
||||
wantUser: "bigmoves",
|
||||
wantRepo: "quickslice",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "github with .git suffix",
|
||||
sourceURL: "https://github.com/user/repo.git",
|
||||
wantPlatform: PlatformGitHub,
|
||||
wantUser: "user",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "github with trailing slash",
|
||||
sourceURL: "https://github.com/user/repo/",
|
||||
wantPlatform: PlatformGitHub,
|
||||
wantUser: "user",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "github with subpath (ignored)",
|
||||
sourceURL: "https://github.com/user/repo/tree/main",
|
||||
wantPlatform: PlatformGitHub,
|
||||
wantUser: "user",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "github user only",
|
||||
sourceURL: "https://github.com/user",
|
||||
wantOK: false,
|
||||
},
|
||||
|
||||
// GitLab
|
||||
{
|
||||
name: "gitlab standard",
|
||||
sourceURL: "https://gitlab.com/user/repo",
|
||||
wantPlatform: PlatformGitLab,
|
||||
wantUser: "user",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "gitlab nested groups",
|
||||
sourceURL: "https://gitlab.com/group/subgroup/repo",
|
||||
wantPlatform: PlatformGitLab,
|
||||
wantUser: "group/subgroup",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "gitlab deep nested groups",
|
||||
sourceURL: "https://gitlab.com/a/b/c/d/repo",
|
||||
wantPlatform: PlatformGitLab,
|
||||
wantUser: "a/b/c/d",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "gitlab with .git suffix",
|
||||
sourceURL: "https://gitlab.com/user/repo.git",
|
||||
wantPlatform: PlatformGitLab,
|
||||
wantUser: "user",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
|
||||
// Tangled
|
||||
{
|
||||
name: "tangled standard",
|
||||
sourceURL: "https://tangled.org/evan.jarrett.net/at-container-registry",
|
||||
wantPlatform: PlatformTangled,
|
||||
wantUser: "evan.jarrett.net",
|
||||
wantRepo: "at-container-registry",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "tangled with legacy @ prefix",
|
||||
sourceURL: "https://tangled.org/@evan.jarrett.net/at-container-registry",
|
||||
wantPlatform: PlatformTangled,
|
||||
wantUser: "evan.jarrett.net",
|
||||
wantRepo: "at-container-registry",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "tangled.sh domain",
|
||||
sourceURL: "https://tangled.sh/user/repo",
|
||||
wantPlatform: PlatformTangled,
|
||||
wantUser: "user",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "tangled with trailing slash",
|
||||
sourceURL: "https://tangled.org/user/repo/",
|
||||
wantPlatform: PlatformTangled,
|
||||
wantUser: "user",
|
||||
wantRepo: "repo",
|
||||
wantOK: true,
|
||||
},
|
||||
|
||||
// Unsupported / Invalid
|
||||
{
|
||||
name: "unsupported platform",
|
||||
sourceURL: "https://bitbucket.org/user/repo",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "empty url",
|
||||
sourceURL: "",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "invalid url",
|
||||
sourceURL: "not-a-url",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "just host",
|
||||
sourceURL: "https://github.com",
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
platform, user, repo, ok := ParseSourceURL(tt.sourceURL)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("ParseSourceURL(%q) ok = %v, want %v", tt.sourceURL, ok, tt.wantOK)
|
||||
return
|
||||
}
|
||||
if !tt.wantOK {
|
||||
return
|
||||
}
|
||||
if platform != tt.wantPlatform {
|
||||
t.Errorf("ParseSourceURL(%q) platform = %v, want %v", tt.sourceURL, platform, tt.wantPlatform)
|
||||
}
|
||||
if user != tt.wantUser {
|
||||
t.Errorf("ParseSourceURL(%q) user = %q, want %q", tt.sourceURL, user, tt.wantUser)
|
||||
}
|
||||
if repo != tt.wantRepo {
|
||||
t.Errorf("ParseSourceURL(%q) repo = %q, want %q", tt.sourceURL, repo, tt.wantRepo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveReadmeURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sourceURL string
|
||||
branch string
|
||||
want string
|
||||
}{
|
||||
// GitHub
|
||||
{
|
||||
name: "github main",
|
||||
sourceURL: "https://github.com/bigmoves/quickslice",
|
||||
branch: "main",
|
||||
want: "https://raw.githubusercontent.com/bigmoves/quickslice/refs/heads/main/README.md",
|
||||
},
|
||||
{
|
||||
name: "github master",
|
||||
sourceURL: "https://github.com/user/repo",
|
||||
branch: "master",
|
||||
want: "https://raw.githubusercontent.com/user/repo/refs/heads/master/README.md",
|
||||
},
|
||||
|
||||
// GitLab
|
||||
{
|
||||
name: "gitlab main",
|
||||
sourceURL: "https://gitlab.com/user/repo",
|
||||
branch: "main",
|
||||
want: "https://gitlab.com/user/repo/-/raw/main/README.md",
|
||||
},
|
||||
{
|
||||
name: "gitlab nested groups",
|
||||
sourceURL: "https://gitlab.com/group/subgroup/repo",
|
||||
branch: "main",
|
||||
want: "https://gitlab.com/group/subgroup/repo/-/raw/main/README.md",
|
||||
},
|
||||
|
||||
// Tangled
|
||||
{
|
||||
name: "tangled main",
|
||||
sourceURL: "https://tangled.org/evan.jarrett.net/at-container-registry",
|
||||
branch: "main",
|
||||
want: "https://tangled.org/evan.jarrett.net/at-container-registry/raw/main/README.md",
|
||||
},
|
||||
{
|
||||
name: "tangled legacy @ prefix",
|
||||
sourceURL: "https://tangled.org/@user/repo",
|
||||
branch: "main",
|
||||
want: "https://tangled.org/user/repo/raw/main/README.md",
|
||||
},
|
||||
|
||||
// Unsupported
|
||||
{
|
||||
name: "unsupported platform",
|
||||
sourceURL: "https://bitbucket.org/user/repo",
|
||||
branch: "main",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty url",
|
||||
sourceURL: "",
|
||||
branch: "main",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DeriveReadmeURL(tt.sourceURL, tt.branch)
|
||||
if got != tt.want {
|
||||
t.Errorf("DeriveReadmeURL(%q, %q) = %q, want %q", tt.sourceURL, tt.branch, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -27,8 +27,9 @@ type UIDependencies struct {
|
||||
BaseURL string
|
||||
DeviceStore *db.DeviceStore
|
||||
HealthChecker *holdhealth.Checker
|
||||
ReadmeCache *readme.Cache
|
||||
ReadmeFetcher *readme.Fetcher
|
||||
Templates *template.Template
|
||||
DefaultHoldDID string // For UserContext creation
|
||||
}
|
||||
|
||||
// RegisterUIRoutes registers all web UI and API routes on the provided router
|
||||
@@ -36,6 +37,14 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
// Extract trimmed registry URL for templates
|
||||
registryURL := trimRegistryURL(deps.BaseURL)
|
||||
|
||||
// Create web auth dependencies for middleware (enables UserContext in web routes)
|
||||
webAuthDeps := middleware.WebAuthDeps{
|
||||
SessionStore: deps.SessionStore,
|
||||
Database: deps.Database,
|
||||
Refresher: deps.Refresher,
|
||||
DefaultHoldDID: deps.DefaultHoldDID,
|
||||
}
|
||||
|
||||
// OAuth login routes (public)
|
||||
router.Get("/auth/oauth/login", (&uihandlers.LoginHandler{
|
||||
Templates: deps.Templates,
|
||||
@@ -45,7 +54,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
|
||||
// Public routes (with optional auth for navbar)
|
||||
// SECURITY: Public pages use read-only DB
|
||||
router.Get("/", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.HomeHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -53,7 +62,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Get("/api/recent-pushes", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/recent-pushes", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.RecentPushesHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -63,7 +72,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// SECURITY: Search uses read-only DB to prevent writes and limit access to sensitive tables
|
||||
router.Get("/search", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/search", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.SearchHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -71,7 +80,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Get("/api/search-results", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/search-results", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.SearchResultsHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -80,7 +89,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// Install page (public)
|
||||
router.Get("/install", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/install", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.InstallHandler{
|
||||
Templates: deps.Templates,
|
||||
RegistryURL: registryURL,
|
||||
@@ -88,7 +97,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// API route for repository stats (public, read-only)
|
||||
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.GetStatsHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -96,7 +105,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// API routes for stars (require authentication)
|
||||
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
||||
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.StarRepositoryHandler{
|
||||
DB: deps.Database, // Needs write access
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -104,7 +113,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
||||
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.UnstarRepositoryHandler{
|
||||
DB: deps.Database, // Needs write access
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -112,7 +121,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.CheckStarHandler{
|
||||
DB: deps.ReadOnlyDB, // Read-only check
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -121,7 +130,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
).ServeHTTP)
|
||||
|
||||
// Manifest detail API endpoint
|
||||
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.ManifestDetailHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
@@ -133,7 +142,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
HealthChecker: deps.HealthChecker,
|
||||
}).ServeHTTP)
|
||||
|
||||
router.Get("/u/{handle}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/u/{handle}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.UserPageHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -152,7 +161,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
DB: deps.ReadOnlyDB,
|
||||
}).ServeHTTP)
|
||||
|
||||
router.Get("/r/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||
router.Get("/r/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.RepositoryPageHandler{
|
||||
DB: deps.ReadOnlyDB,
|
||||
Templates: deps.Templates,
|
||||
@@ -160,13 +169,13 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
Directory: deps.OAuthClientApp.Dir,
|
||||
Refresher: deps.Refresher,
|
||||
HealthChecker: deps.HealthChecker,
|
||||
ReadmeCache: deps.ReadmeCache,
|
||||
ReadmeFetcher: deps.ReadmeFetcher,
|
||||
},
|
||||
).ServeHTTP)
|
||||
|
||||
// Authenticated routes
|
||||
router.Group(func(r chi.Router) {
|
||||
r.Use(middleware.RequireAuth(deps.SessionStore, deps.Database))
|
||||
r.Use(middleware.RequireAuthWithDeps(webAuthDeps))
|
||||
|
||||
r.Get("/settings", (&uihandlers.SettingsHandler{
|
||||
Templates: deps.Templates,
|
||||
@@ -188,6 +197,11 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
Refresher: deps.Refresher,
|
||||
}).ServeHTTP)
|
||||
|
||||
r.Post("/api/images/{repository}/avatar", (&uihandlers.UploadAvatarHandler{
|
||||
DB: deps.Database,
|
||||
Refresher: deps.Refresher,
|
||||
}).ServeHTTP)
|
||||
|
||||
// Device approval page (authenticated)
|
||||
r.Get("/device", (&uihandlers.DeviceApprovalPageHandler{
|
||||
Store: deps.DeviceStore,
|
||||
@@ -219,6 +233,14 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
||||
}
|
||||
router.Get("/auth/logout", logoutHandler.ServeHTTP)
|
||||
router.Post("/auth/logout", logoutHandler.ServeHTTP)
|
||||
|
||||
// Custom 404 handler
|
||||
router.NotFound(middleware.OptionalAuthWithDeps(webAuthDeps)(
|
||||
&uihandlers.NotFoundHandler{
|
||||
Templates: deps.Templates,
|
||||
RegistryURL: registryURL,
|
||||
},
|
||||
).ServeHTTP)
|
||||
}
|
||||
|
||||
// CORSMiddleware returns a middleware that sets CORS headers for API endpoints
|
||||
|
||||
@@ -38,6 +38,10 @@
|
||||
--version-badge-text: #7b1fa2;
|
||||
--version-badge-border: #ba68c8;
|
||||
|
||||
/* Attestation badge */
|
||||
--attestation-badge-bg: #d1fae5;
|
||||
--attestation-badge-text: #065f46;
|
||||
|
||||
/* Hero section colors */
|
||||
--hero-bg-start: #f8f9fa;
|
||||
--hero-bg-end: #e9ecef;
|
||||
@@ -90,6 +94,10 @@
|
||||
--version-badge-text: #ffffff;
|
||||
--version-badge-border: #ba68c8;
|
||||
|
||||
/* Attestation badge */
|
||||
--attestation-badge-bg: #065f46;
|
||||
--attestation-badge-text: #6ee7b7;
|
||||
|
||||
/* Hero section colors */
|
||||
--hero-bg-start: #2d2d2d;
|
||||
--hero-bg-end: #1a1a1a;
|
||||
@@ -109,7 +117,9 @@
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
font-family:
|
||||
-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue",
|
||||
Arial, sans-serif;
|
||||
background: var(--bg);
|
||||
color: var(--fg);
|
||||
line-height: 1.6;
|
||||
@@ -170,7 +180,7 @@ body {
|
||||
}
|
||||
|
||||
.nav-links a:hover {
|
||||
background:var(--secondary);
|
||||
background: var(--secondary);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
@@ -193,7 +203,7 @@ body {
|
||||
}
|
||||
|
||||
.user-menu-btn:hover {
|
||||
background:var(--secondary);
|
||||
background: var(--secondary);
|
||||
}
|
||||
|
||||
.user-avatar {
|
||||
@@ -266,7 +276,7 @@ body {
|
||||
position: absolute;
|
||||
top: calc(100% + 0.5rem);
|
||||
right: 0;
|
||||
background:var(--bg);
|
||||
background: var(--bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
@@ -287,7 +297,7 @@ body {
|
||||
color: var(--fg);
|
||||
text-decoration: none;
|
||||
border: none;
|
||||
background:var(--bg);
|
||||
background: var(--bg);
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
font-size: 0.95rem;
|
||||
@@ -309,7 +319,10 @@ body {
|
||||
}
|
||||
|
||||
/* Buttons */
|
||||
button, .btn, .btn-primary, .btn-secondary {
|
||||
button,
|
||||
.btn,
|
||||
.btn-primary,
|
||||
.btn-secondary {
|
||||
padding: 0.5rem 1rem;
|
||||
background: var(--button-primary);
|
||||
color: var(--btn-text);
|
||||
@@ -322,7 +335,10 @@ button, .btn, .btn-primary, .btn-secondary {
|
||||
transition: opacity 0.2s;
|
||||
}
|
||||
|
||||
button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
|
||||
button:hover,
|
||||
.btn:hover,
|
||||
.btn-primary:hover,
|
||||
.btn-secondary:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
@@ -393,12 +409,13 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
|
||||
}
|
||||
|
||||
/* Cards */
|
||||
.push-card, .repository-card {
|
||||
.push-card,
|
||||
.repository-card {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
margin-bottom: 1rem;
|
||||
background:var(--bg);
|
||||
background: var(--bg);
|
||||
box-shadow: var(--shadow-sm);
|
||||
}
|
||||
|
||||
@@ -449,7 +466,7 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
|
||||
}
|
||||
|
||||
.digest {
|
||||
font-family: 'Monaco', 'Courier New', monospace;
|
||||
font-family: "Monaco", "Courier New", monospace;
|
||||
font-size: 0.85rem;
|
||||
background: var(--code-bg);
|
||||
padding: 0.1rem 0.3rem;
|
||||
@@ -492,7 +509,7 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
|
||||
}
|
||||
|
||||
.docker-command-text {
|
||||
font-family: 'Monaco', 'Courier New', monospace;
|
||||
font-family: "Monaco", "Courier New", monospace;
|
||||
font-size: 0.85rem;
|
||||
color: var(--fg);
|
||||
flex: 0 1 auto;
|
||||
@@ -510,7 +527,9 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
|
||||
border-radius: 4px;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transition: opacity 0.2s, visibility 0.2s;
|
||||
transition:
|
||||
opacity 0.2s,
|
||||
visibility 0.2s;
|
||||
}
|
||||
|
||||
.docker-command:hover .copy-btn {
|
||||
@@ -752,7 +771,7 @@ a.license-badge:hover {
|
||||
}
|
||||
|
||||
.repo-stats {
|
||||
color:var(--border-dark);
|
||||
color: var(--border-dark);
|
||||
font-size: 0.9rem;
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
@@ -781,17 +800,20 @@ a.license-badge:hover {
|
||||
padding-top: 1rem;
|
||||
}
|
||||
|
||||
.tags-section, .manifests-section {
|
||||
.tags-section,
|
||||
.manifests-section {
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.tags-section h3, .manifests-section h3 {
|
||||
.tags-section h3,
|
||||
.manifests-section h3 {
|
||||
font-size: 1.1rem;
|
||||
margin-bottom: 0.5rem;
|
||||
color: var(--secondary);
|
||||
}
|
||||
|
||||
.tag-row, .manifest-row {
|
||||
.tag-row,
|
||||
.manifest-row {
|
||||
display: flex;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
@@ -799,7 +821,8 @@ a.license-badge:hover {
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
|
||||
.tag-row:last-child, .manifest-row:last-child {
|
||||
.tag-row:last-child,
|
||||
.manifest-row:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
@@ -821,7 +844,7 @@ a.license-badge:hover {
|
||||
}
|
||||
|
||||
.settings-section {
|
||||
background:var(--bg);
|
||||
background: var(--bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
@@ -918,7 +941,7 @@ a.license-badge:hover {
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
font-family: 'Monaco', 'Courier New', monospace;
|
||||
font-family: "Monaco", "Courier New", monospace;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--border);
|
||||
}
|
||||
@@ -1004,13 +1027,6 @@ a.license-badge:hover {
|
||||
margin: 1rem 0;
|
||||
}
|
||||
|
||||
/* Load More Button */
|
||||
.load-more {
|
||||
width: 100%;
|
||||
margin-top: 1rem;
|
||||
background: var(--secondary);
|
||||
}
|
||||
|
||||
/* Login Page */
|
||||
.login-page {
|
||||
max-width: 450px;
|
||||
@@ -1031,7 +1047,7 @@ a.license-badge:hover {
|
||||
}
|
||||
|
||||
.login-form {
|
||||
background:var(--bg);
|
||||
background: var(--bg);
|
||||
padding: 2rem;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--border);
|
||||
@@ -1182,7 +1198,7 @@ a.license-badge:hover {
|
||||
}
|
||||
|
||||
.repository-header {
|
||||
background:var(--bg);
|
||||
background: var(--bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
padding: 2rem;
|
||||
@@ -1220,6 +1236,35 @@ a.license-badge:hover {
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.repo-hero-icon-wrapper {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.avatar-upload-overlay {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
border-radius: 12px;
|
||||
opacity: 0;
|
||||
cursor: pointer;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.avatar-upload-overlay i {
|
||||
color: white;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
}
|
||||
|
||||
.repo-hero-icon-wrapper:hover .avatar-upload-overlay {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.repo-hero-info {
|
||||
flex: 1;
|
||||
}
|
||||
@@ -1290,7 +1335,7 @@ a.license-badge:hover {
|
||||
}
|
||||
|
||||
.star-btn.starred {
|
||||
border-color:var(--star);
|
||||
border-color: var(--star);
|
||||
background: var(--code-bg);
|
||||
}
|
||||
|
||||
@@ -1374,7 +1419,7 @@ a.license-badge:hover {
|
||||
}
|
||||
|
||||
.repo-section {
|
||||
background:var(--bg);
|
||||
background: var(--bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
@@ -1389,20 +1434,23 @@ a.license-badge:hover {
|
||||
border-bottom: 2px solid var(--border);
|
||||
}
|
||||
|
||||
.tags-list, .manifests-list {
|
||||
.tags-list,
|
||||
.manifests-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.tag-item, .manifest-item {
|
||||
.tag-item,
|
||||
.manifest-item {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 6px;
|
||||
padding: 1rem;
|
||||
background: var(--hover-bg);
|
||||
}
|
||||
|
||||
.tag-item-header, .manifest-item-header {
|
||||
.tag-item-header,
|
||||
.manifest-item-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
@@ -1532,7 +1580,7 @@ a.license-badge:hover {
|
||||
color: var(--fg);
|
||||
border: 1px solid var(--border);
|
||||
white-space: nowrap;
|
||||
font-family: 'Monaco', 'Courier New', monospace;
|
||||
font-family: "Monaco", "Courier New", monospace;
|
||||
}
|
||||
|
||||
.platforms-inline {
|
||||
@@ -1570,20 +1618,21 @@ a.license-badge:hover {
|
||||
.badge-attestation {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
background: #f3e8ff;
|
||||
color: #7c3aed;
|
||||
border: 1px solid #c4b5fd;
|
||||
border-radius: 4px;
|
||||
font-size: 0.85rem;
|
||||
gap: 0.3rem;
|
||||
padding: 0.25rem 0.6rem;
|
||||
background: var(--attestation-badge-bg);
|
||||
color: var(--attestation-badge-text);
|
||||
border-radius: 12px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
margin-left: 0.5rem;
|
||||
vertical-align: middle;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.badge-attestation .lucide {
|
||||
width: 0.9rem;
|
||||
height: 0.9rem;
|
||||
width: 0.75rem;
|
||||
height: 0.75rem;
|
||||
}
|
||||
|
||||
/* Featured Repositories Section */
|
||||
@@ -1736,7 +1785,11 @@ a.license-badge:hover {
|
||||
|
||||
/* Hero Section */
|
||||
.hero-section {
|
||||
background: linear-gradient(135deg, var(--hero-bg-start) 0%, var(--hero-bg-end) 100%);
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
var(--hero-bg-start) 0%,
|
||||
var(--hero-bg-end) 100%
|
||||
);
|
||||
padding: 4rem 2rem;
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
@@ -1801,7 +1854,7 @@ a.license-badge:hover {
|
||||
.terminal-content {
|
||||
padding: 1.5rem;
|
||||
margin: 0;
|
||||
font-family: 'Monaco', 'Courier New', monospace;
|
||||
font-family: "Monaco", "Courier New", monospace;
|
||||
font-size: 0.95rem;
|
||||
line-height: 1.8;
|
||||
color: var(--terminal-text);
|
||||
@@ -1957,7 +2010,7 @@ a.license-badge:hover {
|
||||
}
|
||||
|
||||
.code-block code {
|
||||
font-family: 'Monaco', 'Menlo', monospace;
|
||||
font-family: "Monaco", "Menlo", monospace;
|
||||
font-size: 0.9rem;
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
@@ -2014,7 +2067,8 @@ a.license-badge:hover {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.tag-row, .manifest-row {
|
||||
.tag-row,
|
||||
.manifest-row {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
@@ -2103,7 +2157,7 @@ a.license-badge:hover {
|
||||
/* README and Repository Layout */
|
||||
.repo-content-layout {
|
||||
display: grid;
|
||||
grid-template-columns: 7fr 3fr;
|
||||
grid-template-columns: 6fr 4fr;
|
||||
gap: 2rem;
|
||||
margin-top: 2rem;
|
||||
}
|
||||
@@ -2214,7 +2268,8 @@ a.license-badge:hover {
|
||||
background: var(--code-bg);
|
||||
padding: 0.2rem 0.4rem;
|
||||
border-radius: 3px;
|
||||
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace;
|
||||
font-family:
|
||||
"SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
@@ -2318,3 +2373,59 @@ a.license-badge:hover {
|
||||
padding: 0.75rem;
|
||||
}
|
||||
}
|
||||
|
||||
/* 404 Error Page */
|
||||
.error-page {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-height: calc(100vh - 60px);
|
||||
text-align: center;
|
||||
padding: 2rem;
|
||||
}
|
||||
|
||||
.error-content {
|
||||
max-width: 480px;
|
||||
}
|
||||
|
||||
.error-icon {
|
||||
width: 80px;
|
||||
height: 80px;
|
||||
color: var(--secondary);
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.error-code {
|
||||
font-size: 8rem;
|
||||
font-weight: 700;
|
||||
color: var(--primary);
|
||||
line-height: 1;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.error-content h1 {
|
||||
font-size: 2rem;
|
||||
margin-bottom: 0.75rem;
|
||||
color: var(--fg);
|
||||
}
|
||||
|
||||
.error-content p {
|
||||
font-size: 1.125rem;
|
||||
color: var(--secondary);
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.error-code {
|
||||
font-size: 5rem;
|
||||
}
|
||||
|
||||
.error-icon {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
}
|
||||
|
||||
.error-content h1 {
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,6 +434,69 @@ function removeManifestElement(sanitizedId) {
|
||||
}
|
||||
}
|
||||
|
||||
// Upload repository avatar
|
||||
async function uploadAvatar(input, repository) {
|
||||
const file = input.files[0];
|
||||
if (!file) return;
|
||||
|
||||
// Client-side validation
|
||||
const validTypes = ['image/png', 'image/jpeg', 'image/webp'];
|
||||
if (!validTypes.includes(file.type)) {
|
||||
alert('Please select a PNG, JPEG, or WebP image');
|
||||
return;
|
||||
}
|
||||
if (file.size > 3 * 1024 * 1024) {
|
||||
alert('Image must be less than 3MB');
|
||||
return;
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('avatar', file);
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/images/${repository}/avatar`, {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
body: formData
|
||||
});
|
||||
|
||||
if (response.status === 401) {
|
||||
window.location.href = '/auth/oauth/login';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(error);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// Update the avatar image on the page
|
||||
const wrapper = document.querySelector('.repo-hero-icon-wrapper');
|
||||
if (!wrapper) return;
|
||||
|
||||
const existingImg = wrapper.querySelector('.repo-hero-icon');
|
||||
const placeholder = wrapper.querySelector('.repo-hero-icon-placeholder');
|
||||
|
||||
if (existingImg) {
|
||||
existingImg.src = data.avatarURL;
|
||||
} else if (placeholder) {
|
||||
const newImg = document.createElement('img');
|
||||
newImg.src = data.avatarURL;
|
||||
newImg.alt = repository;
|
||||
newImg.className = 'repo-hero-icon';
|
||||
placeholder.replaceWith(newImg);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Error uploading avatar:', err);
|
||||
alert('Failed to upload avatar: ' + err.message);
|
||||
}
|
||||
|
||||
// Clear input so same file can be selected again
|
||||
input.value = '';
|
||||
}
|
||||
|
||||
// Close modal when clicking outside
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const modal = document.getElementById('manifest-delete-modal');
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
// DatabaseMetrics interface for tracking pull/push counts and querying hold DIDs
|
||||
type DatabaseMetrics interface {
|
||||
IncrementPullCount(did, repository string) error
|
||||
IncrementPushCount(did, repository string) error
|
||||
GetLatestHoldDIDForRepo(did, repository string) (string, error)
|
||||
}
|
||||
|
||||
// ReadmeCache interface for README content caching
|
||||
type ReadmeCache interface {
|
||||
Get(ctx context.Context, url string) (string, error)
|
||||
Invalidate(url string) error
|
||||
}
|
||||
|
||||
// RegistryContext bundles all the context needed for registry operations
|
||||
// This includes both per-request data (DID, hold) and shared services
|
||||
type RegistryContext struct {
|
||||
// Per-request identity and routing information
|
||||
DID string // User's DID (e.g., "did:plc:abc123")
|
||||
Handle string // User's handle (e.g., "alice.bsky.social")
|
||||
HoldDID string // Hold service DID (e.g., "did:web:hold01.atcr.io")
|
||||
PDSEndpoint string // User's PDS endpoint URL
|
||||
Repository string // Image repository name (e.g., "debian")
|
||||
ServiceToken string // Service token for hold authentication (cached by middleware)
|
||||
ATProtoClient *atproto.Client // Authenticated ATProto client for this user
|
||||
AuthMethod string // Auth method used ("oauth" or "app_password")
|
||||
|
||||
// Shared services (same for all requests)
|
||||
Database DatabaseMetrics // Metrics tracking database
|
||||
Authorizer auth.HoldAuthorizer // Hold access authorization
|
||||
Refresher *oauth.Refresher // OAuth session manager
|
||||
ReadmeCache ReadmeCache // README content cache
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockDatabaseMetrics struct {
|
||||
mu sync.Mutex
|
||||
pullCount int
|
||||
pushCount int
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) IncrementPullCount(did, repository string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.pullCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) IncrementPushCount(did, repository string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.pushCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
||||
// Return empty string for mock - tests can override if needed
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) getPullCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.pullCount
|
||||
}
|
||||
|
||||
func (m *mockDatabaseMetrics) getPushCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.pushCount
|
||||
}
|
||||
|
||||
type mockReadmeCache struct{}
|
||||
|
||||
func (m *mockReadmeCache) Get(ctx context.Context, url string) (string, error) {
|
||||
return "# Test README", nil
|
||||
}
|
||||
|
||||
func (m *mockReadmeCache) Invalidate(url string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockHoldAuthorizer struct{}
|
||||
|
||||
func (m *mockHoldAuthorizer) Authorize(holdDID, userDID, permission string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func TestRegistryContext_Fields(t *testing.T) {
|
||||
// Create a sample RegistryContext
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Handle: "alice.bsky.social",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
PDSEndpoint: "https://bsky.social",
|
||||
Repository: "debian",
|
||||
ServiceToken: "test-token",
|
||||
ATProtoClient: &atproto.Client{
|
||||
// Mock client - would need proper initialization in real tests
|
||||
},
|
||||
Database: &mockDatabaseMetrics{},
|
||||
ReadmeCache: &mockReadmeCache{},
|
||||
}
|
||||
|
||||
// Verify fields are accessible
|
||||
if ctx.DID != "did:plc:test123" {
|
||||
t.Errorf("Expected DID %q, got %q", "did:plc:test123", ctx.DID)
|
||||
}
|
||||
if ctx.Handle != "alice.bsky.social" {
|
||||
t.Errorf("Expected Handle %q, got %q", "alice.bsky.social", ctx.Handle)
|
||||
}
|
||||
if ctx.HoldDID != "did:web:hold01.atcr.io" {
|
||||
t.Errorf("Expected HoldDID %q, got %q", "did:web:hold01.atcr.io", ctx.HoldDID)
|
||||
}
|
||||
if ctx.PDSEndpoint != "https://bsky.social" {
|
||||
t.Errorf("Expected PDSEndpoint %q, got %q", "https://bsky.social", ctx.PDSEndpoint)
|
||||
}
|
||||
if ctx.Repository != "debian" {
|
||||
t.Errorf("Expected Repository %q, got %q", "debian", ctx.Repository)
|
||||
}
|
||||
if ctx.ServiceToken != "test-token" {
|
||||
t.Errorf("Expected ServiceToken %q, got %q", "test-token", ctx.ServiceToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryContext_DatabaseInterface(t *testing.T) {
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := &RegistryContext{
|
||||
Database: db,
|
||||
}
|
||||
|
||||
// Test that interface methods are callable
|
||||
err := ctx.Database.IncrementPullCount("did:plc:test", "repo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
err = ctx.Database.IncrementPushCount("did:plc:test", "repo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryContext_ReadmeCacheInterface(t *testing.T) {
|
||||
cache := &mockReadmeCache{}
|
||||
ctx := &RegistryContext{
|
||||
ReadmeCache: cache,
|
||||
}
|
||||
|
||||
// Test that interface methods are callable
|
||||
content, err := ctx.ReadmeCache.Get(context.Background(), "https://example.com/README.md")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if content != "# Test README" {
|
||||
t.Errorf("Expected content %q, got %q", "# Test README", content)
|
||||
}
|
||||
|
||||
err = ctx.ReadmeCache.Invalidate("https://example.com/README.md")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add more comprehensive tests:
|
||||
// - Test ATProtoClient integration
|
||||
// - Test OAuth Refresher integration
|
||||
// - Test HoldAuthorizer integration
|
||||
// - Test nil handling for optional fields
|
||||
// - Integration tests with real components
|
||||
@@ -1,93 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
"atcr.io/pkg/auth/token"
|
||||
)
|
||||
|
||||
// EnsureCrewMembership attempts to register the user as a crew member on their default hold.
|
||||
// The hold's requestCrew endpoint handles all authorization logic (checking allowAllCrew, existing membership, etc).
|
||||
// This is best-effort and does not fail on errors.
|
||||
func EnsureCrewMembership(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, defaultHoldDID string) {
|
||||
if defaultHoldDID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize URL to DID if needed
|
||||
holdDID := atproto.ResolveHoldDIDFromURL(defaultHoldDID)
|
||||
if holdDID == "" {
|
||||
slog.Warn("failed to resolve hold DID", "defaultHold", defaultHoldDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
holdEndpoint := atproto.ResolveHoldURL(holdDID)
|
||||
|
||||
// Get service token for the hold
|
||||
// Only works with OAuth (refresher required) - app passwords can't get service tokens
|
||||
if refresher == nil {
|
||||
slog.Debug("skipping crew registration - no OAuth refresher (app password flow)", "holdDID", holdDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the refresher to match OAuthSessionRefresher interface
|
||||
serviceToken, err := token.GetOrFetchServiceToken(ctx, refresher, client.DID(), holdDID, client.PDSEndpoint())
|
||||
if err != nil {
|
||||
slog.Warn("failed to get service token", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Call requestCrew endpoint - it handles all the logic:
|
||||
// - Checks allowAllCrew flag
|
||||
// - Checks if already a crew member (returns success if so)
|
||||
// - Creates crew record if authorized
|
||||
if err := requestCrewMembership(ctx, holdEndpoint, serviceToken); err != nil {
|
||||
slog.Warn("failed to request crew membership", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("successfully registered as crew member", "holdDID", holdDID, "userDID", client.DID())
|
||||
}
|
||||
|
||||
// requestCrewMembership calls the hold's requestCrew endpoint
|
||||
// The endpoint handles all authorization and duplicate checking internally
|
||||
func requestCrewMembership(ctx context.Context, holdEndpoint, serviceToken string) error {
|
||||
// Add 5 second timeout to prevent hanging on offline holds
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
url := fmt.Sprintf("%s%s", holdEndpoint, atproto.HoldRequestCrew)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+serviceToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
// Read response body to capture actual error message from hold
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("requestCrew failed with status %d (failed to read error body: %w)", resp.StatusCode, readErr)
|
||||
}
|
||||
return fmt.Errorf("requestCrew failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnsureCrewMembership_EmptyHoldDID(t *testing.T) {
|
||||
// Test that empty hold DID returns early without error (best-effort function)
|
||||
EnsureCrewMembership(context.Background(), nil, nil, "")
|
||||
// If we get here without panic, test passes
|
||||
}
|
||||
|
||||
// TODO: Add comprehensive tests with HTTP client mocking
|
||||
@@ -3,6 +3,7 @@ package storage
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -10,9 +11,12 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/appview/readme"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/opencontainers/go-digest"
|
||||
)
|
||||
@@ -20,24 +24,24 @@ import (
|
||||
// ManifestStore implements distribution.ManifestService
|
||||
// It stores manifests in ATProto as records
|
||||
type ManifestStore struct {
|
||||
ctx *RegistryContext // Context with user/hold info
|
||||
mu sync.RWMutex // Protects lastFetchedHoldDID
|
||||
lastFetchedHoldDID string // Hold DID from most recently fetched manifest (for pull)
|
||||
ctx *auth.UserContext // User context with identity, target, permissions
|
||||
blobStore distribution.BlobStore // Blob store for fetching config during push
|
||||
sqlDB *sql.DB // Database for pull/push counts
|
||||
}
|
||||
|
||||
// NewManifestStore creates a new ATProto-backed manifest store
|
||||
func NewManifestStore(ctx *RegistryContext, blobStore distribution.BlobStore) *ManifestStore {
|
||||
func NewManifestStore(userCtx *auth.UserContext, blobStore distribution.BlobStore, sqlDB *sql.DB) *ManifestStore {
|
||||
return &ManifestStore{
|
||||
ctx: ctx,
|
||||
ctx: userCtx,
|
||||
blobStore: blobStore,
|
||||
sqlDB: sqlDB,
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a manifest exists by digest
|
||||
func (s *ManifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
|
||||
rkey := digestToRKey(dgst)
|
||||
_, err := s.ctx.ATProtoClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
_, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
if err != nil {
|
||||
// If not found, return false without error
|
||||
if errors.Is(err, atproto.ErrRecordNotFound) {
|
||||
@@ -51,37 +55,24 @@ func (s *ManifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, e
|
||||
// Get retrieves a manifest by digest
|
||||
func (s *ManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error) {
|
||||
rkey := digestToRKey(dgst)
|
||||
record, err := s.ctx.ATProtoClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
record, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
if err != nil {
|
||||
return nil, distribution.ErrManifestUnknownRevision{
|
||||
Name: s.ctx.Repository,
|
||||
Name: s.ctx.TargetRepo,
|
||||
Revision: dgst,
|
||||
}
|
||||
}
|
||||
|
||||
var manifestRecord atproto.Manifest
|
||||
var manifestRecord atproto.ManifestRecord
|
||||
if err := json.Unmarshal(record.Value, &manifestRecord); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal manifest record: %w", err)
|
||||
}
|
||||
|
||||
// Store the hold DID for subsequent blob requests during pull
|
||||
// Prefer HoldDid (new format) with fallback to HoldEndpoint (legacy URL format)
|
||||
// The routing repository will cache this for concurrent blob fetches
|
||||
s.mu.Lock()
|
||||
if manifestRecord.HoldDid != nil && *manifestRecord.HoldDid != "" {
|
||||
// New format: DID reference (preferred)
|
||||
s.lastFetchedHoldDID = *manifestRecord.HoldDid
|
||||
} else if manifestRecord.HoldEndpoint != nil && *manifestRecord.HoldEndpoint != "" {
|
||||
// Legacy format: URL reference - convert to DID
|
||||
s.lastFetchedHoldDID = atproto.ResolveHoldDIDFromURL(*manifestRecord.HoldEndpoint)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
var ociManifest []byte
|
||||
|
||||
// New records: Download blob from ATProto blob storage
|
||||
if manifestRecord.ManifestBlob != nil && manifestRecord.ManifestBlob.Ref.Defined() {
|
||||
ociManifest, err = s.ctx.ATProtoClient.GetBlob(ctx, manifestRecord.ManifestBlob.Ref.String())
|
||||
if manifestRecord.ManifestBlob != nil && manifestRecord.ManifestBlob.Ref.Link != "" {
|
||||
ociManifest, err = s.ctx.GetATProtoClient().GetBlob(ctx, manifestRecord.ManifestBlob.Ref.Link)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download manifest blob: %w", err)
|
||||
}
|
||||
@@ -89,12 +80,12 @@ func (s *ManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...
|
||||
|
||||
// Track pull count (increment asynchronously to avoid blocking the response)
|
||||
// Only count GET requests (actual downloads), not HEAD requests (existence checks)
|
||||
if s.ctx.Database != nil {
|
||||
if s.sqlDB != nil {
|
||||
// Check HTTP method from context (distribution library stores it as "http.request.method")
|
||||
if method, ok := ctx.Value("http.request.method").(string); ok && method == "GET" {
|
||||
go func() {
|
||||
if err := s.ctx.Database.IncrementPullCount(s.ctx.DID, s.ctx.Repository); err != nil {
|
||||
slog.Warn("Failed to increment pull count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
|
||||
if err := db.IncrementPullCount(s.sqlDB, s.ctx.TargetOwnerDID, s.ctx.TargetRepo); err != nil {
|
||||
slog.Warn("Failed to increment pull count", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -121,22 +112,20 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
dgst := digest.FromBytes(payload)
|
||||
|
||||
// Upload manifest as blob to PDS
|
||||
blobRef, err := s.ctx.ATProtoClient.UploadBlob(ctx, payload, mediaType)
|
||||
blobRef, err := s.ctx.GetATProtoClient().UploadBlob(ctx, payload, mediaType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload manifest blob: %w", err)
|
||||
}
|
||||
|
||||
// Create manifest record with structured metadata
|
||||
manifestRecord, err := atproto.NewManifestRecord(s.ctx.Repository, dgst.String(), payload)
|
||||
manifestRecord, err := atproto.NewManifestRecord(s.ctx.TargetRepo, dgst.String(), payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create manifest record: %w", err)
|
||||
}
|
||||
|
||||
// Set the blob reference, hold DID, and hold endpoint
|
||||
manifestRecord.ManifestBlob = blobRef
|
||||
if s.ctx.HoldDID != "" {
|
||||
manifestRecord.HoldDid = &s.ctx.HoldDID // Primary reference (DID)
|
||||
}
|
||||
manifestRecord.HoldDID = s.ctx.TargetHoldDID // Primary reference (DID)
|
||||
|
||||
// Extract Dockerfile labels from config blob and add to annotations
|
||||
// Only for image manifests (not manifest lists which don't have config blobs)
|
||||
@@ -163,10 +152,10 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
if !exists {
|
||||
platform := "unknown"
|
||||
if ref.Platform != nil {
|
||||
platform = fmt.Sprintf("%s/%s", ref.Platform.Os, ref.Platform.Architecture)
|
||||
platform = fmt.Sprintf("%s/%s", ref.Platform.OS, ref.Platform.Architecture)
|
||||
}
|
||||
slog.Warn("Manifest list references non-existent child manifest",
|
||||
"repository", s.ctx.Repository,
|
||||
"repository", s.ctx.TargetRepo,
|
||||
"missingDigest", ref.Digest,
|
||||
"platform", platform)
|
||||
return "", distribution.ErrManifestBlobUnknown{Digest: refDigest}
|
||||
@@ -174,24 +163,43 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Label extraction from config blob is currently disabled because the generated
|
||||
// Manifest_Annotations type doesn't support arbitrary keys. The lexicon schema would
|
||||
// need to use "unknown" type for annotations to support dynamic key-value pairs.
|
||||
// TODO: Update lexicon schema if label extraction is needed.
|
||||
_ = isManifestList // silence unused variable warning for now
|
||||
if !isManifestList && s.blobStore != nil && manifestRecord.Config != nil && manifestRecord.Config.Digest != "" {
|
||||
labels, err := s.extractConfigLabels(ctx, manifestRecord.Config.Digest)
|
||||
if err != nil {
|
||||
// Log error but don't fail the push - labels are optional
|
||||
slog.Warn("Failed to extract config labels", "error", err)
|
||||
} else if len(labels) > 0 {
|
||||
// Initialize annotations map if needed
|
||||
if manifestRecord.Annotations == nil {
|
||||
manifestRecord.Annotations = make(map[string]string)
|
||||
}
|
||||
|
||||
// Copy labels to annotations as fallback
|
||||
// Only set label values for keys NOT already in manifest annotations
|
||||
// This ensures explicit annotations take precedence over Dockerfile LABELs
|
||||
// (which may be inherited from base images)
|
||||
for key, value := range labels {
|
||||
if _, exists := manifestRecord.Annotations[key]; !exists {
|
||||
manifestRecord.Annotations[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("Merged labels from config blob", "labelsCount", len(labels), "annotationsCount", len(manifestRecord.Annotations))
|
||||
}
|
||||
}
|
||||
|
||||
// Store manifest record in ATProto
|
||||
rkey := digestToRKey(dgst)
|
||||
_, err = s.ctx.ATProtoClient.PutRecord(ctx, atproto.ManifestCollection, rkey, manifestRecord)
|
||||
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.ManifestCollection, rkey, manifestRecord)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store manifest record in ATProto: %w", err)
|
||||
}
|
||||
|
||||
// Track push count (increment asynchronously to avoid blocking the response)
|
||||
if s.ctx.Database != nil {
|
||||
if s.sqlDB != nil {
|
||||
go func() {
|
||||
if err := s.ctx.Database.IncrementPushCount(s.ctx.DID, s.ctx.Repository); err != nil {
|
||||
slog.Warn("Failed to increment push count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
|
||||
if err := db.IncrementPushCount(s.sqlDB, s.ctx.TargetOwnerDID, s.ctx.TargetRepo); err != nil {
|
||||
slog.Warn("Failed to increment push count", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -201,9 +209,9 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
for _, option := range options {
|
||||
if tagOpt, ok := option.(distribution.WithTagOption); ok {
|
||||
tag = tagOpt.Tag
|
||||
tagRecord := atproto.NewTagRecord(s.ctx.ATProtoClient.DID(), s.ctx.Repository, tag, dgst.String())
|
||||
tagRKey := atproto.RepositoryTagToRKey(s.ctx.Repository, tag)
|
||||
_, err = s.ctx.ATProtoClient.PutRecord(ctx, atproto.TagCollection, tagRKey, tagRecord)
|
||||
tagRecord := atproto.NewTagRecord(s.ctx.GetATProtoClient().DID(), s.ctx.TargetRepo, tag, dgst.String())
|
||||
tagRKey := atproto.RepositoryTagToRKey(s.ctx.TargetRepo, tag)
|
||||
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.TagCollection, tagRKey, tagRecord)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store tag in ATProto: %w", err)
|
||||
}
|
||||
@@ -212,28 +220,30 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
|
||||
// Notify hold about manifest upload (for layer tracking and Bluesky posts)
|
||||
// Do this asynchronously to avoid blocking the push
|
||||
if tag != "" && s.ctx.ServiceToken != "" && s.ctx.Handle != "" {
|
||||
go func() {
|
||||
// Get service token before goroutine (requires context)
|
||||
serviceToken, _ := s.ctx.GetServiceToken(ctx)
|
||||
if tag != "" && serviceToken != "" && s.ctx.TargetOwnerHandle != "" {
|
||||
go func(serviceToken string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("Panic in notifyHoldAboutManifest", "panic", r)
|
||||
}
|
||||
}()
|
||||
if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String()); err != nil {
|
||||
if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String(), serviceToken); err != nil {
|
||||
slog.Warn("Failed to notify hold about manifest", "error", err)
|
||||
}
|
||||
}()
|
||||
}(serviceToken)
|
||||
}
|
||||
|
||||
// Refresh README cache asynchronously if manifest has io.atcr.readme annotation
|
||||
// This ensures fresh README content is available on repository pages
|
||||
// Create or update repo page asynchronously if manifest has relevant annotations
|
||||
// This ensures repository metadata is synced to user's PDS
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("Panic in refreshReadmeCache", "panic", r)
|
||||
slog.Error("Panic in ensureRepoPage", "panic", r)
|
||||
}
|
||||
}()
|
||||
s.refreshReadmeCache(context.Background(), manifestRecord)
|
||||
s.ensureRepoPage(context.Background(), manifestRecord)
|
||||
}()
|
||||
|
||||
return dgst, nil
|
||||
@@ -242,7 +252,7 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
|
||||
// Delete removes a manifest
|
||||
func (s *ManifestStore) Delete(ctx context.Context, dgst digest.Digest) error {
|
||||
rkey := digestToRKey(dgst)
|
||||
return s.ctx.ATProtoClient.DeleteRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
return s.ctx.GetATProtoClient().DeleteRecord(ctx, atproto.ManifestCollection, rkey)
|
||||
}
|
||||
|
||||
// digestToRKey converts a digest to an ATProto record key
|
||||
@@ -252,14 +262,6 @@ func digestToRKey(dgst digest.Digest) string {
|
||||
return dgst.Encoded()
|
||||
}
|
||||
|
||||
// GetLastFetchedHoldDID returns the hold DID from the most recently fetched manifest
|
||||
// This is used by the routing repository to cache the hold for blob requests
|
||||
func (s *ManifestStore) GetLastFetchedHoldDID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.lastFetchedHoldDID
|
||||
}
|
||||
|
||||
// rawManifest is a simple implementation of distribution.Manifest
|
||||
type rawManifest struct {
|
||||
mediaType string
|
||||
@@ -305,18 +307,17 @@ func (s *ManifestStore) extractConfigLabels(ctx context.Context, configDigestStr
|
||||
|
||||
// notifyHoldAboutManifest notifies the hold service about a manifest upload
|
||||
// This enables the hold to create layer records and Bluesky posts
|
||||
func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.Manifest, tag, manifestDigest string) error {
|
||||
// Skip if no service token configured (e.g., anonymous pulls)
|
||||
if s.ctx.ServiceToken == "" {
|
||||
func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.ManifestRecord, tag, manifestDigest, serviceToken string) error {
|
||||
// Skip if no service token provided
|
||||
if serviceToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
// For did:web, this is straightforward (e.g., did:web:hold01.atcr.io → https://hold01.atcr.io)
|
||||
holdEndpoint := atproto.ResolveHoldURL(s.ctx.HoldDID)
|
||||
holdEndpoint := atproto.ResolveHoldURL(s.ctx.TargetHoldDID)
|
||||
|
||||
// Use service token from middleware (already cached and validated)
|
||||
serviceToken := s.ctx.ServiceToken
|
||||
// Service token is passed in (already cached and validated)
|
||||
|
||||
// Build notification request
|
||||
manifestData := map[string]any{
|
||||
@@ -355,7 +356,7 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
|
||||
}
|
||||
if m.Platform != nil {
|
||||
mData["platform"] = map[string]any{
|
||||
"os": m.Platform.Os,
|
||||
"os": m.Platform.OS,
|
||||
"architecture": m.Platform.Architecture,
|
||||
}
|
||||
}
|
||||
@@ -365,10 +366,10 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
|
||||
}
|
||||
|
||||
notifyReq := map[string]any{
|
||||
"repository": s.ctx.Repository,
|
||||
"repository": s.ctx.TargetRepo,
|
||||
"tag": tag,
|
||||
"userDid": s.ctx.DID,
|
||||
"userHandle": s.ctx.Handle,
|
||||
"userDid": s.ctx.TargetOwnerDID,
|
||||
"userHandle": s.ctx.TargetOwnerHandle,
|
||||
"manifest": manifestData,
|
||||
}
|
||||
|
||||
@@ -406,24 +407,251 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
|
||||
// Parse response (optional logging)
|
||||
var notifyResp map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(¬ifyResp); err == nil {
|
||||
slog.Info("Hold notification successful", "repository", s.ctx.Repository, "tag", tag, "response", notifyResp)
|
||||
slog.Info("Hold notification successful", "repository", s.ctx.TargetRepo, "tag", tag, "response", notifyResp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshReadmeCache refreshes the README cache for this manifest if it has io.atcr.readme annotation
|
||||
// This should be called asynchronously after manifest push to keep README content fresh
|
||||
// NOTE: Currently disabled because the generated Manifest_Annotations type doesn't support
|
||||
// arbitrary key-value pairs. Would need to update lexicon schema with "unknown" type.
|
||||
func (s *ManifestStore) refreshReadmeCache(ctx context.Context, manifestRecord *atproto.Manifest) {
|
||||
// Skip if no README cache configured
|
||||
if s.ctx.ReadmeCache == nil {
|
||||
// ensureRepoPage creates or updates a repo page record in the user's PDS if needed
|
||||
// This syncs repository metadata from manifest annotations to the io.atcr.repo.page collection
|
||||
// Only creates a new record if one doesn't exist (doesn't overwrite user's custom content)
|
||||
func (s *ManifestStore) ensureRepoPage(ctx context.Context, manifestRecord *atproto.ManifestRecord) {
|
||||
// Check if repo page already exists (don't overwrite user's custom content)
|
||||
rkey := s.ctx.TargetRepo
|
||||
_, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.RepoPageCollection, rkey)
|
||||
if err == nil {
|
||||
// Record already exists - don't overwrite
|
||||
slog.Debug("Repo page already exists, skipping creation", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Re-enable once lexicon supports annotations as map[string]string
|
||||
// The generated Manifest_Annotations is an empty struct that doesn't support map access.
|
||||
// For now, README cache refresh on push is disabled.
|
||||
_ = manifestRecord // silence unused variable warning
|
||||
// Only continue if it's a "not found" error - other errors mean we should skip
|
||||
if !errors.Is(err, atproto.ErrRecordNotFound) {
|
||||
slog.Warn("Failed to check for existing repo page", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get annotations (may be nil if image has no OCI labels)
|
||||
annotations := manifestRecord.Annotations
|
||||
if annotations == nil {
|
||||
annotations = make(map[string]string)
|
||||
}
|
||||
|
||||
// Try to fetch README content from external sources
|
||||
// Priority: io.atcr.readme annotation > derived from org.opencontainers.image.source > org.opencontainers.image.description
|
||||
description := s.fetchReadmeContent(ctx, annotations)
|
||||
|
||||
// If no README content could be fetched, fall back to description annotation
|
||||
if description == "" {
|
||||
description = annotations["org.opencontainers.image.description"]
|
||||
}
|
||||
|
||||
// Try to fetch and upload icon from io.atcr.icon annotation
|
||||
var avatarRef *atproto.ATProtoBlobRef
|
||||
if iconURL := annotations["io.atcr.icon"]; iconURL != "" {
|
||||
avatarRef = s.fetchAndUploadIcon(ctx, iconURL)
|
||||
}
|
||||
|
||||
// Create new repo page record with description and optional avatar
|
||||
repoPage := atproto.NewRepoPageRecord(s.ctx.TargetRepo, description, avatarRef)
|
||||
|
||||
slog.Info("Creating repo page from manifest annotations", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "descriptionLength", len(description), "hasAvatar", avatarRef != nil)
|
||||
|
||||
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.RepoPageCollection, rkey, repoPage)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to create repo page", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Repo page created successfully", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo)
|
||||
}
|
||||
|
||||
// fetchReadmeContent attempts to fetch README content from external sources
|
||||
// Priority: io.atcr.readme annotation > derived from org.opencontainers.image.source
|
||||
// Returns the raw markdown content, or empty string if not available
|
||||
func (s *ManifestStore) fetchReadmeContent(ctx context.Context, annotations map[string]string) string {
|
||||
|
||||
// Create a context with timeout for README fetching (don't block push too long)
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Priority 1: Direct README URL from io.atcr.readme annotation
|
||||
if readmeURL := annotations["io.atcr.readme"]; readmeURL != "" {
|
||||
content, err := s.fetchRawReadme(fetchCtx, readmeURL)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to fetch README from io.atcr.readme annotation", "url", readmeURL, "error", err)
|
||||
} else if content != "" {
|
||||
slog.Info("Fetched README from io.atcr.readme annotation", "url", readmeURL, "length", len(content))
|
||||
return content
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Derive README URL from org.opencontainers.image.source
|
||||
if sourceURL := annotations["org.opencontainers.image.source"]; sourceURL != "" {
|
||||
// Try main branch first, then master
|
||||
for _, branch := range []string{"main", "master"} {
|
||||
readmeURL := readme.DeriveReadmeURL(sourceURL, branch)
|
||||
if readmeURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := s.fetchRawReadme(fetchCtx, readmeURL)
|
||||
if err != nil {
|
||||
// Only log non-404 errors (404 is expected when trying main vs master)
|
||||
if !readme.Is404(err) {
|
||||
slog.Debug("Failed to fetch README from source URL", "url", readmeURL, "branch", branch, "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
slog.Info("Fetched README from source URL", "sourceURL", sourceURL, "branch", branch, "length", len(content))
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// fetchRawReadme fetches raw markdown content from a URL
|
||||
// Returns the raw markdown (not rendered HTML) for storage in the repo page record
|
||||
func (s *ManifestStore) fetchRawReadme(ctx context.Context, readmeURL string) (string, error) {
|
||||
// Use a simple HTTP client to fetch raw content
|
||||
// We want raw markdown, not rendered HTML (the Fetcher renders to HTML)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", readmeURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "ATCR-README-Fetcher/1.0")
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Limit content size to 100KB (repo page description has 100KB limit in lexicon)
|
||||
limitedReader := io.LimitReader(resp.Body, 100*1024)
|
||||
content, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
// fetchAndUploadIcon fetches an image from a URL and uploads it as a blob to the user's PDS
|
||||
// Returns the blob reference for use in the repo page record, or nil on error
|
||||
func (s *ManifestStore) fetchAndUploadIcon(ctx context.Context, iconURL string) *atproto.ATProtoBlobRef {
|
||||
// Create a context with timeout for icon fetching
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Fetch the icon
|
||||
req, err := http.NewRequestWithContext(fetchCtx, "GET", iconURL, nil)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to create icon request", "url", iconURL, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "ATCR-Icon-Fetcher/1.0")
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to fetch icon", "url", iconURL, "error", err)
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
slog.Debug("Icon fetch returned non-OK status", "url", iconURL, "status", resp.StatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate content type - only allow images
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
mimeType := detectImageMimeType(contentType, iconURL)
|
||||
if mimeType == "" {
|
||||
slog.Debug("Icon has unsupported content type", "url", iconURL, "contentType", contentType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Limit icon size to 3MB (matching lexicon maxSize)
|
||||
limitedReader := io.LimitReader(resp.Body, 3*1024*1024)
|
||||
iconData, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to read icon data", "url", iconURL, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(iconData) == 0 {
|
||||
slog.Debug("Icon data is empty", "url", iconURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Upload the icon as a blob to the user's PDS
|
||||
blobRef, err := s.ctx.GetATProtoClient().UploadBlob(ctx, iconData, mimeType)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to upload icon blob", "url", iconURL, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Info("Uploaded icon blob", "url", iconURL, "size", len(iconData), "mimeType", mimeType, "cid", blobRef.Ref.Link)
|
||||
return blobRef
|
||||
}
|
||||
|
||||
// detectImageMimeType determines the MIME type for an image
|
||||
// Uses Content-Type header first, then falls back to extension-based detection
|
||||
// Only allows types accepted by the lexicon: image/png, image/jpeg, image/webp
|
||||
func detectImageMimeType(contentType, url string) string {
|
||||
// Check Content-Type header first
|
||||
switch {
|
||||
case strings.HasPrefix(contentType, "image/png"):
|
||||
return "image/png"
|
||||
case strings.HasPrefix(contentType, "image/jpeg"):
|
||||
return "image/jpeg"
|
||||
case strings.HasPrefix(contentType, "image/webp"):
|
||||
return "image/webp"
|
||||
}
|
||||
|
||||
// Fall back to URL extension detection
|
||||
lowerURL := strings.ToLower(url)
|
||||
switch {
|
||||
case strings.HasSuffix(lowerURL, ".png"):
|
||||
return "image/png"
|
||||
case strings.HasSuffix(lowerURL, ".jpg"), strings.HasSuffix(lowerURL, ".jpeg"):
|
||||
return "image/jpeg"
|
||||
case strings.HasSuffix(lowerURL, ".webp"):
|
||||
return "image/webp"
|
||||
}
|
||||
|
||||
// Unknown or unsupported type - reject
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -8,15 +8,13 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/opencontainers/go-digest"
|
||||
)
|
||||
|
||||
// mockDatabaseMetrics removed - using the one from context_test.go
|
||||
|
||||
// mockBlobStore is a minimal mock of distribution.BlobStore for testing
|
||||
type mockBlobStore struct {
|
||||
blobs map[digest.Digest][]byte
|
||||
@@ -72,16 +70,11 @@ func (m *mockBlobStore) Open(ctx context.Context, dgst digest.Digest) (io.ReadSe
|
||||
return nil, nil // Not needed for current tests
|
||||
}
|
||||
|
||||
// mockRegistryContext creates a mock RegistryContext for testing
|
||||
func mockRegistryContext(client *atproto.Client, repository, holdDID, did, handle string, database DatabaseMetrics) *RegistryContext {
|
||||
return &RegistryContext{
|
||||
ATProtoClient: client,
|
||||
Repository: repository,
|
||||
HoldDID: holdDID,
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
Database: database,
|
||||
}
|
||||
// mockUserContextForManifest creates a mock auth.UserContext for manifest store testing
|
||||
func mockUserContextForManifest(pdsEndpoint, repository, holdDID, ownerDID, ownerHandle string) *auth.UserContext {
|
||||
userCtx := auth.NewUserContext(ownerDID, "oauth", "PUT", nil)
|
||||
userCtx.SetTarget(ownerDID, ownerHandle, pdsEndpoint, repository, holdDID)
|
||||
return userCtx
|
||||
}
|
||||
|
||||
// TestDigestToRKey tests digest to record key conversion
|
||||
@@ -115,82 +108,27 @@ func TestDigestToRKey(t *testing.T) {
|
||||
|
||||
// TestNewManifestStore tests creating a new manifest store
|
||||
func TestNewManifestStore(t *testing.T) {
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
blobStore := newMockBlobStore()
|
||||
db := &mockDatabaseMetrics{}
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:alice123",
|
||||
"alice.test",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
|
||||
if store.ctx.Repository != "myapp" {
|
||||
t.Errorf("repository = %v, want myapp", store.ctx.Repository)
|
||||
if store.ctx.TargetRepo != "myapp" {
|
||||
t.Errorf("repository = %v, want myapp", store.ctx.TargetRepo)
|
||||
}
|
||||
if store.ctx.HoldDID != "did:web:hold.example.com" {
|
||||
t.Errorf("holdDID = %v, want did:web:hold.example.com", store.ctx.HoldDID)
|
||||
if store.ctx.TargetHoldDID != "did:web:hold.example.com" {
|
||||
t.Errorf("holdDID = %v, want did:web:hold.example.com", store.ctx.TargetHoldDID)
|
||||
}
|
||||
if store.ctx.DID != "did:plc:alice123" {
|
||||
t.Errorf("did = %v, want did:plc:alice123", store.ctx.DID)
|
||||
if store.ctx.TargetOwnerDID != "did:plc:alice123" {
|
||||
t.Errorf("did = %v, want did:plc:alice123", store.ctx.TargetOwnerDID)
|
||||
}
|
||||
if store.ctx.Handle != "alice.test" {
|
||||
t.Errorf("handle = %v, want alice.test", store.ctx.Handle)
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_GetLastFetchedHoldDID tests tracking last fetched hold DID
|
||||
func TestManifestStore_GetLastFetchedHoldDID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
manifestHoldDID string
|
||||
manifestHoldURL string
|
||||
expectedLastFetched string
|
||||
}{
|
||||
{
|
||||
name: "prefers HoldDID",
|
||||
manifestHoldDID: "did:web:hold01.atcr.io",
|
||||
manifestHoldURL: "https://hold01.atcr.io",
|
||||
expectedLastFetched: "did:web:hold01.atcr.io",
|
||||
},
|
||||
{
|
||||
name: "falls back to HoldEndpoint URL conversion",
|
||||
manifestHoldDID: "",
|
||||
manifestHoldURL: "https://hold02.atcr.io",
|
||||
expectedLastFetched: "did:web:hold02.atcr.io",
|
||||
},
|
||||
{
|
||||
name: "empty hold references",
|
||||
manifestHoldDID: "",
|
||||
manifestHoldURL: "",
|
||||
expectedLastFetched: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
|
||||
// Simulate what happens in Get() when parsing a manifest record
|
||||
var manifestRecord atproto.Manifest
|
||||
if tt.manifestHoldDID != "" {
|
||||
manifestRecord.HoldDid = &tt.manifestHoldDID
|
||||
}
|
||||
if tt.manifestHoldURL != "" {
|
||||
manifestRecord.HoldEndpoint = &tt.manifestHoldURL
|
||||
}
|
||||
|
||||
// Mimic the hold DID extraction logic from Get()
|
||||
if manifestRecord.HoldDid != nil && *manifestRecord.HoldDid != "" {
|
||||
store.lastFetchedHoldDID = *manifestRecord.HoldDid
|
||||
} else if manifestRecord.HoldEndpoint != nil && *manifestRecord.HoldEndpoint != "" {
|
||||
store.lastFetchedHoldDID = atproto.ResolveHoldDIDFromURL(*manifestRecord.HoldEndpoint)
|
||||
}
|
||||
|
||||
got := store.GetLastFetchedHoldDID()
|
||||
if got != tt.expectedLastFetched {
|
||||
t.Errorf("GetLastFetchedHoldDID() = %v, want %v", got, tt.expectedLastFetched)
|
||||
}
|
||||
})
|
||||
if store.ctx.TargetOwnerHandle != "alice.test" {
|
||||
t.Errorf("handle = %v, want alice.test", store.ctx.TargetOwnerHandle)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,9 +183,14 @@ func TestExtractConfigLabels(t *testing.T) {
|
||||
blobStore.blobs[configDigest] = configData
|
||||
|
||||
// Create manifest store
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
// Extract labels
|
||||
labels, err := store.extractConfigLabels(context.Background(), configDigest.String())
|
||||
@@ -285,9 +228,14 @@ func TestExtractConfigLabels_NoLabels(t *testing.T) {
|
||||
configDigest := digest.FromBytes(configData)
|
||||
blobStore.blobs[configDigest] = configData
|
||||
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
labels, err := store.extractConfigLabels(context.Background(), configDigest.String())
|
||||
if err != nil {
|
||||
@@ -303,9 +251,14 @@ func TestExtractConfigLabels_NoLabels(t *testing.T) {
|
||||
// TestExtractConfigLabels_InvalidDigest tests error handling for invalid digest
|
||||
func TestExtractConfigLabels_InvalidDigest(t *testing.T) {
|
||||
blobStore := newMockBlobStore()
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
_, err := store.extractConfigLabels(context.Background(), "invalid-digest")
|
||||
if err == nil {
|
||||
@@ -322,9 +275,14 @@ func TestExtractConfigLabels_InvalidJSON(t *testing.T) {
|
||||
configDigest := digest.FromBytes(configData)
|
||||
blobStore.blobs[configDigest] = configData
|
||||
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
_, err := store.extractConfigLabels(context.Background(), configDigest.String())
|
||||
if err == nil {
|
||||
@@ -332,28 +290,18 @@ func TestExtractConfigLabels_InvalidJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_WithMetrics tests that metrics are tracked
|
||||
func TestManifestStore_WithMetrics(t *testing.T) {
|
||||
db := &mockDatabaseMetrics{}
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
// TestManifestStore_WithoutDatabase tests that nil database is acceptable
|
||||
func TestManifestStore_WithoutDatabase(t *testing.T) {
|
||||
userCtx := mockUserContextForManifest(
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:alice123",
|
||||
"alice.test",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
if store.ctx.Database != db {
|
||||
t.Error("ManifestStore should store database reference")
|
||||
}
|
||||
|
||||
// Note: Actual metrics tracking happens in Put() and Get() which require
|
||||
// full mock setup. The important thing is that the database is wired up.
|
||||
}
|
||||
|
||||
// TestManifestStore_WithoutMetrics tests that nil database is acceptable
|
||||
func TestManifestStore_WithoutMetrics(t *testing.T) {
|
||||
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
|
||||
if store.ctx.Database != nil {
|
||||
if store.sqlDB != nil {
|
||||
t.Error("ManifestStore should accept nil database")
|
||||
}
|
||||
}
|
||||
@@ -372,7 +320,7 @@ func TestManifestStore_Exists(t *testing.T) {
|
||||
name: "manifest exists",
|
||||
digest: "sha256:abc123",
|
||||
serverStatus: http.StatusOK,
|
||||
serverResp: `{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","value":{}}`,
|
||||
serverResp: `{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafytest","value":{}}`,
|
||||
wantExists: true,
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -403,9 +351,14 @@ func TestManifestStore_Exists(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
exists, err := store.Exists(context.Background(), tt.digest)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -437,7 +390,7 @@ func TestManifestStore_Get(t *testing.T) {
|
||||
digest: "sha256:abc123",
|
||||
serverResp: `{
|
||||
"uri":"at://did:plc:test123/io.atcr.manifest/abc123",
|
||||
"cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku",
|
||||
"cid":"bafytest",
|
||||
"value":{
|
||||
"$type":"io.atcr.manifest",
|
||||
"repository":"myapp",
|
||||
@@ -447,7 +400,7 @@ func TestManifestStore_Get(t *testing.T) {
|
||||
"mediaType":"application/vnd.oci.image.manifest.v1+json",
|
||||
"manifestBlob":{
|
||||
"$type":"blob",
|
||||
"ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},
|
||||
"ref":{"$link":"bafytest"},
|
||||
"mimeType":"application/vnd.oci.image.manifest.v1+json",
|
||||
"size":100
|
||||
}
|
||||
@@ -481,9 +434,7 @@ func TestManifestStore_Get(t *testing.T) {
|
||||
"holdEndpoint":"https://hold02.atcr.io",
|
||||
"mediaType":"application/vnd.oci.image.manifest.v1+json",
|
||||
"manifestBlob":{
|
||||
"$type":"blob",
|
||||
"ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},
|
||||
"mimeType":"application/json",
|
||||
"ref":{"$link":"bafylegacy"},
|
||||
"size":100
|
||||
}
|
||||
}
|
||||
@@ -523,10 +474,14 @@ func TestManifestStore_Get(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
manifest, err := store.Get(context.Background(), tt.digest)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -547,148 +502,6 @@ func TestManifestStore_Get(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_Get_HoldDIDTracking tests that Get() stores the holdDID
|
||||
func TestManifestStore_Get_HoldDIDTracking(t *testing.T) {
|
||||
ociManifest := []byte(`{"schemaVersion":2}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
manifestResp string
|
||||
expectedHoldDID string
|
||||
}{
|
||||
{
|
||||
name: "tracks HoldDID from new format",
|
||||
manifestResp: `{
|
||||
"uri":"at://did:plc:test123/io.atcr.manifest/abc123",
|
||||
"value":{
|
||||
"$type":"io.atcr.manifest",
|
||||
"holdDid":"did:web:hold01.atcr.io",
|
||||
"holdEndpoint":"https://hold01.atcr.io",
|
||||
"mediaType":"application/vnd.oci.image.manifest.v1+json",
|
||||
"manifestBlob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}
|
||||
}
|
||||
}`,
|
||||
expectedHoldDID: "did:web:hold01.atcr.io",
|
||||
},
|
||||
{
|
||||
name: "tracks HoldDID from legacy HoldEndpoint",
|
||||
manifestResp: `{
|
||||
"uri":"at://did:plc:test123/io.atcr.manifest/abc123",
|
||||
"value":{
|
||||
"$type":"io.atcr.manifest",
|
||||
"holdEndpoint":"https://hold02.atcr.io",
|
||||
"mediaType":"application/vnd.oci.image.manifest.v1+json",
|
||||
"manifestBlob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}
|
||||
}
|
||||
}`,
|
||||
expectedHoldDID: "did:web:hold02.atcr.io",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == atproto.SyncGetBlob {
|
||||
w.Write(ociManifest)
|
||||
return
|
||||
}
|
||||
w.Write([]byte(tt.manifestResp))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
|
||||
_, err := store.Get(context.Background(), "sha256:abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
|
||||
gotHoldDID := store.GetLastFetchedHoldDID()
|
||||
if gotHoldDID != tt.expectedHoldDID {
|
||||
t.Errorf("GetLastFetchedHoldDID() = %v, want %v", gotHoldDID, tt.expectedHoldDID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_Get_OnlyCountsGETRequests verifies that HEAD requests don't increment pull count
|
||||
func TestManifestStore_Get_OnlyCountsGETRequests(t *testing.T) {
|
||||
ociManifest := []byte(`{"schemaVersion":2}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
httpMethod string
|
||||
expectPullIncrement bool
|
||||
}{
|
||||
{
|
||||
name: "GET request increments pull count",
|
||||
httpMethod: "GET",
|
||||
expectPullIncrement: true,
|
||||
},
|
||||
{
|
||||
name: "HEAD request does not increment pull count",
|
||||
httpMethod: "HEAD",
|
||||
expectPullIncrement: false,
|
||||
},
|
||||
{
|
||||
name: "POST request does not increment pull count",
|
||||
httpMethod: "POST",
|
||||
expectPullIncrement: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == atproto.SyncGetBlob {
|
||||
w.Write(ociManifest)
|
||||
return
|
||||
}
|
||||
w.Write([]byte(`{
|
||||
"uri": "at://did:plc:test123/io.atcr.manifest/abc123",
|
||||
"value": {
|
||||
"$type":"io.atcr.manifest",
|
||||
"holdDid":"did:web:hold01.atcr.io",
|
||||
"mediaType":"application/vnd.oci.image.manifest.v1+json",
|
||||
"manifestBlob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}
|
||||
}
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
mockDB := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold01.atcr.io", "did:plc:test123", "test.handle", mockDB)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
|
||||
// Create a context with the HTTP method stored (as distribution library does)
|
||||
testCtx := context.WithValue(context.Background(), "http.request.method", tt.httpMethod)
|
||||
|
||||
_, err := store.Get(testCtx, "sha256:abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
|
||||
// Wait for async goroutine to complete (metrics are incremented asynchronously)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if tt.expectPullIncrement {
|
||||
// Check that IncrementPullCount was called
|
||||
if mockDB.getPullCount() == 0 {
|
||||
t.Error("Expected pull count to be incremented for GET request, but it wasn't")
|
||||
}
|
||||
} else {
|
||||
// Check that IncrementPullCount was NOT called
|
||||
if mockDB.getPullCount() > 0 {
|
||||
t.Errorf("Expected pull count NOT to be incremented for %s request, but it was (count=%d)", tt.httpMethod, mockDB.getPullCount())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestManifestStore_Put tests storing manifests
|
||||
func TestManifestStore_Put(t *testing.T) {
|
||||
ociManifest := []byte(`{
|
||||
@@ -760,7 +573,7 @@ func TestManifestStore_Put(t *testing.T) {
|
||||
// Handle uploadBlob
|
||||
if r.URL.Path == atproto.RepoUploadBlob {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}}`))
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"mimeType":"application/json","size":100}}`))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -769,7 +582,7 @@ func TestManifestStore_Put(t *testing.T) {
|
||||
json.NewDecoder(r.Body).Decode(&lastBody)
|
||||
w.WriteHeader(tt.serverStatus)
|
||||
if tt.serverStatus == http.StatusOK {
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafytest"}`))
|
||||
} else {
|
||||
w.Write([]byte(`{"error":"ServerError"}`))
|
||||
}
|
||||
@@ -780,10 +593,14 @@ func TestManifestStore_Put(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
dgst, err := store.Put(context.Background(), tt.manifest, tt.options...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -821,19 +638,24 @@ func TestManifestStore_Put_WithConfigLabels(t *testing.T) {
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == atproto.RepoUploadBlob {
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"size":100}}`))
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"size":100}}`))
|
||||
return
|
||||
}
|
||||
if r.URL.Path == atproto.RepoPutRecord {
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/config123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/config123","cid":"bafytest"}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
|
||||
// Use config digest in manifest
|
||||
ociManifestWithConfig := []byte(`{
|
||||
@@ -848,7 +670,7 @@ func TestManifestStore_Put_WithConfigLabels(t *testing.T) {
|
||||
payload: ociManifestWithConfig,
|
||||
}
|
||||
|
||||
store := NewManifestStore(ctx, blobStore)
|
||||
store := NewManifestStore(userCtx, blobStore, nil)
|
||||
|
||||
_, err := store.Put(context.Background(), manifest)
|
||||
if err != nil {
|
||||
@@ -876,7 +698,7 @@ func TestManifestStore_Delete(t *testing.T) {
|
||||
name: "successful delete",
|
||||
digest: "sha256:abc123",
|
||||
serverStatus: http.StatusOK,
|
||||
serverResp: `{"commit":{"cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","rev":"12345"}}`,
|
||||
serverResp: `{"commit":{"cid":"bafytest","rev":"12345"}}`,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
@@ -908,9 +730,14 @@ func TestManifestStore_Delete(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
err := store.Delete(context.Background(), tt.digest)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -1033,7 +860,7 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
|
||||
// Handle uploadBlob
|
||||
if r.URL.Path == atproto.RepoUploadBlob {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}}`))
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"mimeType":"application/json","size":100}}`))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1045,7 +872,7 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
|
||||
// If child should exist, return it; otherwise return RecordNotFound
|
||||
if tt.childExists || rkey == childDigest.Encoded() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","value":{}}`))
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafytest","value":{}}`))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"RecordNotFound","message":"Record not found"}`))
|
||||
@@ -1056,7 +883,7 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
|
||||
// Handle putRecord
|
||||
if r.URL.Path == atproto.RepoPutRecord {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafytest"}`))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1064,10 +891,14 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
db := &mockDatabaseMetrics{}
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
manifest := &rawManifest{
|
||||
mediaType: "application/vnd.oci.image.index.v1+json",
|
||||
@@ -1117,14 +948,14 @@ func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == atproto.RepoUploadBlob {
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"size":100}}`))
|
||||
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"size":100}}`))
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == atproto.RepoGetRecord {
|
||||
rkey := r.URL.Query().Get("rkey")
|
||||
if existingManifests[rkey] {
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","value":{}}`))
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafytest","value":{}}`))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"RecordNotFound"}`))
|
||||
@@ -1133,7 +964,7 @@ func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T)
|
||||
}
|
||||
|
||||
if r.URL.Path == atproto.RepoPutRecord {
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
|
||||
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafytest"}`))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1141,9 +972,14 @@ func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
|
||||
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
|
||||
store := NewManifestStore(ctx, nil)
|
||||
userCtx := mockUserContextForManifest(
|
||||
server.URL,
|
||||
"myapp",
|
||||
"did:web:hold.example.com",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
)
|
||||
store := NewManifestStore(userCtx, nil, nil)
|
||||
|
||||
// Create manifest list with both children
|
||||
manifestList := []byte(`{
|
||||
|
||||
@@ -54,7 +54,7 @@ func EnsureProfile(ctx context.Context, client *atproto.Client, defaultHoldDID s
|
||||
// GetProfile retrieves the user's profile from their PDS
|
||||
// Returns nil if profile doesn't exist
|
||||
// Automatically migrates old URL-based defaultHold values to DIDs
|
||||
func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorProfile, error) {
|
||||
func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorProfileRecord, error) {
|
||||
record, err := client.GetRecord(ctx, atproto.SailorProfileCollection, ProfileRKey)
|
||||
if err != nil {
|
||||
// Check if it's a 404 (profile doesn't exist)
|
||||
@@ -65,17 +65,17 @@ func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorPro
|
||||
}
|
||||
|
||||
// Parse the profile record
|
||||
var profile atproto.SailorProfile
|
||||
var profile atproto.SailorProfileRecord
|
||||
if err := json.Unmarshal(record.Value, &profile); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse profile: %w", err)
|
||||
}
|
||||
|
||||
// Migrate old URL-based defaultHold to DID format
|
||||
// This ensures backward compatibility with profiles created before DID migration
|
||||
if profile.DefaultHold != nil && *profile.DefaultHold != "" && !atproto.IsDID(*profile.DefaultHold) {
|
||||
if profile.DefaultHold != "" && !atproto.IsDID(profile.DefaultHold) {
|
||||
// Convert URL to DID transparently
|
||||
migratedDID := atproto.ResolveHoldDIDFromURL(*profile.DefaultHold)
|
||||
profile.DefaultHold = &migratedDID
|
||||
migratedDID := atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
profile.DefaultHold = migratedDID
|
||||
|
||||
// Persist the migration to PDS in a background goroutine
|
||||
// Use a lock to ensure only one goroutine migrates this DID
|
||||
@@ -94,8 +94,7 @@ func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorPro
|
||||
defer cancel()
|
||||
|
||||
// Update the profile on the PDS
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
profile.UpdatedAt = &now
|
||||
profile.UpdatedAt = time.Now()
|
||||
if err := UpdateProfile(ctx, client, &profile); err != nil {
|
||||
slog.Warn("Failed to persist URL-to-DID migration", "component", "profile", "did", did, "error", err)
|
||||
} else {
|
||||
@@ -110,13 +109,12 @@ func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorPro
|
||||
|
||||
// UpdateProfile updates the user's profile
|
||||
// Normalizes defaultHold to DID format before saving
|
||||
func UpdateProfile(ctx context.Context, client *atproto.Client, profile *atproto.SailorProfile) error {
|
||||
func UpdateProfile(ctx context.Context, client *atproto.Client, profile *atproto.SailorProfileRecord) error {
|
||||
// Normalize defaultHold to DID if it's a URL
|
||||
// This ensures we always store DIDs, even if user provides a URL
|
||||
if profile.DefaultHold != nil && *profile.DefaultHold != "" && !atproto.IsDID(*profile.DefaultHold) {
|
||||
normalized := atproto.ResolveHoldDIDFromURL(*profile.DefaultHold)
|
||||
profile.DefaultHold = &normalized
|
||||
slog.Debug("Normalized defaultHold to DID", "component", "profile", "default_hold", normalized)
|
||||
if profile.DefaultHold != "" && !atproto.IsDID(profile.DefaultHold) {
|
||||
profile.DefaultHold = atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
|
||||
slog.Debug("Normalized defaultHold to DID", "component", "profile", "default_hold", profile.DefaultHold)
|
||||
}
|
||||
|
||||
_, err := client.PutRecord(ctx, atproto.SailorProfileCollection, ProfileRKey, profile)
|
||||
|
||||
@@ -39,7 +39,7 @@ func TestEnsureProfile_Create(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var createdProfile *atproto.SailorProfile
|
||||
var createdProfile *atproto.SailorProfileRecord
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// First request: GetRecord (should 404)
|
||||
@@ -95,16 +95,12 @@ func TestEnsureProfile_Create(t *testing.T) {
|
||||
t.Fatal("Profile was not created")
|
||||
}
|
||||
|
||||
if createdProfile.LexiconTypeID != atproto.SailorProfileCollection {
|
||||
t.Errorf("LexiconTypeID = %v, want %v", createdProfile.LexiconTypeID, atproto.SailorProfileCollection)
|
||||
if createdProfile.Type != atproto.SailorProfileCollection {
|
||||
t.Errorf("Type = %v, want %v", createdProfile.Type, atproto.SailorProfileCollection)
|
||||
}
|
||||
|
||||
gotDefaultHold := ""
|
||||
if createdProfile.DefaultHold != nil {
|
||||
gotDefaultHold = *createdProfile.DefaultHold
|
||||
}
|
||||
if gotDefaultHold != tt.wantNormalized {
|
||||
t.Errorf("DefaultHold = %v, want %v", gotDefaultHold, tt.wantNormalized)
|
||||
if createdProfile.DefaultHold != tt.wantNormalized {
|
||||
t.Errorf("DefaultHold = %v, want %v", createdProfile.DefaultHold, tt.wantNormalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -158,7 +154,7 @@ func TestGetProfile(t *testing.T) {
|
||||
name string
|
||||
serverResponse string
|
||||
serverStatus int
|
||||
wantProfile *atproto.SailorProfile
|
||||
wantProfile *atproto.SailorProfileRecord
|
||||
wantNil bool
|
||||
wantErr bool
|
||||
expectMigration bool // Whether URL-to-DID migration should happen
|
||||
@@ -269,12 +265,8 @@ func TestGetProfile(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check that defaultHold is migrated to DID in returned profile
|
||||
gotDefaultHold := ""
|
||||
if profile.DefaultHold != nil {
|
||||
gotDefaultHold = *profile.DefaultHold
|
||||
}
|
||||
if gotDefaultHold != tt.expectedHoldDID {
|
||||
t.Errorf("DefaultHold = %v, want %v", gotDefaultHold, tt.expectedHoldDID)
|
||||
if profile.DefaultHold != tt.expectedHoldDID {
|
||||
t.Errorf("DefaultHold = %v, want %v", profile.DefaultHold, tt.expectedHoldDID)
|
||||
}
|
||||
|
||||
if tt.expectMigration {
|
||||
@@ -374,43 +366,44 @@ func TestGetProfile_MigrationLocking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// testSailorProfile creates a test profile with the given default hold
|
||||
func testSailorProfile(defaultHold string) *atproto.SailorProfile {
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
profile := &atproto.SailorProfile{
|
||||
LexiconTypeID: atproto.SailorProfileCollection,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
if defaultHold != "" {
|
||||
profile.DefaultHold = &defaultHold
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
// TestUpdateProfile tests updating a user's profile
|
||||
func TestUpdateProfile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profile *atproto.SailorProfile
|
||||
profile *atproto.SailorProfileRecord
|
||||
wantNormalized string // Expected defaultHold after normalization
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "update with DID",
|
||||
profile: testSailorProfile("did:web:hold02.atcr.io"),
|
||||
name: "update with DID",
|
||||
profile: &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "did:web:hold02.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
wantNormalized: "did:web:hold02.atcr.io",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "update with URL - should normalize",
|
||||
profile: testSailorProfile("https://hold02.atcr.io"),
|
||||
name: "update with URL - should normalize",
|
||||
profile: &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "https://hold02.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
wantNormalized: "did:web:hold02.atcr.io",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "clear default hold",
|
||||
profile: testSailorProfile(""),
|
||||
name: "clear default hold",
|
||||
profile: &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
wantNormalized: "",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -461,12 +454,8 @@ func TestUpdateProfile(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify normalization also updated the profile object
|
||||
gotProfileHold := ""
|
||||
if tt.profile.DefaultHold != nil {
|
||||
gotProfileHold = *tt.profile.DefaultHold
|
||||
}
|
||||
if gotProfileHold != tt.wantNormalized {
|
||||
t.Errorf("profile.DefaultHold = %v, want %v (should be updated in-place)", gotProfileHold, tt.wantNormalized)
|
||||
if tt.profile.DefaultHold != tt.wantNormalized {
|
||||
t.Errorf("profile.DefaultHold = %v, want %v (should be updated in-place)", tt.profile.DefaultHold, tt.wantNormalized)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -550,8 +539,8 @@ func TestGetProfile_EmptyDefaultHold(t *testing.T) {
|
||||
t.Fatalf("GetProfile() error = %v", err)
|
||||
}
|
||||
|
||||
if profile.DefaultHold != nil && *profile.DefaultHold != "" {
|
||||
t.Errorf("DefaultHold = %v, want empty or nil", profile.DefaultHold)
|
||||
if profile.DefaultHold != "" {
|
||||
t.Errorf("DefaultHold = %v, want empty string", profile.DefaultHold)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,7 +553,12 @@ func TestUpdateProfile_ServerError(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
|
||||
profile := testSailorProfile("did:web:hold01.atcr.io")
|
||||
profile := &atproto.SailorProfileRecord{
|
||||
Type: atproto.SailorProfileCollection,
|
||||
DefaultHold: "did:web:hold01.atcr.io",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := UpdateProfile(context.Background(), client, profile)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/distribution/distribution/v3/registry/api/errcode"
|
||||
"github.com/opencontainers/go-digest"
|
||||
@@ -32,20 +33,20 @@ var (
|
||||
|
||||
// ProxyBlobStore proxies blob requests to an external storage service
|
||||
type ProxyBlobStore struct {
|
||||
ctx *RegistryContext // All context and services
|
||||
holdURL string // Resolved HTTP URL for XRPC requests
|
||||
ctx *auth.UserContext // User context with identity, target, permissions
|
||||
holdURL string // Resolved HTTP URL for XRPC requests
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewProxyBlobStore creates a new proxy blob store
|
||||
func NewProxyBlobStore(ctx *RegistryContext) *ProxyBlobStore {
|
||||
func NewProxyBlobStore(userCtx *auth.UserContext) *ProxyBlobStore {
|
||||
// Resolve DID to URL once at construction time
|
||||
holdURL := atproto.ResolveHoldURL(ctx.HoldDID)
|
||||
holdURL := atproto.ResolveHoldURL(userCtx.TargetHoldDID)
|
||||
|
||||
slog.Debug("NewProxyBlobStore created", "component", "proxy_blob_store", "hold_did", ctx.HoldDID, "hold_url", holdURL, "user_did", ctx.DID, "repo", ctx.Repository)
|
||||
slog.Debug("NewProxyBlobStore created", "component", "proxy_blob_store", "hold_did", userCtx.TargetHoldDID, "hold_url", holdURL, "user_did", userCtx.TargetOwnerDID, "repo", userCtx.TargetRepo)
|
||||
|
||||
return &ProxyBlobStore{
|
||||
ctx: ctx,
|
||||
ctx: userCtx,
|
||||
holdURL: holdURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Minute, // Timeout for presigned URL requests and uploads
|
||||
@@ -61,32 +62,33 @@ func NewProxyBlobStore(ctx *RegistryContext) *ProxyBlobStore {
|
||||
}
|
||||
|
||||
// doAuthenticatedRequest performs an HTTP request with service token authentication
|
||||
// Uses the service token from middleware to authenticate requests to the hold service
|
||||
// Uses the service token from UserContext to authenticate requests to the hold service
|
||||
func (p *ProxyBlobStore) doAuthenticatedRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
// Use service token that middleware already validated and cached
|
||||
// Middleware fails fast with HTTP 401 if OAuth session is invalid
|
||||
if p.ctx.ServiceToken == "" {
|
||||
// Get service token from UserContext (lazy-loaded and cached per holdDID)
|
||||
serviceToken, err := p.ctx.GetServiceToken(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get service token", "component", "proxy_blob_store", "did", p.ctx.DID, "error", err)
|
||||
return nil, fmt.Errorf("failed to get service token: %w", err)
|
||||
}
|
||||
if serviceToken == "" {
|
||||
// Should never happen - middleware validates OAuth before handlers run
|
||||
slog.Error("No service token in context", "component", "proxy_blob_store", "did", p.ctx.DID)
|
||||
return nil, fmt.Errorf("no service token available (middleware should have validated)")
|
||||
}
|
||||
|
||||
// Add Bearer token to Authorization header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.ctx.ServiceToken))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", serviceToken))
|
||||
|
||||
return p.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// checkReadAccess validates that the user has read access to blobs in this hold
|
||||
func (p *ProxyBlobStore) checkReadAccess(ctx context.Context) error {
|
||||
if p.ctx.Authorizer == nil {
|
||||
return nil // No authorization check if authorizer not configured
|
||||
}
|
||||
allowed, err := p.ctx.Authorizer.CheckReadAccess(ctx, p.ctx.HoldDID, p.ctx.DID)
|
||||
canRead, err := p.ctx.CanRead(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authorization check failed: %w", err)
|
||||
}
|
||||
if !allowed {
|
||||
if !canRead {
|
||||
// Return 403 Forbidden instead of masquerading as missing blob
|
||||
return errcode.ErrorCodeDenied.WithMessage("read access denied")
|
||||
}
|
||||
@@ -95,21 +97,17 @@ func (p *ProxyBlobStore) checkReadAccess(ctx context.Context) error {
|
||||
|
||||
// checkWriteAccess validates that the user has write access to blobs in this hold
|
||||
func (p *ProxyBlobStore) checkWriteAccess(ctx context.Context) error {
|
||||
if p.ctx.Authorizer == nil {
|
||||
return nil // No authorization check if authorizer not configured
|
||||
}
|
||||
|
||||
slog.Debug("Checking write access", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
|
||||
allowed, err := p.ctx.Authorizer.CheckWriteAccess(ctx, p.ctx.HoldDID, p.ctx.DID)
|
||||
slog.Debug("Checking write access", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
|
||||
canWrite, err := p.ctx.CanWrite(ctx)
|
||||
if err != nil {
|
||||
slog.Error("Authorization check error", "component", "proxy_blob_store", "error", err)
|
||||
return fmt.Errorf("authorization check failed: %w", err)
|
||||
}
|
||||
if !allowed {
|
||||
slog.Warn("Write access denied", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
|
||||
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("write access denied to hold %s", p.ctx.HoldDID))
|
||||
if !canWrite {
|
||||
slog.Warn("Write access denied", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
|
||||
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("write access denied to hold %s", p.ctx.TargetHoldDID))
|
||||
}
|
||||
slog.Debug("Write access allowed", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
|
||||
slog.Debug("Write access allowed", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -356,10 +354,10 @@ func (p *ProxyBlobStore) Resume(ctx context.Context, id string) (distribution.Bl
|
||||
// getPresignedURL returns the XRPC endpoint URL for blob operations
|
||||
func (p *ProxyBlobStore) getPresignedURL(ctx context.Context, operation string, dgst digest.Digest) (string, error) {
|
||||
// Use XRPC endpoint: /xrpc/com.atproto.sync.getBlob?did={userDID}&cid={digest}
|
||||
// The 'did' parameter is the USER's DID (whose blob we're fetching), not the hold service DID
|
||||
// The 'did' parameter is the TARGET OWNER's DID (whose blob we're fetching), not the hold service DID
|
||||
// Per migration doc: hold accepts OCI digest directly as cid parameter (checks for sha256: prefix)
|
||||
xrpcURL := fmt.Sprintf("%s%s?did=%s&cid=%s&method=%s",
|
||||
p.holdURL, atproto.SyncGetBlob, p.ctx.DID, dgst.String(), operation)
|
||||
p.holdURL, atproto.SyncGetBlob, p.ctx.TargetOwnerDID, dgst.String(), operation)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", xrpcURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,46 +1,41 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/token"
|
||||
"github.com/opencontainers/go-digest"
|
||||
"atcr.io/pkg/auth"
|
||||
)
|
||||
|
||||
// TestGetServiceToken_CachingLogic tests the token caching mechanism
|
||||
// TestGetServiceToken_CachingLogic tests the global service token caching mechanism
|
||||
// These tests use the global auth cache functions directly
|
||||
func TestGetServiceToken_CachingLogic(t *testing.T) {
|
||||
userDID := "did:plc:test"
|
||||
userDID := "did:plc:cache-test"
|
||||
holdDID := "did:web:hold.example.com"
|
||||
|
||||
// Test 1: Empty cache - invalidate any existing token
|
||||
token.InvalidateServiceToken(userDID, holdDID)
|
||||
cachedToken, _ := token.GetServiceToken(userDID, holdDID)
|
||||
auth.InvalidateServiceToken(userDID, holdDID)
|
||||
cachedToken, _ := auth.GetServiceToken(userDID, holdDID)
|
||||
if cachedToken != "" {
|
||||
t.Error("Expected empty cache at start")
|
||||
}
|
||||
|
||||
// Test 2: Insert token into cache
|
||||
// Create a JWT-like token with exp claim for testing
|
||||
// Format: header.payload.signature where payload has exp claim
|
||||
testPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(50*time.Second).Unix())
|
||||
testToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
|
||||
|
||||
err := token.SetServiceToken(userDID, holdDID, testToken)
|
||||
err := auth.SetServiceToken(userDID, holdDID, testToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set service token: %v", err)
|
||||
}
|
||||
|
||||
// Test 3: Retrieve from cache
|
||||
cachedToken, expiresAt := token.GetServiceToken(userDID, holdDID)
|
||||
cachedToken, expiresAt := auth.GetServiceToken(userDID, holdDID)
|
||||
if cachedToken == "" {
|
||||
t.Fatal("Expected token to be in cache")
|
||||
}
|
||||
@@ -56,10 +51,10 @@ func TestGetServiceToken_CachingLogic(t *testing.T) {
|
||||
// Test 4: Expired token - GetServiceToken automatically removes it
|
||||
expiredPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(-1*time.Hour).Unix())
|
||||
expiredToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(expiredPayload) + ".signature"
|
||||
token.SetServiceToken(userDID, holdDID, expiredToken)
|
||||
auth.SetServiceToken(userDID, holdDID, expiredToken)
|
||||
|
||||
// GetServiceToken should return empty string for expired token
|
||||
cachedToken, _ = token.GetServiceToken(userDID, holdDID)
|
||||
cachedToken, _ = auth.GetServiceToken(userDID, holdDID)
|
||||
if cachedToken != "" {
|
||||
t.Error("Expected expired token to be removed from cache")
|
||||
}
|
||||
@@ -70,129 +65,33 @@ func base64URLEncode(data string) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString([]byte(data)), "=")
|
||||
}
|
||||
|
||||
// TestServiceToken_EmptyInContext tests that operations fail when service token is missing
|
||||
func TestServiceToken_EmptyInContext(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "", // No service token (middleware didn't set it)
|
||||
Refresher: nil,
|
||||
}
|
||||
// mockUserContextForProxy creates a mock auth.UserContext for proxy blob store testing.
|
||||
// It sets up both the user identity and target info, and configures test helpers
|
||||
// to bypass network calls.
|
||||
func mockUserContextForProxy(did, holdDID, pdsEndpoint, repository string) *auth.UserContext {
|
||||
userCtx := auth.NewUserContext(did, "oauth", "PUT", nil)
|
||||
userCtx.SetTarget(did, "test.handle", pdsEndpoint, repository, holdDID)
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
// Bypass PDS resolution (avoids network calls)
|
||||
userCtx.SetPDSForTest("test.handle", pdsEndpoint)
|
||||
|
||||
// Try a write operation that requires authentication
|
||||
testDigest := digest.FromString("test-content")
|
||||
_, err := store.Stat(context.Background(), testDigest)
|
||||
// Set up mock authorizer that allows access
|
||||
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
|
||||
|
||||
// Should fail because no service token is available
|
||||
if err == nil {
|
||||
t.Error("Expected error when service token is empty")
|
||||
}
|
||||
// Set default hold DID for push resolution
|
||||
userCtx.SetDefaultHoldDIDForTest(holdDID)
|
||||
|
||||
// Error should indicate authentication issue
|
||||
if !strings.Contains(err.Error(), "UNAUTHORIZED") && !strings.Contains(err.Error(), "authentication") {
|
||||
t.Logf("Got error (acceptable): %v", err)
|
||||
}
|
||||
return userCtx
|
||||
}
|
||||
|
||||
// TestDoAuthenticatedRequest_BearerTokenInjection tests that Bearer tokens are added to requests
|
||||
func TestDoAuthenticatedRequest_BearerTokenInjection(t *testing.T) {
|
||||
// This test verifies the Bearer token injection logic
|
||||
|
||||
testToken := "test-bearer-token-xyz"
|
||||
|
||||
// Create a test server to verify the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
// Create ProxyBlobStore with service token in context (set by middleware)
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:bearer-test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: testToken, // Service token from middleware
|
||||
Refresher: nil,
|
||||
}
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Do authenticated request
|
||||
resp, err := store.doAuthenticatedRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("doAuthenticatedRequest failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Verify Bearer token was added
|
||||
expectedHeader := "Bearer " + testToken
|
||||
if receivedAuthHeader != expectedHeader {
|
||||
t.Errorf("Expected Authorization header %s, got %s", expectedHeader, receivedAuthHeader)
|
||||
}
|
||||
// mockUserContextForProxyWithToken creates a mock UserContext with a pre-populated service token.
|
||||
func mockUserContextForProxyWithToken(did, holdDID, pdsEndpoint, repository, serviceToken string) *auth.UserContext {
|
||||
userCtx := mockUserContextForProxy(did, holdDID, pdsEndpoint, repository)
|
||||
userCtx.SetServiceTokenForTest(holdDID, serviceToken)
|
||||
return userCtx
|
||||
}
|
||||
|
||||
// TestDoAuthenticatedRequest_ErrorWhenTokenUnavailable tests that authentication failures return proper errors
|
||||
func TestDoAuthenticatedRequest_ErrorWhenTokenUnavailable(t *testing.T) {
|
||||
// Create test server (should not be called since auth fails first)
|
||||
called := false
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
// Create ProxyBlobStore without service token (middleware didn't set it)
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:fallback",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "", // No service token
|
||||
Refresher: nil,
|
||||
}
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Do authenticated request - should fail when no service token
|
||||
resp, err := store.doAuthenticatedRequest(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Fatal("Expected doAuthenticatedRequest to fail when no service token is available")
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Verify error indicates authentication/authorization issue
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, "service token") && !strings.Contains(errStr, "UNAUTHORIZED") {
|
||||
t.Errorf("Expected service token or unauthorized error, got: %v", err)
|
||||
}
|
||||
|
||||
if called {
|
||||
t.Error("Expected request to NOT be made when authentication fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveHoldURL tests DID to URL conversion
|
||||
// TestResolveHoldURL tests DID to URL conversion (pure function)
|
||||
func TestResolveHoldURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -200,7 +99,7 @@ func TestResolveHoldURL(t *testing.T) {
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "did:web with http (TEST_MODE)",
|
||||
name: "did:web with http (localhost)",
|
||||
holdDID: "did:web:localhost:8080",
|
||||
expected: "http://localhost:8080",
|
||||
},
|
||||
@@ -228,16 +127,16 @@ func TestResolveHoldURL(t *testing.T) {
|
||||
|
||||
// TestServiceTokenCacheExpiry tests that expired cached tokens are not used
|
||||
func TestServiceTokenCacheExpiry(t *testing.T) {
|
||||
userDID := "did:plc:expiry"
|
||||
userDID := "did:plc:expiry-test"
|
||||
holdDID := "did:web:hold.example.com"
|
||||
|
||||
// Insert expired token
|
||||
expiredPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(-1*time.Hour).Unix())
|
||||
expiredToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(expiredPayload) + ".signature"
|
||||
token.SetServiceToken(userDID, holdDID, expiredToken)
|
||||
auth.SetServiceToken(userDID, holdDID, expiredToken)
|
||||
|
||||
// GetServiceToken should automatically remove expired tokens
|
||||
cachedToken, expiresAt := token.GetServiceToken(userDID, holdDID)
|
||||
cachedToken, expiresAt := auth.GetServiceToken(userDID, holdDID)
|
||||
|
||||
// Should return empty string for expired token
|
||||
if cachedToken != "" {
|
||||
@@ -272,20 +171,20 @@ func TestServiceTokenCacheKeyFormat(t *testing.T) {
|
||||
|
||||
// TestNewProxyBlobStore tests ProxyBlobStore creation
|
||||
func TestNewProxyBlobStore(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
}
|
||||
userCtx := mockUserContextForProxy(
|
||||
"did:plc:test",
|
||||
"did:web:hold.example.com",
|
||||
"https://pds.example.com",
|
||||
"test-repo",
|
||||
)
|
||||
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store := NewProxyBlobStore(userCtx)
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("Expected non-nil ProxyBlobStore")
|
||||
}
|
||||
|
||||
if store.ctx != ctx {
|
||||
if store.ctx != userCtx {
|
||||
t.Error("Expected context to be set")
|
||||
}
|
||||
|
||||
@@ -310,10 +209,10 @@ func BenchmarkServiceTokenCacheAccess(b *testing.B) {
|
||||
|
||||
testPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(50*time.Second).Unix())
|
||||
testTokenStr := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
|
||||
token.SetServiceToken(userDID, holdDID, testTokenStr)
|
||||
auth.SetServiceToken(userDID, holdDID, testTokenStr)
|
||||
|
||||
for b.Loop() {
|
||||
cachedToken, expiresAt := token.GetServiceToken(userDID, holdDID)
|
||||
cachedToken, expiresAt := auth.GetServiceToken(userDID, holdDID)
|
||||
|
||||
if cachedToken == "" || time.Now().After(expiresAt) {
|
||||
b.Error("Cache miss in benchmark")
|
||||
@@ -321,296 +220,55 @@ func BenchmarkServiceTokenCacheAccess(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompleteMultipartUpload_JSONFormat verifies the JSON request format sent to hold service
|
||||
// This test would have caught the "partNumber" vs "part_number" bug
|
||||
func TestCompleteMultipartUpload_JSONFormat(t *testing.T) {
|
||||
var capturedBody map[string]any
|
||||
// TestParseJWTExpiry tests JWT expiry parsing
|
||||
func TestParseJWTExpiry(t *testing.T) {
|
||||
// Create a JWT with known expiry
|
||||
futureTime := time.Now().Add(1 * time.Hour).Unix()
|
||||
testPayload := fmt.Sprintf(`{"exp":%d}`, futureTime)
|
||||
testToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
|
||||
|
||||
// Mock hold service that captures the request body
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.URL.Path, atproto.HoldCompleteUpload) {
|
||||
t.Errorf("Wrong endpoint called: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
// Capture request body
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("Failed to decode request body: %v", err)
|
||||
}
|
||||
capturedBody = body
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{}`))
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with mocked hold URL
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call completeMultipartUpload
|
||||
parts := []CompletedPart{
|
||||
{PartNumber: 1, ETag: "etag-1"},
|
||||
{PartNumber: 2, ETag: "etag-2"},
|
||||
}
|
||||
err := store.completeMultipartUpload(context.Background(), "sha256:abc123", "upload-id-xyz", parts)
|
||||
expiry, err := auth.ParseJWTExpiry(testToken)
|
||||
if err != nil {
|
||||
t.Fatalf("completeMultipartUpload failed: %v", err)
|
||||
t.Fatalf("ParseJWTExpiry failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON format
|
||||
if capturedBody == nil {
|
||||
t.Fatal("No request body was captured")
|
||||
}
|
||||
|
||||
// Check top-level fields
|
||||
if uploadID, ok := capturedBody["uploadId"].(string); !ok || uploadID != "upload-id-xyz" {
|
||||
t.Errorf("Expected uploadId='upload-id-xyz', got %v", capturedBody["uploadId"])
|
||||
}
|
||||
if digest, ok := capturedBody["digest"].(string); !ok || digest != "sha256:abc123" {
|
||||
t.Errorf("Expected digest='sha256:abc123', got %v", capturedBody["digest"])
|
||||
}
|
||||
|
||||
// Check parts array
|
||||
partsArray, ok := capturedBody["parts"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected parts to be array, got %T", capturedBody["parts"])
|
||||
}
|
||||
if len(partsArray) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(partsArray))
|
||||
}
|
||||
|
||||
// Verify first part has "part_number" (not "partNumber")
|
||||
part0, ok := partsArray[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected part to be object, got %T", partsArray[0])
|
||||
}
|
||||
|
||||
// THIS IS THE KEY CHECK - would have caught the bug
|
||||
if _, hasPartNumber := part0["partNumber"]; hasPartNumber {
|
||||
t.Error("Found 'partNumber' (camelCase) - should be 'part_number' (snake_case)")
|
||||
}
|
||||
if partNum, ok := part0["part_number"].(float64); !ok || int(partNum) != 1 {
|
||||
t.Errorf("Expected part_number=1, got %v", part0["part_number"])
|
||||
}
|
||||
if etag, ok := part0["etag"].(string); !ok || etag != "etag-1" {
|
||||
t.Errorf("Expected etag='etag-1', got %v", part0["etag"])
|
||||
// Verify expiry is close to what we set (within 1 second tolerance)
|
||||
expectedExpiry := time.Unix(futureTime, 0)
|
||||
diff := expiry.Sub(expectedExpiry)
|
||||
if diff < -time.Second || diff > time.Second {
|
||||
t.Errorf("Expiry mismatch: expected %v, got %v", expectedExpiry, expiry)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGet_UsesPresignedURLDirectly verifies that Get() doesn't add auth headers to presigned URLs
|
||||
// This test would have caught the presigned URL authentication bug
|
||||
func TestGet_UsesPresignedURLDirectly(t *testing.T) {
|
||||
blobData := []byte("test blob content")
|
||||
var s3ReceivedAuthHeader string
|
||||
|
||||
// Mock S3 server that rejects requests with Authorization header
|
||||
s3Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s3ReceivedAuthHeader = r.Header.Get("Authorization")
|
||||
|
||||
// Presigned URLs should NOT have Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 received Authorization header: %s (should be empty for presigned URLs)", s3ReceivedAuthHeader)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`<?xml version="1.0"?><Error><Code>SignatureDoesNotMatch</Code></Error>`))
|
||||
return
|
||||
}
|
||||
|
||||
// Return blob data
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(blobData)
|
||||
}))
|
||||
defer s3Server.Close()
|
||||
|
||||
// Mock hold service that returns presigned S3 URL
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return presigned URL pointing to S3 server
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
resp := map[string]string{
|
||||
"url": s3Server.URL + "/blob?X-Amz-Signature=fake-signature",
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with service token in context
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call Get()
|
||||
dgst := digest.FromBytes(blobData)
|
||||
retrieved, err := store.Get(context.Background(), dgst)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify correct data was retrieved
|
||||
if string(retrieved) != string(blobData) {
|
||||
t.Errorf("Expected data=%s, got %s", string(blobData), string(retrieved))
|
||||
}
|
||||
|
||||
// Verify S3 received NO Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 should not receive Authorization header for presigned URLs, got: %s", s3ReceivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpen_UsesPresignedURLDirectly verifies that Open() doesn't add auth headers to presigned URLs
|
||||
// This test would have caught the presigned URL authentication bug
|
||||
func TestOpen_UsesPresignedURLDirectly(t *testing.T) {
|
||||
blobData := []byte("test blob stream content")
|
||||
var s3ReceivedAuthHeader string
|
||||
|
||||
// Mock S3 server
|
||||
s3Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s3ReceivedAuthHeader = r.Header.Get("Authorization")
|
||||
|
||||
// Presigned URLs should NOT have Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 received Authorization header: %s (should be empty)", s3ReceivedAuthHeader)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(blobData)
|
||||
}))
|
||||
defer s3Server.Close()
|
||||
|
||||
// Mock hold service
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"url": s3Server.URL + "/blob?X-Amz-Signature=fake",
|
||||
})
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with service token in context
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call Open()
|
||||
dgst := digest.FromBytes(blobData)
|
||||
reader, err := store.Open(context.Background(), dgst)
|
||||
if err != nil {
|
||||
t.Fatalf("Open() failed: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Verify S3 received NO Authorization header
|
||||
if s3ReceivedAuthHeader != "" {
|
||||
t.Errorf("S3 should not receive Authorization header for presigned URLs, got: %s", s3ReceivedAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipartEndpoints_CorrectURLs verifies all multipart XRPC endpoints use correct URLs
|
||||
// This would have caught the old com.atproto.repo.uploadBlob vs new io.atcr.hold.* endpoints
|
||||
func TestMultipartEndpoints_CorrectURLs(t *testing.T) {
|
||||
// TestParseJWTExpiry_InvalidToken tests error handling for invalid tokens
|
||||
func TestParseJWTExpiry_InvalidToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(*ProxyBlobStore) error
|
||||
expectedPath string
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "startMultipartUpload",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
_, err := store.startMultipartUpload(context.Background(), "sha256:test")
|
||||
return err
|
||||
},
|
||||
expectedPath: atproto.HoldInitiateUpload,
|
||||
},
|
||||
{
|
||||
name: "getPartUploadInfo",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
_, err := store.getPartUploadInfo(context.Background(), "sha256:test", "upload-123", 1)
|
||||
return err
|
||||
},
|
||||
expectedPath: atproto.HoldGetPartUploadURL,
|
||||
},
|
||||
{
|
||||
name: "completeMultipartUpload",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
parts := []CompletedPart{{PartNumber: 1, ETag: "etag1"}}
|
||||
return store.completeMultipartUpload(context.Background(), "sha256:test", "upload-123", parts)
|
||||
},
|
||||
expectedPath: atproto.HoldCompleteUpload,
|
||||
},
|
||||
{
|
||||
name: "abortMultipartUpload",
|
||||
testFunc: func(store *ProxyBlobStore) error {
|
||||
return store.abortMultipartUpload(context.Background(), "sha256:test", "upload-123")
|
||||
},
|
||||
expectedPath: atproto.HoldAbortUpload,
|
||||
},
|
||||
{"empty token", ""},
|
||||
{"single part", "header"},
|
||||
{"two parts", "header.payload"},
|
||||
{"invalid base64 payload", "header.!!!.signature"},
|
||||
{"missing exp claim", "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(`{"sub":"test"}`) + ".sig"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedPath string
|
||||
|
||||
// Mock hold service that captures request path
|
||||
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
|
||||
// Return success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
resp := map[string]string{
|
||||
"uploadId": "test-upload-id",
|
||||
"url": "https://s3.example.com/presigned",
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer holdServer.Close()
|
||||
|
||||
// Create store with service token in context
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test",
|
||||
HoldDID: "did:web:hold.example.com",
|
||||
PDSEndpoint: "https://pds.example.com",
|
||||
Repository: "test-repo",
|
||||
ServiceToken: "test-service-token", // Service token from middleware
|
||||
}
|
||||
store := NewProxyBlobStore(ctx)
|
||||
store.holdURL = holdServer.URL
|
||||
|
||||
// Call the function
|
||||
_ = tt.testFunc(store) // Ignore error, we just care about the URL
|
||||
|
||||
// Verify correct endpoint was called
|
||||
if capturedPath != tt.expectedPath {
|
||||
t.Errorf("Expected endpoint %s, got %s", tt.expectedPath, capturedPath)
|
||||
}
|
||||
|
||||
// Verify it's NOT the old endpoint
|
||||
if strings.Contains(capturedPath, "com.atproto.repo.uploadBlob") {
|
||||
t.Error("Still using old com.atproto.repo.uploadBlob endpoint!")
|
||||
_, err := auth.ParseJWTExpiry(tt.token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid token")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Tests for doAuthenticatedRequest, Get, Open, completeMultipartUpload, etc.
|
||||
// require complex dependency mocking (OAuth refresher, PDS resolution, HoldAuthorizer).
|
||||
// These should be tested at the integration level with proper infrastructure.
|
||||
//
|
||||
// The current unit tests cover:
|
||||
// - Global service token cache (auth.GetServiceToken, auth.SetServiceToken, etc.)
|
||||
// - URL resolution (atproto.ResolveHoldURL)
|
||||
// - JWT parsing (auth.ParseJWTExpiry)
|
||||
// - Store construction (NewProxyBlobStore)
|
||||
|
||||
@@ -6,110 +6,75 @@ package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"atcr.io/pkg/auth"
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/distribution/reference"
|
||||
)
|
||||
|
||||
// RoutingRepository routes manifests to ATProto and blobs to external hold service
|
||||
// The registry (AppView) is stateless and NEVER stores blobs locally
|
||||
// RoutingRepository routes manifests to ATProto and blobs to external hold service.
|
||||
// The registry (AppView) is stateless and NEVER stores blobs locally.
|
||||
// A new instance is created per HTTP request - no caching or synchronization needed.
|
||||
type RoutingRepository struct {
|
||||
distribution.Repository
|
||||
Ctx *RegistryContext // All context and services (exported for token updates)
|
||||
mu sync.Mutex // Protects manifestStore and blobStore
|
||||
manifestStore *ManifestStore // Cached manifest store instance
|
||||
blobStore *ProxyBlobStore // Cached blob store instance
|
||||
userCtx *auth.UserContext
|
||||
sqlDB *sql.DB
|
||||
}
|
||||
|
||||
// NewRoutingRepository creates a new routing repository
|
||||
func NewRoutingRepository(baseRepo distribution.Repository, ctx *RegistryContext) *RoutingRepository {
|
||||
func NewRoutingRepository(baseRepo distribution.Repository, userCtx *auth.UserContext, sqlDB *sql.DB) *RoutingRepository {
|
||||
return &RoutingRepository{
|
||||
Repository: baseRepo,
|
||||
Ctx: ctx,
|
||||
userCtx: userCtx,
|
||||
sqlDB: sqlDB,
|
||||
}
|
||||
}
|
||||
|
||||
// Manifests returns the ATProto-backed manifest service
|
||||
func (r *RoutingRepository) Manifests(ctx context.Context, options ...distribution.ManifestServiceOption) (distribution.ManifestService, error) {
|
||||
r.mu.Lock()
|
||||
// Create or return cached manifest store
|
||||
if r.manifestStore == nil {
|
||||
// Ensure blob store is created first (needed for label extraction during push)
|
||||
// Release lock while calling Blobs to avoid deadlock
|
||||
r.mu.Unlock()
|
||||
blobStore := r.Blobs(ctx)
|
||||
r.mu.Lock()
|
||||
|
||||
// Double-check after reacquiring lock (another goroutine might have set it)
|
||||
if r.manifestStore == nil {
|
||||
r.manifestStore = NewManifestStore(r.Ctx, blobStore)
|
||||
}
|
||||
}
|
||||
manifestStore := r.manifestStore
|
||||
r.mu.Unlock()
|
||||
|
||||
return manifestStore, nil
|
||||
// blobStore used to fetch labels from th
|
||||
blobStore := r.Blobs(ctx)
|
||||
return NewManifestStore(r.userCtx, blobStore, r.sqlDB), nil
|
||||
}
|
||||
|
||||
// Blobs returns a proxy blob store that routes to external hold service
|
||||
// The registry (AppView) NEVER stores blobs locally - all blobs go through hold service
|
||||
func (r *RoutingRepository) Blobs(ctx context.Context) distribution.BlobStore {
|
||||
r.mu.Lock()
|
||||
// Return cached blob store if available
|
||||
if r.blobStore != nil {
|
||||
blobStore := r.blobStore
|
||||
r.mu.Unlock()
|
||||
slog.Debug("Returning cached blob store", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository)
|
||||
return blobStore
|
||||
}
|
||||
|
||||
// Determine if this is a pull (GET) or push (PUT/POST/HEAD/etc) operation
|
||||
// Pull operations use the historical hold DID from the database (blobs are where they were pushed)
|
||||
// Push operations use the discovery-based hold DID from user's profile/default
|
||||
// This allows users to change their default hold and have new pushes go there
|
||||
isPull := false
|
||||
if method, ok := ctx.Value("http.request.method").(string); ok {
|
||||
isPull = method == "GET"
|
||||
}
|
||||
|
||||
holdDID := r.Ctx.HoldDID // Default to discovery-based DID
|
||||
holdSource := "discovery"
|
||||
|
||||
// Only query database for pull operations
|
||||
if isPull && r.Ctx.Database != nil {
|
||||
// Query database for the latest manifest's hold DID
|
||||
if dbHoldDID, err := r.Ctx.Database.GetLatestHoldDIDForRepo(r.Ctx.DID, r.Ctx.Repository); err == nil && dbHoldDID != "" {
|
||||
// Use hold DID from database (pull case - use historical reference)
|
||||
holdDID = dbHoldDID
|
||||
holdSource = "database"
|
||||
slog.Debug("Using hold from database manifest (pull)", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", dbHoldDID)
|
||||
} else if err != nil {
|
||||
// Log error but don't fail - fall back to discovery-based DID
|
||||
slog.Warn("Failed to query database for hold DID", "component", "storage/blobs", "error", err)
|
||||
}
|
||||
// If dbHoldDID is empty (no manifests yet), fall through to use discovery-based DID
|
||||
// Resolve hold DID: pull uses DB lookup, push uses profile discovery
|
||||
holdDID, err := r.userCtx.ResolveHoldDID(ctx, r.sqlDB)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to resolve hold DID", "component", "storage/blobs", "error", err)
|
||||
holdDID = r.userCtx.TargetHoldDID
|
||||
}
|
||||
|
||||
if holdDID == "" {
|
||||
// This should never happen if middleware is configured correctly
|
||||
panic("hold DID not set in RegistryContext - ensure default_hold_did is configured in middleware")
|
||||
panic("hold DID not set - ensure default_hold_did is configured in middleware")
|
||||
}
|
||||
|
||||
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID, "source", holdSource)
|
||||
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.userCtx.TargetOwnerDID, "repo", r.userCtx.TargetRepo, "hold", holdDID, "action", r.userCtx.Action.String())
|
||||
|
||||
// Update context with the correct hold DID (may be from database or discovered)
|
||||
r.Ctx.HoldDID = holdDID
|
||||
|
||||
// Create and cache proxy blob store
|
||||
r.blobStore = NewProxyBlobStore(r.Ctx)
|
||||
blobStore := r.blobStore
|
||||
r.mu.Unlock()
|
||||
return blobStore
|
||||
return NewProxyBlobStore(r.userCtx)
|
||||
}
|
||||
|
||||
// Tags returns the tag service
|
||||
// Tags are stored in ATProto as io.atcr.tag records
|
||||
func (r *RoutingRepository) Tags(ctx context.Context) distribution.TagService {
|
||||
return NewTagStore(r.Ctx.ATProtoClient, r.Ctx.Repository)
|
||||
return NewTagStore(r.userCtx.GetATProtoClient(), r.userCtx.TargetRepo)
|
||||
}
|
||||
|
||||
// Named returns a reference to the repository name.
|
||||
// If the base repository is set, it delegates to the base.
|
||||
// Otherwise, it constructs a name from the user context.
|
||||
func (r *RoutingRepository) Named() reference.Named {
|
||||
if r.Repository != nil {
|
||||
return r.Repository.Named()
|
||||
}
|
||||
// Construct from user context
|
||||
name, err := reference.WithName(r.userCtx.TargetRepo)
|
||||
if err != nil {
|
||||
// Fallback: return a simple reference
|
||||
name, _ = reference.WithName("unknown")
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
@@ -2,273 +2,117 @@ package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/distribution/distribution/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
)
|
||||
|
||||
// mockDatabase is a simple mock for testing
|
||||
type mockDatabase struct {
|
||||
holdDID string
|
||||
err error
|
||||
// mockUserContext creates a mock auth.UserContext for testing.
|
||||
// It sets up both the user identity and target info, and configures
|
||||
// test helpers to bypass network calls.
|
||||
func mockUserContext(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID string) *auth.UserContext {
|
||||
userCtx := auth.NewUserContext(did, authMethod, httpMethod, nil)
|
||||
userCtx.SetTarget(targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID)
|
||||
|
||||
// Bypass PDS resolution (avoids network calls)
|
||||
userCtx.SetPDSForTest(targetOwnerHandle, targetOwnerPDS)
|
||||
|
||||
// Set up mock authorizer that allows access
|
||||
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
|
||||
|
||||
// Set default hold DID for push resolution
|
||||
userCtx.SetDefaultHoldDIDForTest(targetHoldDID)
|
||||
|
||||
return userCtx
|
||||
}
|
||||
|
||||
func (m *mockDatabase) IncrementPullCount(did, repository string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabase) IncrementPushCount(did, repository string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDatabase) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.holdDID, nil
|
||||
// mockUserContextWithToken creates a mock UserContext with a pre-populated service token.
|
||||
func mockUserContextWithToken(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID, serviceToken string) *auth.UserContext {
|
||||
userCtx := mockUserContext(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID)
|
||||
userCtx.SetServiceTokenForTest(targetHoldDID, serviceToken)
|
||||
return userCtx
|
||||
}
|
||||
|
||||
func TestNewRoutingRepository(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "debian",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: &atproto.Client{},
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123", // authenticated user
|
||||
"oauth", // auth method
|
||||
"GET", // HTTP method
|
||||
"did:plc:test123", // target owner
|
||||
"test.handle", // target owner handle
|
||||
"https://pds.example.com", // target owner PDS
|
||||
"debian", // repository
|
||||
"did:web:hold01.atcr.io", // hold DID
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
if repo.userCtx.TargetOwnerDID != "did:plc:test123" {
|
||||
t.Errorf("Expected TargetOwnerDID %q, got %q", "did:plc:test123", repo.userCtx.TargetOwnerDID)
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
if repo.Ctx.DID != "did:plc:test123" {
|
||||
t.Errorf("Expected DID %q, got %q", "did:plc:test123", repo.Ctx.DID)
|
||||
if repo.userCtx.TargetRepo != "debian" {
|
||||
t.Errorf("Expected TargetRepo %q, got %q", "debian", repo.userCtx.TargetRepo)
|
||||
}
|
||||
|
||||
if repo.Ctx.Repository != "debian" {
|
||||
t.Errorf("Expected repository %q, got %q", "debian", repo.Ctx.Repository)
|
||||
}
|
||||
|
||||
if repo.manifestStore != nil {
|
||||
t.Error("Expected manifestStore to be nil initially")
|
||||
}
|
||||
|
||||
if repo.blobStore != nil {
|
||||
t.Error("Expected blobStore to be nil initially")
|
||||
if repo.userCtx.TargetHoldDID != "did:web:hold01.atcr.io" {
|
||||
t.Errorf("Expected TargetHoldDID %q, got %q", "did:web:hold01.atcr.io", repo.userCtx.TargetHoldDID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Manifests tests the Manifests() method
|
||||
func TestRoutingRepository_Manifests(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
manifestService, err := repo.Manifests(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, manifestService)
|
||||
|
||||
// Verify the manifest store is cached
|
||||
assert.NotNil(t, repo.manifestStore, "manifest store should be cached")
|
||||
|
||||
// Call again and verify we get the same instance
|
||||
manifestService2, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, manifestService, manifestService2, "should return cached manifest store")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_ManifestStoreCaching tests that manifest store is cached
|
||||
func TestRoutingRepository_ManifestStoreCaching(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
// TestRoutingRepository_Blobs tests the Blobs() method
|
||||
func TestRoutingRepository_Blobs(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// First call creates the store
|
||||
store1, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, store1)
|
||||
|
||||
// Second call returns cached store
|
||||
store2, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, store1, store2, "should return cached manifest store instance")
|
||||
|
||||
// Verify internal cache
|
||||
assert.NotNil(t, repo.manifestStore)
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PullUsesDatabase tests that GET (pull) uses database hold DID
|
||||
func TestRoutingRepository_Blobs_PullUsesDatabase(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: discoveryHoldDID, // Discovery-based hold (should be overridden for pull)
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// Create context with GET method (pull operation)
|
||||
pullCtx := context.WithValue(context.Background(), "http.request.method", "GET")
|
||||
blobStore := repo.Blobs(pullCtx)
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID was updated to use the database value for pull
|
||||
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "pull (GET) should use database hold DID")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PushUsesDiscovery tests that push operations use discovery hold DID
|
||||
func TestRoutingRepository_Blobs_PushUsesDiscovery(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
}{
|
||||
{"PUT", "PUT"},
|
||||
{"POST", "POST"},
|
||||
{"HEAD", "HEAD"},
|
||||
{"PATCH", "PATCH"},
|
||||
{"DELETE", "DELETE"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp-" + tc.method, // Unique repo to avoid caching
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// Create context with push method
|
||||
pushCtx := context.WithValue(context.Background(), "http.request.method", tc.method)
|
||||
blobStore := repo.Blobs(pushCtx)
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID remains the discovery-based one for push operations
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "%s should use discovery hold DID, not database", tc.method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_NoMethodUsesDiscovery tests that missing method defaults to discovery
|
||||
func TestRoutingRepository_Blobs_NoMethodUsesDiscovery(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp-nomethod",
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// Context without HTTP method (shouldn't happen in practice, but test defensive behavior)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Without method, should default to discovery (safer for push scenarios)
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "missing method should use discovery hold DID")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_WithoutDatabase tests blob store with discovery-based hold
|
||||
func TestRoutingRepository_Blobs_WithoutDatabase(t *testing.T) {
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:nocache456",
|
||||
Repository: "uncached-app",
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:nocache456", ""),
|
||||
Database: nil, // No database
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID remains the discovery-based one
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should use discovery-based hold DID")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_DatabaseEmptyFallback tests fallback when database returns empty hold DID
|
||||
func TestRoutingRepository_Blobs_DatabaseEmptyFallback(t *testing.T) {
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "newapp",
|
||||
HoldDID: discoveryHoldDID,
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: ""}, // Empty string (no manifests yet)
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Verify the hold DID falls back to discovery-based
|
||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should fall back to discovery-based hold DID when database returns empty")
|
||||
}
|
||||
|
||||
// TestRoutingRepository_BlobStoreCaching tests that blob store is cached
|
||||
func TestRoutingRepository_BlobStoreCaching(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
|
||||
// First call creates the store
|
||||
store1 := repo.Blobs(context.Background())
|
||||
assert.NotNil(t, store1)
|
||||
|
||||
// Second call returns cached store
|
||||
store2 := repo.Blobs(context.Background())
|
||||
assert.Same(t, store1, store2, "should return cached blob store instance")
|
||||
|
||||
// Verify internal cache
|
||||
assert.NotNil(t, repo.blobStore)
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PanicOnEmptyHoldDID tests panic when hold DID is empty
|
||||
func TestRoutingRepository_Blobs_PanicOnEmptyHoldDID(t *testing.T) {
|
||||
// Use a unique DID/repo to ensure no cache entry exists
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:emptyholdtest999",
|
||||
Repository: "empty-hold-app",
|
||||
HoldDID: "", // Empty hold DID should panic
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:emptyholdtest999", ""),
|
||||
}
|
||||
// Create context without default hold and empty target hold
|
||||
userCtx := auth.NewUserContext("did:plc:emptyholdtest999", "oauth", "GET", nil)
|
||||
userCtx.SetTarget("did:plc:emptyholdtest999", "test.handle", "https://pds.example.com", "empty-hold-app", "")
|
||||
userCtx.SetPDSForTest("test.handle", "https://pds.example.com")
|
||||
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
|
||||
// Intentionally NOT setting default hold DID
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
// Should panic with empty hold DID
|
||||
assert.Panics(t, func() {
|
||||
@@ -278,106 +122,140 @@ func TestRoutingRepository_Blobs_PanicOnEmptyHoldDID(t *testing.T) {
|
||||
|
||||
// TestRoutingRepository_Tags tests the Tags() method
|
||||
func TestRoutingRepository_Tags(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
}
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
tagService := repo.Tags(context.Background())
|
||||
|
||||
assert.NotNil(t, tagService)
|
||||
|
||||
// Call again and verify we get a new instance (Tags() doesn't cache)
|
||||
// Call again and verify we get a fresh instance (no caching)
|
||||
tagService2 := repo.Tags(context.Background())
|
||||
assert.NotNil(t, tagService2)
|
||||
// Tags service is not cached, so each call creates a new instance
|
||||
}
|
||||
|
||||
// TestRoutingRepository_ConcurrentAccess tests concurrent access to cached stores
|
||||
func TestRoutingRepository_ConcurrentAccess(t *testing.T) {
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp",
|
||||
HoldDID: "did:web:hold01.atcr.io",
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
// TestRoutingRepository_UserContext tests that UserContext fields are properly set
|
||||
func TestRoutingRepository_UserContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
httpMethod string
|
||||
expectedAction auth.RequestAction
|
||||
}{
|
||||
{"GET request is pull", "GET", auth.ActionPull},
|
||||
{"HEAD request is pull", "HEAD", auth.ActionPull},
|
||||
{"PUT request is push", "PUT", auth.ActionPush},
|
||||
{"POST request is push", "POST", auth.ActionPush},
|
||||
{"DELETE request is push", "DELETE", auth.ActionPush},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
tc.httpMethod,
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
// Track all manifest stores returned
|
||||
manifestStores := make([]distribution.ManifestService, numGoroutines)
|
||||
blobStores := make([]distribution.BlobStore, numGoroutines)
|
||||
|
||||
// Concurrent access to Manifests()
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
store, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
manifestStores[index] = store
|
||||
}(i)
|
||||
assert.Equal(t, tc.expectedAction, repo.userCtx.Action, "action should match HTTP method")
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all stores are non-nil (due to race conditions, they may not all be the same instance)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
assert.NotNil(t, manifestStores[i], "manifest store should not be nil")
|
||||
}
|
||||
|
||||
// After concurrent creation, subsequent calls should return the cached instance
|
||||
cachedStore, err := repo.Manifests(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cachedStore)
|
||||
|
||||
// Concurrent access to Blobs()
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
blobStores[index] = repo.Blobs(context.Background())
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all stores are non-nil (due to race conditions, they may not all be the same instance)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
assert.NotNil(t, blobStores[i], "blob store should not be nil")
|
||||
}
|
||||
|
||||
// After concurrent creation, subsequent calls should return the cached instance
|
||||
cachedBlobStore := repo.Blobs(context.Background())
|
||||
assert.NotNil(t, cachedBlobStore)
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Blobs_PullPriority tests that database hold DID takes priority for pull (GET)
|
||||
func TestRoutingRepository_Blobs_PullPriority(t *testing.T) {
|
||||
dbHoldDID := "did:web:database.hold.io"
|
||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||
|
||||
ctx := &RegistryContext{
|
||||
DID: "did:plc:test123",
|
||||
Repository: "myapp-priority",
|
||||
HoldDID: discoveryHoldDID, // Discovery-based hold
|
||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||
Database: &mockDatabase{holdDID: dbHoldDID}, // Database has a different hold DID
|
||||
// TestRoutingRepository_DifferentHoldDIDs tests routing with different hold DIDs
|
||||
func TestRoutingRepository_DifferentHoldDIDs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
holdDID string
|
||||
}{
|
||||
{"did:web hold", "did:web:hold01.atcr.io"},
|
||||
{"did:web with port", "did:web:localhost:8080"},
|
||||
{"did:plc hold", "did:plc:xyz123"},
|
||||
}
|
||||
|
||||
repo := NewRoutingRepository(nil, ctx)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"PUT",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
tc.holdDID,
|
||||
)
|
||||
|
||||
// For pull (GET), database should take priority
|
||||
pullCtx := context.WithValue(context.Background(), "http.request.method", "GET")
|
||||
blobStore := repo.Blobs(pullCtx)
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
blobStore := repo.Blobs(context.Background())
|
||||
|
||||
assert.NotNil(t, blobStore)
|
||||
// Database hold DID should take priority over discovery for pull operations
|
||||
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "database hold DID should take priority over discovery for pull (GET)")
|
||||
assert.NotNil(t, blobStore, "should create blob store for %s", tc.holdDID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutingRepository_Named tests the Named() method
|
||||
func TestRoutingRepository_Named(t *testing.T) {
|
||||
userCtx := mockUserContext(
|
||||
"did:plc:test123",
|
||||
"oauth",
|
||||
"GET",
|
||||
"did:plc:test123",
|
||||
"test.handle",
|
||||
"https://pds.example.com",
|
||||
"myapp",
|
||||
"did:web:hold01.atcr.io",
|
||||
)
|
||||
|
||||
repo := NewRoutingRepository(nil, userCtx, nil)
|
||||
|
||||
// Named() returns a reference.Named from the base repository
|
||||
// Since baseRepo is nil, this tests our implementation handles that case
|
||||
named := repo.Named()
|
||||
|
||||
// With nil base, Named() should return a name constructed from context
|
||||
assert.NotNil(t, named)
|
||||
assert.Contains(t, named.Name(), "myapp")
|
||||
}
|
||||
|
||||
// TestATProtoResolveHoldURL tests DID to URL resolution
|
||||
func TestATProtoResolveHoldURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
holdDID string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "did:web simple domain",
|
||||
holdDID: "did:web:hold01.atcr.io",
|
||||
expected: "https://hold01.atcr.io",
|
||||
},
|
||||
{
|
||||
name: "did:web with port (localhost)",
|
||||
holdDID: "did:web:localhost:8080",
|
||||
expected: "http://localhost:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := atproto.ResolveHoldURL(tt.holdDID)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func (s *TagStore) Get(ctx context.Context, tag string) (distribution.Descriptor
|
||||
return distribution.Descriptor{}, distribution.ErrTagUnknown{Tag: tag}
|
||||
}
|
||||
|
||||
var tagRecord atproto.Tag
|
||||
var tagRecord atproto.TagRecord
|
||||
if err := json.Unmarshal(record.Value, &tagRecord); err != nil {
|
||||
return distribution.Descriptor{}, fmt.Errorf("failed to unmarshal tag record: %w", err)
|
||||
}
|
||||
@@ -91,7 +91,7 @@ func (s *TagStore) All(ctx context.Context) ([]string, error) {
|
||||
|
||||
var tags []string
|
||||
for _, record := range records {
|
||||
var tagRecord atproto.Tag
|
||||
var tagRecord atproto.TagRecord
|
||||
if err := json.Unmarshal(record.Value, &tagRecord); err != nil {
|
||||
// Skip invalid records
|
||||
continue
|
||||
@@ -116,7 +116,7 @@ func (s *TagStore) Lookup(ctx context.Context, desc distribution.Descriptor) ([]
|
||||
|
||||
var tags []string
|
||||
for _, record := range records {
|
||||
var tagRecord atproto.Tag
|
||||
var tagRecord atproto.TagRecord
|
||||
if err := json.Unmarshal(record.Value, &tagRecord); err != nil {
|
||||
// Skip invalid records
|
||||
continue
|
||||
|
||||
@@ -229,7 +229,7 @@ func TestTagStore_Tag(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var sentTagRecord *atproto.Tag
|
||||
var sentTagRecord *atproto.TagRecord
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
@@ -254,7 +254,7 @@ func TestTagStore_Tag(t *testing.T) {
|
||||
// Parse and verify tag record
|
||||
recordData := body["record"].(map[string]any)
|
||||
recordBytes, _ := json.Marshal(recordData)
|
||||
var tagRecord atproto.Tag
|
||||
var tagRecord atproto.TagRecord
|
||||
json.Unmarshal(recordBytes, &tagRecord)
|
||||
sentTagRecord = &tagRecord
|
||||
|
||||
@@ -284,8 +284,8 @@ func TestTagStore_Tag(t *testing.T) {
|
||||
|
||||
if !tt.wantErr && sentTagRecord != nil {
|
||||
// Verify the tag record
|
||||
if sentTagRecord.LexiconTypeID != atproto.TagCollection {
|
||||
t.Errorf("LexiconTypeID = %v, want %v", sentTagRecord.LexiconTypeID, atproto.TagCollection)
|
||||
if sentTagRecord.Type != atproto.TagCollection {
|
||||
t.Errorf("Type = %v, want %v", sentTagRecord.Type, atproto.TagCollection)
|
||||
}
|
||||
if sentTagRecord.Repository != "myapp" {
|
||||
t.Errorf("Repository = %v, want myapp", sentTagRecord.Repository)
|
||||
@@ -295,11 +295,11 @@ func TestTagStore_Tag(t *testing.T) {
|
||||
}
|
||||
// New records should have manifest field
|
||||
expectedURI := atproto.BuildManifestURI("did:plc:test123", tt.digest.String())
|
||||
if sentTagRecord.Manifest == nil || *sentTagRecord.Manifest != expectedURI {
|
||||
if sentTagRecord.Manifest != expectedURI {
|
||||
t.Errorf("Manifest = %v, want %v", sentTagRecord.Manifest, expectedURI)
|
||||
}
|
||||
// New records should NOT have manifestDigest field
|
||||
if sentTagRecord.ManifestDigest != nil && *sentTagRecord.ManifestDigest != "" {
|
||||
if sentTagRecord.ManifestDigest != "" {
|
||||
t.Errorf("ManifestDigest should be empty for new records, got %v", sentTagRecord.ManifestDigest)
|
||||
}
|
||||
}
|
||||
|
||||
22
pkg/appview/templates/pages/404.html
Normal file
22
pkg/appview/templates/pages/404.html
Normal file
@@ -0,0 +1,22 @@
|
||||
{{ define "404" }}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>404 - Lost at Sea | ATCR</title>
|
||||
{{ template "head" . }}
|
||||
</head>
|
||||
<body>
|
||||
{{ template "nav-simple" . }}
|
||||
<main class="error-page">
|
||||
<div class="error-content">
|
||||
<i data-lucide="anchor" class="error-icon"></i>
|
||||
<div class="error-code">404</div>
|
||||
<h1>Lost at Sea</h1>
|
||||
<p>The page you're looking for has drifted into uncharted waters.</p>
|
||||
<a href="/" class="btn btn-primary">Return to Port</a>
|
||||
</div>
|
||||
</main>
|
||||
<script>lucide.createIcons();</script>
|
||||
</body>
|
||||
</html>
|
||||
{{ end }}
|
||||
@@ -27,11 +27,20 @@
|
||||
<!-- Repository Header -->
|
||||
<div class="repository-header">
|
||||
<div class="repo-hero">
|
||||
{{ if .Repository.IconURL }}
|
||||
<img src="{{ .Repository.IconURL }}" alt="{{ .Repository.Name }}" class="repo-hero-icon">
|
||||
{{ else }}
|
||||
<div class="repo-hero-icon-placeholder">{{ firstChar .Repository.Name }}</div>
|
||||
{{ end }}
|
||||
<div class="repo-hero-icon-wrapper">
|
||||
{{ if .Repository.IconURL }}
|
||||
<img src="{{ .Repository.IconURL }}" alt="{{ .Repository.Name }}" class="repo-hero-icon">
|
||||
{{ else }}
|
||||
<div class="repo-hero-icon-placeholder">{{ firstChar .Repository.Name }}</div>
|
||||
{{ end }}
|
||||
{{ if $.IsOwner }}
|
||||
<label class="avatar-upload-overlay" for="avatar-upload">
|
||||
<i data-lucide="plus"></i>
|
||||
</label>
|
||||
<input type="file" id="avatar-upload" accept="image/png,image/jpeg,image/webp"
|
||||
onchange="uploadAvatar(this, '{{ .Repository.Name }}')" hidden>
|
||||
{{ end }}
|
||||
</div>
|
||||
<div class="repo-hero-info">
|
||||
<h1>
|
||||
<a href="/u/{{ .Owner.Handle }}" class="owner-link">{{ .Owner.Handle }}</a>
|
||||
@@ -130,6 +139,9 @@
|
||||
{{ if .IsMultiArch }}
|
||||
<span class="badge-multi">Multi-arch</span>
|
||||
{{ end }}
|
||||
{{ if .HasAttestations }}
|
||||
<span class="badge-attestation"><i data-lucide="shield-check"></i> Attestations</span>
|
||||
{{ end }}
|
||||
</div>
|
||||
<div style="display: flex; gap: 1rem; align-items: center;">
|
||||
<time class="tag-timestamp" datetime="{{ .Tag.CreatedAt.Format "2006-01-02T15:04:05Z07:00" }}">
|
||||
|
||||
@@ -44,15 +44,6 @@
|
||||
</div>
|
||||
{{ end }}
|
||||
|
||||
{{ if .HasMore }}
|
||||
<button class="load-more"
|
||||
hx-get="/api/recent-pushes?offset={{ .NextOffset }}"
|
||||
hx-target="#push-list"
|
||||
hx-swap="beforeend">
|
||||
Load More
|
||||
</button>
|
||||
{{ end }}
|
||||
|
||||
{{ if eq (len .Pushes) 0 }}
|
||||
<div class="empty-state">
|
||||
<p>No pushes yet. Start using ATCR by pushing your first image!</p>
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
package appview
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
func TestResolveHoldURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "DID with HTTPS domain",
|
||||
input: "did:web:hold.example.com",
|
||||
expected: "https://hold.example.com",
|
||||
},
|
||||
{
|
||||
name: "DID with HTTP and port (IP)",
|
||||
input: "did:web:172.28.0.3:8080",
|
||||
expected: "http://172.28.0.3:8080",
|
||||
},
|
||||
{
|
||||
name: "DID with HTTP and port (localhost)",
|
||||
input: "did:web:127.0.0.1:8080",
|
||||
expected: "http://127.0.0.1:8080",
|
||||
},
|
||||
{
|
||||
name: "DID with localhost",
|
||||
input: "did:web:localhost:8080",
|
||||
expected: "http://localhost:8080",
|
||||
},
|
||||
{
|
||||
name: "Already HTTPS URL (passthrough)",
|
||||
input: "https://hold.example.com",
|
||||
expected: "https://hold.example.com",
|
||||
},
|
||||
{
|
||||
name: "Already HTTP URL (passthrough)",
|
||||
input: "http://172.28.0.3:8080",
|
||||
expected: "http://172.28.0.3:8080",
|
||||
},
|
||||
{
|
||||
name: "Plain hostname (fallback to HTTPS)",
|
||||
input: "hold.example.com",
|
||||
expected: "https://hold.example.com",
|
||||
},
|
||||
{
|
||||
name: "DID with subdomain",
|
||||
input: "did:web:hold01.atcr.io",
|
||||
expected: "https://hold01.atcr.io",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := atproto.ResolveHoldURL(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ResolveHoldURL(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,8 +13,6 @@ import (
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/atclient"
|
||||
indigo_oauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
lexutil "github.com/bluesky-social/indigo/lex/util"
|
||||
"github.com/ipfs/go-cid"
|
||||
)
|
||||
|
||||
// Sentinel errors
|
||||
@@ -303,7 +301,7 @@ type Link struct {
|
||||
}
|
||||
|
||||
// UploadBlob uploads binary data to the PDS and returns a blob reference
|
||||
func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (*lexutil.LexBlob, error) {
|
||||
func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (*ATProtoBlobRef, error) {
|
||||
// Use session provider (locked OAuth with DPoP) - prevents nonce races
|
||||
if c.sessionProvider != nil {
|
||||
var result struct {
|
||||
@@ -312,12 +310,15 @@ func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (
|
||||
|
||||
err := c.sessionProvider.DoWithSession(ctx, c.did, func(session *indigo_oauth.ClientSession) error {
|
||||
apiClient := session.APIClient()
|
||||
// IMPORTANT: Use io.Reader for blob uploads
|
||||
// LexDo JSON-encodes []byte (base64), but streams io.Reader as raw bytes
|
||||
// Use the actual MIME type so PDS can validate against blob:image/* scope
|
||||
return apiClient.LexDo(ctx,
|
||||
"POST",
|
||||
mimeType,
|
||||
"com.atproto.repo.uploadBlob",
|
||||
nil,
|
||||
data,
|
||||
bytes.NewReader(data),
|
||||
&result,
|
||||
)
|
||||
})
|
||||
@@ -325,7 +326,7 @@ func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (
|
||||
return nil, fmt.Errorf("uploadBlob failed: %w", err)
|
||||
}
|
||||
|
||||
return atProtoBlobRefToLexBlob(&result.Blob)
|
||||
return &result.Blob, nil
|
||||
}
|
||||
|
||||
// Basic Auth (app passwords)
|
||||
@@ -356,22 +357,7 @@ func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return atProtoBlobRefToLexBlob(&result.Blob)
|
||||
}
|
||||
|
||||
// atProtoBlobRefToLexBlob converts an ATProtoBlobRef to a lexutil.LexBlob
|
||||
func atProtoBlobRefToLexBlob(ref *ATProtoBlobRef) (*lexutil.LexBlob, error) {
|
||||
// Parse the CID string from the $link field
|
||||
c, err := cid.Decode(ref.Ref.Link)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse blob CID %q: %w", ref.Ref.Link, err)
|
||||
}
|
||||
|
||||
return &lexutil.LexBlob{
|
||||
Ref: lexutil.LexLink(c),
|
||||
MimeType: ref.MimeType,
|
||||
Size: ref.Size,
|
||||
}, nil
|
||||
return &result.Blob, nil
|
||||
}
|
||||
|
||||
// GetBlob downloads a blob by its CID from the PDS
|
||||
|
||||
@@ -386,11 +386,11 @@ func TestUploadBlob(t *testing.T) {
|
||||
t.Errorf("Content-Type = %v, want %v", r.Header.Get("Content-Type"), mimeType)
|
||||
}
|
||||
|
||||
// Send response - use a valid CIDv1 in base32 format
|
||||
// Send response
|
||||
response := `{
|
||||
"blob": {
|
||||
"$type": "blob",
|
||||
"ref": {"$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},
|
||||
"ref": {"$link": "bafytest123"},
|
||||
"mimeType": "application/octet-stream",
|
||||
"size": 17
|
||||
}
|
||||
@@ -406,14 +406,12 @@ func TestUploadBlob(t *testing.T) {
|
||||
t.Fatalf("UploadBlob() error = %v", err)
|
||||
}
|
||||
|
||||
if blobRef.MimeType != mimeType {
|
||||
t.Errorf("MimeType = %v, want %v", blobRef.MimeType, mimeType)
|
||||
if blobRef.Type != "blob" {
|
||||
t.Errorf("Type = %v, want blob", blobRef.Type)
|
||||
}
|
||||
|
||||
// LexBlob.Ref is a LexLink (cid.Cid alias), use .String() to get the CID string
|
||||
expectedCID := "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
|
||||
if blobRef.Ref.String() != expectedCID {
|
||||
t.Errorf("Ref.String() = %v, want %v", blobRef.Ref.String(), expectedCID)
|
||||
if blobRef.Ref.Link != "bafytest123" {
|
||||
t.Errorf("Ref.Link = %v, want bafytest123", blobRef.Ref.Link)
|
||||
}
|
||||
|
||||
if blobRef.Size != 17 {
|
||||
|
||||
@@ -3,194 +3,17 @@
|
||||
|
||||
package main
|
||||
|
||||
// Lexicon and CBOR Code Generator
|
||||
// CBOR Code Generator
|
||||
//
|
||||
// This generates:
|
||||
// 1. Go types from lexicon JSON files (via lex/lexgen library)
|
||||
// 2. CBOR marshaling code for ATProto records (via cbor-gen)
|
||||
// 3. Type registration for lexutil (register.go)
|
||||
// This generates optimized CBOR marshaling code for ATProto records.
|
||||
//
|
||||
// Usage:
|
||||
// go generate ./pkg/atproto/...
|
||||
//
|
||||
// Key insight: We use RegisterLexiconTypeID: false to avoid generating init()
|
||||
// blocks that require CBORMarshaler. This breaks the circular dependency between
|
||||
// lexgen and cbor-gen. See: https://github.com/bluesky-social/indigo/issues/931
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/lexicon"
|
||||
"github.com/bluesky-social/indigo/lex/lexgen"
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Find repo root
|
||||
repoRoot, err := findRepoRoot()
|
||||
if err != nil {
|
||||
fmt.Printf("failed to find repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
pkgDir := filepath.Join(repoRoot, "pkg/atproto")
|
||||
lexDir := filepath.Join(repoRoot, "lexicons")
|
||||
|
||||
// Step 0: Clean up old register.go to avoid conflicts
|
||||
// (It will be regenerated at the end)
|
||||
os.Remove(filepath.Join(pkgDir, "register.go"))
|
||||
|
||||
// Step 1: Load all lexicon schemas into catalog (for cross-references)
|
||||
fmt.Println("Loading lexicons...")
|
||||
cat := lexicon.NewBaseCatalog()
|
||||
if err := cat.LoadDirectory(lexDir); err != nil {
|
||||
fmt.Printf("failed to load lexicons: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Step 2: Generate Go code for each lexicon file
|
||||
fmt.Println("Running lexgen...")
|
||||
config := &lexgen.GenConfig{
|
||||
RegisterLexiconTypeID: false, // KEY: no init() blocks generated
|
||||
UnknownType: "map-string-any",
|
||||
WarningText: "Code generated by generate.go; DO NOT EDIT.",
|
||||
}
|
||||
|
||||
// Track generated types for register.go
|
||||
var registeredTypes []typeInfo
|
||||
|
||||
// Walk lexicon directory and generate code for each file
|
||||
err = filepath.Walk(lexDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() || !strings.HasSuffix(path, ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load and parse the schema file
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", path, err)
|
||||
}
|
||||
|
||||
var sf lexicon.SchemaFile
|
||||
if err := json.Unmarshal(data, &sf); err != nil {
|
||||
return fmt.Errorf("failed to parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
if err := sf.FinishParse(); err != nil {
|
||||
return fmt.Errorf("failed to finish parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Flatten the schema
|
||||
flat, err := lexgen.FlattenSchemaFile(&sf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flatten schema %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Generate code
|
||||
var buf bytes.Buffer
|
||||
gen := &lexgen.CodeGenerator{
|
||||
Config: config,
|
||||
Lex: flat,
|
||||
Cat: &cat,
|
||||
Out: &buf,
|
||||
}
|
||||
|
||||
if err := gen.WriteLexicon(); err != nil {
|
||||
return fmt.Errorf("failed to generate code for %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Fix package name: lexgen generates "ioatcr" but we want "atproto"
|
||||
code := bytes.Replace(buf.Bytes(), []byte("package ioatcr"), []byte("package atproto"), 1)
|
||||
|
||||
// Format with goimports
|
||||
fileName := gen.FileName()
|
||||
formatted, err := imports.Process(fileName, code, nil)
|
||||
if err != nil {
|
||||
// Write unformatted for debugging
|
||||
outPath := filepath.Join(pkgDir, fileName)
|
||||
os.WriteFile(outPath+".broken", code, 0644)
|
||||
return fmt.Errorf("failed to format %s: %w (wrote to %s.broken)", fileName, err, outPath)
|
||||
}
|
||||
|
||||
// Write output file
|
||||
outPath := filepath.Join(pkgDir, fileName)
|
||||
if err := os.WriteFile(outPath, formatted, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write %s: %w", outPath, err)
|
||||
}
|
||||
|
||||
fmt.Printf(" Generated %s\n", fileName)
|
||||
|
||||
// Track type for registration - compute type name from NSID
|
||||
typeName := nsidToTypeName(sf.ID)
|
||||
registeredTypes = append(registeredTypes, typeInfo{
|
||||
NSID: sf.ID,
|
||||
TypeName: typeName,
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Printf("lexgen failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Step 3: Run cbor-gen via exec.Command
|
||||
// This must be a separate process so it can compile the freshly generated types
|
||||
fmt.Println("Running cbor-gen...")
|
||||
if err := runCborGen(repoRoot, pkgDir); err != nil {
|
||||
fmt.Printf("cbor-gen failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Step 4: Generate register.go
|
||||
fmt.Println("Generating register.go...")
|
||||
if err := generateRegisterFile(pkgDir, registeredTypes); err != nil {
|
||||
fmt.Printf("failed to generate register.go: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println("Code generation complete!")
|
||||
}
|
||||
|
||||
type typeInfo struct {
|
||||
NSID string
|
||||
TypeName string
|
||||
}
|
||||
|
||||
// nsidToTypeName converts an NSID to a Go type name
|
||||
// io.atcr.manifest → Manifest
|
||||
// io.atcr.hold.captain → HoldCaptain
|
||||
// io.atcr.sailor.profile → SailorProfile
|
||||
func nsidToTypeName(nsid string) string {
|
||||
parts := strings.Split(nsid, ".")
|
||||
if len(parts) < 3 {
|
||||
return ""
|
||||
}
|
||||
// Skip the first two parts (authority, e.g., "io.atcr")
|
||||
// and capitalize each remaining part
|
||||
var result string
|
||||
for _, part := range parts[2:] {
|
||||
if len(part) > 0 {
|
||||
result += strings.ToUpper(part[:1]) + part[1:]
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func runCborGen(repoRoot, pkgDir string) error {
|
||||
// Create a temporary Go file that runs cbor-gen
|
||||
cborGenCode := `//go:build ignore
|
||||
|
||||
package main
|
||||
// This creates pkg/atproto/cbor_gen.go which should be committed to git.
|
||||
// Only re-run when you modify types in pkg/atproto/types.go
|
||||
//
|
||||
// The //go:generate directive is in lexicon.go
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -202,81 +25,14 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Generate map-style encoders for CrewRecord, CaptainRecord, LayerRecord, and TangledProfileRecord
|
||||
if err := cbg.WriteMapEncodersToFile("cbor_gen.go", "atproto",
|
||||
// Manifest types
|
||||
atproto.Manifest{},
|
||||
atproto.Manifest_BlobReference{},
|
||||
atproto.Manifest_ManifestReference{},
|
||||
atproto.Manifest_Platform{},
|
||||
atproto.Manifest_Annotations{},
|
||||
atproto.Manifest_BlobReference_Annotations{},
|
||||
atproto.Manifest_ManifestReference_Annotations{},
|
||||
// Tag
|
||||
atproto.Tag{},
|
||||
// Sailor types
|
||||
atproto.SailorProfile{},
|
||||
atproto.SailorStar{},
|
||||
atproto.SailorStar_Subject{},
|
||||
// Hold types
|
||||
atproto.HoldCaptain{},
|
||||
atproto.HoldCrew{},
|
||||
atproto.HoldLayer{},
|
||||
// External types
|
||||
atproto.CrewRecord{},
|
||||
atproto.CaptainRecord{},
|
||||
atproto.LayerRecord{},
|
||||
atproto.TangledProfileRecord{},
|
||||
); err != nil {
|
||||
fmt.Printf("cbor-gen failed: %v\n", err)
|
||||
fmt.Printf("Failed to generate CBOR encoders: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
// Write temp file
|
||||
tmpFile := filepath.Join(pkgDir, "cborgen_tmp.go")
|
||||
if err := os.WriteFile(tmpFile, []byte(cborGenCode), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write temp cbor-gen file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile)
|
||||
|
||||
// Run it
|
||||
cmd := exec.Command("go", "run", tmpFile)
|
||||
cmd.Dir = pkgDir
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func generateRegisterFile(pkgDir string, types []typeInfo) error {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.WriteString("// Code generated by generate.go; DO NOT EDIT.\n\n")
|
||||
buf.WriteString("package atproto\n\n")
|
||||
buf.WriteString("import lexutil \"github.com/bluesky-social/indigo/lex/util\"\n\n")
|
||||
buf.WriteString("func init() {\n")
|
||||
|
||||
for _, t := range types {
|
||||
fmt.Fprintf(&buf, "\tlexutil.RegisterType(%q, &%s{})\n", t.NSID, t.TypeName)
|
||||
}
|
||||
|
||||
buf.WriteString("}\n")
|
||||
|
||||
outPath := filepath.Join(pkgDir, "register.go")
|
||||
return os.WriteFile(outPath, buf.Bytes(), 0644)
|
||||
}
|
||||
|
||||
func findRepoRoot() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
return dir, nil
|
||||
}
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
return "", fmt.Errorf("go.mod not found")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
// Lexicon schema: io.atcr.hold.captain
|
||||
|
||||
package atproto
|
||||
|
||||
// Represents the hold's ownership and metadata. Stored as a singleton record at rkey 'self' in the hold's embedded PDS.
|
||||
type HoldCaptain struct {
|
||||
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.hold.captain"`
|
||||
// allowAllCrew: Allow any authenticated user to register as crew
|
||||
AllowAllCrew bool `json:"allowAllCrew" cborgen:"allowAllCrew"`
|
||||
// deployedAt: RFC3339 timestamp of when the hold was deployed
|
||||
DeployedAt string `json:"deployedAt" cborgen:"deployedAt"`
|
||||
// enableBlueskyPosts: Enable Bluesky posts when manifests are pushed
|
||||
EnableBlueskyPosts bool `json:"enableBlueskyPosts" cborgen:"enableBlueskyPosts"`
|
||||
// owner: DID of the hold owner
|
||||
Owner string `json:"owner" cborgen:"owner"`
|
||||
// provider: Deployment provider (e.g., fly.io, aws, etc.)
|
||||
Provider *string `json:"provider,omitempty" cborgen:"provider,omitempty"`
|
||||
// public: Whether this hold allows public blob reads (pulls) without authentication
|
||||
Public bool `json:"public" cborgen:"public"`
|
||||
// region: S3 region where blobs are stored
|
||||
Region *string `json:"region,omitempty" cborgen:"region,omitempty"`
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
// Lexicon schema: io.atcr.hold.crew
|
||||
|
||||
package atproto
|
||||
|
||||
// Crew member in a hold's embedded PDS. Grants access permissions to push blobs to the hold. Stored in the hold's embedded PDS (one record per member).
|
||||
type HoldCrew struct {
|
||||
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.hold.crew"`
|
||||
// addedAt: RFC3339 timestamp of when the member was added
|
||||
AddedAt string `json:"addedAt" cborgen:"addedAt"`
|
||||
// member: DID of the crew member
|
||||
Member string `json:"member" cborgen:"member"`
|
||||
// permissions: Specific permissions granted to this member
|
||||
Permissions []string `json:"permissions" cborgen:"permissions"`
|
||||
// role: Member's role in the hold
|
||||
Role string `json:"role" cborgen:"role"`
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
// Lexicon schema: io.atcr.hold.layer
|
||||
|
||||
package atproto
|
||||
|
||||
// Represents metadata about a container layer stored in the hold. Stored in the hold's embedded PDS for tracking and analytics.
|
||||
type HoldLayer struct {
|
||||
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.hold.layer"`
|
||||
// createdAt: RFC3339 timestamp of when the layer was uploaded
|
||||
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
|
||||
// digest: Layer digest (e.g., sha256:abc123...)
|
||||
Digest string `json:"digest" cborgen:"digest"`
|
||||
// mediaType: Media type (e.g., application/vnd.oci.image.layer.v1.tar+gzip)
|
||||
MediaType string `json:"mediaType" cborgen:"mediaType"`
|
||||
// repository: Repository this layer belongs to
|
||||
Repository string `json:"repository" cborgen:"repository"`
|
||||
// size: Size in bytes
|
||||
Size int64 `json:"size" cborgen:"size"`
|
||||
// userDid: DID of user who uploaded this layer
|
||||
UserDid string `json:"userDid" cborgen:"userDid"`
|
||||
// userHandle: Handle of user (for display purposes)
|
||||
UserHandle string `json:"userHandle" cborgen:"userHandle"`
|
||||
}
|
||||
@@ -18,9 +18,6 @@ const (
|
||||
// TagCollection is the collection name for image tags
|
||||
TagCollection = "io.atcr.tag"
|
||||
|
||||
// HoldCollection is the collection name for storage holds (BYOS)
|
||||
HoldCollection = "io.atcr.hold"
|
||||
|
||||
// HoldCrewCollection is the collection name for hold crew (membership) - LEGACY BYOS model
|
||||
// Stored in owner's PDS for BYOS holds
|
||||
HoldCrewCollection = "io.atcr.hold.crew"
|
||||
@@ -41,9 +38,6 @@ const (
|
||||
// TangledProfileCollection is the collection name for tangled profiles
|
||||
// Stored in hold's embedded PDS (singleton record at rkey "self")
|
||||
TangledProfileCollection = "sh.tangled.actor.profile"
|
||||
|
||||
// BskyPostCollection is the collection name for Bluesky posts
|
||||
BskyPostCollection = "app.bsky.feed.post"
|
||||
|
||||
// BskyPostCollection is the collection name for Bluesky posts
|
||||
BskyPostCollection = "app.bsky.feed.post"
|
||||
@@ -53,6 +47,10 @@ const (
|
||||
|
||||
// StarCollection is the collection name for repository stars
|
||||
StarCollection = "io.atcr.sailor.star"
|
||||
|
||||
// RepoPageCollection is the collection name for repository page metadata
|
||||
// Stored in user's PDS with rkey = repository name
|
||||
RepoPageCollection = "io.atcr.repo.page"
|
||||
)
|
||||
|
||||
// ManifestRecord represents a container image manifest stored in ATProto
|
||||
@@ -312,17 +310,6 @@ type HoldRecord struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// NewHoldRecord creates a new hold record
|
||||
func NewHoldRecord(endpoint, owner string, public bool) *HoldRecord {
|
||||
return &HoldRecord{
|
||||
Type: HoldCollection,
|
||||
Endpoint: endpoint,
|
||||
Owner: owner,
|
||||
Public: public,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// SailorProfileRecord represents a user's profile with registry preferences
|
||||
// Stored in the user's PDS to configure default hold and other settings
|
||||
type SailorProfileRecord struct {
|
||||
@@ -353,6 +340,42 @@ func NewSailorProfileRecord(defaultHold string) *SailorProfileRecord {
|
||||
}
|
||||
}
|
||||
|
||||
// RepoPageRecord represents repository page metadata (description + avatar)
|
||||
// Stored in the user's PDS with rkey = repository name
|
||||
// Users can edit this directly in their PDS to customize their repository page
|
||||
type RepoPageRecord struct {
|
||||
// Type should be "io.atcr.repo.page"
|
||||
Type string `json:"$type"`
|
||||
|
||||
// Repository is the name of the repository (e.g., "myapp")
|
||||
Repository string `json:"repository"`
|
||||
|
||||
// Description is the markdown README/description content
|
||||
Description string `json:"description,omitempty"`
|
||||
|
||||
// Avatar is the repository avatar/icon blob reference
|
||||
Avatar *ATProtoBlobRef `json:"avatar,omitempty"`
|
||||
|
||||
// CreatedAt timestamp
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
||||
// UpdatedAt timestamp
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// NewRepoPageRecord creates a new repo page record
|
||||
func NewRepoPageRecord(repository, description string, avatar *ATProtoBlobRef) *RepoPageRecord {
|
||||
now := time.Now()
|
||||
return &RepoPageRecord{
|
||||
Type: RepoPageCollection,
|
||||
Repository: repository,
|
||||
Description: description,
|
||||
Avatar: avatar,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// StarSubject represents the subject of a star (the repository being starred)
|
||||
type StarSubject struct {
|
||||
// DID is the DID of the repository owner
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
package atproto
|
||||
|
||||
// This file contains ATProto record types that are NOT generated from our lexicons.
|
||||
// These are either external schemas or special types that require manual definition.
|
||||
|
||||
// TangledProfileRecord represents a Tangled profile for the hold
|
||||
// Collection: sh.tangled.actor.profile (external schema - not controlled by ATCR)
|
||||
// Stored in hold's embedded PDS (singleton record at rkey "self")
|
||||
// Uses CBOR encoding for efficient storage in hold's carstore
|
||||
type TangledProfileRecord struct {
|
||||
Type string `json:"$type" cborgen:"$type"`
|
||||
Links []string `json:"links" cborgen:"links"`
|
||||
Stats []string `json:"stats" cborgen:"stats"`
|
||||
Bluesky bool `json:"bluesky" cborgen:"bluesky"`
|
||||
Location string `json:"location" cborgen:"location"`
|
||||
Description string `json:"description" cborgen:"description"`
|
||||
PinnedRepositories []string `json:"pinnedRepositories" cborgen:"pinnedRepositories"`
|
||||
}
|
||||
@@ -1,360 +0,0 @@
|
||||
package atproto
|
||||
|
||||
//go:generate go run generate.go
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Collection names for ATProto records
|
||||
const (
|
||||
// ManifestCollection is the collection name for container manifests
|
||||
ManifestCollection = "io.atcr.manifest"
|
||||
|
||||
// TagCollection is the collection name for image tags
|
||||
TagCollection = "io.atcr.tag"
|
||||
|
||||
// HoldCollection is the collection name for storage holds (BYOS) - LEGACY
|
||||
HoldCollection = "io.atcr.hold"
|
||||
|
||||
// HoldCrewCollection is the collection name for hold crew (membership) - LEGACY BYOS model
|
||||
// Stored in owner's PDS for BYOS holds
|
||||
HoldCrewCollection = "io.atcr.hold.crew"
|
||||
|
||||
// CaptainCollection is the collection name for captain records (hold ownership) - EMBEDDED PDS model
|
||||
// Stored in hold's embedded PDS (singleton record at rkey "self")
|
||||
CaptainCollection = "io.atcr.hold.captain"
|
||||
|
||||
// CrewCollection is the collection name for crew records (access control) - EMBEDDED PDS model
|
||||
// Stored in hold's embedded PDS (one record per member)
|
||||
// Note: Uses same collection name as HoldCrewCollection but stored in different PDS (hold's PDS vs owner's PDS)
|
||||
CrewCollection = "io.atcr.hold.crew"
|
||||
|
||||
// LayerCollection is the collection name for container layer metadata
|
||||
// Stored in hold's embedded PDS to track which layers are stored
|
||||
LayerCollection = "io.atcr.hold.layer"
|
||||
|
||||
// TangledProfileCollection is the collection name for tangled profiles
|
||||
// Stored in hold's embedded PDS (singleton record at rkey "self")
|
||||
TangledProfileCollection = "sh.tangled.actor.profile"
|
||||
|
||||
// BskyPostCollection is the collection name for Bluesky posts
|
||||
BskyPostCollection = "app.bsky.feed.post"
|
||||
|
||||
// SailorProfileCollection is the collection name for user profiles
|
||||
SailorProfileCollection = "io.atcr.sailor.profile"
|
||||
|
||||
// StarCollection is the collection name for repository stars
|
||||
StarCollection = "io.atcr.sailor.star"
|
||||
)
|
||||
|
||||
// NewManifestRecord creates a new manifest record from OCI manifest JSON
|
||||
func NewManifestRecord(repository, digest string, ociManifest []byte) (*Manifest, error) {
|
||||
// Parse the OCI manifest
|
||||
var ociData struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config json.RawMessage `json:"config,omitempty"`
|
||||
Layers []json.RawMessage `json:"layers,omitempty"`
|
||||
Manifests []json.RawMessage `json:"manifests,omitempty"`
|
||||
Subject json.RawMessage `json:"subject,omitempty"`
|
||||
Annotations map[string]string `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(ociManifest, &ociData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Detect manifest type based on media type
|
||||
isManifestList := strings.Contains(ociData.MediaType, "manifest.list") ||
|
||||
strings.Contains(ociData.MediaType, "image.index")
|
||||
|
||||
// Validate: must have either (config+layers) OR (manifests), never both
|
||||
hasImageFields := len(ociData.Config) > 0 || len(ociData.Layers) > 0
|
||||
hasIndexFields := len(ociData.Manifests) > 0
|
||||
|
||||
if hasImageFields && hasIndexFields {
|
||||
return nil, fmt.Errorf("manifest cannot have both image fields (config/layers) and index fields (manifests)")
|
||||
}
|
||||
if !hasImageFields && !hasIndexFields {
|
||||
return nil, fmt.Errorf("manifest must have either image fields (config/layers) or index fields (manifests)")
|
||||
}
|
||||
|
||||
record := &Manifest{
|
||||
LexiconTypeID: ManifestCollection,
|
||||
Repository: repository,
|
||||
Digest: digest,
|
||||
MediaType: ociData.MediaType,
|
||||
SchemaVersion: int64(ociData.SchemaVersion),
|
||||
// ManifestBlob will be set by the caller after uploading to blob storage
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// Handle annotations - Manifest_Annotations is an empty struct in generated code
|
||||
// We don't copy ociData.Annotations since the generated type doesn't support arbitrary keys
|
||||
|
||||
if isManifestList {
|
||||
// Parse manifest list/index
|
||||
record.Manifests = make([]Manifest_ManifestReference, len(ociData.Manifests))
|
||||
for i, m := range ociData.Manifests {
|
||||
var ref struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Platform *Manifest_Platform `json:"platform,omitempty"`
|
||||
Annotations map[string]string `json:"annotations,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(m, &ref); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse manifest reference %d: %w", i, err)
|
||||
}
|
||||
record.Manifests[i] = Manifest_ManifestReference{
|
||||
MediaType: ref.MediaType,
|
||||
Digest: ref.Digest,
|
||||
Size: ref.Size,
|
||||
Platform: ref.Platform,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Parse image manifest
|
||||
if len(ociData.Config) > 0 {
|
||||
var config Manifest_BlobReference
|
||||
if err := json.Unmarshal(ociData.Config, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w", err)
|
||||
}
|
||||
record.Config = &config
|
||||
}
|
||||
|
||||
// Parse layers
|
||||
record.Layers = make([]Manifest_BlobReference, len(ociData.Layers))
|
||||
for i, layer := range ociData.Layers {
|
||||
if err := json.Unmarshal(layer, &record.Layers[i]); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse subject if present (works for both types)
|
||||
if len(ociData.Subject) > 0 {
|
||||
var subject Manifest_BlobReference
|
||||
if err := json.Unmarshal(ociData.Subject, &subject); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record.Subject = &subject
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// NewTagRecord creates a new tag record with manifest AT-URI
|
||||
// did: The DID of the user (e.g., "did:plc:xyz123")
|
||||
// repository: The repository name (e.g., "myapp")
|
||||
// tag: The tag name (e.g., "latest", "v1.0.0")
|
||||
// manifestDigest: The manifest digest (e.g., "sha256:abc123...")
|
||||
func NewTagRecord(did, repository, tag, manifestDigest string) *Tag {
|
||||
// Build AT-URI for the manifest
|
||||
// Format: at://did:plc:xyz/io.atcr.manifest/<digest-without-sha256-prefix>
|
||||
manifestURI := BuildManifestURI(did, manifestDigest)
|
||||
|
||||
return &Tag{
|
||||
LexiconTypeID: TagCollection,
|
||||
Repository: repository,
|
||||
Tag: tag,
|
||||
Manifest: &manifestURI,
|
||||
// Note: ManifestDigest is not set for new records (only for backward compat with old records)
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewSailorProfileRecord creates a new sailor profile record
|
||||
func NewSailorProfileRecord(defaultHold string) *SailorProfile {
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
var holdPtr *string
|
||||
if defaultHold != "" {
|
||||
holdPtr = &defaultHold
|
||||
}
|
||||
return &SailorProfile{
|
||||
LexiconTypeID: SailorProfileCollection,
|
||||
DefaultHold: holdPtr,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// NewStarRecord creates a new star record
|
||||
func NewStarRecord(ownerDID, repository string) *SailorStar {
|
||||
return &SailorStar{
|
||||
LexiconTypeID: StarCollection,
|
||||
Subject: SailorStar_Subject{
|
||||
Did: ownerDID,
|
||||
Repository: repository,
|
||||
},
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewLayerRecord creates a new layer record
|
||||
func NewLayerRecord(digest string, size int64, mediaType, repository, userDID, userHandle string) *HoldLayer {
|
||||
return &HoldLayer{
|
||||
LexiconTypeID: LayerCollection,
|
||||
Digest: digest,
|
||||
Size: size,
|
||||
MediaType: mediaType,
|
||||
Repository: repository,
|
||||
UserDid: userDID,
|
||||
UserHandle: userHandle,
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// StarRecordKey generates a record key for a star
|
||||
// Uses a simple hash to ensure uniqueness and prevent duplicate stars
|
||||
func StarRecordKey(ownerDID, repository string) string {
|
||||
// Use base64 encoding of "ownerDID/repository" as the record key
|
||||
// This is deterministic and prevents duplicate stars
|
||||
combined := ownerDID + "/" + repository
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(combined))
|
||||
}
|
||||
|
||||
// ParseStarRecordKey decodes a star record key back to ownerDID and repository
|
||||
func ParseStarRecordKey(rkey string) (ownerDID, repository string, err error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(rkey)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to decode star rkey: %w", err)
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(decoded), "/", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", fmt.Errorf("invalid star rkey format: %s", string(decoded))
|
||||
}
|
||||
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
|
||||
// ResolveHoldDIDFromURL converts a hold endpoint URL to a did:web DID
|
||||
// This ensures that different representations of the same hold are deduplicated:
|
||||
// - http://172.28.0.3:8080 → did:web:172.28.0.3:8080
|
||||
// - http://hold01.atcr.io → did:web:hold01.atcr.io
|
||||
// - https://hold01.atcr.io → did:web:hold01.atcr.io
|
||||
// - did:web:hold01.atcr.io → did:web:hold01.atcr.io (passthrough)
|
||||
func ResolveHoldDIDFromURL(holdURL string) string {
|
||||
// Handle empty URLs
|
||||
if holdURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// If already a DID, return as-is
|
||||
if IsDID(holdURL) {
|
||||
return holdURL
|
||||
}
|
||||
|
||||
// Parse URL to get hostname
|
||||
holdURL = strings.TrimPrefix(holdURL, "http://")
|
||||
holdURL = strings.TrimPrefix(holdURL, "https://")
|
||||
holdURL = strings.TrimSuffix(holdURL, "/")
|
||||
|
||||
// Extract hostname (remove path if present)
|
||||
parts := strings.Split(holdURL, "/")
|
||||
hostname := parts[0]
|
||||
|
||||
// Convert to did:web
|
||||
// did:web uses hostname directly (port included if non-standard)
|
||||
return "did:web:" + hostname
|
||||
}
|
||||
|
||||
// IsDID checks if a string is a DID (starts with "did:")
|
||||
func IsDID(s string) bool {
|
||||
return len(s) > 4 && s[:4] == "did:"
|
||||
}
|
||||
|
||||
// RepositoryTagToRKey converts a repository and tag to an ATProto record key
|
||||
// ATProto record keys must match: ^[a-zA-Z0-9._~-]{1,512}$
|
||||
func RepositoryTagToRKey(repository, tag string) string {
|
||||
// Combine repository and tag to create a unique key
|
||||
// Replace invalid characters: slashes become tildes (~)
|
||||
// We use tilde instead of dash to avoid ambiguity with repository names that contain hyphens
|
||||
key := fmt.Sprintf("%s_%s", repository, tag)
|
||||
|
||||
// Replace / with ~ (slash not allowed in rkeys, tilde is allowed and unlikely in repo names)
|
||||
key = strings.ReplaceAll(key, "/", "~")
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// RKeyToRepositoryTag converts an ATProto record key back to repository and tag
|
||||
// This is the inverse of RepositoryTagToRKey
|
||||
// Note: If the tag contains underscores, this will split on the LAST underscore
|
||||
func RKeyToRepositoryTag(rkey string) (repository, tag string) {
|
||||
// Find the last underscore to split repository and tag
|
||||
lastUnderscore := strings.LastIndex(rkey, "_")
|
||||
if lastUnderscore == -1 {
|
||||
// No underscore found - treat entire string as tag with empty repository
|
||||
return "", rkey
|
||||
}
|
||||
|
||||
repository = rkey[:lastUnderscore]
|
||||
tag = rkey[lastUnderscore+1:]
|
||||
|
||||
// Convert tildes back to slashes in repository (tilde was used to encode slashes)
|
||||
repository = strings.ReplaceAll(repository, "~", "/")
|
||||
|
||||
return repository, tag
|
||||
}
|
||||
|
||||
// BuildManifestURI creates an AT-URI for a manifest record
|
||||
// did: The DID of the user (e.g., "did:plc:xyz123")
|
||||
// manifestDigest: The manifest digest (e.g., "sha256:abc123...")
|
||||
// Returns: AT-URI in format "at://did:plc:xyz/io.atcr.manifest/<digest-without-sha256-prefix>"
|
||||
func BuildManifestURI(did, manifestDigest string) string {
|
||||
// Remove the "sha256:" prefix from the digest to get the rkey
|
||||
rkey := strings.TrimPrefix(manifestDigest, "sha256:")
|
||||
return fmt.Sprintf("at://%s/%s/%s", did, ManifestCollection, rkey)
|
||||
}
|
||||
|
||||
// ParseManifestURI extracts the digest from a manifest AT-URI
|
||||
// manifestURI: AT-URI in format "at://did:plc:xyz/io.atcr.manifest/<digest-without-sha256-prefix>"
|
||||
// Returns: Full digest with "sha256:" prefix (e.g., "sha256:abc123...")
|
||||
func ParseManifestURI(manifestURI string) (string, error) {
|
||||
// Expected format: at://did:plc:xyz/io.atcr.manifest/<rkey>
|
||||
if !strings.HasPrefix(manifestURI, "at://") {
|
||||
return "", fmt.Errorf("invalid AT-URI format: must start with 'at://'")
|
||||
}
|
||||
|
||||
// Remove "at://" prefix
|
||||
remainder := strings.TrimPrefix(manifestURI, "at://")
|
||||
|
||||
// Split by "/"
|
||||
parts := strings.Split(remainder, "/")
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf("invalid AT-URI format: expected 3 parts (did/collection/rkey), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Validate collection
|
||||
if parts[1] != ManifestCollection {
|
||||
return "", fmt.Errorf("invalid AT-URI: expected collection %s, got %s", ManifestCollection, parts[1])
|
||||
}
|
||||
|
||||
// The rkey is the digest without the "sha256:" prefix
|
||||
// Add it back to get the full digest
|
||||
rkey := parts[2]
|
||||
return "sha256:" + rkey, nil
|
||||
}
|
||||
|
||||
// GetManifestDigest extracts the digest from a Tag, preferring the manifest field
|
||||
// Returns the digest with "sha256:" prefix (e.g., "sha256:abc123...")
|
||||
func (t *Tag) GetManifestDigest() (string, error) {
|
||||
// Prefer the new manifest field
|
||||
if t.Manifest != nil && *t.Manifest != "" {
|
||||
return ParseManifestURI(*t.Manifest)
|
||||
}
|
||||
|
||||
// Fall back to the legacy manifestDigest field
|
||||
if t.ManifestDigest != nil && *t.ManifestDigest != "" {
|
||||
return *t.ManifestDigest, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("tag record has neither manifest nor manifestDigest field")
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
digest string
|
||||
ociManifest string
|
||||
wantErr bool
|
||||
checkFunc func(*testing.T, *Manifest)
|
||||
checkFunc func(*testing.T, *ManifestRecord)
|
||||
}{
|
||||
{
|
||||
name: "valid OCI manifest",
|
||||
@@ -112,9 +112,9 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
digest: "sha256:abc123",
|
||||
ociManifest: validOCIManifest,
|
||||
wantErr: false,
|
||||
checkFunc: func(t *testing.T, record *Manifest) {
|
||||
if record.LexiconTypeID != ManifestCollection {
|
||||
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, ManifestCollection)
|
||||
checkFunc: func(t *testing.T, record *ManifestRecord) {
|
||||
if record.Type != ManifestCollection {
|
||||
t.Errorf("Type = %v, want %v", record.Type, ManifestCollection)
|
||||
}
|
||||
if record.Repository != "myapp" {
|
||||
t.Errorf("Repository = %v, want myapp", record.Repository)
|
||||
@@ -143,9 +143,11 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
if record.Layers[1].Digest != "sha256:layer2" {
|
||||
t.Errorf("Layers[1].Digest = %v, want sha256:layer2", record.Layers[1].Digest)
|
||||
}
|
||||
// Note: Annotations are not copied to generated type (empty struct)
|
||||
if record.CreatedAt == "" {
|
||||
t.Error("CreatedAt should not be empty")
|
||||
if record.Annotations["org.opencontainers.image.created"] != "2025-01-01T00:00:00Z" {
|
||||
t.Errorf("Annotations missing expected key")
|
||||
}
|
||||
if record.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should not be zero")
|
||||
}
|
||||
if record.Subject != nil {
|
||||
t.Error("Subject should be nil")
|
||||
@@ -158,7 +160,7 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
digest: "sha256:abc123",
|
||||
ociManifest: manifestWithSubject,
|
||||
wantErr: false,
|
||||
checkFunc: func(t *testing.T, record *Manifest) {
|
||||
checkFunc: func(t *testing.T, record *ManifestRecord) {
|
||||
if record.Subject == nil {
|
||||
t.Fatal("Subject should not be nil")
|
||||
}
|
||||
@@ -190,7 +192,7 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
digest: "sha256:multiarch",
|
||||
ociManifest: manifestList,
|
||||
wantErr: false,
|
||||
checkFunc: func(t *testing.T, record *Manifest) {
|
||||
checkFunc: func(t *testing.T, record *ManifestRecord) {
|
||||
if record.MediaType != "application/vnd.oci.image.index.v1+json" {
|
||||
t.Errorf("MediaType = %v, want application/vnd.oci.image.index.v1+json", record.MediaType)
|
||||
}
|
||||
@@ -217,8 +219,8 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
if record.Manifests[0].Platform.Architecture != "amd64" {
|
||||
t.Errorf("Platform.Architecture = %v, want amd64", record.Manifests[0].Platform.Architecture)
|
||||
}
|
||||
if record.Manifests[0].Platform.Os != "linux" {
|
||||
t.Errorf("Platform.Os = %v, want linux", record.Manifests[0].Platform.Os)
|
||||
if record.Manifests[0].Platform.OS != "linux" {
|
||||
t.Errorf("Platform.OS = %v, want linux", record.Manifests[0].Platform.OS)
|
||||
}
|
||||
|
||||
// Check second manifest (arm64)
|
||||
@@ -228,7 +230,7 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
if record.Manifests[1].Platform.Architecture != "arm64" {
|
||||
t.Errorf("Platform.Architecture = %v, want arm64", record.Manifests[1].Platform.Architecture)
|
||||
}
|
||||
if record.Manifests[1].Platform.Variant == nil || *record.Manifests[1].Platform.Variant != "v8" {
|
||||
if record.Manifests[1].Platform.Variant != "v8" {
|
||||
t.Errorf("Platform.Variant = %v, want v8", record.Manifests[1].Platform.Variant)
|
||||
}
|
||||
},
|
||||
@@ -266,13 +268,12 @@ func TestNewManifestRecord(t *testing.T) {
|
||||
|
||||
func TestNewTagRecord(t *testing.T) {
|
||||
did := "did:plc:test123"
|
||||
// Truncate to second precision since RFC3339 doesn't have sub-second precision
|
||||
before := time.Now().Truncate(time.Second)
|
||||
before := time.Now()
|
||||
record := NewTagRecord(did, "myapp", "latest", "sha256:abc123")
|
||||
after := time.Now().Truncate(time.Second).Add(time.Second)
|
||||
after := time.Now()
|
||||
|
||||
if record.LexiconTypeID != TagCollection {
|
||||
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, TagCollection)
|
||||
if record.Type != TagCollection {
|
||||
t.Errorf("Type = %v, want %v", record.Type, TagCollection)
|
||||
}
|
||||
|
||||
if record.Repository != "myapp" {
|
||||
@@ -285,21 +286,17 @@ func TestNewTagRecord(t *testing.T) {
|
||||
|
||||
// New records should have manifest field (AT-URI)
|
||||
expectedURI := "at://did:plc:test123/io.atcr.manifest/abc123"
|
||||
if record.Manifest == nil || *record.Manifest != expectedURI {
|
||||
if record.Manifest != expectedURI {
|
||||
t.Errorf("Manifest = %v, want %v", record.Manifest, expectedURI)
|
||||
}
|
||||
|
||||
// New records should NOT have manifestDigest field
|
||||
if record.ManifestDigest != nil && *record.ManifestDigest != "" {
|
||||
t.Errorf("ManifestDigest should be nil for new records, got %v", record.ManifestDigest)
|
||||
if record.ManifestDigest != "" {
|
||||
t.Errorf("ManifestDigest should be empty for new records, got %v", record.ManifestDigest)
|
||||
}
|
||||
|
||||
createdAt, err := time.Parse(time.RFC3339, record.CreatedAt)
|
||||
if err != nil {
|
||||
t.Errorf("CreatedAt is not valid RFC3339: %v", err)
|
||||
}
|
||||
if createdAt.Before(before) || createdAt.After(after) {
|
||||
t.Errorf("CreatedAt = %v, want between %v and %v", createdAt, before, after)
|
||||
if record.UpdatedAt.Before(before) || record.UpdatedAt.After(after) {
|
||||
t.Errorf("UpdatedAt = %v, want between %v and %v", record.UpdatedAt, before, after)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,50 +391,47 @@ func TestParseManifestURI(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTagRecord_GetManifestDigest(t *testing.T) {
|
||||
manifestURI := "at://did:plc:test123/io.atcr.manifest/abc123"
|
||||
digestValue := "sha256:def456"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
record Tag
|
||||
record TagRecord
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "new record with manifest field",
|
||||
record: Tag{
|
||||
Manifest: &manifestURI,
|
||||
record: TagRecord{
|
||||
Manifest: "at://did:plc:test123/io.atcr.manifest/abc123",
|
||||
},
|
||||
want: "sha256:abc123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "old record with manifestDigest field",
|
||||
record: Tag{
|
||||
ManifestDigest: &digestValue,
|
||||
record: TagRecord{
|
||||
ManifestDigest: "sha256:def456",
|
||||
},
|
||||
want: "sha256:def456",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "prefers manifest over manifestDigest",
|
||||
record: Tag{
|
||||
Manifest: &manifestURI,
|
||||
ManifestDigest: &digestValue,
|
||||
record: TagRecord{
|
||||
Manifest: "at://did:plc:test123/io.atcr.manifest/abc123",
|
||||
ManifestDigest: "sha256:def456",
|
||||
},
|
||||
want: "sha256:abc123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no fields set",
|
||||
record: Tag{},
|
||||
record: TagRecord{},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid manifest URI",
|
||||
record: Tag{
|
||||
Manifest: func() *string { s := "invalid-uri"; return &s }(),
|
||||
record: TagRecord{
|
||||
Manifest: "invalid-uri",
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
@@ -458,8 +452,6 @@ func TestTagRecord_GetManifestDigest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewHoldRecord is removed - HoldRecord is no longer supported (legacy BYOS)
|
||||
|
||||
func TestNewSailorProfileRecord(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -481,72 +473,53 @@ func TestNewSailorProfileRecord(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Truncate to second precision since RFC3339 doesn't have sub-second precision
|
||||
before := time.Now().Truncate(time.Second)
|
||||
before := time.Now()
|
||||
record := NewSailorProfileRecord(tt.defaultHold)
|
||||
after := time.Now().Truncate(time.Second).Add(time.Second)
|
||||
after := time.Now()
|
||||
|
||||
if record.LexiconTypeID != SailorProfileCollection {
|
||||
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, SailorProfileCollection)
|
||||
if record.Type != SailorProfileCollection {
|
||||
t.Errorf("Type = %v, want %v", record.Type, SailorProfileCollection)
|
||||
}
|
||||
|
||||
if tt.defaultHold == "" {
|
||||
if record.DefaultHold != nil {
|
||||
t.Errorf("DefaultHold = %v, want nil", record.DefaultHold)
|
||||
}
|
||||
} else {
|
||||
if record.DefaultHold == nil || *record.DefaultHold != tt.defaultHold {
|
||||
t.Errorf("DefaultHold = %v, want %v", record.DefaultHold, tt.defaultHold)
|
||||
}
|
||||
if record.DefaultHold != tt.defaultHold {
|
||||
t.Errorf("DefaultHold = %v, want %v", record.DefaultHold, tt.defaultHold)
|
||||
}
|
||||
|
||||
createdAt, err := time.Parse(time.RFC3339, record.CreatedAt)
|
||||
if err != nil {
|
||||
t.Errorf("CreatedAt is not valid RFC3339: %v", err)
|
||||
}
|
||||
if createdAt.Before(before) || createdAt.After(after) {
|
||||
t.Errorf("CreatedAt = %v, want between %v and %v", createdAt, before, after)
|
||||
if record.CreatedAt.Before(before) || record.CreatedAt.After(after) {
|
||||
t.Errorf("CreatedAt = %v, want between %v and %v", record.CreatedAt, before, after)
|
||||
}
|
||||
|
||||
if record.UpdatedAt == nil {
|
||||
t.Error("UpdatedAt should not be nil")
|
||||
} else {
|
||||
updatedAt, err := time.Parse(time.RFC3339, *record.UpdatedAt)
|
||||
if err != nil {
|
||||
t.Errorf("UpdatedAt is not valid RFC3339: %v", err)
|
||||
}
|
||||
if updatedAt.Before(before) || updatedAt.After(after) {
|
||||
t.Errorf("UpdatedAt = %v, want between %v and %v", updatedAt, before, after)
|
||||
}
|
||||
if record.UpdatedAt.Before(before) || record.UpdatedAt.After(after) {
|
||||
t.Errorf("UpdatedAt = %v, want between %v and %v", record.UpdatedAt, before, after)
|
||||
}
|
||||
|
||||
// CreatedAt and UpdatedAt should be equal for new records
|
||||
if !record.CreatedAt.Equal(record.UpdatedAt) {
|
||||
t.Errorf("CreatedAt (%v) != UpdatedAt (%v)", record.CreatedAt, record.UpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStarRecord(t *testing.T) {
|
||||
// Truncate to second precision since RFC3339 doesn't have sub-second precision
|
||||
before := time.Now().Truncate(time.Second)
|
||||
before := time.Now()
|
||||
record := NewStarRecord("did:plc:alice123", "myapp")
|
||||
after := time.Now().Truncate(time.Second).Add(time.Second)
|
||||
after := time.Now()
|
||||
|
||||
if record.LexiconTypeID != StarCollection {
|
||||
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, StarCollection)
|
||||
if record.Type != StarCollection {
|
||||
t.Errorf("Type = %v, want %v", record.Type, StarCollection)
|
||||
}
|
||||
|
||||
if record.Subject.Did != "did:plc:alice123" {
|
||||
t.Errorf("Subject.Did = %v, want did:plc:alice123", record.Subject.Did)
|
||||
if record.Subject.DID != "did:plc:alice123" {
|
||||
t.Errorf("Subject.DID = %v, want did:plc:alice123", record.Subject.DID)
|
||||
}
|
||||
|
||||
if record.Subject.Repository != "myapp" {
|
||||
t.Errorf("Subject.Repository = %v, want myapp", record.Subject.Repository)
|
||||
}
|
||||
|
||||
createdAt, err := time.Parse(time.RFC3339, record.CreatedAt)
|
||||
if err != nil {
|
||||
t.Errorf("CreatedAt is not valid RFC3339: %v", err)
|
||||
}
|
||||
if createdAt.Before(before) || createdAt.After(after) {
|
||||
t.Errorf("CreatedAt = %v, want between %v and %v", createdAt, before, after)
|
||||
if record.CreatedAt.Before(before) || record.CreatedAt.After(after) {
|
||||
t.Errorf("CreatedAt = %v, want between %v and %v", record.CreatedAt, before, after)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -834,8 +807,7 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add hold DID
|
||||
holdDID := "did:web:hold01.atcr.io"
|
||||
record.HoldDid = &holdDID
|
||||
record.HoldDID = "did:web:hold01.atcr.io"
|
||||
|
||||
// Serialize to JSON
|
||||
jsonData, err := json.Marshal(record)
|
||||
@@ -844,14 +816,14 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
// Deserialize from JSON
|
||||
var decoded Manifest
|
||||
var decoded ManifestRecord
|
||||
if err := json.Unmarshal(jsonData, &decoded); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.LexiconTypeID != record.LexiconTypeID {
|
||||
t.Errorf("LexiconTypeID = %v, want %v", decoded.LexiconTypeID, record.LexiconTypeID)
|
||||
if decoded.Type != record.Type {
|
||||
t.Errorf("Type = %v, want %v", decoded.Type, record.Type)
|
||||
}
|
||||
if decoded.Repository != record.Repository {
|
||||
t.Errorf("Repository = %v, want %v", decoded.Repository, record.Repository)
|
||||
@@ -859,8 +831,8 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
|
||||
if decoded.Digest != record.Digest {
|
||||
t.Errorf("Digest = %v, want %v", decoded.Digest, record.Digest)
|
||||
}
|
||||
if decoded.HoldDid == nil || *decoded.HoldDid != *record.HoldDid {
|
||||
t.Errorf("HoldDid = %v, want %v", decoded.HoldDid, record.HoldDid)
|
||||
if decoded.HoldDID != record.HoldDID {
|
||||
t.Errorf("HoldDID = %v, want %v", decoded.HoldDID, record.HoldDID)
|
||||
}
|
||||
if decoded.Config.Digest != record.Config.Digest {
|
||||
t.Errorf("Config.Digest = %v, want %v", decoded.Config.Digest, record.Config.Digest)
|
||||
@@ -871,12 +843,14 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBlobReference_JSONSerialization(t *testing.T) {
|
||||
blob := Manifest_BlobReference{
|
||||
blob := BlobReference{
|
||||
MediaType: "application/vnd.oci.image.layer.v1.tar+gzip",
|
||||
Digest: "sha256:abc123",
|
||||
Size: 12345,
|
||||
Urls: []string{"https://s3.example.com/blob"},
|
||||
// Note: Annotations is now an empty struct, not a map
|
||||
URLs: []string{"https://s3.example.com/blob"},
|
||||
Annotations: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
}
|
||||
|
||||
// Serialize
|
||||
@@ -886,7 +860,7 @@ func TestBlobReference_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
// Deserialize
|
||||
var decoded Manifest_BlobReference
|
||||
var decoded BlobReference
|
||||
if err := json.Unmarshal(jsonData, &decoded); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
@@ -904,8 +878,8 @@ func TestBlobReference_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStarSubject_JSONSerialization(t *testing.T) {
|
||||
subject := SailorStar_Subject{
|
||||
Did: "did:plc:alice123",
|
||||
subject := StarSubject{
|
||||
DID: "did:plc:alice123",
|
||||
Repository: "myapp",
|
||||
}
|
||||
|
||||
@@ -916,14 +890,14 @@ func TestStarSubject_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
// Deserialize
|
||||
var decoded SailorStar_Subject
|
||||
var decoded StarSubject
|
||||
if err := json.Unmarshal(jsonData, &decoded); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
if decoded.Did != subject.Did {
|
||||
t.Errorf("Did = %v, want %v", decoded.Did, subject.Did)
|
||||
if decoded.DID != subject.DID {
|
||||
t.Errorf("DID = %v, want %v", decoded.DID, subject.DID)
|
||||
}
|
||||
if decoded.Repository != subject.Repository {
|
||||
t.Errorf("Repository = %v, want %v", decoded.Repository, subject.Repository)
|
||||
@@ -1170,8 +1144,8 @@ func TestNewLayerRecord(t *testing.T) {
|
||||
t.Fatal("NewLayerRecord() returned nil")
|
||||
}
|
||||
|
||||
if record.LexiconTypeID != LayerCollection {
|
||||
t.Errorf("LexiconTypeID = %q, want %q", record.LexiconTypeID, LayerCollection)
|
||||
if record.Type != LayerCollection {
|
||||
t.Errorf("Type = %q, want %q", record.Type, LayerCollection)
|
||||
}
|
||||
|
||||
if record.Digest != tt.digest {
|
||||
@@ -1190,8 +1164,8 @@ func TestNewLayerRecord(t *testing.T) {
|
||||
t.Errorf("Repository = %q, want %q", record.Repository, tt.repository)
|
||||
}
|
||||
|
||||
if record.UserDid != tt.userDID {
|
||||
t.Errorf("UserDid = %q, want %q", record.UserDid, tt.userDID)
|
||||
if record.UserDID != tt.userDID {
|
||||
t.Errorf("UserDID = %q, want %q", record.UserDID, tt.userDID)
|
||||
}
|
||||
|
||||
if record.UserHandle != tt.userHandle {
|
||||
@@ -1213,7 +1187,7 @@ func TestNewLayerRecord(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewLayerRecordJSON(t *testing.T) {
|
||||
// Test that HoldLayer can be marshaled/unmarshaled to/from JSON
|
||||
// Test that LayerRecord can be marshaled/unmarshaled to/from JSON
|
||||
record := NewLayerRecord(
|
||||
"sha256:abc123",
|
||||
1024,
|
||||
@@ -1230,14 +1204,14 @@ func TestNewLayerRecordJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded HoldLayer
|
||||
var decoded LayerRecord
|
||||
if err := json.Unmarshal(jsonData, &decoded); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify fields match
|
||||
if decoded.LexiconTypeID != record.LexiconTypeID {
|
||||
t.Errorf("LexiconTypeID = %q, want %q", decoded.LexiconTypeID, record.LexiconTypeID)
|
||||
if decoded.Type != record.Type {
|
||||
t.Errorf("Type = %q, want %q", decoded.Type, record.Type)
|
||||
}
|
||||
if decoded.Digest != record.Digest {
|
||||
t.Errorf("Digest = %q, want %q", decoded.Digest, record.Digest)
|
||||
@@ -1251,8 +1225,8 @@ func TestNewLayerRecordJSON(t *testing.T) {
|
||||
if decoded.Repository != record.Repository {
|
||||
t.Errorf("Repository = %q, want %q", decoded.Repository, record.Repository)
|
||||
}
|
||||
if decoded.UserDid != record.UserDid {
|
||||
t.Errorf("UserDid = %q, want %q", decoded.UserDid, record.UserDid)
|
||||
if decoded.UserDID != record.UserDID {
|
||||
t.Errorf("UserDID = %q, want %q", decoded.UserDID, record.UserDID)
|
||||
}
|
||||
if decoded.UserHandle != record.UserHandle {
|
||||
t.Errorf("UserHandle = %q, want %q", decoded.UserHandle, record.UserHandle)
|
||||
@@ -1261,3 +1235,135 @@ func TestNewLayerRecordJSON(t *testing.T) {
|
||||
t.Errorf("CreatedAt = %q, want %q", decoded.CreatedAt, record.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRepoPageRecord(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
repository string
|
||||
description string
|
||||
avatar *ATProtoBlobRef
|
||||
}{
|
||||
{
|
||||
name: "with description only",
|
||||
repository: "myapp",
|
||||
description: "# My App\n\nA cool container image.",
|
||||
avatar: nil,
|
||||
},
|
||||
{
|
||||
name: "with avatar only",
|
||||
repository: "another-app",
|
||||
description: "",
|
||||
avatar: &ATProtoBlobRef{
|
||||
Type: "blob",
|
||||
Ref: Link{Link: "bafyreiabc123"},
|
||||
MimeType: "image/png",
|
||||
Size: 1024,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with both description and avatar",
|
||||
repository: "full-app",
|
||||
description: "This is a full description.",
|
||||
avatar: &ATProtoBlobRef{
|
||||
Type: "blob",
|
||||
Ref: Link{Link: "bafyreiabc456"},
|
||||
MimeType: "image/jpeg",
|
||||
Size: 2048,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty values",
|
||||
repository: "",
|
||||
description: "",
|
||||
avatar: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
before := time.Now()
|
||||
record := NewRepoPageRecord(tt.repository, tt.description, tt.avatar)
|
||||
after := time.Now()
|
||||
|
||||
if record.Type != RepoPageCollection {
|
||||
t.Errorf("Type = %v, want %v", record.Type, RepoPageCollection)
|
||||
}
|
||||
|
||||
if record.Repository != tt.repository {
|
||||
t.Errorf("Repository = %v, want %v", record.Repository, tt.repository)
|
||||
}
|
||||
|
||||
if record.Description != tt.description {
|
||||
t.Errorf("Description = %v, want %v", record.Description, tt.description)
|
||||
}
|
||||
|
||||
if tt.avatar == nil && record.Avatar != nil {
|
||||
t.Error("Avatar should be nil")
|
||||
}
|
||||
|
||||
if tt.avatar != nil {
|
||||
if record.Avatar == nil {
|
||||
t.Fatal("Avatar should not be nil")
|
||||
}
|
||||
if record.Avatar.Ref.Link != tt.avatar.Ref.Link {
|
||||
t.Errorf("Avatar.Ref.Link = %v, want %v", record.Avatar.Ref.Link, tt.avatar.Ref.Link)
|
||||
}
|
||||
}
|
||||
|
||||
if record.CreatedAt.Before(before) || record.CreatedAt.After(after) {
|
||||
t.Errorf("CreatedAt = %v, want between %v and %v", record.CreatedAt, before, after)
|
||||
}
|
||||
|
||||
if record.UpdatedAt.Before(before) || record.UpdatedAt.After(after) {
|
||||
t.Errorf("UpdatedAt = %v, want between %v and %v", record.UpdatedAt, before, after)
|
||||
}
|
||||
|
||||
// CreatedAt and UpdatedAt should be equal for new records
|
||||
if !record.CreatedAt.Equal(record.UpdatedAt) {
|
||||
t.Errorf("CreatedAt (%v) != UpdatedAt (%v)", record.CreatedAt, record.UpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepoPageRecord_JSONSerialization(t *testing.T) {
|
||||
record := NewRepoPageRecord(
|
||||
"myapp",
|
||||
"# My App\n\nA description with **markdown**.",
|
||||
&ATProtoBlobRef{
|
||||
Type: "blob",
|
||||
Ref: Link{Link: "bafyreiabc123"},
|
||||
MimeType: "image/png",
|
||||
Size: 1024,
|
||||
},
|
||||
)
|
||||
|
||||
// Serialize to JSON
|
||||
jsonData, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Deserialize from JSON
|
||||
var decoded RepoPageRecord
|
||||
if err := json.Unmarshal(jsonData, &decoded); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.Type != record.Type {
|
||||
t.Errorf("Type = %v, want %v", decoded.Type, record.Type)
|
||||
}
|
||||
if decoded.Repository != record.Repository {
|
||||
t.Errorf("Repository = %v, want %v", decoded.Repository, record.Repository)
|
||||
}
|
||||
if decoded.Description != record.Description {
|
||||
t.Errorf("Description = %v, want %v", decoded.Description, record.Description)
|
||||
}
|
||||
if decoded.Avatar == nil {
|
||||
t.Fatal("Avatar should not be nil")
|
||||
}
|
||||
if decoded.Avatar.Ref.Link != record.Avatar.Ref.Link {
|
||||
t.Errorf("Avatar.Ref.Link = %v, want %v", decoded.Avatar.Ref.Link, record.Avatar.Ref.Link)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
// Lexicon schema: io.atcr.manifest
|
||||
|
||||
package atproto
|
||||
|
||||
import (
|
||||
lexutil "github.com/bluesky-social/indigo/lex/util"
|
||||
)
|
||||
|
||||
// A container image manifest following OCI specification, stored in ATProto
|
||||
type Manifest struct {
|
||||
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.manifest"`
|
||||
// annotations: Optional metadata annotations
|
||||
Annotations *Manifest_Annotations `json:"annotations,omitempty" cborgen:"annotations,omitempty"`
|
||||
// config: Reference to image configuration blob
|
||||
Config *Manifest_BlobReference `json:"config,omitempty" cborgen:"config,omitempty"`
|
||||
// createdAt: Record creation timestamp
|
||||
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
|
||||
// digest: Content digest (e.g., 'sha256:abc123...')
|
||||
Digest string `json:"digest" cborgen:"digest"`
|
||||
// holdDid: DID of the hold service where blobs are stored (e.g., 'did:web:hold01.atcr.io'). Primary reference for hold resolution.
|
||||
HoldDid *string `json:"holdDid,omitempty" cborgen:"holdDid,omitempty"`
|
||||
// holdEndpoint: Hold service endpoint URL where blobs are stored. DEPRECATED: Use holdDid instead. Kept for backward compatibility.
|
||||
HoldEndpoint *string `json:"holdEndpoint,omitempty" cborgen:"holdEndpoint,omitempty"`
|
||||
// layers: Filesystem layers (for image manifests)
|
||||
Layers []Manifest_BlobReference `json:"layers,omitempty" cborgen:"layers,omitempty"`
|
||||
// manifestBlob: The full OCI manifest stored as a blob in ATProto.
|
||||
ManifestBlob *lexutil.LexBlob `json:"manifestBlob,omitempty" cborgen:"manifestBlob,omitempty"`
|
||||
// manifests: Referenced manifests (for manifest lists/indexes)
|
||||
Manifests []Manifest_ManifestReference `json:"manifests,omitempty" cborgen:"manifests,omitempty"`
|
||||
// mediaType: OCI media type
|
||||
MediaType string `json:"mediaType" cborgen:"mediaType"`
|
||||
// repository: Repository name (e.g., 'myapp'). Scoped to user's DID.
|
||||
Repository string `json:"repository" cborgen:"repository"`
|
||||
// schemaVersion: OCI schema version (typically 2)
|
||||
SchemaVersion int64 `json:"schemaVersion" cborgen:"schemaVersion"`
|
||||
// subject: Optional reference to another manifest (for attestations, signatures)
|
||||
Subject *Manifest_BlobReference `json:"subject,omitempty" cborgen:"subject,omitempty"`
|
||||
}
|
||||
|
||||
// Optional metadata annotations
|
||||
type Manifest_Annotations struct {
|
||||
}
|
||||
|
||||
// Manifest_BlobReference is a "blobReference" in the io.atcr.manifest schema.
|
||||
//
|
||||
// Reference to a blob stored in S3 or external storage
|
||||
type Manifest_BlobReference struct {
|
||||
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.manifest#blobReference,omitempty"`
|
||||
// annotations: Optional metadata
|
||||
Annotations *Manifest_BlobReference_Annotations `json:"annotations,omitempty" cborgen:"annotations,omitempty"`
|
||||
// digest: Content digest (e.g., 'sha256:...')
|
||||
Digest string `json:"digest" cborgen:"digest"`
|
||||
// mediaType: MIME type of the blob
|
||||
MediaType string `json:"mediaType" cborgen:"mediaType"`
|
||||
// size: Size in bytes
|
||||
Size int64 `json:"size" cborgen:"size"`
|
||||
// urls: Optional direct URLs to blob (for BYOS)
|
||||
Urls []string `json:"urls,omitempty" cborgen:"urls,omitempty"`
|
||||
}
|
||||
|
||||
// Optional metadata
|
||||
type Manifest_BlobReference_Annotations struct {
|
||||
}
|
||||
|
||||
// Manifest_ManifestReference is a "manifestReference" in the io.atcr.manifest schema.
|
||||
//
|
||||
// Reference to a manifest in a manifest list/index
|
||||
type Manifest_ManifestReference struct {
|
||||
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.manifest#manifestReference,omitempty"`
|
||||
// annotations: Optional metadata
|
||||
Annotations *Manifest_ManifestReference_Annotations `json:"annotations,omitempty" cborgen:"annotations,omitempty"`
|
||||
// digest: Content digest (e.g., 'sha256:...')
|
||||
Digest string `json:"digest" cborgen:"digest"`
|
||||
// mediaType: Media type of the referenced manifest
|
||||
MediaType string `json:"mediaType" cborgen:"mediaType"`
|
||||
// platform: Platform information for this manifest
|
||||
Platform *Manifest_Platform `json:"platform,omitempty" cborgen:"platform,omitempty"`
|
||||
// size: Size in bytes
|
||||
Size int64 `json:"size" cborgen:"size"`
|
||||
}
|
||||
|
||||
// Optional metadata
|
||||
type Manifest_ManifestReference_Annotations struct {
|
||||
}
|
||||
|
||||
// Manifest_Platform is a "platform" in the io.atcr.manifest schema.
|
||||
//
|
||||
// Platform information describing OS and architecture
|
||||
type Manifest_Platform struct {
|
||||
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.manifest#platform,omitempty"`
|
||||
// architecture: CPU architecture (e.g., 'amd64', 'arm64', 'arm')
|
||||
Architecture string `json:"architecture" cborgen:"architecture"`
|
||||
// os: Operating system (e.g., 'linux', 'windows', 'darwin')
|
||||
Os string `json:"os" cborgen:"os"`
|
||||
// osFeatures: Optional OS features
|
||||
OsFeatures []string `json:"osFeatures,omitempty" cborgen:"osFeatures,omitempty"`
|
||||
// osVersion: Optional OS version
|
||||
OsVersion *string `json:"osVersion,omitempty" cborgen:"osVersion,omitempty"`
|
||||
// variant: Optional CPU variant (e.g., 'v7' for ARM)
|
||||
Variant *string `json:"variant,omitempty" cborgen:"variant,omitempty"`
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
package atproto
|
||||
|
||||
import lexutil "github.com/bluesky-social/indigo/lex/util"
|
||||
|
||||
func init() {
|
||||
lexutil.RegisterType("io.atcr.hold.captain", &HoldCaptain{})
|
||||
lexutil.RegisterType("io.atcr.hold.crew", &HoldCrew{})
|
||||
lexutil.RegisterType("io.atcr.hold.layer", &HoldLayer{})
|
||||
lexutil.RegisterType("io.atcr.manifest", &Manifest{})
|
||||
lexutil.RegisterType("io.atcr.sailor.profile", &SailorProfile{})
|
||||
lexutil.RegisterType("io.atcr.sailor.star", &SailorStar{})
|
||||
lexutil.RegisterType("io.atcr.tag", &Tag{})
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
// Lexicon schema: io.atcr.sailor.profile
|
||||
|
||||
package atproto
|
||||
|
||||
// User profile for ATCR registry. Stores preferences like default hold for blob storage.
|
||||
type SailorProfile struct {
|
||||
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.sailor.profile"`
|
||||
// createdAt: Profile creation timestamp
|
||||
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
|
||||
// defaultHold: Default hold endpoint for blob storage. If null, user has opted out of defaults.
|
||||
DefaultHold *string `json:"defaultHold,omitempty" cborgen:"defaultHold,omitempty"`
|
||||
// updatedAt: Profile last updated timestamp
|
||||
UpdatedAt *string `json:"updatedAt,omitempty" cborgen:"updatedAt,omitempty"`
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
// Lexicon schema: io.atcr.sailor.star
|
||||
|
||||
package atproto
|
||||
|
||||
// A star (like) on a container image repository. Stored in the starrer's PDS, similar to Bluesky likes.
|
||||
type SailorStar struct {
|
||||
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.sailor.star"`
|
||||
// createdAt: Star creation timestamp
|
||||
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
|
||||
// subject: The repository being starred
|
||||
Subject SailorStar_Subject `json:"subject" cborgen:"subject"`
|
||||
}
|
||||
|
||||
// SailorStar_Subject is a "subject" in the io.atcr.sailor.star schema.
|
||||
//
|
||||
// Reference to a repository owned by a user
|
||||
type SailorStar_Subject struct {
|
||||
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.sailor.star#subject,omitempty"`
|
||||
// did: DID of the repository owner
|
||||
Did string `json:"did" cborgen:"did"`
|
||||
// repository: Repository name (e.g., 'myapp')
|
||||
Repository string `json:"repository" cborgen:"repository"`
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
// Code generated by generate.go; DO NOT EDIT.
|
||||
|
||||
// Lexicon schema: io.atcr.tag
|
||||
|
||||
package atproto
|
||||
|
||||
// A named tag pointing to a specific manifest digest
|
||||
type Tag struct {
|
||||
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.tag"`
|
||||
// createdAt: Tag creation timestamp
|
||||
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
|
||||
// manifest: AT-URI of the manifest this tag points to (e.g., 'at://did:plc:xyz/io.atcr.manifest/abc123'). Preferred over manifestDigest for new records.
|
||||
Manifest *string `json:"manifest,omitempty" cborgen:"manifest,omitempty"`
|
||||
// manifestDigest: DEPRECATED: Digest of the manifest (e.g., 'sha256:...'). Kept for backward compatibility with old records. New records should use 'manifest' field instead.
|
||||
ManifestDigest *string `json:"manifestDigest,omitempty" cborgen:"manifestDigest,omitempty"`
|
||||
// repository: Repository name (e.g., 'myapp'). Scoped to user's DID.
|
||||
Repository string `json:"repository" cborgen:"repository"`
|
||||
// tag: Tag name (e.g., 'latest', 'v1.0.0', '12-slim')
|
||||
Tag string `json:"tag" cborgen:"tag"`
|
||||
}
|
||||
@@ -2,14 +2,10 @@
|
||||
// Service tokens are JWTs issued by a user's PDS to authorize AppView to
|
||||
// act on their behalf when communicating with hold services. Tokens are
|
||||
// cached with automatic expiry parsing and 10-second safety margins.
|
||||
package token
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -18,6 +14,8 @@ import (
|
||||
type serviceTokenEntry struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
err error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// Global cache for service tokens (DID:HoldDID -> token)
|
||||
@@ -61,7 +59,7 @@ func SetServiceToken(did, holdDID, token string) error {
|
||||
cacheKey := did + ":" + holdDID
|
||||
|
||||
// Parse JWT to extract expiry (don't verify signature - we trust the PDS)
|
||||
expiry, err := parseJWTExpiry(token)
|
||||
expiry, err := ParseJWTExpiry(token)
|
||||
if err != nil {
|
||||
// If parsing fails, use default 50s TTL (conservative fallback)
|
||||
slog.Warn("Failed to parse JWT expiry, using default 50s", "error", err, "cacheKey", cacheKey)
|
||||
@@ -85,37 +83,6 @@ func SetServiceToken(did, holdDID, token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseJWTExpiry extracts the expiry time from a JWT without verifying the signature
|
||||
// We trust tokens from the user's PDS, so signature verification isn't needed here
|
||||
// Manually decodes the JWT payload to avoid algorithm compatibility issues
|
||||
func parseJWTExpiry(tokenString string) (time.Time, error) {
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON payload
|
||||
var claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
if claims.Exp == 0 {
|
||||
return time.Time{}, fmt.Errorf("JWT missing exp claim")
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
||||
// InvalidateServiceToken removes a service token from the cache
|
||||
// Used when we detect that a token is invalid or the user's session has expired
|
||||
func InvalidateServiceToken(did, holdDID string) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package token
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -21,7 +21,7 @@ type HoldAuthorizer interface {
|
||||
|
||||
// GetCaptainRecord retrieves the captain record for a hold
|
||||
// Used to check public flag and allowAllCrew settings
|
||||
GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error)
|
||||
GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error)
|
||||
|
||||
// IsCrewMember checks if userDID is a crew member of holdDID
|
||||
IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error)
|
||||
@@ -32,7 +32,7 @@ type HoldAuthorizer interface {
|
||||
// Read access rules:
|
||||
// - Public hold: allow anyone (even anonymous)
|
||||
// - Private hold: require authentication (any authenticated user)
|
||||
func CheckReadAccessWithCaptain(captain *atproto.HoldCaptain, userDID string) bool {
|
||||
func CheckReadAccessWithCaptain(captain *atproto.CaptainRecord, userDID string) bool {
|
||||
if captain.Public {
|
||||
// Public hold - allow anyone (even anonymous)
|
||||
return true
|
||||
@@ -55,7 +55,7 @@ func CheckReadAccessWithCaptain(captain *atproto.HoldCaptain, userDID string) bo
|
||||
// Write access rules:
|
||||
// - Must be authenticated
|
||||
// - Must be hold owner OR crew member
|
||||
func CheckWriteAccessWithCaptain(captain *atproto.HoldCaptain, userDID string, isCrew bool) bool {
|
||||
func CheckWriteAccessWithCaptain(captain *atproto.CaptainRecord, userDID string, isCrew bool) bool {
|
||||
slog.Debug("Checking write access", "userDID", userDID, "owner", captain.Owner, "isCrew", isCrew)
|
||||
|
||||
if userDID == "" {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestCheckReadAccessWithCaptain_PublicHold(t *testing.T) {
|
||||
captain := &atproto.HoldCaptain{
|
||||
captain := &atproto.CaptainRecord{
|
||||
Public: true,
|
||||
Owner: "did:plc:owner123",
|
||||
}
|
||||
@@ -26,7 +26,7 @@ func TestCheckReadAccessWithCaptain_PublicHold(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckReadAccessWithCaptain_PrivateHold(t *testing.T) {
|
||||
captain := &atproto.HoldCaptain{
|
||||
captain := &atproto.CaptainRecord{
|
||||
Public: false,
|
||||
Owner: "did:plc:owner123",
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func TestCheckReadAccessWithCaptain_PrivateHold(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckWriteAccessWithCaptain_Owner(t *testing.T) {
|
||||
captain := &atproto.HoldCaptain{
|
||||
captain := &atproto.CaptainRecord{
|
||||
Public: false,
|
||||
Owner: "did:plc:owner123",
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func TestCheckWriteAccessWithCaptain_Owner(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckWriteAccessWithCaptain_Crew(t *testing.T) {
|
||||
captain := &atproto.HoldCaptain{
|
||||
captain := &atproto.CaptainRecord{
|
||||
Public: false,
|
||||
Owner: "did:plc:owner123",
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func TestCheckWriteAccessWithCaptain_Crew(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckWriteAccessWithCaptain_Anonymous(t *testing.T) {
|
||||
captain := &atproto.HoldCaptain{
|
||||
captain := &atproto.CaptainRecord{
|
||||
Public: false,
|
||||
Owner: "did:plc:owner123",
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func NewLocalHoldAuthorizerFromInterface(holdPDS any) HoldAuthorizer {
|
||||
}
|
||||
|
||||
// GetCaptainRecord retrieves the captain record from the hold's PDS
|
||||
func (a *LocalHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error) {
|
||||
func (a *LocalHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
|
||||
// Verify that the requested holdDID matches this hold
|
||||
if holdDID != a.pds.DID() {
|
||||
return nil, fmt.Errorf("holdDID mismatch: requested %s, this hold is %s", holdDID, a.pds.DID())
|
||||
@@ -47,7 +47,7 @@ func (a *LocalHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID stri
|
||||
return nil, fmt.Errorf("failed to get captain record: %w", err)
|
||||
}
|
||||
|
||||
// The PDS returns *atproto.HoldCaptain directly
|
||||
// The PDS returns *atproto.CaptainRecord directly now (after we update pds to use atproto types)
|
||||
return pdsCaptain, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -101,14 +101,14 @@ func (a *RemoteHoldAuthorizer) cleanupRecentDenials() {
|
||||
// 1. Check database cache
|
||||
// 2. If cache miss or expired, query hold's XRPC endpoint
|
||||
// 3. Update cache
|
||||
func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error) {
|
||||
func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
|
||||
// Try cache first
|
||||
if a.db != nil {
|
||||
cached, err := a.getCachedCaptainRecord(holdDID)
|
||||
if err == nil && cached != nil {
|
||||
// Cache hit - check if still valid
|
||||
if time.Since(cached.UpdatedAt) < a.cacheTTL {
|
||||
return cached.HoldCaptain, nil
|
||||
return cached.CaptainRecord, nil
|
||||
}
|
||||
// Cache expired - continue to fetch fresh data
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID str
|
||||
|
||||
// captainRecordWithMeta includes UpdatedAt for cache management
|
||||
type captainRecordWithMeta struct {
|
||||
*atproto.HoldCaptain
|
||||
*atproto.CaptainRecord
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func (a *RemoteHoldAuthorizer) getCachedCaptainRecord(holdDID string) (*captainR
|
||||
WHERE hold_did = ?
|
||||
`
|
||||
|
||||
var record atproto.HoldCaptain
|
||||
var record atproto.CaptainRecord
|
||||
var deployedAt, region, provider sql.NullString
|
||||
var updatedAt time.Time
|
||||
|
||||
@@ -172,20 +172,20 @@ func (a *RemoteHoldAuthorizer) getCachedCaptainRecord(holdDID string) (*captainR
|
||||
record.DeployedAt = deployedAt.String
|
||||
}
|
||||
if region.Valid {
|
||||
record.Region = ®ion.String
|
||||
record.Region = region.String
|
||||
}
|
||||
if provider.Valid {
|
||||
record.Provider = &provider.String
|
||||
record.Provider = provider.String
|
||||
}
|
||||
|
||||
return &captainRecordWithMeta{
|
||||
HoldCaptain: &record,
|
||||
UpdatedAt: updatedAt,
|
||||
CaptainRecord: &record,
|
||||
UpdatedAt: updatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// setCachedCaptainRecord stores a captain record in database cache
|
||||
func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *atproto.HoldCaptain) error {
|
||||
func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *atproto.CaptainRecord) error {
|
||||
query := `
|
||||
INSERT INTO hold_captain_records (
|
||||
hold_did, owner_did, public, allow_all_crew,
|
||||
@@ -207,8 +207,8 @@ func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *at
|
||||
record.Public,
|
||||
record.AllowAllCrew,
|
||||
nullString(record.DeployedAt),
|
||||
nullStringPtr(record.Region),
|
||||
nullStringPtr(record.Provider),
|
||||
nullString(record.Region),
|
||||
nullString(record.Provider),
|
||||
time.Now(),
|
||||
)
|
||||
|
||||
@@ -216,7 +216,7 @@ func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *at
|
||||
}
|
||||
|
||||
// fetchCaptainRecordFromXRPC queries the hold's XRPC endpoint for captain record
|
||||
func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error) {
|
||||
func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
|
||||
// Resolve DID to URL
|
||||
holdURL := atproto.ResolveHoldURL(holdDID)
|
||||
|
||||
@@ -261,20 +261,14 @@ func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, h
|
||||
}
|
||||
|
||||
// Convert to our type
|
||||
record := &atproto.HoldCaptain{
|
||||
LexiconTypeID: atproto.CaptainCollection,
|
||||
Owner: xrpcResp.Value.Owner,
|
||||
Public: xrpcResp.Value.Public,
|
||||
AllowAllCrew: xrpcResp.Value.AllowAllCrew,
|
||||
DeployedAt: xrpcResp.Value.DeployedAt,
|
||||
}
|
||||
|
||||
// Handle optional pointer fields
|
||||
if xrpcResp.Value.Region != "" {
|
||||
record.Region = &xrpcResp.Value.Region
|
||||
}
|
||||
if xrpcResp.Value.Provider != "" {
|
||||
record.Provider = &xrpcResp.Value.Provider
|
||||
record := &atproto.CaptainRecord{
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: xrpcResp.Value.Owner,
|
||||
Public: xrpcResp.Value.Public,
|
||||
AllowAllCrew: xrpcResp.Value.AllowAllCrew,
|
||||
DeployedAt: xrpcResp.Value.DeployedAt,
|
||||
Region: xrpcResp.Value.Region,
|
||||
Provider: xrpcResp.Value.Provider,
|
||||
}
|
||||
|
||||
return record, nil
|
||||
@@ -414,14 +408,6 @@ func nullString(s string) sql.NullString {
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
// nullStringPtr converts a *string to sql.NullString
|
||||
func nullStringPtr(s *string) sql.NullString {
|
||||
if s == nil || *s == "" {
|
||||
return sql.NullString{Valid: false}
|
||||
}
|
||||
return sql.NullString{String: *s, Valid: true}
|
||||
}
|
||||
|
||||
// getCachedApproval checks if user has a cached crew approval
|
||||
func (a *RemoteHoldAuthorizer) getCachedApproval(holdDID, userDID string) (bool, error) {
|
||||
query := `
|
||||
|
||||
@@ -14,11 +14,6 @@ import (
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// ptrString returns a pointer to the given string
|
||||
func ptrString(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func TestNewRemoteHoldAuthorizer(t *testing.T) {
|
||||
// Test with nil database (should still work)
|
||||
authorizer := NewRemoteHoldAuthorizer(nil, false)
|
||||
@@ -138,14 +133,14 @@ func TestGetCaptainRecord_CacheHit(t *testing.T) {
|
||||
holdDID := "did:web:hold01.atcr.io"
|
||||
|
||||
// Pre-populate cache with a captain record
|
||||
captainRecord := &atproto.HoldCaptain{
|
||||
LexiconTypeID: atproto.CaptainCollection,
|
||||
Owner: "did:plc:owner123",
|
||||
Public: true,
|
||||
AllowAllCrew: false,
|
||||
DeployedAt: "2025-10-28T00:00:00Z",
|
||||
Region: ptrString("us-east-1"),
|
||||
Provider: ptrString("fly.io"),
|
||||
captainRecord := &atproto.CaptainRecord{
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: "did:plc:owner123",
|
||||
Public: true,
|
||||
AllowAllCrew: false,
|
||||
DeployedAt: "2025-10-28T00:00:00Z",
|
||||
Region: "us-east-1",
|
||||
Provider: "fly.io",
|
||||
}
|
||||
|
||||
err := remote.setCachedCaptainRecord(holdDID, captainRecord)
|
||||
|
||||
80
pkg/auth/mock_authorizer.go
Normal file
80
pkg/auth/mock_authorizer.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// MockHoldAuthorizer is a test double for HoldAuthorizer.
|
||||
// It allows tests to control the return values of authorization checks
|
||||
// without making network calls or querying a real PDS.
|
||||
type MockHoldAuthorizer struct {
|
||||
// Direct result control
|
||||
CanReadResult bool
|
||||
CanWriteResult bool
|
||||
CanAdminResult bool
|
||||
Error error
|
||||
|
||||
// Captain record to return (optional, for GetCaptainRecord)
|
||||
CaptainRecord *atproto.CaptainRecord
|
||||
|
||||
// Crew membership (optional, for IsCrewMember)
|
||||
IsCrewResult bool
|
||||
}
|
||||
|
||||
// NewMockHoldAuthorizer creates a MockHoldAuthorizer with sensible defaults.
|
||||
// By default, it allows all access (public hold, user is owner).
|
||||
func NewMockHoldAuthorizer() *MockHoldAuthorizer {
|
||||
return &MockHoldAuthorizer{
|
||||
CanReadResult: true,
|
||||
CanWriteResult: true,
|
||||
CanAdminResult: false,
|
||||
IsCrewResult: false,
|
||||
CaptainRecord: &atproto.CaptainRecord{
|
||||
Type: "io.atcr.hold.captain",
|
||||
Owner: "did:plc:mock-owner",
|
||||
Public: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CheckReadAccess returns the configured CanReadResult.
|
||||
func (m *MockHoldAuthorizer) CheckReadAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
return m.CanReadResult, nil
|
||||
}
|
||||
|
||||
// CheckWriteAccess returns the configured CanWriteResult.
|
||||
func (m *MockHoldAuthorizer) CheckWriteAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
return m.CanWriteResult, nil
|
||||
}
|
||||
|
||||
// GetCaptainRecord returns the configured CaptainRecord or a default.
|
||||
func (m *MockHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
if m.CaptainRecord != nil {
|
||||
return m.CaptainRecord, nil
|
||||
}
|
||||
// Return a default captain record
|
||||
return &atproto.CaptainRecord{
|
||||
Type: "io.atcr.hold.captain",
|
||||
Owner: "did:plc:mock-owner",
|
||||
Public: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsCrewMember returns the configured IsCrewResult.
|
||||
func (m *MockHoldAuthorizer) IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error) {
|
||||
if m.Error != nil {
|
||||
return false, m.Error
|
||||
}
|
||||
return m.IsCrewResult, nil
|
||||
}
|
||||
@@ -72,11 +72,19 @@ func RedirectURI(baseURL string) string {
|
||||
return baseURL + "/auth/oauth/callback"
|
||||
}
|
||||
|
||||
// GetDefaultScopes returns the default OAuth scopes for ATCR registry operations
|
||||
// testMode determines whether to use transition:generic (test) or rpc scopes (production)
|
||||
// GetDefaultScopes returns the default OAuth scopes for ATCR registry operations.
|
||||
// Includes io.atcr.authFullApp permission-set plus individual scopes for PDS compatibility.
|
||||
// Blob scopes are listed explicitly (not supported in Lexicon permission-sets).
|
||||
func GetDefaultScopes(did string) []string {
|
||||
scopes := []string{
|
||||
return []string{
|
||||
"atproto",
|
||||
// Permission-set (for future PDS support)
|
||||
// See lexicons/io/atcr/authFullApp.json for definition
|
||||
// Uses "include:" prefix per ATProto permission spec
|
||||
"include:io.atcr.authFullApp",
|
||||
// com.atproto scopes must be separate (permission-sets are namespace-limited)
|
||||
"rpc:com.atproto.repo.getRecord?aud=*",
|
||||
// Blob scopes (not supported in Lexicon permission-sets)
|
||||
// Image manifest types (single-arch)
|
||||
"blob:application/vnd.oci.image.manifest.v1+json",
|
||||
"blob:application/vnd.docker.distribution.manifest.v2+json",
|
||||
@@ -85,19 +93,9 @@ func GetDefaultScopes(did string) []string {
|
||||
"blob:application/vnd.docker.distribution.manifest.list.v2+json",
|
||||
// OCI artifact manifests (for cosign signatures, SBOMs, attestations)
|
||||
"blob:application/vnd.cncf.oras.artifact.manifest.v1+json",
|
||||
// Used for service token validation on holds
|
||||
"rpc:com.atproto.repo.getRecord?aud=*",
|
||||
// Image avatars
|
||||
"blob:image/*",
|
||||
}
|
||||
|
||||
// Add repo scopes
|
||||
scopes = append(scopes,
|
||||
fmt.Sprintf("repo:%s", atproto.ManifestCollection),
|
||||
fmt.Sprintf("repo:%s", atproto.TagCollection),
|
||||
fmt.Sprintf("repo:%s", atproto.StarCollection),
|
||||
fmt.Sprintf("repo:%s", atproto.SailorProfileCollection),
|
||||
)
|
||||
|
||||
return scopes
|
||||
}
|
||||
|
||||
// ScopesMatch checks if two scope lists are equivalent (order-independent)
|
||||
@@ -225,6 +223,18 @@ func (r *Refresher) DoWithSession(ctx context.Context, did string, fn func(sessi
|
||||
// The session's PersistSessionCallback will save nonce updates to DB
|
||||
err = fn(session)
|
||||
|
||||
// If request failed with auth error, delete session to force re-auth
|
||||
if err != nil && isAuthError(err) {
|
||||
slog.Warn("Auth error detected, deleting session to force re-auth",
|
||||
"component", "oauth/refresher",
|
||||
"did", did,
|
||||
"error", err)
|
||||
// Don't hold the lock while deleting - release first
|
||||
mutex.Unlock()
|
||||
_ = r.DeleteSession(ctx, did)
|
||||
mutex.Lock() // Re-acquire for the deferred unlock
|
||||
}
|
||||
|
||||
slog.Debug("Released session lock for DoWithSession",
|
||||
"component", "oauth/refresher",
|
||||
"did", did,
|
||||
@@ -233,6 +243,19 @@ func (r *Refresher) DoWithSession(ctx context.Context, did string, fn func(sessi
|
||||
return err
|
||||
}
|
||||
|
||||
// isAuthError checks if an error looks like an OAuth/auth failure
|
||||
func isAuthError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "unauthorized") ||
|
||||
strings.Contains(errStr, "invalid_token") ||
|
||||
strings.Contains(errStr, "insufficient_scope") ||
|
||||
strings.Contains(errStr, "token expired") ||
|
||||
strings.Contains(errStr, "401")
|
||||
}
|
||||
|
||||
// resumeSession loads a session from storage
|
||||
func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
|
||||
// Parse DID
|
||||
@@ -257,28 +280,15 @@ func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.Clien
|
||||
return nil, fmt.Errorf("no session found for DID: %s", did)
|
||||
}
|
||||
|
||||
// Validate that session scopes match current desired scopes
|
||||
// Log scope differences for debugging, but don't delete session
|
||||
// The PDS will reject requests if scopes are insufficient
|
||||
// (Permission-sets get expanded by PDS, so exact matching doesn't work)
|
||||
desiredScopes := r.clientApp.Config.Scopes
|
||||
if !ScopesMatch(sessionData.Scopes, desiredScopes) {
|
||||
slog.Debug("Scope mismatch, deleting session",
|
||||
slog.Debug("Session scopes differ from desired (may be permission-set expansion)",
|
||||
"did", did,
|
||||
"storedScopes", sessionData.Scopes,
|
||||
"desiredScopes", desiredScopes)
|
||||
|
||||
// Delete the session from database since scopes have changed
|
||||
if err := r.clientApp.Store.DeleteSession(ctx, accountDID, sessionID); err != nil {
|
||||
slog.Warn("Failed to delete session with mismatched scopes", "error", err, "did", did)
|
||||
}
|
||||
|
||||
// Also invalidate UI sessions since OAuth is now invalid
|
||||
if r.uiSessionStore != nil {
|
||||
r.uiSessionStore.DeleteByDID(did)
|
||||
slog.Info("Invalidated UI sessions due to scope mismatch",
|
||||
"component", "oauth/refresher",
|
||||
"did", did)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("OAuth scopes changed, re-authentication required")
|
||||
}
|
||||
|
||||
// Resume session
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewClientApp(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
keyPath := tmpDir + "/oauth-key.bin"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
keyPath := t.TempDir() + "/oauth-key.bin"
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
baseURL := "http://localhost:5000"
|
||||
scopes := GetDefaultScopes("*")
|
||||
@@ -32,14 +27,8 @@ func TestNewClientApp(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewClientAppWithCustomScopes(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
keyPath := tmpDir + "/oauth-key.bin"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
keyPath := t.TempDir() + "/oauth-key.bin"
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
baseURL := "http://localhost:5000"
|
||||
scopes := []string{"atproto", "custom:scope"}
|
||||
@@ -128,13 +117,7 @@ func TestScopesMatch(t *testing.T) {
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestNewRefresher(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -153,13 +136,7 @@ func TestNewRefresher(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRefresher_SetUISessionStore(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
|
||||
@@ -26,11 +26,7 @@ func InteractiveFlowWithCallback(
|
||||
registerCallback func(handler http.HandlerFunc) error,
|
||||
displayAuthURL func(string) error,
|
||||
) (*InteractiveResult, error) {
|
||||
// Create temporary file store for this flow
|
||||
store, err := NewFileStore("/tmp/atcr-oauth-temp.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OAuth store: %w", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
// Create OAuth client app with custom scopes (or defaults if nil)
|
||||
// Interactive flows are typically for production use (credential helper, etc.)
|
||||
|
||||
@@ -2,6 +2,7 @@ package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -11,13 +12,7 @@ import (
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
// Create a basic OAuth app for testing
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -36,13 +31,7 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_SetRefresher(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -60,13 +49,7 @@ func TestServer_SetRefresher(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_SetPostAuthCallback(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -87,13 +70,7 @@ func TestServer_SetPostAuthCallback(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_SetUISessionStore(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -151,13 +128,7 @@ func (m *mockRefresher) InvalidateSession(did string) {
|
||||
// ServeAuthorize tests
|
||||
|
||||
func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -179,13 +150,7 @@ func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -209,13 +174,7 @@ func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
|
||||
// ServeCallback tests
|
||||
|
||||
func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -237,13 +196,7 @@ func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_ServeCallback_OAuthError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -270,13 +223,7 @@ func TestServer_ServeCallback_OAuthError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_ServeCallback_WithPostAuthCallback(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -315,13 +262,7 @@ func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -345,13 +286,7 @@ func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_RenderError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
@@ -380,13 +315,7 @@ func TestServer_RenderError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_RenderRedirectToSettings(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
store := oauth.NewMemStore()
|
||||
|
||||
scopes := GetDefaultScopes("*")
|
||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
||||
|
||||
@@ -1,236 +0,0 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
)
|
||||
|
||||
// FileStore implements oauth.ClientAuthStore with file-based persistence
|
||||
type FileStore struct {
|
||||
path string
|
||||
sessions map[string]*oauth.ClientSessionData // Key: "did:sessionID"
|
||||
requests map[string]*oauth.AuthRequestData // Key: state
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// FileStoreData represents the JSON structure stored on disk
|
||||
type FileStoreData struct {
|
||||
Sessions map[string]*oauth.ClientSessionData `json:"sessions"`
|
||||
Requests map[string]*oauth.AuthRequestData `json:"requests"`
|
||||
}
|
||||
|
||||
// NewFileStore creates a new file-based OAuth store
|
||||
func NewFileStore(path string) (*FileStore, error) {
|
||||
store := &FileStore{
|
||||
path: path,
|
||||
sessions: make(map[string]*oauth.ClientSessionData),
|
||||
requests: make(map[string]*oauth.AuthRequestData),
|
||||
}
|
||||
|
||||
// Load existing data if file exists
|
||||
if err := store.load(); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("failed to load store: %w", err)
|
||||
}
|
||||
// File doesn't exist yet, that's ok
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// GetDefaultStorePath returns the default storage path for OAuth data
|
||||
func GetDefaultStorePath() (string, error) {
|
||||
// For AppView: /var/lib/atcr/oauth-sessions.json
|
||||
// For CLI tools: ~/.atcr/oauth-sessions.json
|
||||
|
||||
// Check if running as a service (has write access to /var/lib)
|
||||
servicePath := "/var/lib/atcr/oauth-sessions.json"
|
||||
if err := os.MkdirAll(filepath.Dir(servicePath), 0700); err == nil {
|
||||
// Can write to /var/lib, use service path
|
||||
return servicePath, nil
|
||||
}
|
||||
|
||||
// Fall back to user home directory
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
atcrDir := filepath.Join(homeDir, ".atcr")
|
||||
if err := os.MkdirAll(atcrDir, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to create .atcr directory: %w", err)
|
||||
}
|
||||
|
||||
return filepath.Join(atcrDir, "oauth-sessions.json"), nil
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by DID and session ID
|
||||
func (s *FileStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
key := makeSessionKey(did.String(), sessionID)
|
||||
session, ok := s.sessions[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found: %s/%s", did, sessionID)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SaveSession saves or updates a session (upsert)
|
||||
func (s *FileStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := makeSessionKey(sess.AccountDID.String(), sess.SessionID)
|
||||
s.sessions[key] = &sess
|
||||
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// DeleteSession removes a session
|
||||
func (s *FileStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := makeSessionKey(did.String(), sessionID)
|
||||
delete(s.sessions, key)
|
||||
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// GetAuthRequestInfo retrieves authentication request data by state
|
||||
func (s *FileStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
request, ok := s.requests[state]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("auth request not found: %s", state)
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// SaveAuthRequestInfo saves authentication request data
|
||||
func (s *FileStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.requests[info.State] = &info
|
||||
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// DeleteAuthRequestInfo removes authentication request data
|
||||
func (s *FileStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.requests, state)
|
||||
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// CleanupExpired removes expired sessions and auth requests
|
||||
// Should be called periodically (e.g., every hour)
|
||||
func (s *FileStore) CleanupExpired() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
modified := false
|
||||
|
||||
// Clean up auth requests older than 10 minutes
|
||||
// (OAuth flows should complete quickly)
|
||||
for state := range s.requests {
|
||||
// Note: AuthRequestData doesn't have a timestamp in indigo's implementation
|
||||
// For now, we'll rely on the OAuth server's cleanup routine
|
||||
// or we could extend AuthRequestData with metadata
|
||||
_ = state // Placeholder for future expiration logic
|
||||
}
|
||||
|
||||
// Sessions don't have expiry in the data structure
|
||||
// Cleanup would need to be token-based (check token expiry)
|
||||
// For now, manual cleanup via DeleteSession
|
||||
_ = now
|
||||
|
||||
if modified {
|
||||
return s.save()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListSessions returns all stored sessions for debugging/management
|
||||
func (s *FileStore) ListSessions() map[string]*oauth.ClientSessionData {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make(map[string]*oauth.ClientSessionData)
|
||||
maps.Copy(result, s.sessions)
|
||||
return result
|
||||
}
|
||||
|
||||
// load reads data from disk
|
||||
func (s *FileStore) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var storeData FileStoreData
|
||||
if err := json.Unmarshal(data, &storeData); err != nil {
|
||||
return fmt.Errorf("failed to parse store: %w", err)
|
||||
}
|
||||
|
||||
if storeData.Sessions != nil {
|
||||
s.sessions = storeData.Sessions
|
||||
}
|
||||
if storeData.Requests != nil {
|
||||
s.requests = storeData.Requests
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// save writes data to disk
|
||||
func (s *FileStore) save() error {
|
||||
storeData := FileStoreData{
|
||||
Sessions: s.sessions,
|
||||
Requests: s.requests,
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(storeData, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal store: %w", err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Write with restrictive permissions
|
||||
if err := os.WriteFile(s.path, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeSessionKey creates a composite key for session storage
|
||||
func makeSessionKey(did, sessionID string) string {
|
||||
return fmt.Sprintf("%s:%s", did, sessionID)
|
||||
}
|
||||
@@ -1,631 +0,0 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
)
|
||||
|
||||
func TestNewFileStore(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("Expected non-nil store")
|
||||
}
|
||||
|
||||
if store.path != storePath {
|
||||
t.Errorf("Expected path %q, got %q", storePath, store.path)
|
||||
}
|
||||
|
||||
if store.sessions == nil {
|
||||
t.Error("Expected sessions map to be initialized")
|
||||
}
|
||||
|
||||
if store.requests == nil {
|
||||
t.Error("Expected requests map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_LoadNonExistent(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/nonexistent.json"
|
||||
|
||||
// Should succeed even if file doesn't exist
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() should succeed with non-existent file, got error: %v", err)
|
||||
}
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("Expected non-nil store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_LoadCorruptedFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/corrupted.json"
|
||||
|
||||
// Create corrupted JSON file
|
||||
if err := os.WriteFile(storePath, []byte("invalid json {{{"), 0600); err != nil {
|
||||
t.Fatalf("Failed to create corrupted file: %v", err)
|
||||
}
|
||||
|
||||
// Should fail to load corrupted file
|
||||
_, err := NewFileStore(storePath)
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading corrupted file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_GetSession_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:test123")
|
||||
sessionID := "session123"
|
||||
|
||||
// Should return error for non-existent session
|
||||
session, err := store.GetSession(ctx, did, sessionID)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent session")
|
||||
}
|
||||
if session != nil {
|
||||
t.Error("Expected nil session for non-existent entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_SaveAndGetSession(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
|
||||
// Create test session
|
||||
sessionData := oauth.ClientSessionData{
|
||||
AccountDID: did,
|
||||
SessionID: "test-session-123",
|
||||
HostURL: "https://pds.example.com",
|
||||
Scopes: []string{"atproto", "blob:read"},
|
||||
}
|
||||
|
||||
// Save session
|
||||
if err := store.SaveSession(ctx, sessionData); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
// Retrieve session
|
||||
retrieved, err := store.GetSession(ctx, did, "test-session-123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession() error = %v", err)
|
||||
}
|
||||
|
||||
if retrieved == nil {
|
||||
t.Fatal("Expected non-nil session")
|
||||
}
|
||||
|
||||
if retrieved.SessionID != sessionData.SessionID {
|
||||
t.Errorf("Expected sessionID %q, got %q", sessionData.SessionID, retrieved.SessionID)
|
||||
}
|
||||
|
||||
if retrieved.AccountDID.String() != did.String() {
|
||||
t.Errorf("Expected DID %q, got %q", did.String(), retrieved.AccountDID.String())
|
||||
}
|
||||
|
||||
if retrieved.HostURL != sessionData.HostURL {
|
||||
t.Errorf("Expected hostURL %q, got %q", sessionData.HostURL, retrieved.HostURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_UpdateSession(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
|
||||
// Save initial session
|
||||
sessionData := oauth.ClientSessionData{
|
||||
AccountDID: did,
|
||||
SessionID: "test-session-123",
|
||||
HostURL: "https://pds.example.com",
|
||||
Scopes: []string{"atproto"},
|
||||
}
|
||||
|
||||
if err := store.SaveSession(ctx, sessionData); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
// Update session with new scopes
|
||||
sessionData.Scopes = []string{"atproto", "blob:read", "blob:write"}
|
||||
if err := store.SaveSession(ctx, sessionData); err != nil {
|
||||
t.Fatalf("SaveSession() (update) error = %v", err)
|
||||
}
|
||||
|
||||
// Retrieve updated session
|
||||
retrieved, err := store.GetSession(ctx, did, "test-session-123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession() error = %v", err)
|
||||
}
|
||||
|
||||
if len(retrieved.Scopes) != 3 {
|
||||
t.Errorf("Expected 3 scopes, got %d", len(retrieved.Scopes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_DeleteSession(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
|
||||
// Save session
|
||||
sessionData := oauth.ClientSessionData{
|
||||
AccountDID: did,
|
||||
SessionID: "test-session-123",
|
||||
HostURL: "https://pds.example.com",
|
||||
}
|
||||
|
||||
if err := store.SaveSession(ctx, sessionData); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
if _, err := store.GetSession(ctx, did, "test-session-123"); err != nil {
|
||||
t.Fatalf("GetSession() should succeed before delete, got error: %v", err)
|
||||
}
|
||||
|
||||
// Delete session
|
||||
if err := store.DeleteSession(ctx, did, "test-session-123"); err != nil {
|
||||
t.Fatalf("DeleteSession() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it's gone
|
||||
_, err = store.GetSession(ctx, did, "test-session-123")
|
||||
if err == nil {
|
||||
t.Error("Expected error after deleting session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_DeleteNonExistentSession(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
|
||||
// Delete non-existent session should not error
|
||||
if err := store.DeleteSession(ctx, did, "nonexistent"); err != nil {
|
||||
t.Errorf("DeleteSession() on non-existent session should not error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_SaveAndGetAuthRequestInfo(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test auth request
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
authRequest := oauth.AuthRequestData{
|
||||
State: "test-state-123",
|
||||
AuthServerURL: "https://pds.example.com",
|
||||
AccountDID: &did,
|
||||
Scopes: []string{"atproto", "blob:read"},
|
||||
RequestURI: "urn:ietf:params:oauth:request_uri:test123",
|
||||
AuthServerTokenEndpoint: "https://pds.example.com/oauth/token",
|
||||
}
|
||||
|
||||
// Save auth request
|
||||
if err := store.SaveAuthRequestInfo(ctx, authRequest); err != nil {
|
||||
t.Fatalf("SaveAuthRequestInfo() error = %v", err)
|
||||
}
|
||||
|
||||
// Retrieve auth request
|
||||
retrieved, err := store.GetAuthRequestInfo(ctx, "test-state-123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAuthRequestInfo() error = %v", err)
|
||||
}
|
||||
|
||||
if retrieved == nil {
|
||||
t.Fatal("Expected non-nil auth request")
|
||||
}
|
||||
|
||||
if retrieved.State != authRequest.State {
|
||||
t.Errorf("Expected state %q, got %q", authRequest.State, retrieved.State)
|
||||
}
|
||||
|
||||
if retrieved.AuthServerURL != authRequest.AuthServerURL {
|
||||
t.Errorf("Expected authServerURL %q, got %q", authRequest.AuthServerURL, retrieved.AuthServerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_GetAuthRequestInfo_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Should return error for non-existent request
|
||||
_, err = store.GetAuthRequestInfo(ctx, "nonexistent-state")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent auth request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_DeleteAuthRequestInfo(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save auth request
|
||||
authRequest := oauth.AuthRequestData{
|
||||
State: "test-state-123",
|
||||
AuthServerURL: "https://pds.example.com",
|
||||
}
|
||||
|
||||
if err := store.SaveAuthRequestInfo(ctx, authRequest); err != nil {
|
||||
t.Fatalf("SaveAuthRequestInfo() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
if _, err := store.GetAuthRequestInfo(ctx, "test-state-123"); err != nil {
|
||||
t.Fatalf("GetAuthRequestInfo() should succeed before delete, got error: %v", err)
|
||||
}
|
||||
|
||||
// Delete auth request
|
||||
if err := store.DeleteAuthRequestInfo(ctx, "test-state-123"); err != nil {
|
||||
t.Fatalf("DeleteAuthRequestInfo() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it's gone
|
||||
_, err = store.GetAuthRequestInfo(ctx, "test-state-123")
|
||||
if err == nil {
|
||||
t.Error("Expected error after deleting auth request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_ListSessions(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially empty
|
||||
sessions := store.ListSessions()
|
||||
if len(sessions) != 0 {
|
||||
t.Errorf("Expected 0 sessions, got %d", len(sessions))
|
||||
}
|
||||
|
||||
// Add multiple sessions
|
||||
did1, _ := syntax.ParseDID("did:plc:alice123")
|
||||
did2, _ := syntax.ParseDID("did:plc:bob456")
|
||||
|
||||
session1 := oauth.ClientSessionData{
|
||||
AccountDID: did1,
|
||||
SessionID: "session-1",
|
||||
HostURL: "https://pds1.example.com",
|
||||
}
|
||||
|
||||
session2 := oauth.ClientSessionData{
|
||||
AccountDID: did2,
|
||||
SessionID: "session-2",
|
||||
HostURL: "https://pds2.example.com",
|
||||
}
|
||||
|
||||
if err := store.SaveSession(ctx, session1); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
if err := store.SaveSession(ctx, session2); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
// List sessions
|
||||
sessions = store.ListSessions()
|
||||
if len(sessions) != 2 {
|
||||
t.Errorf("Expected 2 sessions, got %d", len(sessions))
|
||||
}
|
||||
|
||||
// Verify we got both sessions
|
||||
key1 := makeSessionKey(did1.String(), "session-1")
|
||||
key2 := makeSessionKey(did2.String(), "session-2")
|
||||
|
||||
if sessions[key1] == nil {
|
||||
t.Error("Expected session1 in list")
|
||||
}
|
||||
|
||||
if sessions[key2] == nil {
|
||||
t.Error("Expected session2 in list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_Persistence_Across_Instances(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
|
||||
// Create first store and save data
|
||||
store1, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionData := oauth.ClientSessionData{
|
||||
AccountDID: did,
|
||||
SessionID: "persistent-session",
|
||||
HostURL: "https://pds.example.com",
|
||||
}
|
||||
|
||||
if err := store1.SaveSession(ctx, sessionData); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
authRequest := oauth.AuthRequestData{
|
||||
State: "persistent-state",
|
||||
AuthServerURL: "https://pds.example.com",
|
||||
}
|
||||
|
||||
if err := store1.SaveAuthRequestInfo(ctx, authRequest); err != nil {
|
||||
t.Fatalf("SaveAuthRequestInfo() error = %v", err)
|
||||
}
|
||||
|
||||
// Create second store from same file
|
||||
store2, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Second NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify session persisted
|
||||
retrievedSession, err := store2.GetSession(ctx, did, "persistent-session")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession() from second store error = %v", err)
|
||||
}
|
||||
|
||||
if retrievedSession.SessionID != "persistent-session" {
|
||||
t.Errorf("Expected persistent session ID, got %q", retrievedSession.SessionID)
|
||||
}
|
||||
|
||||
// Verify auth request persisted
|
||||
retrievedAuth, err := store2.GetAuthRequestInfo(ctx, "persistent-state")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAuthRequestInfo() from second store error = %v", err)
|
||||
}
|
||||
|
||||
if retrievedAuth.State != "persistent-state" {
|
||||
t.Errorf("Expected persistent state, got %q", retrievedAuth.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_FileSecurity(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
|
||||
// Save some data to trigger file creation
|
||||
sessionData := oauth.ClientSessionData{
|
||||
AccountDID: did,
|
||||
SessionID: "test-session",
|
||||
HostURL: "https://pds.example.com",
|
||||
}
|
||||
|
||||
if err := store.SaveSession(ctx, sessionData); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
// Check file permissions (should be 0600)
|
||||
info, err := os.Stat(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat file: %v", err)
|
||||
}
|
||||
|
||||
mode := info.Mode()
|
||||
if mode.Perm() != 0600 {
|
||||
t.Errorf("Expected file permissions 0600, got %o", mode.Perm())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_JSONFormat(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
|
||||
// Save data
|
||||
sessionData := oauth.ClientSessionData{
|
||||
AccountDID: did,
|
||||
SessionID: "test-session",
|
||||
HostURL: "https://pds.example.com",
|
||||
}
|
||||
|
||||
if err := store.SaveSession(ctx, sessionData); err != nil {
|
||||
t.Fatalf("SaveSession() error = %v", err)
|
||||
}
|
||||
|
||||
// Read and verify JSON format
|
||||
data, err := os.ReadFile(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read file: %v", err)
|
||||
}
|
||||
|
||||
var storeData FileStoreData
|
||||
if err := json.Unmarshal(data, &storeData); err != nil {
|
||||
t.Fatalf("Failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
if storeData.Sessions == nil {
|
||||
t.Error("Expected sessions in JSON")
|
||||
}
|
||||
|
||||
if storeData.Requests == nil {
|
||||
t.Error("Expected requests in JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_CleanupExpired(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
// CleanupExpired should not error even with no data
|
||||
if err := store.CleanupExpired(); err != nil {
|
||||
t.Errorf("CleanupExpired() error = %v", err)
|
||||
}
|
||||
|
||||
// Note: Current implementation doesn't actually clean anything
|
||||
// since AuthRequestData and ClientSessionData don't have expiry timestamps
|
||||
// This test verifies the method doesn't panic
|
||||
}
|
||||
|
||||
func TestGetDefaultStorePath(t *testing.T) {
|
||||
path, err := GetDefaultStorePath()
|
||||
if err != nil {
|
||||
t.Fatalf("GetDefaultStorePath() error = %v", err)
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
t.Fatal("Expected non-empty path")
|
||||
}
|
||||
|
||||
// Path should either be /var/lib/atcr or ~/.atcr
|
||||
// We can't assert exact path since it depends on permissions
|
||||
t.Logf("Default store path: %s", path)
|
||||
}
|
||||
|
||||
func TestMakeSessionKey(t *testing.T) {
|
||||
did := "did:plc:alice123"
|
||||
sessionID := "session-456"
|
||||
|
||||
key := makeSessionKey(did, sessionID)
|
||||
expected := "did:plc:alice123:session-456"
|
||||
|
||||
if key != expected {
|
||||
t.Errorf("Expected key %q, got %q", expected, key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_ConcurrentAccess(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storePath := tmpDir + "/oauth-test.json"
|
||||
|
||||
store, err := NewFileStore(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileStore() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Run concurrent operations
|
||||
done := make(chan bool)
|
||||
|
||||
// Writer goroutine
|
||||
go func() {
|
||||
for i := 0; i < 10; i++ {
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
sessionData := oauth.ClientSessionData{
|
||||
AccountDID: did,
|
||||
SessionID: "session-1",
|
||||
HostURL: "https://pds.example.com",
|
||||
}
|
||||
store.SaveSession(ctx, sessionData)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Reader goroutine
|
||||
go func() {
|
||||
for i := 0; i < 10; i++ {
|
||||
did, _ := syntax.ParseDID("did:plc:alice123")
|
||||
store.GetSession(ctx, did, "session-1")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Wait for both goroutines
|
||||
<-done
|
||||
<-done
|
||||
|
||||
// If we got here without panicking, the locking works
|
||||
t.Log("Concurrent access test passed")
|
||||
}
|
||||
300
pkg/auth/servicetoken.go
Normal file
300
pkg/auth/servicetoken.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/atclient"
|
||||
indigo_oauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
)
|
||||
|
||||
// getErrorHint provides context-specific troubleshooting hints based on API error type
|
||||
func getErrorHint(apiErr *atclient.APIError) string {
|
||||
switch apiErr.Name {
|
||||
case "use_dpop_nonce":
|
||||
return "DPoP nonce mismatch - indigo library should automatically retry with new nonce. If this persists, check for concurrent request issues or PDS session corruption."
|
||||
case "invalid_client":
|
||||
if apiErr.Message != "" && apiErr.Message == "Validation of \"client_assertion\" failed: \"iat\" claim timestamp check failed (it should be in the past)" {
|
||||
return "JWT timestamp validation failed - system clock on AppView may be ahead of PDS clock. Check NTP sync with: timedatectl status"
|
||||
}
|
||||
return "OAuth client authentication failed - check client key configuration and PDS OAuth server status"
|
||||
case "invalid_token", "invalid_grant":
|
||||
return "OAuth tokens expired or invalidated - user will need to re-authenticate via OAuth flow"
|
||||
case "server_error":
|
||||
if apiErr.StatusCode == 500 {
|
||||
return "PDS returned internal server error - this may occur after repeated DPoP nonce failures or other PDS-side issues. Check PDS logs for root cause."
|
||||
}
|
||||
return "PDS server error - check PDS health and logs"
|
||||
case "invalid_dpop_proof":
|
||||
return "DPoP proof validation failed - check system clock sync and DPoP key configuration"
|
||||
default:
|
||||
if apiErr.StatusCode == 401 || apiErr.StatusCode == 403 {
|
||||
return "Authentication/authorization failed - OAuth session may be expired or revoked"
|
||||
}
|
||||
return "PDS rejected the request - see errorName and errorMessage for details"
|
||||
}
|
||||
}
|
||||
|
||||
// ParseJWTExpiry extracts the expiry time from a JWT without verifying the signature
|
||||
// We trust tokens from the user's PDS, so signature verification isn't needed here
|
||||
// Manually decodes the JWT payload to avoid algorithm compatibility issues
|
||||
func ParseJWTExpiry(tokenString string) (time.Time, error) {
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON payload
|
||||
var claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
if claims.Exp == 0 {
|
||||
return time.Time{}, fmt.Errorf("JWT missing exp claim")
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
||||
// buildServiceAuthURL constructs the URL for com.atproto.server.getServiceAuth
|
||||
func buildServiceAuthURL(pdsEndpoint, holdDID string) string {
|
||||
// Request 5-minute expiry (PDS may grant less)
|
||||
// exp must be absolute Unix timestamp, not relative duration
|
||||
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
|
||||
return fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
pdsEndpoint,
|
||||
atproto.ServerGetServiceAuth,
|
||||
url.QueryEscape(holdDID),
|
||||
url.QueryEscape("com.atproto.repo.getRecord"),
|
||||
expiryTime,
|
||||
)
|
||||
}
|
||||
|
||||
// parseServiceTokenResponse extracts the token from a service auth response
|
||||
func parseServiceTokenResponse(resp *http.Response) (string, error) {
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to decode service auth response: %w", err)
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
return "", fmt.Errorf("empty token in service auth response")
|
||||
}
|
||||
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
// GetOrFetchServiceToken gets a service token for hold authentication.
|
||||
// Handles both OAuth/DPoP and app-password authentication based on authMethod.
|
||||
// Checks cache first, then fetches from PDS if needed.
|
||||
//
|
||||
// For OAuth: Uses DoWithSession() to hold a per-DID lock through the entire PDS interaction.
|
||||
// This prevents DPoP nonce race conditions when multiple Docker layers upload concurrently.
|
||||
//
|
||||
// For app-password: Uses Bearer token authentication without locking (no DPoP complexity).
|
||||
func GetOrFetchServiceToken(
|
||||
ctx context.Context,
|
||||
authMethod string,
|
||||
refresher *oauth.Refresher, // Required for OAuth, nil for app-password
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Using cached service token",
|
||||
"did", did,
|
||||
"authMethod", authMethod,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
return cachedToken, nil
|
||||
}
|
||||
|
||||
// Cache miss or expiring soon - fetch new service token
|
||||
if cachedToken == "" {
|
||||
slog.Debug("Service token cache miss, fetching new token", "did", did, "authMethod", authMethod)
|
||||
} else {
|
||||
slog.Debug("Service token expiring soon, proactively renewing", "did", did, "authMethod", authMethod)
|
||||
}
|
||||
|
||||
var serviceToken string
|
||||
var err error
|
||||
|
||||
// Branch based on auth method
|
||||
if authMethod == AuthMethodOAuth {
|
||||
serviceToken, err = doOAuthFetch(ctx, refresher, did, holdDID, pdsEndpoint)
|
||||
// OAuth-specific cleanup: delete stale session on error
|
||||
if err != nil && refresher != nil {
|
||||
if delErr := refresher.DeleteSession(ctx, did); delErr != nil {
|
||||
slog.Warn("Failed to delete stale OAuth session",
|
||||
"component", "auth/servicetoken",
|
||||
"did", did,
|
||||
"error", delErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
serviceToken, err = doAppPasswordFetch(ctx, did, holdDID, pdsEndpoint)
|
||||
}
|
||||
|
||||
// Unified error handling
|
||||
if err != nil {
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
|
||||
var apiErr *atclient.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
slog.Error("Service token request failed",
|
||||
"component", "auth/servicetoken",
|
||||
"authMethod", authMethod,
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err,
|
||||
"httpStatus", apiErr.StatusCode,
|
||||
"errorName", apiErr.Name,
|
||||
"errorMessage", apiErr.Message,
|
||||
"hint", getErrorHint(apiErr))
|
||||
} else {
|
||||
slog.Error("Service token request failed",
|
||||
"component", "auth/servicetoken",
|
||||
"authMethod", authMethod,
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if cacheErr := SetServiceToken(did, holdDID, serviceToken); cacheErr != nil {
|
||||
slog.Warn("Failed to cache service token", "error", cacheErr, "did", did, "holdDID", holdDID)
|
||||
}
|
||||
|
||||
slog.Debug("Service token obtained", "did", did, "authMethod", authMethod)
|
||||
return serviceToken, nil
|
||||
}
|
||||
|
||||
// doOAuthFetch fetches a service token using OAuth/DPoP authentication.
|
||||
// Uses DoWithSession() for per-DID locking to prevent DPoP nonce races.
|
||||
// Returns (token, error) without logging - caller handles error logging.
|
||||
func doOAuthFetch(
|
||||
ctx context.Context,
|
||||
refresher *oauth.Refresher,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
if refresher == nil {
|
||||
return "", fmt.Errorf("refresher is nil (OAuth session required)")
|
||||
}
|
||||
|
||||
var serviceToken string
|
||||
var fetchErr error
|
||||
|
||||
err := refresher.DoWithSession(ctx, did, func(session *indigo_oauth.ClientSession) error {
|
||||
// Double-check cache after acquiring lock (double-checked locking pattern)
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Service token cache hit after lock acquisition",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
serviceToken = cachedToken
|
||||
return nil
|
||||
}
|
||||
|
||||
serviceAuthURL := buildServiceAuthURL(pdsEndpoint, holdDID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
fetchErr = fmt.Errorf("failed to create request: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
||||
if err != nil {
|
||||
fetchErr = fmt.Errorf("OAuth request failed: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
token, parseErr := parseServiceTokenResponse(resp)
|
||||
if parseErr != nil {
|
||||
fetchErr = parseErr
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
serviceToken = token
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if fetchErr != nil {
|
||||
return "", fetchErr
|
||||
}
|
||||
return "", fmt.Errorf("failed to get OAuth session: %w", err)
|
||||
}
|
||||
|
||||
return serviceToken, nil
|
||||
}
|
||||
|
||||
// doAppPasswordFetch fetches a service token using Bearer token authentication.
|
||||
// Returns (token, error) without logging - caller handles error logging.
|
||||
func doAppPasswordFetch(
|
||||
ctx context.Context,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
accessToken, ok := GetGlobalTokenCache().Get(did)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no app-password access token available for DID %s", did)
|
||||
}
|
||||
|
||||
serviceAuthURL := buildServiceAuthURL(pdsEndpoint, holdDID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
resp.Body.Close()
|
||||
// Clear stale app-password token
|
||||
GetGlobalTokenCache().Delete(did)
|
||||
return "", fmt.Errorf("app-password authentication failed: token expired or invalid")
|
||||
}
|
||||
|
||||
return parseServiceTokenResponse(resp)
|
||||
}
|
||||
27
pkg/auth/servicetoken_test.go
Normal file
27
pkg/auth/servicetoken_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetOrFetchServiceToken_NilRefresher(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
did := "did:plc:test123"
|
||||
holdDID := "did:web:hold.example.com"
|
||||
pdsEndpoint := "https://pds.example.com"
|
||||
|
||||
// Test with nil refresher and OAuth auth method - should return error
|
||||
_, err := GetOrFetchServiceToken(ctx, AuthMethodOAuth, nil, did, holdDID, pdsEndpoint)
|
||||
if err == nil {
|
||||
t.Error("Expected error when refresher is nil for OAuth")
|
||||
}
|
||||
|
||||
expectedErrMsg := "refresher is nil (OAuth session required)"
|
||||
if err.Error() != expectedErrMsg {
|
||||
t.Errorf("Expected error message %q, got %q", expectedErrMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Full tests with mocked OAuth refresher and HTTP client will be added
|
||||
// in the comprehensive test implementation phase
|
||||
@@ -56,3 +56,22 @@ func ExtractAuthMethod(tokenString string) string {
|
||||
|
||||
return claims.AuthMethod
|
||||
}
|
||||
|
||||
// ExtractSubject parses a JWT token string and extracts the Subject claim (the user's DID)
|
||||
// Returns the subject or empty string if not found or token is invalid
|
||||
// This does NOT validate the token - it only parses it to extract the claim
|
||||
func ExtractSubject(tokenString string) string {
|
||||
// Parse token without validation (we only need the claims, validation is done by distribution library)
|
||||
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
||||
token, _, err := parser.ParseUnverified(tokenString, &Claims{})
|
||||
if err != nil {
|
||||
return "" // Invalid token format
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return "" // Wrong claims type
|
||||
}
|
||||
|
||||
return claims.Subject
|
||||
}
|
||||
|
||||
@@ -1,362 +0,0 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/atclient"
|
||||
indigo_oauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
)
|
||||
|
||||
// getErrorHint provides context-specific troubleshooting hints based on API error type
|
||||
func getErrorHint(apiErr *atclient.APIError) string {
|
||||
switch apiErr.Name {
|
||||
case "use_dpop_nonce":
|
||||
return "DPoP nonce mismatch - indigo library should automatically retry with new nonce. If this persists, check for concurrent request issues or PDS session corruption."
|
||||
case "invalid_client":
|
||||
if apiErr.Message != "" && apiErr.Message == "Validation of \"client_assertion\" failed: \"iat\" claim timestamp check failed (it should be in the past)" {
|
||||
return "JWT timestamp validation failed - system clock on AppView may be ahead of PDS clock. Check NTP sync with: timedatectl status"
|
||||
}
|
||||
return "OAuth client authentication failed - check client key configuration and PDS OAuth server status"
|
||||
case "invalid_token", "invalid_grant":
|
||||
return "OAuth tokens expired or invalidated - user will need to re-authenticate via OAuth flow"
|
||||
case "server_error":
|
||||
if apiErr.StatusCode == 500 {
|
||||
return "PDS returned internal server error - this may occur after repeated DPoP nonce failures or other PDS-side issues. Check PDS logs for root cause."
|
||||
}
|
||||
return "PDS server error - check PDS health and logs"
|
||||
case "invalid_dpop_proof":
|
||||
return "DPoP proof validation failed - check system clock sync and DPoP key configuration"
|
||||
default:
|
||||
if apiErr.StatusCode == 401 || apiErr.StatusCode == 403 {
|
||||
return "Authentication/authorization failed - OAuth session may be expired or revoked"
|
||||
}
|
||||
return "PDS rejected the request - see errorName and errorMessage for details"
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrFetchServiceToken gets a service token for hold authentication.
|
||||
// Checks cache first, then fetches from PDS with OAuth/DPoP if needed.
|
||||
// This is the canonical implementation used by both middleware and crew registration.
|
||||
//
|
||||
// IMPORTANT: Uses DoWithSession() to hold a per-DID lock through the entire PDS interaction.
|
||||
// This prevents DPoP nonce race conditions when multiple Docker layers upload concurrently.
|
||||
func GetOrFetchServiceToken(
|
||||
ctx context.Context,
|
||||
refresher *oauth.Refresher,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
if refresher == nil {
|
||||
return "", fmt.Errorf("refresher is nil (OAuth session required for service tokens)")
|
||||
}
|
||||
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Using cached service token",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
return cachedToken, nil
|
||||
}
|
||||
|
||||
// Cache miss or expiring soon - validate OAuth and get new service token
|
||||
if cachedToken == "" {
|
||||
slog.Debug("Service token cache miss, fetching new token", "did", did)
|
||||
} else {
|
||||
slog.Debug("Service token expiring soon, proactively renewing", "did", did)
|
||||
}
|
||||
|
||||
// Use DoWithSession to hold the lock through the entire PDS interaction.
|
||||
// This prevents DPoP nonce races when multiple goroutines try to fetch service tokens.
|
||||
var serviceToken string
|
||||
var fetchErr error
|
||||
|
||||
err := refresher.DoWithSession(ctx, did, func(session *indigo_oauth.ClientSession) error {
|
||||
// Double-check cache after acquiring lock - another goroutine may have
|
||||
// populated it while we were waiting (classic double-checked locking pattern)
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Service token cache hit after lock acquisition",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
serviceToken = cachedToken
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cache still empty/expired - proceed with PDS call
|
||||
// Request 5-minute expiry (PDS may grant less)
|
||||
// exp must be absolute Unix timestamp, not relative duration
|
||||
// Note: OAuth scope includes #atcr_hold fragment, but service auth aud must be bare DID
|
||||
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
|
||||
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
pdsEndpoint,
|
||||
atproto.ServerGetServiceAuth,
|
||||
url.QueryEscape(holdDID),
|
||||
url.QueryEscape("com.atproto.repo.getRecord"),
|
||||
expiryTime,
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
fetchErr = fmt.Errorf("failed to create service auth request: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
// Use OAuth session to authenticate to PDS (with DPoP)
|
||||
// The lock is held, so DPoP nonce negotiation is serialized per-DID
|
||||
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
||||
if err != nil {
|
||||
// Auth error - may indicate expired tokens or corrupted session
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
|
||||
// Inspect the error to extract detailed information from indigo's APIError
|
||||
var apiErr *atclient.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
// Log detailed API error information
|
||||
slog.Error("OAuth authentication failed during service token request",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"url", serviceAuthURL,
|
||||
"error", err,
|
||||
"httpStatus", apiErr.StatusCode,
|
||||
"errorName", apiErr.Name,
|
||||
"errorMessage", apiErr.Message,
|
||||
"hint", getErrorHint(apiErr))
|
||||
} else {
|
||||
// Fallback for non-API errors (network errors, etc.)
|
||||
slog.Error("OAuth authentication failed during service token request",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"url", serviceAuthURL,
|
||||
"error", err,
|
||||
"errorType", fmt.Sprintf("%T", err),
|
||||
"hint", "Network error or unexpected failure during OAuth request")
|
||||
}
|
||||
|
||||
fetchErr = fmt.Errorf("OAuth validation failed: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Service auth failed
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("Service token request returned non-200 status",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseBody", string(bodyBytes),
|
||||
"hint", "PDS rejected the service token request - check PDS logs for details")
|
||||
fetchErr = fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
// Parse response to get service token
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
fetchErr = fmt.Errorf("failed to decode service auth response: %w", err)
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
fetchErr = fmt.Errorf("empty token in service auth response")
|
||||
return fetchErr
|
||||
}
|
||||
|
||||
serviceToken = result.Token
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// DoWithSession failed (session load or callback error)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
|
||||
// Try to extract detailed error information
|
||||
var apiErr *atclient.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
slog.Error("Failed to get OAuth session for service token",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err,
|
||||
"httpStatus", apiErr.StatusCode,
|
||||
"errorName", apiErr.Name,
|
||||
"errorMessage", apiErr.Message,
|
||||
"hint", getErrorHint(apiErr))
|
||||
} else if fetchErr == nil {
|
||||
// Session load failed (not a fetch error)
|
||||
slog.Error("Failed to get OAuth session for service token",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err,
|
||||
"errorType", fmt.Sprintf("%T", err),
|
||||
"hint", "OAuth session not found in database or token refresh failed")
|
||||
}
|
||||
|
||||
// Delete the stale OAuth session to force re-authentication
|
||||
// This also invalidates the UI session automatically
|
||||
if delErr := refresher.DeleteSession(ctx, did); delErr != nil {
|
||||
slog.Warn("Failed to delete stale OAuth session",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"error", delErr)
|
||||
}
|
||||
|
||||
if fetchErr != nil {
|
||||
return "", fetchErr
|
||||
}
|
||||
return "", fmt.Errorf("failed to get OAuth session: %w", err)
|
||||
}
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if err := SetServiceToken(did, holdDID, serviceToken); err != nil {
|
||||
slog.Warn("Failed to cache service token", "error", err, "did", did, "holdDID", holdDID)
|
||||
// Non-fatal - we have the token, just won't be cached
|
||||
}
|
||||
|
||||
slog.Debug("OAuth validation succeeded, service token obtained", "did", did)
|
||||
return serviceToken, nil
|
||||
}
|
||||
|
||||
// GetOrFetchServiceTokenWithAppPassword gets a service token using app-password Bearer authentication.
|
||||
// Used when auth method is app_password instead of OAuth.
|
||||
func GetOrFetchServiceTokenWithAppPassword(
|
||||
ctx context.Context,
|
||||
did, holdDID, pdsEndpoint string,
|
||||
) (string, error) {
|
||||
// Check cache first to avoid unnecessary PDS calls on every request
|
||||
cachedToken, expiresAt := GetServiceToken(did, holdDID)
|
||||
|
||||
// Use cached token if it exists and has > 10s remaining
|
||||
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
|
||||
slog.Debug("Using cached service token (app-password)",
|
||||
"did", did,
|
||||
"expiresIn", time.Until(expiresAt).Round(time.Second))
|
||||
return cachedToken, nil
|
||||
}
|
||||
|
||||
// Cache miss or expiring soon - get app-password token and fetch new service token
|
||||
if cachedToken == "" {
|
||||
slog.Debug("Service token cache miss, fetching new token with app-password", "did", did)
|
||||
} else {
|
||||
slog.Debug("Service token expiring soon, proactively renewing with app-password", "did", did)
|
||||
}
|
||||
|
||||
// Get app-password access token from cache
|
||||
accessToken, ok := auth.GetGlobalTokenCache().Get(did)
|
||||
if !ok {
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("No app-password access token found in cache",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"hint", "User must re-authenticate with docker login")
|
||||
return "", fmt.Errorf("no app-password access token available for DID %s", did)
|
||||
}
|
||||
|
||||
// Call com.atproto.server.getServiceAuth on the user's PDS with Bearer token
|
||||
// Request 5-minute expiry (PDS may grant less)
|
||||
// exp must be absolute Unix timestamp, not relative duration
|
||||
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
|
||||
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
|
||||
pdsEndpoint,
|
||||
atproto.ServerGetServiceAuth,
|
||||
url.QueryEscape(holdDID),
|
||||
url.QueryEscape("com.atproto.repo.getRecord"),
|
||||
expiryTime,
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create service auth request: %w", err)
|
||||
}
|
||||
|
||||
// Set Bearer token authentication (app-password)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
// Make request with standard HTTP client
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("App-password service token request failed",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"error", err)
|
||||
return "", fmt.Errorf("failed to request service token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
// App-password token is invalid or expired - clear from cache
|
||||
auth.GetGlobalTokenCache().Delete(did)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("App-password token rejected by PDS",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"hint", "User must re-authenticate with docker login")
|
||||
return "", fmt.Errorf("app-password authentication failed: token expired or invalid")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Service auth failed
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
InvalidateServiceToken(did, holdDID)
|
||||
slog.Error("Service token request returned non-200 status (app-password)",
|
||||
"component", "token/servicetoken",
|
||||
"did", did,
|
||||
"holdDID", holdDID,
|
||||
"pdsEndpoint", pdsEndpoint,
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseBody", string(bodyBytes))
|
||||
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Parse response to get service token
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to decode service auth response: %w", err)
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
return "", fmt.Errorf("empty token in service auth response")
|
||||
}
|
||||
|
||||
serviceToken := result.Token
|
||||
|
||||
// Cache the token (parses JWT to extract actual expiry)
|
||||
if err := SetServiceToken(did, holdDID, serviceToken); err != nil {
|
||||
slog.Warn("Failed to cache service token", "error", err, "did", did, "holdDID", holdDID)
|
||||
// Non-fatal - we have the token, just won't be cached
|
||||
}
|
||||
|
||||
slog.Debug("App-password validation succeeded, service token obtained", "did", did)
|
||||
return serviceToken, nil
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetOrFetchServiceToken_NilRefresher(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
did := "did:plc:test123"
|
||||
holdDID := "did:web:hold.example.com"
|
||||
pdsEndpoint := "https://pds.example.com"
|
||||
|
||||
// Test with nil refresher - should return error
|
||||
_, err := GetOrFetchServiceToken(ctx, nil, did, holdDID, pdsEndpoint)
|
||||
if err == nil {
|
||||
t.Error("Expected error when refresher is nil")
|
||||
}
|
||||
|
||||
expectedErrMsg := "refresher is nil"
|
||||
if err.Error() != "refresher is nil (OAuth session required for service tokens)" {
|
||||
t.Errorf("Expected error message to contain %q, got %q", expectedErrMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Full tests with mocked OAuth refresher and HTTP client will be added
|
||||
// in the comprehensive test implementation phase
|
||||
784
pkg/auth/usercontext.go
Normal file
784
pkg/auth/usercontext.go
Normal file
@@ -0,0 +1,784 @@
|
||||
// Package auth provides UserContext for managing authenticated user state
|
||||
// throughout request handling in the AppView.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/atproto"
|
||||
"atcr.io/pkg/auth/oauth"
|
||||
)
|
||||
|
||||
// Auth method constants (duplicated from token package to avoid import cycle)
|
||||
const (
|
||||
AuthMethodOAuth = "oauth"
|
||||
AuthMethodAppPassword = "app_password"
|
||||
)
|
||||
|
||||
// RequestAction represents the type of registry operation
|
||||
type RequestAction int
|
||||
|
||||
const (
|
||||
ActionUnknown RequestAction = iota
|
||||
ActionPull // GET/HEAD - reading from registry
|
||||
ActionPush // PUT/POST/DELETE - writing to registry
|
||||
ActionInspect // Metadata operations only
|
||||
)
|
||||
|
||||
func (a RequestAction) String() string {
|
||||
switch a {
|
||||
case ActionPull:
|
||||
return "pull"
|
||||
case ActionPush:
|
||||
return "push"
|
||||
case ActionInspect:
|
||||
return "inspect"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// HoldPermissions describes what the user can do on a specific hold
|
||||
type HoldPermissions struct {
|
||||
HoldDID string // Hold being checked
|
||||
IsOwner bool // User is captain of this hold
|
||||
IsCrew bool // User is a crew member
|
||||
IsPublic bool // Hold allows public reads
|
||||
CanRead bool // Computed: can user read blobs?
|
||||
CanWrite bool // Computed: can user write blobs?
|
||||
CanAdmin bool // Computed: can user manage crew?
|
||||
Permissions []string // Raw permissions from crew record
|
||||
}
|
||||
|
||||
// contextKey is unexported to prevent collisions
|
||||
type contextKey struct{}
|
||||
|
||||
// userContextKey is the context key for UserContext
|
||||
var userContextKey = contextKey{}
|
||||
|
||||
// userSetupCache tracks which users have had their profile/crew setup ensured
|
||||
var userSetupCache sync.Map // did -> time.Time
|
||||
|
||||
// userSetupTTL is how long to cache user setup status (1 hour)
|
||||
const userSetupTTL = 1 * time.Hour
|
||||
|
||||
// Dependencies bundles services needed by UserContext
|
||||
type Dependencies struct {
|
||||
Refresher *oauth.Refresher
|
||||
Authorizer HoldAuthorizer
|
||||
DefaultHoldDID string // AppView's default hold DID
|
||||
}
|
||||
|
||||
// UserContext encapsulates authenticated user state for a request.
|
||||
// Built early in the middleware chain and available throughout request processing.
|
||||
//
|
||||
// Two-phase initialization:
|
||||
// 1. Middleware phase: Identity is set (DID, authMethod, action)
|
||||
// 2. Repository() phase: Target is set via SetTarget() (owner, repo, holdDID)
|
||||
type UserContext struct {
|
||||
// === User Identity (set in middleware) ===
|
||||
DID string // User's DID (empty if unauthenticated)
|
||||
Handle string // User's handle (may be empty)
|
||||
PDSEndpoint string // User's PDS endpoint
|
||||
AuthMethod string // "oauth", "app_password", or ""
|
||||
IsAuthenticated bool
|
||||
|
||||
// === Request Info ===
|
||||
Action RequestAction
|
||||
HTTPMethod string
|
||||
|
||||
// === Target Info (set by SetTarget) ===
|
||||
TargetOwnerDID string // whose repo is being accessed
|
||||
TargetOwnerHandle string
|
||||
TargetOwnerPDS string
|
||||
TargetRepo string // image name (e.g., "quickslice")
|
||||
TargetHoldDID string // hold where blobs live/will live
|
||||
|
||||
// === Dependencies (injected) ===
|
||||
refresher *oauth.Refresher
|
||||
authorizer HoldAuthorizer
|
||||
defaultHoldDID string
|
||||
|
||||
// === Cached State (lazy-loaded) ===
|
||||
serviceTokens sync.Map // holdDID -> *serviceTokenEntry
|
||||
permissions sync.Map // holdDID -> *HoldPermissions
|
||||
pdsResolved bool
|
||||
pdsResolveErr error
|
||||
mu sync.Mutex // protects PDS resolution
|
||||
atprotoClient *atproto.Client
|
||||
atprotoClientOnce sync.Once
|
||||
}
|
||||
|
||||
// FromContext retrieves UserContext from context.
|
||||
// Returns nil if not present (unauthenticated or before middleware).
|
||||
func FromContext(ctx context.Context) *UserContext {
|
||||
uc, _ := ctx.Value(userContextKey).(*UserContext)
|
||||
return uc
|
||||
}
|
||||
|
||||
// WithUserContext adds UserContext to context
|
||||
func WithUserContext(ctx context.Context, uc *UserContext) context.Context {
|
||||
return context.WithValue(ctx, userContextKey, uc)
|
||||
}
|
||||
|
||||
// NewUserContext creates a UserContext from extracted JWT claims.
|
||||
// The deps parameter provides access to services needed for lazy operations.
|
||||
func NewUserContext(did, authMethod, httpMethod string, deps *Dependencies) *UserContext {
|
||||
action := ActionUnknown
|
||||
switch httpMethod {
|
||||
case "GET", "HEAD":
|
||||
action = ActionPull
|
||||
case "PUT", "POST", "PATCH", "DELETE":
|
||||
action = ActionPush
|
||||
}
|
||||
|
||||
var refresher *oauth.Refresher
|
||||
var authorizer HoldAuthorizer
|
||||
var defaultHoldDID string
|
||||
|
||||
if deps != nil {
|
||||
refresher = deps.Refresher
|
||||
authorizer = deps.Authorizer
|
||||
defaultHoldDID = deps.DefaultHoldDID
|
||||
}
|
||||
|
||||
return &UserContext{
|
||||
DID: did,
|
||||
AuthMethod: authMethod,
|
||||
IsAuthenticated: did != "",
|
||||
Action: action,
|
||||
HTTPMethod: httpMethod,
|
||||
refresher: refresher,
|
||||
authorizer: authorizer,
|
||||
defaultHoldDID: defaultHoldDID,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPDS sets the user's PDS endpoint directly, bypassing network resolution.
|
||||
// Use when PDS is already known (e.g., from previous resolution or client).
|
||||
func (uc *UserContext) SetPDS(handle, pdsEndpoint string) {
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
uc.Handle = handle
|
||||
uc.PDSEndpoint = pdsEndpoint
|
||||
uc.pdsResolved = true
|
||||
uc.pdsResolveErr = nil
|
||||
}
|
||||
|
||||
// SetTarget sets the target repository information.
|
||||
// Called in Repository() after resolving the owner identity.
|
||||
func (uc *UserContext) SetTarget(ownerDID, ownerHandle, ownerPDS, repo, holdDID string) {
|
||||
uc.TargetOwnerDID = ownerDID
|
||||
uc.TargetOwnerHandle = ownerHandle
|
||||
uc.TargetOwnerPDS = ownerPDS
|
||||
uc.TargetRepo = repo
|
||||
uc.TargetHoldDID = holdDID
|
||||
}
|
||||
|
||||
// ResolvePDS resolves the user's PDS endpoint (lazy, cached).
|
||||
// Safe to call multiple times; resolution happens once.
|
||||
func (uc *UserContext) ResolvePDS(ctx context.Context) error {
|
||||
if !uc.IsAuthenticated {
|
||||
return nil // Nothing to resolve for anonymous users
|
||||
}
|
||||
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
|
||||
if uc.pdsResolved {
|
||||
return uc.pdsResolveErr
|
||||
}
|
||||
|
||||
_, handle, pds, err := atproto.ResolveIdentity(ctx, uc.DID)
|
||||
if err != nil {
|
||||
uc.pdsResolveErr = err
|
||||
uc.pdsResolved = true
|
||||
return err
|
||||
}
|
||||
|
||||
uc.Handle = handle
|
||||
uc.PDSEndpoint = pds
|
||||
uc.pdsResolved = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServiceToken returns a service token for the target hold.
|
||||
// Uses internal caching with sync.Once per holdDID.
|
||||
// Requires target to be set via SetTarget().
|
||||
func (uc *UserContext) GetServiceToken(ctx context.Context) (string, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return "", fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
return uc.GetServiceTokenForHold(ctx, uc.TargetHoldDID)
|
||||
}
|
||||
|
||||
// GetServiceTokenForHold returns a service token for an arbitrary hold.
|
||||
// Uses internal caching with sync.Once per holdDID.
|
||||
func (uc *UserContext) GetServiceTokenForHold(ctx context.Context, holdDID string) (string, error) {
|
||||
if !uc.IsAuthenticated {
|
||||
return "", fmt.Errorf("cannot get service token: user not authenticated")
|
||||
}
|
||||
|
||||
// Ensure PDS is resolved
|
||||
if err := uc.ResolvePDS(ctx); err != nil {
|
||||
return "", fmt.Errorf("failed to resolve PDS: %w", err)
|
||||
}
|
||||
|
||||
// Load or create cache entry
|
||||
entryVal, _ := uc.serviceTokens.LoadOrStore(holdDID, &serviceTokenEntry{})
|
||||
entry := entryVal.(*serviceTokenEntry)
|
||||
|
||||
entry.once.Do(func() {
|
||||
slog.Debug("Fetching service token",
|
||||
"component", "auth/context",
|
||||
"userDID", uc.DID,
|
||||
"holdDID", holdDID,
|
||||
"authMethod", uc.AuthMethod)
|
||||
|
||||
// Use unified service token function (handles both OAuth and app-password)
|
||||
serviceToken, err := GetOrFetchServiceToken(
|
||||
ctx, uc.AuthMethod, uc.refresher, uc.DID, holdDID, uc.PDSEndpoint,
|
||||
)
|
||||
|
||||
entry.token = serviceToken
|
||||
entry.err = err
|
||||
if err == nil {
|
||||
// Parse JWT to get expiry
|
||||
expiry, parseErr := ParseJWTExpiry(serviceToken)
|
||||
if parseErr == nil {
|
||||
entry.expiresAt = expiry.Add(-10 * time.Second) // Safety margin
|
||||
} else {
|
||||
entry.expiresAt = time.Now().Add(45 * time.Second) // Default fallback
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return entry.token, entry.err
|
||||
}
|
||||
|
||||
// CanRead checks if user can read blobs from target hold.
|
||||
// - Public hold: any user (even anonymous)
|
||||
// - Private hold: owner OR crew with blob:read/blob:write
|
||||
func (uc *UserContext) CanRead(ctx context.Context) (bool, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return false, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return false, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
return uc.authorizer.CheckReadAccess(ctx, uc.TargetHoldDID, uc.DID)
|
||||
}
|
||||
|
||||
// CanWrite checks if user can write blobs to target hold.
|
||||
// - Must be authenticated
|
||||
// - Must be owner OR crew with blob:write
|
||||
func (uc *UserContext) CanWrite(ctx context.Context) (bool, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return false, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
|
||||
if !uc.IsAuthenticated {
|
||||
return false, nil // Anonymous writes never allowed
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return false, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
return uc.authorizer.CheckWriteAccess(ctx, uc.TargetHoldDID, uc.DID)
|
||||
}
|
||||
|
||||
// GetPermissions returns detailed permissions for target hold.
|
||||
// Lazy-loaded and cached per holdDID.
|
||||
func (uc *UserContext) GetPermissions(ctx context.Context) (*HoldPermissions, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return nil, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
return uc.GetPermissionsForHold(ctx, uc.TargetHoldDID)
|
||||
}
|
||||
|
||||
// GetPermissionsForHold returns detailed permissions for an arbitrary hold.
|
||||
// Lazy-loaded and cached per holdDID.
|
||||
func (uc *UserContext) GetPermissionsForHold(ctx context.Context, holdDID string) (*HoldPermissions, error) {
|
||||
// Check cache first
|
||||
if cached, ok := uc.permissions.Load(holdDID); ok {
|
||||
return cached.(*HoldPermissions), nil
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return nil, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
// Build permissions by querying authorizer
|
||||
captain, err := uc.authorizer.GetCaptainRecord(ctx, holdDID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get captain record: %w", err)
|
||||
}
|
||||
|
||||
perms := &HoldPermissions{
|
||||
HoldDID: holdDID,
|
||||
IsPublic: captain.Public,
|
||||
IsOwner: uc.DID != "" && uc.DID == captain.Owner,
|
||||
}
|
||||
|
||||
// Check crew membership if authenticated and not owner
|
||||
if uc.IsAuthenticated && !perms.IsOwner {
|
||||
isCrew, crewErr := uc.authorizer.IsCrewMember(ctx, holdDID, uc.DID)
|
||||
if crewErr != nil {
|
||||
slog.Warn("Failed to check crew membership",
|
||||
"component", "auth/context",
|
||||
"holdDID", holdDID,
|
||||
"userDID", uc.DID,
|
||||
"error", crewErr)
|
||||
}
|
||||
perms.IsCrew = isCrew
|
||||
}
|
||||
|
||||
// Compute permissions based on role
|
||||
if perms.IsOwner {
|
||||
perms.CanRead = true
|
||||
perms.CanWrite = true
|
||||
perms.CanAdmin = true
|
||||
} else if perms.IsCrew {
|
||||
// Crew members can read and write (for now, all crew have blob:write)
|
||||
// TODO: Check specific permissions from crew record
|
||||
perms.CanRead = true
|
||||
perms.CanWrite = true
|
||||
perms.CanAdmin = false
|
||||
} else if perms.IsPublic {
|
||||
// Public hold - anyone can read
|
||||
perms.CanRead = true
|
||||
perms.CanWrite = false
|
||||
perms.CanAdmin = false
|
||||
} else if uc.IsAuthenticated {
|
||||
// Private hold, authenticated non-crew
|
||||
// Per permission matrix: cannot read private holds
|
||||
perms.CanRead = false
|
||||
perms.CanWrite = false
|
||||
perms.CanAdmin = false
|
||||
} else {
|
||||
// Anonymous on private hold
|
||||
perms.CanRead = false
|
||||
perms.CanWrite = false
|
||||
perms.CanAdmin = false
|
||||
}
|
||||
|
||||
// Cache and return
|
||||
uc.permissions.Store(holdDID, perms)
|
||||
return perms, nil
|
||||
}
|
||||
|
||||
// IsCrewMember checks if user is crew of target hold.
|
||||
func (uc *UserContext) IsCrewMember(ctx context.Context) (bool, error) {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return false, fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
|
||||
if !uc.IsAuthenticated {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if uc.authorizer == nil {
|
||||
return false, fmt.Errorf("authorizer not configured")
|
||||
}
|
||||
|
||||
return uc.authorizer.IsCrewMember(ctx, uc.TargetHoldDID, uc.DID)
|
||||
}
|
||||
|
||||
// EnsureCrewMembership is a standalone function to register as crew on a hold.
|
||||
// Use this when you don't have a UserContext (e.g., OAuth callback).
|
||||
// This is best-effort and logs errors without failing.
|
||||
func EnsureCrewMembership(ctx context.Context, did, pdsEndpoint string, refresher *oauth.Refresher, holdDID string) {
|
||||
if holdDID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Only works with OAuth (refresher required) - app passwords can't get service tokens
|
||||
if refresher == nil {
|
||||
slog.Debug("skipping crew registration - no OAuth refresher (app password flow)", "holdDID", holdDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize URL to DID if needed
|
||||
if !atproto.IsDID(holdDID) {
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
|
||||
if holdDID == "" {
|
||||
slog.Warn("failed to resolve hold DID", "defaultHold", holdDID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get service token for the hold (OAuth only at this point)
|
||||
serviceToken, err := GetOrFetchServiceToken(ctx, AuthMethodOAuth, refresher, did, holdDID, pdsEndpoint)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get service token", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
holdEndpoint := atproto.ResolveHoldURL(holdDID)
|
||||
if holdEndpoint == "" {
|
||||
slog.Warn("failed to resolve hold endpoint", "holdDID", holdDID)
|
||||
return
|
||||
}
|
||||
|
||||
// Call requestCrew endpoint
|
||||
if err := requestCrewMembership(ctx, holdEndpoint, serviceToken); err != nil {
|
||||
slog.Warn("failed to request crew membership", "holdDID", holdDID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("successfully registered as crew member", "holdDID", holdDID, "userDID", did)
|
||||
}
|
||||
|
||||
// ensureCrewMembership attempts to register as crew on target hold (UserContext method).
|
||||
// Called automatically during first push; idempotent.
|
||||
// This is a best-effort operation and logs errors without failing.
|
||||
// Requires SetTarget() to be called first.
|
||||
func (uc *UserContext) ensureCrewMembership(ctx context.Context) error {
|
||||
if uc.TargetHoldDID == "" {
|
||||
return fmt.Errorf("target hold not set (call SetTarget first)")
|
||||
}
|
||||
return uc.EnsureCrewMembershipForHold(ctx, uc.TargetHoldDID)
|
||||
}
|
||||
|
||||
// EnsureCrewMembershipForHold attempts to register as crew on the specified hold.
|
||||
// This is the core implementation that can be called with any holdDID.
|
||||
// Called automatically during first push; idempotent.
|
||||
// This is a best-effort operation and logs errors without failing.
|
||||
func (uc *UserContext) EnsureCrewMembershipForHold(ctx context.Context, holdDID string) error {
|
||||
if holdDID == "" {
|
||||
return nil // Nothing to do
|
||||
}
|
||||
|
||||
// Normalize URL to DID if needed
|
||||
if !atproto.IsDID(holdDID) {
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
|
||||
if holdDID == "" {
|
||||
return fmt.Errorf("failed to resolve hold DID from URL")
|
||||
}
|
||||
}
|
||||
|
||||
if !uc.IsAuthenticated {
|
||||
return fmt.Errorf("cannot register as crew: user not authenticated")
|
||||
}
|
||||
|
||||
if uc.refresher == nil {
|
||||
return fmt.Errorf("cannot register as crew: OAuth session required")
|
||||
}
|
||||
|
||||
// Get service token for the hold
|
||||
serviceToken, err := uc.GetServiceTokenForHold(ctx, holdDID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service token: %w", err)
|
||||
}
|
||||
|
||||
// Resolve hold DID to HTTP endpoint
|
||||
holdEndpoint := atproto.ResolveHoldURL(holdDID)
|
||||
if holdEndpoint == "" {
|
||||
return fmt.Errorf("failed to resolve hold endpoint for %s", holdDID)
|
||||
}
|
||||
|
||||
// Call requestCrew endpoint
|
||||
return requestCrewMembership(ctx, holdEndpoint, serviceToken)
|
||||
}
|
||||
|
||||
// requestCrewMembership calls the hold's requestCrew endpoint
|
||||
// The endpoint handles all authorization and duplicate checking internally
|
||||
func requestCrewMembership(ctx context.Context, holdEndpoint, serviceToken string) error {
|
||||
// Add 5 second timeout to prevent hanging on offline holds
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
url := fmt.Sprintf("%s%s", holdEndpoint, atproto.HoldRequestCrew)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+serviceToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
// Read response body to capture actual error message from hold
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("requestCrew failed with status %d (failed to read error body: %w)", resp.StatusCode, readErr)
|
||||
}
|
||||
return fmt.Errorf("requestCrew failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserClient returns an authenticated ATProto client for the user's own PDS.
|
||||
// Used for profile operations (reading/writing to user's own repo).
|
||||
// Returns nil if not authenticated or PDS not resolved.
|
||||
func (uc *UserContext) GetUserClient() *atproto.Client {
|
||||
if !uc.IsAuthenticated || uc.PDSEndpoint == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if uc.AuthMethod == AuthMethodOAuth && uc.refresher != nil {
|
||||
return atproto.NewClientWithSessionProvider(uc.PDSEndpoint, uc.DID, uc.refresher)
|
||||
} else if uc.AuthMethod == AuthMethodAppPassword {
|
||||
accessToken, _ := GetGlobalTokenCache().Get(uc.DID)
|
||||
return atproto.NewClient(uc.PDSEndpoint, uc.DID, accessToken)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureUserSetup ensures the user has a profile and crew membership.
|
||||
// Called once per user (cached for userSetupTTL). Runs in background - does not block.
|
||||
// Safe to call on every request.
|
||||
func (uc *UserContext) EnsureUserSetup() {
|
||||
if !uc.IsAuthenticated || uc.DID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Check cache - skip if recently set up
|
||||
if lastSetup, ok := userSetupCache.Load(uc.DID); ok {
|
||||
if time.Since(lastSetup.(time.Time)) < userSetupTTL {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Run in background to avoid blocking requests
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
|
||||
// 1. Ensure profile exists
|
||||
if client := uc.GetUserClient(); client != nil {
|
||||
uc.ensureProfile(bgCtx, client)
|
||||
}
|
||||
|
||||
// 2. Ensure crew membership on default hold
|
||||
if uc.defaultHoldDID != "" {
|
||||
EnsureCrewMembership(bgCtx, uc.DID, uc.PDSEndpoint, uc.refresher, uc.defaultHoldDID)
|
||||
}
|
||||
|
||||
// Mark as set up
|
||||
userSetupCache.Store(uc.DID, time.Now())
|
||||
slog.Debug("User setup complete",
|
||||
"component", "auth/usercontext",
|
||||
"did", uc.DID,
|
||||
"defaultHoldDID", uc.defaultHoldDID)
|
||||
}()
|
||||
}
|
||||
|
||||
// ensureProfile creates sailor profile if it doesn't exist.
|
||||
// Inline implementation to avoid circular import with storage package.
|
||||
func (uc *UserContext) ensureProfile(ctx context.Context, client *atproto.Client) {
|
||||
// Check if profile already exists
|
||||
profile, err := client.GetRecord(ctx, atproto.SailorProfileCollection, "self")
|
||||
if err == nil && profile != nil {
|
||||
return // Already exists
|
||||
}
|
||||
|
||||
// Create profile with default hold
|
||||
normalizedDID := ""
|
||||
if uc.defaultHoldDID != "" {
|
||||
normalizedDID = atproto.ResolveHoldDIDFromURL(uc.defaultHoldDID)
|
||||
}
|
||||
|
||||
newProfile := atproto.NewSailorProfileRecord(normalizedDID)
|
||||
if _, err := client.PutRecord(ctx, atproto.SailorProfileCollection, "self", newProfile); err != nil {
|
||||
slog.Warn("Failed to create sailor profile",
|
||||
"component", "auth/usercontext",
|
||||
"did", uc.DID,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("Created sailor profile",
|
||||
"component", "auth/usercontext",
|
||||
"did", uc.DID,
|
||||
"defaultHold", normalizedDID)
|
||||
}
|
||||
|
||||
// GetATProtoClient returns a cached ATProto client for the target owner's PDS.
|
||||
// Authenticated if user is owner, otherwise anonymous.
|
||||
// Cached per-request (uses sync.Once).
|
||||
func (uc *UserContext) GetATProtoClient() *atproto.Client {
|
||||
uc.atprotoClientOnce.Do(func() {
|
||||
if uc.TargetOwnerPDS == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// If puller is owner and authenticated, use authenticated client
|
||||
if uc.DID == uc.TargetOwnerDID && uc.IsAuthenticated {
|
||||
if uc.AuthMethod == AuthMethodOAuth && uc.refresher != nil {
|
||||
uc.atprotoClient = atproto.NewClientWithSessionProvider(uc.TargetOwnerPDS, uc.TargetOwnerDID, uc.refresher)
|
||||
return
|
||||
} else if uc.AuthMethod == AuthMethodAppPassword {
|
||||
accessToken, _ := GetGlobalTokenCache().Get(uc.TargetOwnerDID)
|
||||
uc.atprotoClient = atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, accessToken)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymous client for reads
|
||||
uc.atprotoClient = atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, "")
|
||||
})
|
||||
return uc.atprotoClient
|
||||
}
|
||||
|
||||
// ResolveHoldDID finds the hold for the target repository.
|
||||
// - Pull: uses database lookup (historical from manifest)
|
||||
// - Push: uses discovery (sailor profile → default)
|
||||
//
|
||||
// Must be called after SetTarget() is called with at least TargetOwnerDID and TargetRepo set.
|
||||
// Updates TargetHoldDID on success.
|
||||
func (uc *UserContext) ResolveHoldDID(ctx context.Context, sqlDB *sql.DB) (string, error) {
|
||||
if uc.TargetOwnerDID == "" {
|
||||
return "", fmt.Errorf("target owner not set")
|
||||
}
|
||||
|
||||
var holdDID string
|
||||
var err error
|
||||
|
||||
switch uc.Action {
|
||||
case ActionPull:
|
||||
// For pulls, look up historical hold from database
|
||||
holdDID, err = uc.resolveHoldForPull(ctx, sqlDB)
|
||||
case ActionPush:
|
||||
// For pushes, discover hold from owner's profile
|
||||
holdDID, err = uc.resolveHoldForPush(ctx)
|
||||
default:
|
||||
// Default to push discovery
|
||||
holdDID, err = uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if holdDID == "" {
|
||||
return "", fmt.Errorf("no hold DID found for %s/%s", uc.TargetOwnerDID, uc.TargetRepo)
|
||||
}
|
||||
|
||||
uc.TargetHoldDID = holdDID
|
||||
return holdDID, nil
|
||||
}
|
||||
|
||||
// resolveHoldForPull looks up the hold from the database (historical reference)
|
||||
func (uc *UserContext) resolveHoldForPull(ctx context.Context, sqlDB *sql.DB) (string, error) {
|
||||
// If no database is available, fall back to discovery
|
||||
if sqlDB == nil {
|
||||
return uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
// Try database lookup first
|
||||
holdDID, err := db.GetLatestHoldDIDForRepo(sqlDB, uc.TargetOwnerDID, uc.TargetRepo)
|
||||
if err != nil {
|
||||
slog.Debug("Database lookup failed, falling back to discovery",
|
||||
"component", "auth/context",
|
||||
"ownerDID", uc.TargetOwnerDID,
|
||||
"repo", uc.TargetRepo,
|
||||
"error", err)
|
||||
return uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
if holdDID != "" {
|
||||
return holdDID, nil
|
||||
}
|
||||
|
||||
// No historical hold found, fall back to discovery
|
||||
return uc.resolveHoldForPush(ctx)
|
||||
}
|
||||
|
||||
// resolveHoldForPush discovers hold from owner's sailor profile or default
|
||||
func (uc *UserContext) resolveHoldForPush(ctx context.Context) (string, error) {
|
||||
// Create anonymous client to query owner's profile
|
||||
client := atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, "")
|
||||
|
||||
// Try to get owner's sailor profile
|
||||
record, err := client.GetRecord(ctx, atproto.SailorProfileCollection, "self")
|
||||
if err == nil && record != nil {
|
||||
var profile atproto.SailorProfileRecord
|
||||
if jsonErr := json.Unmarshal(record.Value, &profile); jsonErr == nil {
|
||||
if profile.DefaultHold != "" {
|
||||
// Normalize to DID if needed
|
||||
holdDID := profile.DefaultHold
|
||||
if !atproto.IsDID(holdDID) {
|
||||
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
|
||||
}
|
||||
slog.Debug("Found hold from owner's profile",
|
||||
"component", "auth/context",
|
||||
"ownerDID", uc.TargetOwnerDID,
|
||||
"holdDID", holdDID)
|
||||
return holdDID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default hold
|
||||
if uc.defaultHoldDID != "" {
|
||||
slog.Debug("Using default hold",
|
||||
"component", "auth/context",
|
||||
"ownerDID", uc.TargetOwnerDID,
|
||||
"defaultHoldDID", uc.defaultHoldDID)
|
||||
return uc.defaultHoldDID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no hold configured for %s and no default hold set", uc.TargetOwnerDID)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test Helper Methods
|
||||
// =============================================================================
|
||||
// These methods are designed to make UserContext testable by allowing tests
|
||||
// to bypass network-dependent code paths (PDS resolution, OAuth token fetching).
|
||||
// Only use these in tests - they are not intended for production use.
|
||||
|
||||
// SetPDSForTest sets the PDS endpoint directly, bypassing ResolvePDS network calls.
|
||||
// This allows tests to skip DID resolution which would make network requests.
|
||||
// Deprecated: Use SetPDS instead.
|
||||
func (uc *UserContext) SetPDSForTest(handle, pdsEndpoint string) {
|
||||
uc.SetPDS(handle, pdsEndpoint)
|
||||
}
|
||||
|
||||
// SetServiceTokenForTest pre-populates a service token for the given holdDID,
|
||||
// bypassing the sync.Once and OAuth/app-password fetching logic.
|
||||
// The token will appear as if it was already fetched and cached.
|
||||
func (uc *UserContext) SetServiceTokenForTest(holdDID, token string) {
|
||||
entry := &serviceTokenEntry{
|
||||
token: token,
|
||||
expiresAt: time.Now().Add(5 * time.Minute),
|
||||
err: nil,
|
||||
}
|
||||
// Mark the sync.Once as done so real fetch won't happen
|
||||
entry.once.Do(func() {})
|
||||
uc.serviceTokens.Store(holdDID, entry)
|
||||
}
|
||||
|
||||
// SetAuthorizerForTest sets the authorizer for permission checks.
|
||||
// Use with MockHoldAuthorizer to control CanRead/CanWrite behavior in tests.
|
||||
func (uc *UserContext) SetAuthorizerForTest(authorizer HoldAuthorizer) {
|
||||
uc.authorizer = authorizer
|
||||
}
|
||||
|
||||
// SetDefaultHoldDIDForTest sets the default hold DID for tests.
|
||||
// This is used as fallback when resolving hold for push operations.
|
||||
func (uc *UserContext) SetDefaultHoldDIDForTest(holdDID string) {
|
||||
uc.defaultHoldDID = holdDID
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -18,6 +19,44 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// Authentication errors
|
||||
var (
|
||||
ErrMissingAuthHeader = errors.New("missing Authorization header")
|
||||
ErrInvalidAuthFormat = errors.New("invalid Authorization header format")
|
||||
ErrInvalidAuthScheme = errors.New("invalid authorization scheme: expected 'Bearer' or 'DPoP'")
|
||||
ErrMissingToken = errors.New("missing token")
|
||||
ErrMissingDPoPHeader = errors.New("missing DPoP header")
|
||||
)
|
||||
|
||||
// JWT validation errors
|
||||
var (
|
||||
ErrInvalidJWTFormat = errors.New("invalid JWT format: expected header.payload.signature")
|
||||
ErrMissingISSClaim = errors.New("missing 'iss' claim in token")
|
||||
ErrMissingSubClaim = errors.New("missing 'sub' claim in token")
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
)
|
||||
|
||||
// AuthError provides structured authorization error information
|
||||
type AuthError struct {
|
||||
Action string // The action being attempted: "blob:read", "blob:write", "crew:admin"
|
||||
Reason string // Why access was denied
|
||||
Required []string // What permission(s) would grant access
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return fmt.Sprintf("access denied for %s: %s (required: %s)",
|
||||
e.Action, e.Reason, strings.Join(e.Required, " or "))
|
||||
}
|
||||
|
||||
// NewAuthError creates a new AuthError
|
||||
func NewAuthError(action, reason string, required ...string) *AuthError {
|
||||
return &AuthError{
|
||||
Action: action,
|
||||
Reason: reason,
|
||||
Required: required,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPClient interface allows injecting a custom HTTP client for testing
|
||||
type HTTPClient interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
@@ -44,13 +83,13 @@ func ValidateDPoPRequest(r *http.Request, httpClient HTTPClient) (*ValidatedUser
|
||||
// Extract Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return nil, fmt.Errorf("missing Authorization header")
|
||||
return nil, ErrMissingAuthHeader
|
||||
}
|
||||
|
||||
// Check for DPoP authorization scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid Authorization header format")
|
||||
return nil, ErrInvalidAuthFormat
|
||||
}
|
||||
|
||||
if parts[0] != "DPoP" {
|
||||
@@ -59,13 +98,13 @@ func ValidateDPoPRequest(r *http.Request, httpClient HTTPClient) (*ValidatedUser
|
||||
|
||||
accessToken := parts[1]
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("missing access token")
|
||||
return nil, ErrMissingToken
|
||||
}
|
||||
|
||||
// Extract DPoP header
|
||||
dpopProof := r.Header.Get("DPoP")
|
||||
if dpopProof == "" {
|
||||
return nil, fmt.Errorf("missing DPoP header")
|
||||
return nil, ErrMissingDPoPHeader
|
||||
}
|
||||
|
||||
// TODO: We could verify the DPoP proof locally (signature, HTM, HTU, etc.)
|
||||
@@ -109,7 +148,7 @@ func extractDIDFromToken(token string) (string, string, error) {
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", "", fmt.Errorf("invalid JWT format")
|
||||
return "", "", ErrInvalidJWTFormat
|
||||
}
|
||||
|
||||
// Decode payload (base64url)
|
||||
@@ -129,11 +168,11 @@ func extractDIDFromToken(token string) (string, string, error) {
|
||||
}
|
||||
|
||||
if claims.Sub == "" {
|
||||
return "", "", fmt.Errorf("missing sub claim (DID)")
|
||||
return "", "", ErrMissingSubClaim
|
||||
}
|
||||
|
||||
if claims.Iss == "" {
|
||||
return "", "", fmt.Errorf("missing iss claim (PDS)")
|
||||
return "", "", ErrMissingISSClaim
|
||||
}
|
||||
|
||||
return claims.Sub, claims.Iss, nil
|
||||
@@ -216,7 +255,7 @@ func ValidateOwnerOrCrewAdmin(r *http.Request, pds *HoldPDS, httpClient HTTPClie
|
||||
return nil, fmt.Errorf("DPoP authentication failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or invalid Authorization header (expected Bearer or DPoP)")
|
||||
return nil, ErrInvalidAuthScheme
|
||||
}
|
||||
|
||||
// Get captain record to check owner
|
||||
@@ -243,12 +282,12 @@ func ValidateOwnerOrCrewAdmin(r *http.Request, pds *HoldPDS, httpClient HTTPClie
|
||||
return user, nil
|
||||
}
|
||||
// User is crew but doesn't have admin permission
|
||||
return nil, fmt.Errorf("crew member lacks required 'crew:admin' permission")
|
||||
return nil, NewAuthError("crew:admin", "crew member lacks permission", "crew:admin")
|
||||
}
|
||||
}
|
||||
|
||||
// User is neither owner nor authorized crew
|
||||
return nil, fmt.Errorf("user is not authorized (must be hold owner or crew admin)")
|
||||
return nil, NewAuthError("crew:admin", "user is not a crew member", "crew:admin")
|
||||
}
|
||||
|
||||
// ValidateBlobWriteAccess validates that the request has valid authentication
|
||||
@@ -276,7 +315,7 @@ func ValidateBlobWriteAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClien
|
||||
return nil, fmt.Errorf("DPoP authentication failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or invalid Authorization header (expected Bearer or DPoP)")
|
||||
return nil, ErrInvalidAuthScheme
|
||||
}
|
||||
|
||||
// Get captain record to check owner and public settings
|
||||
@@ -303,17 +342,18 @@ func ValidateBlobWriteAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClien
|
||||
return user, nil
|
||||
}
|
||||
// User is crew but doesn't have write permission
|
||||
return nil, fmt.Errorf("crew member lacks required 'blob:write' permission")
|
||||
return nil, NewAuthError("blob:write", "crew member lacks permission", "blob:write")
|
||||
}
|
||||
}
|
||||
|
||||
// User is neither owner nor authorized crew
|
||||
return nil, fmt.Errorf("user is not authorized for blob write (must be hold owner or crew with blob:write permission)")
|
||||
return nil, NewAuthError("blob:write", "user is not a crew member", "blob:write")
|
||||
}
|
||||
|
||||
// ValidateBlobReadAccess validates that the request has read access to blobs
|
||||
// If captain.public = true: No auth required (returns nil user to indicate public access)
|
||||
// If captain.public = false: Requires valid DPoP + OAuth and (captain OR crew with blob:read permission).
|
||||
// If captain.public = false: Requires valid DPoP + OAuth and (captain OR crew with blob:read or blob:write permission).
|
||||
// Note: blob:write implicitly grants blob:read access.
|
||||
// The httpClient parameter is optional and defaults to http.DefaultClient if nil.
|
||||
func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient) (*ValidatedUser, error) {
|
||||
// Get captain record to check public setting
|
||||
@@ -344,7 +384,7 @@ func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient
|
||||
return nil, fmt.Errorf("DPoP authentication failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or invalid Authorization header (expected Bearer or DPoP)")
|
||||
return nil, ErrInvalidAuthScheme
|
||||
}
|
||||
|
||||
// Check if user is the owner (always has read access)
|
||||
@@ -352,7 +392,8 @@ func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Check if user is crew with blob:read permission
|
||||
// Check if user is crew with blob:read or blob:write permission
|
||||
// Note: blob:write implicitly grants blob:read access
|
||||
crew, err := pds.ListCrewMembers(r.Context())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check crew membership: %w", err)
|
||||
@@ -360,17 +401,19 @@ func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient
|
||||
|
||||
for _, member := range crew {
|
||||
if member.Record.Member == user.DID {
|
||||
// Check if this crew member has blob:read permission
|
||||
if slices.Contains(member.Record.Permissions, "blob:read") {
|
||||
// Check if this crew member has blob:read or blob:write permission
|
||||
// blob:write implicitly grants read access (can't push without pulling)
|
||||
if slices.Contains(member.Record.Permissions, "blob:read") ||
|
||||
slices.Contains(member.Record.Permissions, "blob:write") {
|
||||
return user, nil
|
||||
}
|
||||
// User is crew but doesn't have read permission
|
||||
return nil, fmt.Errorf("crew member lacks required 'blob:read' permission")
|
||||
// User is crew but doesn't have read or write permission
|
||||
return nil, NewAuthError("blob:read", "crew member lacks permission", "blob:read", "blob:write")
|
||||
}
|
||||
}
|
||||
|
||||
// User is neither owner nor authorized crew
|
||||
return nil, fmt.Errorf("user is not authorized for blob read (must be hold owner or crew with blob:read permission)")
|
||||
return nil, NewAuthError("blob:read", "user is not a crew member", "blob:read", "blob:write")
|
||||
}
|
||||
|
||||
// ServiceTokenClaims represents the claims in a service token JWT
|
||||
@@ -385,13 +428,13 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
|
||||
// Extract Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return nil, fmt.Errorf("missing Authorization header")
|
||||
return nil, ErrMissingAuthHeader
|
||||
}
|
||||
|
||||
// Check for Bearer authorization scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid Authorization header format")
|
||||
return nil, ErrInvalidAuthFormat
|
||||
}
|
||||
|
||||
if parts[0] != "Bearer" {
|
||||
@@ -400,7 +443,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
|
||||
|
||||
tokenString := parts[1]
|
||||
if tokenString == "" {
|
||||
return nil, fmt.Errorf("missing token")
|
||||
return nil, ErrMissingToken
|
||||
}
|
||||
|
||||
slog.Debug("Validating service token", "holdDID", holdDID)
|
||||
@@ -409,7 +452,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
|
||||
// Split token: header.payload.signature
|
||||
tokenParts := strings.Split(tokenString, ".")
|
||||
if len(tokenParts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
return nil, ErrInvalidJWTFormat
|
||||
}
|
||||
|
||||
// Decode payload (second part) to extract claims
|
||||
@@ -427,7 +470,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
|
||||
// Get issuer (user DID)
|
||||
issuerDID := claims.Issuer
|
||||
if issuerDID == "" {
|
||||
return nil, fmt.Errorf("missing iss claim")
|
||||
return nil, ErrMissingISSClaim
|
||||
}
|
||||
|
||||
// Verify audience matches this hold service
|
||||
@@ -445,7 +488,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
|
||||
return nil, fmt.Errorf("failed to get expiration: %w", err)
|
||||
}
|
||||
if exp != nil && time.Now().After(exp.Time) {
|
||||
return nil, fmt.Errorf("token has expired")
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
|
||||
// Verify JWT signature using ATProto's secp256k1 crypto
|
||||
|
||||
@@ -771,6 +771,116 @@ func TestValidateBlobReadAccess_PrivateHold(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateBlobReadAccess_BlobWriteImpliesRead tests that blob:write grants read access
|
||||
func TestValidateBlobReadAccess_BlobWriteImpliesRead(t *testing.T) {
|
||||
ownerDID := "did:plc:owner123"
|
||||
|
||||
pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, false, false)
|
||||
|
||||
// Verify captain record has public=false (private hold)
|
||||
_, captain, err := pds.GetCaptainRecord(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get captain record: %v", err)
|
||||
}
|
||||
|
||||
if captain.Public {
|
||||
t.Error("Expected public=false for captain record")
|
||||
}
|
||||
|
||||
// Add crew member with ONLY blob:write permission (no blob:read)
|
||||
writerDID := "did:plc:writer123"
|
||||
_, err = pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add crew writer: %v", err)
|
||||
}
|
||||
|
||||
mockClient := &mockPDSClient{}
|
||||
|
||||
// Test writer (has only blob:write permission) can read
|
||||
t.Run("crew with blob:write can read", func(t *testing.T) {
|
||||
dpopHelper, err := NewDPoPTestHelper(writerDID, "https://test-pds.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create DPoP helper: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
if err := dpopHelper.AddDPoPToRequest(req); err != nil {
|
||||
t.Fatalf("Failed to add DPoP to request: %v", err)
|
||||
}
|
||||
|
||||
// This should SUCCEED because blob:write implies blob:read
|
||||
user, err := ValidateBlobReadAccess(req, pds, mockClient)
|
||||
if err != nil {
|
||||
t.Errorf("Expected blob:write to grant read access, got error: %v", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
t.Error("Expected user to be returned for valid read access")
|
||||
} else if user.DID != writerDID {
|
||||
t.Errorf("Expected user DID %s, got %s", writerDID, user.DID)
|
||||
}
|
||||
})
|
||||
|
||||
// Also verify that crew with only blob:read still works
|
||||
t.Run("crew with blob:read can read", func(t *testing.T) {
|
||||
readerDID := "did:plc:reader123"
|
||||
_, err = pds.AddCrewMember(ctx, readerDID, "reader", []string{"blob:read"})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add crew reader: %v", err)
|
||||
}
|
||||
|
||||
dpopHelper, err := NewDPoPTestHelper(readerDID, "https://test-pds.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create DPoP helper: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
if err := dpopHelper.AddDPoPToRequest(req); err != nil {
|
||||
t.Fatalf("Failed to add DPoP to request: %v", err)
|
||||
}
|
||||
|
||||
user, err := ValidateBlobReadAccess(req, pds, mockClient)
|
||||
if err != nil {
|
||||
t.Errorf("Expected blob:read to grant read access, got error: %v", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
t.Error("Expected user to be returned for valid read access")
|
||||
} else if user.DID != readerDID {
|
||||
t.Errorf("Expected user DID %s, got %s", readerDID, user.DID)
|
||||
}
|
||||
})
|
||||
|
||||
// Verify crew with neither permission cannot read
|
||||
t.Run("crew without read or write cannot read", func(t *testing.T) {
|
||||
noPermDID := "did:plc:noperm123"
|
||||
_, err = pds.AddCrewMember(ctx, noPermDID, "noperm", []string{"crew:admin"})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add crew member: %v", err)
|
||||
}
|
||||
|
||||
dpopHelper, err := NewDPoPTestHelper(noPermDID, "https://test-pds.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create DPoP helper: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
if err := dpopHelper.AddDPoPToRequest(req); err != nil {
|
||||
t.Fatalf("Failed to add DPoP to request: %v", err)
|
||||
}
|
||||
|
||||
_, err = ValidateBlobReadAccess(req, pds, mockClient)
|
||||
if err == nil {
|
||||
t.Error("Expected error for crew without read or write permission")
|
||||
}
|
||||
|
||||
// Verify error message format
|
||||
if !strings.Contains(err.Error(), "access denied for blob:read") {
|
||||
t.Errorf("Expected structured error message, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestValidateOwnerOrCrewAdmin tests admin permission checking
|
||||
func TestValidateOwnerOrCrewAdmin(t *testing.T) {
|
||||
ownerDID := "did:plc:owner123"
|
||||
|
||||
@@ -18,8 +18,8 @@ const (
|
||||
// CreateCaptainRecord creates the captain record for the hold (first-time only).
|
||||
// This will FAIL if the captain record already exists. Use UpdateCaptainRecord to modify.
|
||||
func (p *HoldPDS) CreateCaptainRecord(ctx context.Context, ownerDID string, public bool, allowAllCrew bool, enableBlueskyPosts bool) (cid.Cid, error) {
|
||||
captainRecord := &atproto.HoldCaptain{
|
||||
LexiconTypeID: atproto.CaptainCollection,
|
||||
captainRecord := &atproto.CaptainRecord{
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: ownerDID,
|
||||
Public: public,
|
||||
AllowAllCrew: allowAllCrew,
|
||||
@@ -40,7 +40,7 @@ func (p *HoldPDS) CreateCaptainRecord(ctx context.Context, ownerDID string, publ
|
||||
}
|
||||
|
||||
// GetCaptainRecord retrieves the captain record
|
||||
func (p *HoldPDS) GetCaptainRecord(ctx context.Context) (cid.Cid, *atproto.HoldCaptain, error) {
|
||||
func (p *HoldPDS) GetCaptainRecord(ctx context.Context) (cid.Cid, *atproto.CaptainRecord, error) {
|
||||
// Use repomgr.GetRecord - our types are registered in init()
|
||||
// so it will automatically unmarshal to the concrete type
|
||||
recordCID, val, err := p.repomgr.GetRecord(ctx, p.uid, atproto.CaptainCollection, CaptainRkey, cid.Undef)
|
||||
@@ -49,7 +49,7 @@ func (p *HoldPDS) GetCaptainRecord(ctx context.Context) (cid.Cid, *atproto.HoldC
|
||||
}
|
||||
|
||||
// Type assert to our concrete type
|
||||
captainRecord, ok := val.(*atproto.HoldCaptain)
|
||||
captainRecord, ok := val.(*atproto.CaptainRecord)
|
||||
if !ok {
|
||||
return cid.Undef, nil, fmt.Errorf("unexpected type for captain record: %T", val)
|
||||
}
|
||||
|
||||
@@ -12,11 +12,6 @@ import (
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// ptrString returns a pointer to the given string
|
||||
func ptrString(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// setupTestPDS creates a test PDS instance in a temporary directory
|
||||
// It initializes the repo but does NOT create captain/crew records
|
||||
// Tests should call Bootstrap or create records as needed
|
||||
@@ -151,8 +146,8 @@ func TestCreateCaptainRecord(t *testing.T) {
|
||||
if captain.EnableBlueskyPosts != tt.enableBlueskyPosts {
|
||||
t.Errorf("Expected enableBlueskyPosts=%v, got %v", tt.enableBlueskyPosts, captain.EnableBlueskyPosts)
|
||||
}
|
||||
if captain.LexiconTypeID != atproto.CaptainCollection {
|
||||
t.Errorf("Expected type %s, got %s", atproto.CaptainCollection, captain.LexiconTypeID)
|
||||
if captain.Type != atproto.CaptainCollection {
|
||||
t.Errorf("Expected type %s, got %s", atproto.CaptainCollection, captain.Type)
|
||||
}
|
||||
if captain.DeployedAt == "" {
|
||||
t.Error("Expected deployedAt to be set")
|
||||
@@ -327,40 +322,40 @@ func TestUpdateCaptainRecord_NotFound(t *testing.T) {
|
||||
func TestCaptainRecord_CBORRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
record *atproto.HoldCaptain
|
||||
record *atproto.CaptainRecord
|
||||
}{
|
||||
{
|
||||
name: "Basic captain",
|
||||
record: &atproto.HoldCaptain{
|
||||
LexiconTypeID: atproto.CaptainCollection,
|
||||
Owner: "did:plc:alice123",
|
||||
Public: true,
|
||||
AllowAllCrew: false,
|
||||
DeployedAt: "2025-10-16T12:00:00Z",
|
||||
record: &atproto.CaptainRecord{
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: "did:plc:alice123",
|
||||
Public: true,
|
||||
AllowAllCrew: false,
|
||||
DeployedAt: "2025-10-16T12:00:00Z",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Captain with optional fields",
|
||||
record: &atproto.HoldCaptain{
|
||||
LexiconTypeID: atproto.CaptainCollection,
|
||||
Owner: "did:plc:bob456",
|
||||
Public: false,
|
||||
AllowAllCrew: true,
|
||||
DeployedAt: "2025-10-16T12:00:00Z",
|
||||
Region: ptrString("us-west-2"),
|
||||
Provider: ptrString("fly.io"),
|
||||
record: &atproto.CaptainRecord{
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: "did:plc:bob456",
|
||||
Public: false,
|
||||
AllowAllCrew: true,
|
||||
DeployedAt: "2025-10-16T12:00:00Z",
|
||||
Region: "us-west-2",
|
||||
Provider: "fly.io",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Captain with empty optional fields",
|
||||
record: &atproto.HoldCaptain{
|
||||
LexiconTypeID: atproto.CaptainCollection,
|
||||
Owner: "did:plc:charlie789",
|
||||
Public: true,
|
||||
AllowAllCrew: true,
|
||||
DeployedAt: "2025-10-16T12:00:00Z",
|
||||
Region: ptrString(""),
|
||||
Provider: ptrString(""),
|
||||
record: &atproto.CaptainRecord{
|
||||
Type: atproto.CaptainCollection,
|
||||
Owner: "did:plc:charlie789",
|
||||
Public: true,
|
||||
AllowAllCrew: true,
|
||||
DeployedAt: "2025-10-16T12:00:00Z",
|
||||
Region: "",
|
||||
Provider: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -380,15 +375,15 @@ func TestCaptainRecord_CBORRoundtrip(t *testing.T) {
|
||||
}
|
||||
|
||||
// Unmarshal from CBOR
|
||||
var decoded atproto.HoldCaptain
|
||||
var decoded atproto.CaptainRecord
|
||||
err = decoded.UnmarshalCBOR(bytes.NewReader(cborBytes))
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalCBOR failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify all fields match
|
||||
if decoded.LexiconTypeID != tt.record.LexiconTypeID {
|
||||
t.Errorf("LexiconTypeID mismatch: expected %s, got %s", tt.record.LexiconTypeID, decoded.LexiconTypeID)
|
||||
if decoded.Type != tt.record.Type {
|
||||
t.Errorf("Type mismatch: expected %s, got %s", tt.record.Type, decoded.Type)
|
||||
}
|
||||
if decoded.Owner != tt.record.Owner {
|
||||
t.Errorf("Owner mismatch: expected %s, got %s", tt.record.Owner, decoded.Owner)
|
||||
@@ -402,17 +397,11 @@ func TestCaptainRecord_CBORRoundtrip(t *testing.T) {
|
||||
if decoded.DeployedAt != tt.record.DeployedAt {
|
||||
t.Errorf("DeployedAt mismatch: expected %s, got %s", tt.record.DeployedAt, decoded.DeployedAt)
|
||||
}
|
||||
// Compare Region pointers (may be nil)
|
||||
if (decoded.Region == nil) != (tt.record.Region == nil) {
|
||||
t.Errorf("Region nil mismatch: expected %v, got %v", tt.record.Region, decoded.Region)
|
||||
} else if decoded.Region != nil && *decoded.Region != *tt.record.Region {
|
||||
t.Errorf("Region mismatch: expected %q, got %q", *tt.record.Region, *decoded.Region)
|
||||
if decoded.Region != tt.record.Region {
|
||||
t.Errorf("Region mismatch: expected %s, got %s", tt.record.Region, decoded.Region)
|
||||
}
|
||||
// Compare Provider pointers (may be nil)
|
||||
if (decoded.Provider == nil) != (tt.record.Provider == nil) {
|
||||
t.Errorf("Provider nil mismatch: expected %v, got %v", tt.record.Provider, decoded.Provider)
|
||||
} else if decoded.Provider != nil && *decoded.Provider != *tt.record.Provider {
|
||||
t.Errorf("Provider mismatch: expected %q, got %q", *tt.record.Provider, *decoded.Provider)
|
||||
if decoded.Provider != tt.record.Provider {
|
||||
t.Errorf("Provider mismatch: expected %s, got %s", tt.record.Provider, decoded.Provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user