22 Commits

Author SHA1 Message Date
Evan Jarrett
31dc4b4f53 major refactor to implement usercontext 2025-12-29 17:02:07 -06:00
Evan Jarrett
af99929aa3 remove old test file 2025-12-29 17:01:48 -06:00
Evan Jarrett
7f2d780b0a move packages out of token that are not related to docker jwt token 2025-12-29 16:57:14 -06:00
Evan Jarrett
8956568ed2 remove unused filestore. replace it with memstore for tests 2025-12-29 16:51:08 -06:00
Evan Jarrett
c1f2ae0f7a fix scope mismatch? 2025-12-26 17:41:38 -06:00
Evan Jarrett
012a14c4ee try fix permission scope again 2025-12-26 17:13:19 -06:00
Evan Jarrett
4cda163099 add back individual scopes 2025-12-26 17:05:51 -06:00
Evan Jarrett
41bcee4a59 try new permission sets 2025-12-26 16:51:49 -06:00
Evan Jarrett
24d6b49481 clean up unused locks 2025-12-26 09:48:25 -06:00
Evan Jarrett
363c12e6bf remove unused function 2025-12-26 09:37:57 -06:00
Evan Jarrett
2a60a47fd5 fix issues pulling other users images. fix labels taking priority over annotations. fix various auth errors 2025-12-23 16:20:52 -06:00
Evan Jarrett
34c2b8b17c add a cache-control header to metadata page 2025-12-22 21:01:28 -06:00
Evan Jarrett
8d0cff63fb add 404 page 2025-12-22 12:43:18 -06:00
Evan Jarrett
d11356cd18 more improvements on repo page rendering. allow for repo avatar image uploads (requires new scopes) 2025-12-21 21:51:44 -06:00
Evan Jarrett
79d1126726 better handling for io.atcr.repo.page 2025-12-20 21:50:09 -06:00
Evan Jarrett
8e31137c62 better logic for relative urls 2025-12-20 16:48:08 -06:00
Evan Jarrett
023efb05aa add in the lexicon json 2025-12-20 16:32:55 -06:00
Evan Jarrett
b18e4c3996 implement io.atcr.repo.page. try and fetch from github,gitlab,tangled README.md files if source exists. 2025-12-20 16:32:41 -06:00
Evan Jarrett
24b265bf12 try and fetch from github/gitlab/tangled READMEs 2025-12-20 16:00:15 -06:00
Evan Jarrett
e8e375639d lexicon validation fix 2025-12-20 11:30:08 -06:00
Evan Jarrett
5a208de4c9 add attestation badge to tags 2025-12-20 11:00:24 -06:00
Evan Jarrett
104eb86c04 fix go version 2025-12-20 10:49:37 -06:00
106 changed files with 5572 additions and 8420 deletions

View File

@@ -475,12 +475,47 @@ Lightweight standalone service for BYOS (Bring Your Own Storage) with embedded P
Read access:
- **Public hold** (`HOLD_PUBLIC=true`): Anonymous + all authenticated users
- **Private hold** (`HOLD_PUBLIC=false`): Requires authentication + crew membership with blob:read permission
- **Private hold** (`HOLD_PUBLIC=false`): Requires authentication + crew membership with blob:read OR blob:write permission
- **Note:** `blob:write` implicitly grants `blob:read` access (can't push without pulling)
Write access:
- Hold owner OR crew members with blob:write permission
- Verified via `io.atcr.hold.crew` records in hold's embedded PDS
**Permission Matrix:**
| User Type | Public Read | Private Read | Write | Crew Admin |
|-----------|-------------|--------------|-------|------------|
| Anonymous | Yes | No | No | No |
| Owner (captain) | Yes | Yes | Yes | Yes (implied) |
| Crew (blob:read only) | Yes | Yes | No | No |
| Crew (blob:write only) | Yes | Yes* | Yes | No |
| Crew (blob:read + blob:write) | Yes | Yes | Yes | No |
| Crew (crew:admin) | Yes | Yes | Yes | Yes |
| Authenticated non-crew | Yes | No | No | No |
*`blob:write` implicitly grants `blob:read` access
**Authorization Error Format:**
All authorization failures use consistent structured errors (`pkg/hold/pds/auth.go`):
```
access denied for [action]: [reason] (required: [permission(s)])
```
Examples:
- `access denied for blob:read: user is not a crew member (required: blob:read or blob:write)`
- `access denied for blob:write: crew member lacks permission (required: blob:write)`
- `access denied for crew:admin: user is not a crew member (required: crew:admin)`
**Shared Error Constants** (`pkg/hold/pds/auth.go`):
- `ErrMissingAuthHeader` - Missing Authorization header
- `ErrInvalidAuthFormat` - Invalid Authorization header format
- `ErrInvalidAuthScheme` - Invalid scheme (expected Bearer or DPoP)
- `ErrInvalidJWTFormat` - Malformed JWT
- `ErrMissingISSClaim` / `ErrMissingSubClaim` - Missing JWT claims
- `ErrTokenExpired` - Token has expired
**Embedded PDS Endpoints** (`pkg/hold/pds/xrpc.go`):
Standard ATProto sync endpoints:

View File

@@ -1,6 +1,6 @@
# Production build for ATCR AppView
# Result: ~30MB scratch image with static binary
FROM docker.io/golang:1.25.2-trixie AS builder
FROM docker.io/golang:1.25.4-trixie AS builder
ENV DEBIAN_FRONTEND=noninteractive
@@ -34,12 +34,12 @@ EXPOSE 5000
LABEL org.opencontainers.image.title="ATCR AppView" \
org.opencontainers.image.description="ATProto Container Registry - OCI-compliant registry using AT Protocol for manifest storage" \
org.opencontainers.image.authors="ATCR Contributors" \
org.opencontainers.image.source="https://tangled.org/@evan.jarrett.net/at-container-registry" \
org.opencontainers.image.documentation="https://tangled.org/@evan.jarrett.net/at-container-registry" \
org.opencontainers.image.source="https://tangled.org/evan.jarrett.net/at-container-registry" \
org.opencontainers.image.documentation="https://tangled.org/evan.jarrett.net/at-container-registry" \
org.opencontainers.image.licenses="MIT" \
org.opencontainers.image.version="0.1.0" \
io.atcr.icon="https://imgs.blue/evan.jarrett.net/1TpTNrRelfloN2emuWZDrWmPT0o93bAjEnozjD6UPgoVV9m4" \
io.atcr.readme="https://tangled.org/@evan.jarrett.net/at-container-registry/raw/main/docs/appview.md"
io.atcr.readme="https://tangled.org/evan.jarrett.net/at-container-registry/raw/main/docs/appview.md"
ENTRYPOINT ["/atcr-appview"]
CMD ["serve"]

View File

@@ -1,7 +1,7 @@
# Development image with Air hot reload
# Build: docker build -f Dockerfile.dev -t atcr-appview-dev .
# Run: docker run -v $(pwd):/app -p 5000:5000 atcr-appview-dev
FROM docker.io/golang:1.25.2-trixie
FROM docker.io/golang:1.25.4-trixie
ENV DEBIAN_FRONTEND=noninteractive

View File

@@ -1,4 +1,4 @@
FROM docker.io/golang:1.25.2-trixie AS builder
FROM docker.io/golang:1.25.4-trixie AS builder
ENV DEBIAN_FRONTEND=noninteractive
@@ -38,11 +38,11 @@ EXPOSE 8080
LABEL org.opencontainers.image.title="ATCR Hold Service" \
org.opencontainers.image.description="ATCR Hold Service - Bring Your Own Storage component for ATCR" \
org.opencontainers.image.authors="ATCR Contributors" \
org.opencontainers.image.source="https://tangled.org/@evan.jarrett.net/at-container-registry" \
org.opencontainers.image.documentation="https://tangled.org/@evan.jarrett.net/at-container-registry" \
org.opencontainers.image.source="https://tangled.org/evan.jarrett.net/at-container-registry" \
org.opencontainers.image.documentation="https://tangled.org/evan.jarrett.net/at-container-registry" \
org.opencontainers.image.licenses="MIT" \
org.opencontainers.image.version="0.1.0" \
io.atcr.icon="https://imgs.blue/evan.jarrett.net/1TpTOdtS60GdJWBYEqtK22y688jajbQ9a5kbYRFtwuqrkBAE" \
io.atcr.readme="https://tangled.org/@evan.jarrett.net/at-container-registry/raw/main/docs/hold.md"
io.atcr.readme="https://tangled.org/evan.jarrett.net/at-container-registry/raw/main/docs/hold.md"
ENTRYPOINT ["/atcr-hold"]

View File

@@ -82,9 +82,8 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
slog.Info("Initializing hold health checker", "cache_ttl", cfg.Health.CacheTTL)
healthChecker := holdhealth.NewChecker(cfg.Health.CacheTTL)
// Initialize README cache
slog.Info("Initializing README cache", "cache_ttl", cfg.Health.ReadmeCacheTTL)
readmeCache := readme.NewCache(uiDatabase, cfg.Health.ReadmeCacheTTL)
// Initialize README fetcher for rendering repo page descriptions
readmeFetcher := readme.NewFetcher()
// Start background health check worker
startupDelay := 5 * time.Second // Wait for hold services to start (Docker compose)
@@ -151,20 +150,15 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
middleware.SetGlobalRefresher(refresher)
// Set global database for pull/push metrics tracking
metricsDB := db.NewMetricsDB(uiDatabase)
middleware.SetGlobalDatabase(metricsDB)
middleware.SetGlobalDatabase(uiDatabase)
// Create RemoteHoldAuthorizer for hold authorization with caching
holdAuthorizer := auth.NewRemoteHoldAuthorizer(uiDatabase, testMode)
middleware.SetGlobalAuthorizer(holdAuthorizer)
slog.Info("Hold authorizer initialized with database caching")
// Set global readme cache for middleware
middleware.SetGlobalReadmeCache(readmeCache)
slog.Info("README cache initialized for manifest push refresh")
// Initialize Jetstream workers (background services before HTTP routes)
initializeJetstream(uiDatabase, &cfg.Jetstream, defaultHoldDID, testMode)
initializeJetstream(uiDatabase, &cfg.Jetstream, defaultHoldDID, testMode, refresher)
// Create main chi router
mainRouter := chi.NewRouter()
@@ -194,8 +188,9 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
BaseURL: baseURL,
DeviceStore: deviceStore,
HealthChecker: healthChecker,
ReadmeCache: readmeCache,
ReadmeFetcher: readmeFetcher,
Templates: uiTemplates,
DefaultHoldDID: defaultHoldDID,
})
}
}
@@ -217,14 +212,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
// Create ATProto client with session provider (uses DoWithSession for DPoP nonce safety)
client := atproto.NewClientWithSessionProvider(pdsEndpoint, did, refresher)
// Ensure sailor profile exists (creates with default hold if configured)
slog.Debug("Ensuring profile exists", "component", "appview/callback", "did", did, "default_hold_did", defaultHoldDID)
if err := storage.EnsureProfile(ctx, client, defaultHoldDID); err != nil {
slog.Warn("Failed to ensure profile", "component", "appview/callback", "did", did, "error", err)
// Continue anyway - profile creation is not critical for avatar fetch
} else {
slog.Debug("Profile ensured", "component", "appview/callback", "did", did)
}
// Note: Profile and crew setup now happen automatically via UserContext.EnsureUserSetup()
// Fetch user's profile record from PDS (contains blob references)
profileRecord, err := client.GetProfileRecord(ctx, did)
@@ -275,36 +263,23 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
return nil // Non-fatal
}
var holdDID string
if profile != nil && profile.DefaultHold != nil && *profile.DefaultHold != "" {
defaultHold := *profile.DefaultHold
// Migrate profile URL→DID if needed (legacy migration, crew registration now handled by UserContext)
if profile != nil && profile.DefaultHold != "" {
// Check if defaultHold is a URL (needs migration)
if strings.HasPrefix(defaultHold, "http://") || strings.HasPrefix(defaultHold, "https://") {
slog.Debug("Migrating hold URL to DID", "component", "appview/callback", "did", did, "hold_url", defaultHold)
if strings.HasPrefix(profile.DefaultHold, "http://") || strings.HasPrefix(profile.DefaultHold, "https://") {
slog.Debug("Migrating hold URL to DID", "component", "appview/callback", "did", did, "hold_url", profile.DefaultHold)
// Resolve URL to DID
holdDID = atproto.ResolveHoldDIDFromURL(defaultHold)
holdDID := atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
// Update profile with DID
profile.DefaultHold = &holdDID
profile.DefaultHold = holdDID
if err := storage.UpdateProfile(ctx, client, profile); err != nil {
slog.Warn("Failed to update profile with hold DID", "component", "appview/callback", "did", did, "error", err)
} else {
slog.Debug("Updated profile with hold DID", "component", "appview/callback", "hold_did", holdDID)
}
} else {
// Already a DID - use it
holdDID = defaultHold
}
// Register crew regardless of migration (outside the migration block)
// Run in background to avoid blocking OAuth callback if hold is offline
// Use background context - don't inherit request context which gets canceled on response
slog.Debug("Attempting crew registration", "component", "appview/callback", "did", did, "hold_did", holdDID)
go func(client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
ctx := context.Background()
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
}(client, refresher, holdDID)
}
return nil // All errors are non-fatal, logged for debugging
@@ -326,10 +301,19 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
ctx := context.Background()
app := handlers.NewApp(ctx, cfg.Distribution)
// Wrap registry app with auth method extraction middleware
// This extracts the auth method from the JWT and stores it in the request context
// Wrap registry app with middleware chain:
// 1. ExtractAuthMethod - extracts auth method from JWT and stores in context
// 2. UserContextMiddleware - builds UserContext with identity, permissions, service tokens
wrappedApp := middleware.ExtractAuthMethod(app)
// Create dependencies for UserContextMiddleware
userContextDeps := &auth.Dependencies{
Refresher: refresher,
Authorizer: holdAuthorizer,
DefaultHoldDID: defaultHoldDID,
}
wrappedApp = middleware.UserContextMiddleware(userContextDeps)(wrappedApp)
// Mount registry at /v2/
mainRouter.Handle("/v2/*", wrappedApp)
@@ -398,6 +382,9 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
// Limit caching to allow scope changes to propagate quickly
// PDS servers cache client metadata, so short max-age helps with updates
w.Header().Set("Cache-Control", "public, max-age=300")
if err := json.NewEncoder(w).Encode(metadataMap); err != nil {
http.Error(w, "Failed to encode metadata", http.StatusInternalServerError)
}
@@ -415,23 +402,11 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
// Prevents the flood of errors when a stale session is discovered during push
tokenHandler.SetOAuthSessionValidator(refresher)
// Register token post-auth callback for profile management
// This decouples the token package from AppView-specific dependencies
// Register token post-auth callback
// Note: Profile and crew setup now happen automatically via UserContext.EnsureUserSetup()
tokenHandler.SetPostAuthCallback(func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error {
slog.Debug("Token post-auth callback", "component", "appview/callback", "did", did)
// Create ATProto client with validated token
atprotoClient := atproto.NewClient(pdsEndpoint, did, accessToken)
// Ensure profile exists (will create with default hold if not exists and default is configured)
if err := storage.EnsureProfile(ctx, atprotoClient, defaultHoldDID); err != nil {
// Log error but don't fail auth - profile management is not critical
slog.Warn("Failed to ensure profile", "component", "appview/callback", "did", did, "error", err)
} else {
slog.Debug("Profile ensured with default hold", "component", "appview/callback", "did", did, "default_hold_did", defaultHoldDID)
}
return nil // All errors are non-fatal
return nil
})
mainRouter.Get("/auth/token", tokenHandler.ServeHTTP)
@@ -520,7 +495,7 @@ func createTokenIssuer(cfg *appview.Config) (*token.Issuer, error) {
}
// initializeJetstream initializes the Jetstream workers for real-time events and backfill
func initializeJetstream(database *sql.DB, jetstreamCfg *appview.JetstreamConfig, defaultHoldDID string, testMode bool) {
func initializeJetstream(database *sql.DB, jetstreamCfg *appview.JetstreamConfig, defaultHoldDID string, testMode bool, refresher *oauth.Refresher) {
// Start Jetstream worker
jetstreamURL := jetstreamCfg.URL
@@ -544,7 +519,7 @@ func initializeJetstream(database *sql.DB, jetstreamCfg *appview.JetstreamConfig
// Get relay endpoint for sync API (defaults to Bluesky's relay)
relayEndpoint := jetstreamCfg.RelayEndpoint
backfillWorker, err := jetstream.NewBackfillWorker(database, relayEndpoint, defaultHoldDID, testMode)
backfillWorker, err := jetstream.NewBackfillWorker(database, relayEndpoint, defaultHoldDID, testMode, refresher)
if err != nil {
slog.Warn("Failed to create backfill worker", "component", "jetstream/backfill", "error", err)
} else {

View File

@@ -0,0 +1,84 @@
# Hold Service XRPC Endpoints
This document lists all XRPC endpoints implemented in the Hold service (`pkg/hold/`).
## PDS Endpoints (`pkg/hold/pds/xrpc.go`)
### Public (No Auth Required)
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/xrpc/_health` | GET | Health check |
| `/xrpc/com.atproto.server.describeServer` | GET | Server metadata |
| `/xrpc/com.atproto.repo.describeRepo` | GET | Repository information |
| `/xrpc/com.atproto.repo.getRecord` | GET | Retrieve a single record |
| `/xrpc/com.atproto.repo.listRecords` | GET | List records in a collection (paginated) |
| `/xrpc/com.atproto.sync.listRepos` | GET | List all repositories |
| `/xrpc/com.atproto.sync.getRecord` | GET | Get record as CAR file |
| `/xrpc/com.atproto.sync.getRepo` | GET | Full repository as CAR file |
| `/xrpc/com.atproto.sync.getRepoStatus` | GET | Repository hosting status |
| `/xrpc/com.atproto.sync.subscribeRepos` | GET | WebSocket firehose |
| `/xrpc/com.atproto.identity.resolveHandle` | GET | Resolve handle to DID |
| `/xrpc/app.bsky.actor.getProfile` | GET | Get actor profile |
| `/xrpc/app.bsky.actor.getProfiles` | GET | Get multiple profiles |
| `/.well-known/did.json` | GET | DID document |
| `/.well-known/atproto-did` | GET | DID for handle resolution |
### Conditional Auth (based on captain.public)
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/xrpc/com.atproto.sync.getBlob` | GET/HEAD | Get blob (routes OCI vs ATProto) |
### Owner/Crew Admin Required
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/xrpc/com.atproto.repo.deleteRecord` | POST | Delete a record |
| `/xrpc/com.atproto.repo.uploadBlob` | POST | Upload ATProto blob |
### DPoP Auth Required
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/xrpc/io.atcr.hold.requestCrew` | POST | Request crew membership |
---
## OCI Multipart Upload Endpoints (`pkg/hold/oci/xrpc.go`)
All require `blob:write` permission via service token:
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/xrpc/io.atcr.hold.initiateUpload` | POST | Start multipart upload |
| `/xrpc/io.atcr.hold.getPartUploadUrl` | POST | Get presigned URL for part |
| `/xrpc/io.atcr.hold.uploadPart` | PUT | Direct buffered part upload |
| `/xrpc/io.atcr.hold.completeUpload` | POST | Finalize multipart upload |
| `/xrpc/io.atcr.hold.abortUpload` | POST | Cancel multipart upload |
| `/xrpc/io.atcr.hold.notifyManifest` | POST | Notify manifest push (creates layer records + optional Bluesky post) |
---
## Standard ATProto Endpoints (excluding io.atcr.hold.*)
| Endpoint |
|----------|
| /xrpc/_health |
| /xrpc/com.atproto.server.describeServer |
| /xrpc/com.atproto.repo.describeRepo |
| /xrpc/com.atproto.repo.getRecord |
| /xrpc/com.atproto.repo.listRecords |
| /xrpc/com.atproto.repo.deleteRecord |
| /xrpc/com.atproto.repo.uploadBlob |
| /xrpc/com.atproto.sync.listRepos |
| /xrpc/com.atproto.sync.getRecord |
| /xrpc/com.atproto.sync.getRepo |
| /xrpc/com.atproto.sync.getRepoStatus |
| /xrpc/com.atproto.sync.getBlob |
| /xrpc/com.atproto.sync.subscribeRepos |
| /xrpc/com.atproto.identity.resolveHandle |
| /xrpc/app.bsky.actor.getProfile |
| /xrpc/app.bsky.actor.getProfiles |
| /.well-known/did.json |
| /.well-known/atproto-did |

View File

@@ -112,7 +112,6 @@ Several packages show decreased percentages despite improvements. This is due to
**Remaining gaps:**
- `notifyHoldAboutManifest()` - 0% (background notification, less critical)
- `refreshReadmeCache()` - 11.8% (UI feature, lower priority)
## Critical Priority: Core Registry Functionality
@@ -423,12 +422,12 @@ Embedded PDS implementation. Has good test coverage for critical parts, but supp
---
### 🟡 pkg/appview/readme (16.7% coverage)
### 🟡 pkg/appview/readme (Partial coverage)
README fetching and caching. Less critical but still needs work.
README rendering for repo page descriptions. The cache.go was removed as README content is now stored in `io.atcr.repo.page` records and synced via Jetstream.
#### cache.go (0% coverage)
#### fetcher.go (📊 Partial coverage)
- `RenderMarkdown()` - renders repo page description markdown
---

399
docs/VALKEY_MIGRATION.md Normal file
View File

@@ -0,0 +1,399 @@
# Analysis: AppView SQL Database Usage
## Overview
The AppView uses SQLite with 19 tables. The key finding: **most data is a cache of ATProto records** that could theoretically be rebuilt from users' PDS instances.
## Data Categories
### 1. MUST PERSIST (Local State Only)
These tables contain data that **cannot be reconstructed** from external sources:
| Table | Purpose | Why It Must Persist |
|-------|---------|---------------------|
| `oauth_sessions` | OAuth tokens | Refresh tokens are stateful; losing them = users must re-auth |
| `ui_sessions` | Web browser sessions | Session continuity for logged-in users |
| `devices` | Approved devices + bcrypt secrets | User authorization decisions; secrets are one-way hashed |
| `pending_device_auth` | In-flight auth flows | Short-lived (10min) but critical during auth |
| `oauth_auth_requests` | OAuth flow state | Short-lived but required for auth completion |
| `repository_stats` | Pull/push counts | **Locally tracked metrics** - not stored in ATProto |
### 2. CACHED FROM PDS (Rebuildable)
These tables are essentially a **read-through cache** of ATProto data:
| Table | Source | ATProto Collection |
|-------|--------|-------------------|
| `users` | User's PDS profile | `app.bsky.actor.profile` + DID document |
| `manifests` | User's PDS | `io.atcr.manifest` records |
| `tags` | User's PDS | `io.atcr.tag` records |
| `layers` | Derived from manifests | Parsed from manifest content |
| `manifest_references` | Derived from manifest lists | Parsed from multi-arch manifests |
| `repository_annotations` | Manifest config blob | OCI annotations from config |
| `repo_pages` | User's PDS | `io.atcr.repo.page` records |
| `stars` | User's PDS | `io.atcr.sailor.star` records (synced via Jetstream) |
| `hold_captain_records` | Hold's embedded PDS | `io.atcr.hold.captain` records |
| `hold_crew_approvals` | Hold's embedded PDS | `io.atcr.hold.crew` records |
| `hold_crew_denials` | Local authorization cache | Could re-check on demand |
### 3. OPERATIONAL
| Table | Purpose |
|-------|---------|
| `schema_migrations` | Migration tracking |
| `firehose_cursor` | Jetstream position (can restart from 0) |
## Key Insights
### What's Actually Unique to AppView?
1. **Authentication state** - OAuth sessions, devices, UI sessions
2. **Engagement metrics** - Pull/push counts (locally tracked, not in ATProto)
### What Could Be Eliminated?
If ATCR fully embraced the ATProto model:
1. **`users`** - Query PDS on demand (with caching)
2. **`manifests`, `tags`, `layers`** - Query PDS on demand (with caching)
3. **`repository_annotations`** - Fetch manifest config on demand
4. **`repo_pages`** - Query PDS on demand
5. **`hold_*` tables** - Query hold's PDS on demand
### Trade-offs
**Current approach (heavy caching):**
- Fast queries for UI (search, browse, stats)
- Offline resilience (PDS down doesn't break UI)
- Complex sync logic (Jetstream consumer, backfill)
- State can diverge from source of truth
**Lighter approach (query on demand):**
- Always fresh data
- Simpler codebase (no sync)
- Slower queries (network round-trips)
- Depends on PDS availability
## Current Limitation: No Cache-Miss Queries
**Finding:** There's no "query PDS on cache miss" logic. Users/manifests only enter the DB via:
1. OAuth login (user authenticates)
2. Jetstream events (firehose activity)
**Problem:** If someone visits `atcr.io/alice/myapp` before alice is indexed → 404
**Where this happens:**
- `pkg/appview/handlers/repository.go:50-53`: If `db.GetUserByDID()` returns nil → 404
- No fallback to `atproto.Client.ListRecords()` or similar
**This matters for Valkey migration:** If cache is ephemeral and restarts clear it, you need cache-miss logic to repopulate on demand. Otherwise:
- Restart Valkey → all users/manifests gone
- Wait for Jetstream to re-index OR implement cache-miss queries
**Cache-miss implementation design:**
Existing code to reuse: `pkg/appview/jetstream/processor.go:43-97` (`EnsureUser`)
```go
// New: pkg/appview/cache/loader.go
type Loader struct {
cache Cache // Valkey interface
client *atproto.Client
}
// GetUser with cache-miss fallback
func (l *Loader) GetUser(ctx context.Context, did string) (*User, error) {
// 1. Try cache
if user := l.cache.GetUser(did); user != nil {
return user, nil
}
// 2. Cache miss - resolve identity (already queries network)
_, handle, pdsEndpoint, err := atproto.ResolveIdentity(ctx, did)
if err != nil {
return nil, err // User doesn't exist in network
}
// 3. Fetch profile for avatar
client := atproto.NewClient(pdsEndpoint, "", "")
profile, _ := client.GetProfileRecord(ctx, did)
avatarURL := ""
if profile != nil && profile.Avatar != nil {
avatarURL = atproto.BlobCDNURL(did, profile.Avatar.Ref.Link)
}
// 4. Cache and return
user := &User{DID: did, Handle: handle, PDSEndpoint: pdsEndpoint, Avatar: avatarURL}
l.cache.SetUser(user, 1*time.Hour)
return user, nil
}
// GetManifestsForRepo with cache-miss fallback
func (l *Loader) GetManifestsForRepo(ctx context.Context, did, repo string) ([]Manifest, error) {
cacheKey := fmt.Sprintf("manifests:%s:%s", did, repo)
// 1. Try cache
if cached := l.cache.Get(cacheKey); cached != nil {
return cached.([]Manifest), nil
}
// 2. Cache miss - get user's PDS endpoint
user, err := l.GetUser(ctx, did)
if err != nil {
return nil, err
}
// 3. Query PDS for manifests
client := atproto.NewClient(user.PDSEndpoint, "", "")
records, _, err := client.ListRecordsForRepo(ctx, did, atproto.ManifestCollection, 100, "")
if err != nil {
return nil, err
}
// 4. Filter by repository and parse
var manifests []Manifest
for _, rec := range records {
var m atproto.ManifestRecord
if err := json.Unmarshal(rec.Value, &m); err != nil {
continue
}
if m.Repository == repo {
manifests = append(manifests, convertManifest(m))
}
}
// 5. Cache and return
l.cache.Set(cacheKey, manifests, 10*time.Minute)
return manifests, nil
}
```
**Handler changes:**
```go
// Before (repository.go:45-53):
owner, err := db.GetUserByDID(h.DB, did)
if owner == nil {
RenderNotFound(w, r, h.Templates, h.RegistryURL)
return
}
// After:
owner, err := h.Loader.GetUser(r.Context(), did)
if err != nil {
RenderNotFound(w, r, h.Templates, h.RegistryURL)
return
}
```
**Performance considerations:**
- Cache hit: ~1ms (Valkey lookup)
- Cache miss: ~200-500ms (PDS round-trip)
- First request after restart: slower but correct
- Jetstream still useful for proactive warming
---
## Proposed Architecture: Valkey + ATProto
### Goal
Replace SQLite with Valkey (Redis-compatible) for ephemeral state, push remaining persistent data to ATProto.
### What goes to Valkey (ephemeral, TTL-based)
| Current Table | Valkey Key Pattern | TTL | Notes |
|---------------|-------------------|-----|-------|
| `oauth_sessions` | `oauth:{did}:{session_id}` | 90 days | Lost on restart = re-auth |
| `ui_sessions` | `ui:{session_id}` | Session duration | Lost on restart = re-login |
| `oauth_auth_requests` | `authreq:{state}` | 10 min | In-flight flows |
| `pending_device_auth` | `pending:{device_code}` | 10 min | In-flight flows |
| `firehose_cursor` | `cursor:jetstream` | None | Can restart from 0 |
| All PDS cache tables | `cache:{collection}:{did}:{rkey}` | 10-60 min | Query PDS on miss |
**Benefits:**
- Multi-instance ready (shared Valkey)
- No schema migrations
- Natural TTL expiry
- Simpler code (no SQL)
### What could become ATProto records
| Current Table | Proposed Collection | Where Stored | Open Questions |
|---------------|---------------------|--------------|----------------|
| `devices` | `io.atcr.sailor.device` | User's PDS | Privacy: IP, user-agent sensitive? |
| `repository_stats` | `io.atcr.repo.stats` | Hold's PDS or User's PDS | Who owns the stats? |
**Devices → Valkey:**
- Move current device table to Valkey
- Key: `device:{did}:{device_id}``{name, secret_hash, ip, user_agent, created_at, last_used}`
- TTL: Long (1 year?) or no expiry
- Device list: `devices:{did}` → Set of device IDs
- Secret validation works the same, just different backend
**Service auth exploration (future):**
The challenge with pure ATProto service auth is the AppView still needs the user's OAuth session to write manifests to their PDS. The current flow:
1. User authenticates via OAuth → AppView gets OAuth tokens
2. AppView issues registry JWT to credential helper
3. Credential helper presents JWT on each push/pull
4. AppView uses OAuth session to write to user's PDS
Service auth could work for the hold side (AppView → Hold), but not for the user's OAuth session.
**Repository stats → Hold's PDS:**
**Challenge discovered:** The hold's `getBlob` endpoint only receives `did` + `cid`, not the repository name.
Current flow (`proxy_blob_store.go:358-362`):
```go
xrpcURL := fmt.Sprintf("%s%s?did=%s&cid=%s&method=%s",
p.holdURL, atproto.SyncGetBlob, p.ctx.DID, dgst.String(), operation)
```
**Implementation options:**
**Option A: Add repository parameter to getBlob (recommended)**
```go
// Modified AppView call:
xrpcURL := fmt.Sprintf("%s%s?did=%s&cid=%s&method=%s&repo=%s",
p.holdURL, atproto.SyncGetBlob, p.ctx.DID, dgst.String(), operation, p.ctx.Repository)
```
```go
// Modified hold handler (xrpc.go:969):
func (h *XRPCHandler) HandleGetBlob(w http.ResponseWriter, r *http.Request) {
did := r.URL.Query().Get("did")
cidOrDigest := r.URL.Query().Get("cid")
repo := r.URL.Query().Get("repo") // NEW
// ... existing blob handling ...
// Increment stats if repo provided
if repo != "" {
go h.pds.IncrementPullCount(did, repo) // Async, non-blocking
}
}
```
**Stats record structure:**
```
Collection: io.atcr.hold.stats
Rkey: base64(did:repository) // Deterministic, unique
{
"$type": "io.atcr.hold.stats",
"did": "did:plc:alice123",
"repository": "myapp",
"pullCount": 1542,
"pushCount": 47,
"lastPull": "2025-01-15T...",
"lastPush": "2025-01-10T...",
"createdAt": "2025-01-01T..."
}
```
**Hold-side implementation:**
```go
// New file: pkg/hold/pds/stats.go
func (p *HoldPDS) IncrementPullCount(ctx context.Context, did, repo string) error {
rkey := statsRecordKey(did, repo)
// Get or create stats record
stats, err := p.GetStatsRecord(ctx, rkey)
if err != nil || stats == nil {
stats = &atproto.StatsRecord{
Type: atproto.StatsCollection,
DID: did,
Repository: repo,
PullCount: 0,
PushCount: 0,
CreatedAt: time.Now(),
}
}
// Increment and update
stats.PullCount++
stats.LastPull = time.Now()
_, err = p.repomgr.UpdateRecord(ctx, p.uid, atproto.StatsCollection, rkey, stats)
return err
}
```
**Query endpoint (new XRPC):**
```
GET /xrpc/io.atcr.hold.getStats?did={userDID}&repo={repository}
→ Returns JSON: { pullCount, pushCount, lastPull, lastPush }
GET /xrpc/io.atcr.hold.listStats?did={userDID}
→ Returns all stats for a user across all repos on this hold
```
**AppView aggregation:**
```go
func (l *Loader) GetAggregatedStats(ctx context.Context, did, repo string) (*Stats, error) {
// 1. Get all holds that have served this repo
holdDIDs, _ := l.cache.GetHoldDIDsForRepo(did, repo)
// 2. Query each hold for stats
var total Stats
for _, holdDID := range holdDIDs {
holdURL := resolveHoldDID(holdDID)
stats, _ := queryHoldStats(ctx, holdURL, did, repo)
total.PullCount += stats.PullCount
total.PushCount += stats.PushCount
}
return &total, nil
}
```
**Files to modify:**
- `pkg/atproto/lexicon.go` - Add `StatsCollection` + `StatsRecord`
- `pkg/hold/pds/stats.go` - New file for stats operations
- `pkg/hold/pds/xrpc.go` - Add `repo` param to getBlob, add stats endpoints
- `pkg/appview/storage/proxy_blob_store.go` - Pass repository to getBlob
- `pkg/appview/cache/loader.go` - Aggregation logic
### Migration Path
**Phase 1: Add Valkey infrastructure**
- Add Valkey client to AppView
- Create store interfaces that abstract SQLite vs Valkey
- Dual-write OAuth sessions to both
**Phase 2: Migrate sessions to Valkey**
- OAuth sessions, UI sessions, auth requests, pending device auth
- Remove SQLite session tables
- Test: restart AppView, users get logged out (acceptable)
**Phase 3: Migrate devices to Valkey**
- Move device store to Valkey
- Same data structure, different backend
- Consider device expiry policy
**Phase 4: Implement hold-side stats**
- Add `io.atcr.hold.stats` collection to hold's embedded PDS
- Hold increments stats on blob access
- Add XRPC endpoint: `io.atcr.hold.getStats`
**Phase 5: AppView stats aggregation**
- Track holdDids per repo in Valkey cache
- Query holds for stats, aggregate
- Cache aggregated stats with TTL
**Phase 6: Remove SQLite (optional)**
- Keep SQLite as optional cache layer for UI queries
- Or: Query PDS on demand with Valkey caching
- Jetstream still useful for real-time updates
## Summary Table
| Category | Tables | % of Schema | Truly Persistent? |
|----------|--------|-------------|-------------------|
| Auth & Sessions + Metrics | 6 | 32% | Yes |
| PDS Cache | 11 | 58% | No (rebuildable) |
| Operational | 2 | 10% | No |
**~58% of the database is cached ATProto data that could be rebuilt from PDSes.**

2
go.mod
View File

@@ -1,6 +1,6 @@
module atcr.io
go 1.25.5
go 1.25.4
require (
github.com/aws/aws-sdk-go v1.55.5

View File

@@ -0,0 +1,21 @@
{
"lexicon": 1,
"id": "io.atcr.authFullApp",
"defs": {
"main": {
"type": "permission-set",
"title": "AT Container Registry",
"title:langs": {},
"detail": "Push and pull container images to the ATProto Container Registry. Includes creating and managing image manifests, tags, and repository settings.",
"detail:langs": {},
"permissions": [
{
"type": "permission",
"resource": "repo",
"action": ["create", "update", "delete"],
"collection": ["io.atcr.manifest", "io.atcr.tag", "io.atcr.sailor.star", "io.atcr.sailor.profile", "io.atcr.repo.page"]
}
]
}
}
}

View File

@@ -34,11 +34,13 @@
},
"region": {
"type": "string",
"description": "S3 region where blobs are stored"
"description": "S3 region where blobs are stored",
"maxLength": 64
},
"provider": {
"type": "string",
"description": "Deployment provider (e.g., fly.io, aws, etc.)"
"description": "Deployment provider (e.g., fly.io, aws, etc.)",
"maxLength": 64
}
}
}

View File

@@ -18,13 +18,15 @@
"role": {
"type": "string",
"description": "Member's role in the hold",
"knownValues": ["owner", "admin", "write", "read"]
"knownValues": ["owner", "admin", "write", "read"],
"maxLength": 32
},
"permissions": {
"type": "array",
"description": "Specific permissions granted to this member",
"items": {
"type": "string"
"type": "string",
"maxLength": 64
}
},
"addedAt": {

View File

@@ -12,7 +12,8 @@
"properties": {
"digest": {
"type": "string",
"description": "Layer digest (e.g., sha256:abc123...)"
"description": "Layer digest (e.g., sha256:abc123...)",
"maxLength": 128
},
"size": {
"type": "integer",
@@ -20,11 +21,13 @@
},
"mediaType": {
"type": "string",
"description": "Media type (e.g., application/vnd.oci.image.layer.v1.tar+gzip)"
"description": "Media type (e.g., application/vnd.oci.image.layer.v1.tar+gzip)",
"maxLength": 128
},
"repository": {
"type": "string",
"description": "Repository this layer belongs to"
"description": "Repository this layer belongs to",
"maxLength": 255
},
"userDid": {
"type": "string",

View File

@@ -17,7 +17,8 @@
},
"digest": {
"type": "string",
"description": "Content digest (e.g., 'sha256:abc123...')"
"description": "Content digest (e.g., 'sha256:abc123...')",
"maxLength": 128
},
"holdDid": {
"type": "string",
@@ -37,7 +38,8 @@
"application/vnd.docker.distribution.manifest.v2+json",
"application/vnd.oci.image.index.v1+json",
"application/vnd.docker.distribution.manifest.list.v2+json"
]
],
"maxLength": 128
},
"schemaVersion": {
"type": "integer",
@@ -65,8 +67,8 @@
"description": "Referenced manifests (for manifest lists/indexes)"
},
"annotations": {
"type": "object",
"description": "Optional metadata annotations"
"type": "unknown",
"description": "Optional OCI annotation metadata. Map of string keys to string values (e.g., org.opencontainers.image.title → 'My App')."
},
"subject": {
"type": "ref",
@@ -92,7 +94,8 @@
"properties": {
"mediaType": {
"type": "string",
"description": "MIME type of the blob"
"description": "MIME type of the blob",
"maxLength": 128
},
"size": {
"type": "integer",
@@ -100,7 +103,8 @@
},
"digest": {
"type": "string",
"description": "Content digest (e.g., 'sha256:...')"
"description": "Content digest (e.g., 'sha256:...')",
"maxLength": 128
},
"urls": {
"type": "array",
@@ -111,8 +115,8 @@
"description": "Optional direct URLs to blob (for BYOS)"
},
"annotations": {
"type": "object",
"description": "Optional metadata"
"type": "unknown",
"description": "Optional OCI annotation metadata. Map of string keys to string values."
}
}
},
@@ -123,7 +127,8 @@
"properties": {
"mediaType": {
"type": "string",
"description": "Media type of the referenced manifest"
"description": "Media type of the referenced manifest",
"maxLength": 128
},
"size": {
"type": "integer",
@@ -131,7 +136,8 @@
},
"digest": {
"type": "string",
"description": "Content digest (e.g., 'sha256:...')"
"description": "Content digest (e.g., 'sha256:...')",
"maxLength": 128
},
"platform": {
"type": "ref",
@@ -139,8 +145,8 @@
"description": "Platform information for this manifest"
},
"annotations": {
"type": "object",
"description": "Optional metadata"
"type": "unknown",
"description": "Optional OCI annotation metadata. Map of string keys to string values."
}
}
},
@@ -151,26 +157,31 @@
"properties": {
"architecture": {
"type": "string",
"description": "CPU architecture (e.g., 'amd64', 'arm64', 'arm')"
"description": "CPU architecture (e.g., 'amd64', 'arm64', 'arm')",
"maxLength": 32
},
"os": {
"type": "string",
"description": "Operating system (e.g., 'linux', 'windows', 'darwin')"
"description": "Operating system (e.g., 'linux', 'windows', 'darwin')",
"maxLength": 32
},
"osVersion": {
"type": "string",
"description": "Optional OS version"
"description": "Optional OS version",
"maxLength": 64
},
"osFeatures": {
"type": "array",
"items": {
"type": "string"
"type": "string",
"maxLength": 64
},
"description": "Optional OS features"
},
"variant": {
"type": "string",
"description": "Optional CPU variant (e.g., 'v7' for ARM)"
"description": "Optional CPU variant (e.g., 'v7' for ARM)",
"maxLength": 32
}
}
}

View File

@@ -0,0 +1,43 @@
{
"lexicon": 1,
"id": "io.atcr.repo.page",
"defs": {
"main": {
"type": "record",
"description": "Repository page metadata including description and avatar. Users can edit this directly in their PDS to customize their repository page.",
"key": "any",
"record": {
"type": "object",
"required": ["repository", "createdAt", "updatedAt"],
"properties": {
"repository": {
"type": "string",
"description": "The name of the repository (e.g., 'myapp'). Must match the rkey.",
"maxLength": 256
},
"description": {
"type": "string",
"description": "Markdown README/description content for the repository page.",
"maxLength": 100000
},
"avatar": {
"type": "blob",
"description": "Repository avatar/icon image.",
"accept": ["image/png", "image/jpeg", "image/webp"],
"maxSize": 3000000
},
"createdAt": {
"type": "string",
"format": "datetime",
"description": "Record creation timestamp"
},
"updatedAt": {
"type": "string",
"format": "datetime",
"description": "Record last updated timestamp"
}
}
}
}
}
}

View File

@@ -27,7 +27,8 @@
},
"manifestDigest": {
"type": "string",
"description": "DEPRECATED: Digest of the manifest (e.g., 'sha256:...'). Kept for backward compatibility with old records. New records should use 'manifest' field instead."
"description": "DEPRECATED: Digest of the manifest (e.g., 'sha256:...'). Kept for backward compatibility with old records. New records should use 'manifest' field instead.",
"maxLength": 128
},
"createdAt": {
"type": "string",

View File

@@ -79,9 +79,6 @@ type HealthConfig struct {
// CheckInterval is the hold health check refresh interval (from env: ATCR_HEALTH_CHECK_INTERVAL, default: 15m)
CheckInterval time.Duration `yaml:"check_interval"`
// ReadmeCacheTTL is the README cache TTL (from env: ATCR_README_CACHE_TTL, default: 1h)
ReadmeCacheTTL time.Duration `yaml:"readme_cache_ttl"`
}
// JetstreamConfig defines ATProto Jetstream settings
@@ -165,7 +162,6 @@ func LoadConfigFromEnv() (*Config, error) {
// Health and cache configuration
cfg.Health.CacheTTL = getDurationOrDefault("ATCR_HEALTH_CACHE_TTL", 15*time.Minute)
cfg.Health.CheckInterval = getDurationOrDefault("ATCR_HEALTH_CHECK_INTERVAL", 15*time.Minute)
cfg.Health.ReadmeCacheTTL = getDurationOrDefault("ATCR_README_CACHE_TTL", 1*time.Hour)
// Jetstream configuration
cfg.Jetstream.URL = getEnvOrDefault("JETSTREAM_URL", "wss://jetstream2.us-west.bsky.network/subscribe")

View File

@@ -0,0 +1,18 @@
description: Add repo_pages table and remove readme_cache
query: |
-- Create repo_pages table for storing repository page metadata
-- This replaces readme_cache with PDS-synced data
CREATE TABLE IF NOT EXISTS repo_pages (
did TEXT NOT NULL,
repository TEXT NOT NULL,
description TEXT,
avatar_cid TEXT,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP NOT NULL,
PRIMARY KEY(did, repository),
FOREIGN KEY(did) REFERENCES users(did) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_repo_pages_did ON repo_pages(did);
-- Drop readme_cache table (no longer needed)
DROP TABLE IF EXISTS readme_cache;

View File

@@ -148,8 +148,9 @@ type PlatformInfo struct {
// TagWithPlatforms extends Tag with platform information
type TagWithPlatforms struct {
Tag
Platforms []PlatformInfo
IsMultiArch bool
Platforms []PlatformInfo
IsMultiArch bool
HasAttestations bool // true if manifest list contains attestation references
}
// ManifestWithMetadata extends Manifest with tags and platform information

View File

@@ -7,6 +7,12 @@ import (
"time"
)
// BlobCDNURL returns the CDN URL for an ATProto blob
// This is a local copy to avoid importing atproto (prevents circular dependencies)
func BlobCDNURL(did, cid string) string {
return fmt.Sprintf("https://imgs.blue/%s/%s", did, cid)
}
// escapeLikePattern escapes SQL LIKE wildcards (%, _) and backslash for safe searching.
// It also sanitizes the input to prevent injection attacks via special characters.
func escapeLikePattern(s string) string {
@@ -46,11 +52,13 @@ func GetRecentPushes(db *sql.DB, limit, offset int, userFilter string, currentUs
COALESCE((SELECT COUNT(*) FROM stars WHERE owner_did = u.did AND repository = t.repository), 0),
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = u.did AND repository = t.repository), 0),
t.created_at,
m.hold_endpoint
m.hold_endpoint,
COALESCE(rp.avatar_cid, '')
FROM tags t
JOIN users u ON t.did = u.did
JOIN manifests m ON t.did = m.did AND t.repository = m.repository AND t.digest = m.digest
LEFT JOIN repository_stats rs ON t.did = rs.did AND t.repository = rs.repository
LEFT JOIN repo_pages rp ON t.did = rp.did AND t.repository = rp.repository
`
args := []any{currentUserDID}
@@ -73,10 +81,15 @@ func GetRecentPushes(db *sql.DB, limit, offset int, userFilter string, currentUs
for rows.Next() {
var p Push
var isStarredInt int
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint); err != nil {
var avatarCID string
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint, &avatarCID); err != nil {
return nil, 0, err
}
p.IsStarred = isStarredInt > 0
// Prefer repo page avatar over annotation icon
if avatarCID != "" {
p.IconURL = BlobCDNURL(p.DID, avatarCID)
}
pushes = append(pushes, p)
}
@@ -119,11 +132,13 @@ func SearchPushes(db *sql.DB, query string, limit, offset int, currentUserDID st
COALESCE((SELECT COUNT(*) FROM stars WHERE owner_did = u.did AND repository = t.repository), 0),
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = u.did AND repository = t.repository), 0),
t.created_at,
m.hold_endpoint
m.hold_endpoint,
COALESCE(rp.avatar_cid, '')
FROM tags t
JOIN users u ON t.did = u.did
JOIN manifests m ON t.did = m.did AND t.repository = m.repository AND t.digest = m.digest
LEFT JOIN repository_stats rs ON t.did = rs.did AND t.repository = rs.repository
LEFT JOIN repo_pages rp ON t.did = rp.did AND t.repository = rp.repository
WHERE u.handle LIKE ? ESCAPE '\'
OR u.did = ?
OR t.repository LIKE ? ESCAPE '\'
@@ -146,10 +161,15 @@ func SearchPushes(db *sql.DB, query string, limit, offset int, currentUserDID st
for rows.Next() {
var p Push
var isStarredInt int
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint); err != nil {
var avatarCID string
if err := rows.Scan(&p.DID, &p.Handle, &p.Repository, &p.Tag, &p.Digest, &p.Title, &p.Description, &p.IconURL, &p.PullCount, &p.StarCount, &isStarredInt, &p.CreatedAt, &p.HoldEndpoint, &avatarCID); err != nil {
return nil, 0, err
}
p.IsStarred = isStarredInt > 0
// Prefer repo page avatar over annotation icon
if avatarCID != "" {
p.IconURL = BlobCDNURL(p.DID, avatarCID)
}
pushes = append(pushes, p)
}
@@ -293,6 +313,12 @@ func GetUserRepositories(db *sql.DB, did string) ([]Repository, error) {
r.IconURL = annotations["io.atcr.icon"]
r.ReadmeURL = annotations["io.atcr.readme"]
// Check for repo page avatar (overrides annotation icon)
repoPage, err := GetRepoPage(db, did, r.Name)
if err == nil && repoPage != nil && repoPage.AvatarCID != "" {
r.IconURL = BlobCDNURL(did, repoPage.AvatarCID)
}
repos = append(repos, r)
}
@@ -596,6 +622,7 @@ func DeleteTag(db *sql.DB, did, repository, tag string) error {
// GetTagsWithPlatforms returns all tags for a repository with platform information
// Only multi-arch tags (manifest lists) have platform info in manifest_references
// Single-arch tags will have empty Platforms slice (platform is obvious for single-arch)
// Attestation references (unknown/unknown platforms) are filtered out but tracked via HasAttestations
func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatforms, error) {
rows, err := db.Query(`
SELECT
@@ -609,7 +636,8 @@ func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatform
COALESCE(mr.platform_os, '') as platform_os,
COALESCE(mr.platform_architecture, '') as platform_architecture,
COALESCE(mr.platform_variant, '') as platform_variant,
COALESCE(mr.platform_os_version, '') as platform_os_version
COALESCE(mr.platform_os_version, '') as platform_os_version,
COALESCE(mr.is_attestation, 0) as is_attestation
FROM tags t
JOIN manifests m ON t.digest = m.digest AND t.did = m.did AND t.repository = m.repository
LEFT JOIN manifest_references mr ON m.id = mr.manifest_id
@@ -629,9 +657,10 @@ func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatform
for rows.Next() {
var t Tag
var mediaType, platformOS, platformArch, platformVariant, platformOSVersion string
var isAttestation bool
if err := rows.Scan(&t.ID, &t.DID, &t.Repository, &t.Tag, &t.Digest, &t.CreatedAt,
&mediaType, &platformOS, &platformArch, &platformVariant, &platformOSVersion); err != nil {
&mediaType, &platformOS, &platformArch, &platformVariant, &platformOSVersion, &isAttestation); err != nil {
return nil, err
}
@@ -645,6 +674,13 @@ func GetTagsWithPlatforms(db *sql.DB, did, repository string) ([]TagWithPlatform
tagOrder = append(tagOrder, tagKey)
}
// Track if manifest list has attestations
if isAttestation {
tagMap[tagKey].HasAttestations = true
// Skip attestation references in platform display
continue
}
// Add platform info if present (only for multi-arch manifest lists)
if platformOS != "" || platformArch != "" {
tagMap[tagKey].Platforms = append(tagMap[tagKey].Platforms, PlatformInfo{
@@ -1598,31 +1634,6 @@ func parseTimestamp(s string) (time.Time, error) {
return time.Time{}, fmt.Errorf("unable to parse timestamp: %s", s)
}
// MetricsDB wraps a sql.DB and implements the metrics interface for middleware
type MetricsDB struct {
db *sql.DB
}
// NewMetricsDB creates a new metrics database wrapper
func NewMetricsDB(db *sql.DB) *MetricsDB {
return &MetricsDB{db: db}
}
// IncrementPullCount increments the pull count for a repository
func (m *MetricsDB) IncrementPullCount(did, repository string) error {
return IncrementPullCount(m.db, did, repository)
}
// IncrementPushCount increments the push count for a repository
func (m *MetricsDB) IncrementPushCount(did, repository string) error {
return IncrementPushCount(m.db, did, repository)
}
// GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository
func (m *MetricsDB) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
return GetLatestHoldDIDForRepo(m.db, did, repository)
}
// GetFeaturedRepositories fetches top repositories sorted by stars and pulls
func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]FeaturedRepository, error) {
query := `
@@ -1650,11 +1661,13 @@ func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]Fe
COALESCE((SELECT value FROM repository_annotations WHERE did = m.did AND repository = m.repository AND key = 'io.atcr.icon'), ''),
rs.pull_count,
rs.star_count,
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = m.did AND repository = m.repository), 0)
COALESCE((SELECT COUNT(*) FROM stars WHERE starrer_did = ? AND owner_did = m.did AND repository = m.repository), 0),
COALESCE(rp.avatar_cid, '')
FROM latest_manifests lm
JOIN manifests m ON lm.latest_id = m.id
JOIN users u ON m.did = u.did
JOIN repo_stats rs ON m.did = rs.did AND m.repository = rs.repository
LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
ORDER BY rs.score DESC, rs.star_count DESC, rs.pull_count DESC, m.created_at DESC
LIMIT ?
`
@@ -1669,15 +1682,88 @@ func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]Fe
for rows.Next() {
var f FeaturedRepository
var isStarredInt int
var avatarCID string
if err := rows.Scan(&f.OwnerDID, &f.OwnerHandle, &f.Repository,
&f.Title, &f.Description, &f.IconURL, &f.PullCount, &f.StarCount, &isStarredInt); err != nil {
&f.Title, &f.Description, &f.IconURL, &f.PullCount, &f.StarCount, &isStarredInt, &avatarCID); err != nil {
return nil, err
}
f.IsStarred = isStarredInt > 0
// Prefer repo page avatar over annotation icon
if avatarCID != "" {
f.IconURL = BlobCDNURL(f.OwnerDID, avatarCID)
}
featured = append(featured, f)
}
return featured, nil
}
// RepoPage represents a repository page record cached from PDS
type RepoPage struct {
DID string
Repository string
Description string
AvatarCID string
CreatedAt time.Time
UpdatedAt time.Time
}
// UpsertRepoPage inserts or updates a repo page record
func UpsertRepoPage(db *sql.DB, did, repository, description, avatarCID string, createdAt, updatedAt time.Time) error {
_, err := db.Exec(`
INSERT INTO repo_pages (did, repository, description, avatar_cid, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(did, repository) DO UPDATE SET
description = excluded.description,
avatar_cid = excluded.avatar_cid,
updated_at = excluded.updated_at
`, did, repository, description, avatarCID, createdAt, updatedAt)
return err
}
// GetRepoPage retrieves a repo page record
func GetRepoPage(db *sql.DB, did, repository string) (*RepoPage, error) {
var rp RepoPage
err := db.QueryRow(`
SELECT did, repository, description, avatar_cid, created_at, updated_at
FROM repo_pages
WHERE did = ? AND repository = ?
`, did, repository).Scan(&rp.DID, &rp.Repository, &rp.Description, &rp.AvatarCID, &rp.CreatedAt, &rp.UpdatedAt)
if err != nil {
return nil, err
}
return &rp, nil
}
// DeleteRepoPage deletes a repo page record
func DeleteRepoPage(db *sql.DB, did, repository string) error {
_, err := db.Exec(`
DELETE FROM repo_pages WHERE did = ? AND repository = ?
`, did, repository)
return err
}
// GetRepoPagesByDID returns all repo pages for a DID
func GetRepoPagesByDID(db *sql.DB, did string) ([]RepoPage, error) {
rows, err := db.Query(`
SELECT did, repository, description, avatar_cid, created_at, updated_at
FROM repo_pages
WHERE did = ?
`, did)
if err != nil {
return nil, err
}
defer rows.Close()
var pages []RepoPage
for rows.Next() {
var rp RepoPage
if err := rows.Scan(&rp.DID, &rp.Repository, &rp.Description, &rp.AvatarCID, &rp.CreatedAt, &rp.UpdatedAt); err != nil {
return nil, err
}
pages = append(pages, rp)
}
return pages, rows.Err()
}

View File

@@ -205,9 +205,14 @@ CREATE TABLE IF NOT EXISTS hold_crew_denials (
);
CREATE INDEX IF NOT EXISTS idx_crew_denials_retry ON hold_crew_denials(next_retry_at);
CREATE TABLE IF NOT EXISTS readme_cache (
url TEXT PRIMARY KEY,
html TEXT NOT NULL,
fetched_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
CREATE TABLE IF NOT EXISTS repo_pages (
did TEXT NOT NULL,
repository TEXT NOT NULL,
description TEXT,
avatar_cid TEXT,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP NOT NULL,
PRIMARY KEY(did, repository),
FOREIGN KEY(did) REFERENCES users(did) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_readme_cache_fetched ON readme_cache(fetched_at);
CREATE INDEX IF NOT EXISTS idx_repo_pages_did ON repo_pages(did);

View File

@@ -0,0 +1,32 @@
package handlers
import (
"html/template"
"net/http"
)
// NotFoundHandler handles 404 errors
type NotFoundHandler struct {
Templates *template.Template
RegistryURL string
}
func (h *NotFoundHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
RenderNotFound(w, r, h.Templates, h.RegistryURL)
}
// RenderNotFound renders the 404 page template.
// Use this from other handlers when a resource is not found.
func RenderNotFound(w http.ResponseWriter, r *http.Request, templates *template.Template, registryURL string) {
w.WriteHeader(http.StatusNotFound)
data := struct {
PageData
}{
PageData: NewPageData(r, registryURL),
}
if err := templates.ExecuteTemplate(w, "404", data); err != nil {
http.Error(w, "Page not found", http.StatusNotFound)
}
}

View File

@@ -3,9 +3,12 @@ package handlers
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"atcr.io/pkg/appview/db"
"atcr.io/pkg/appview/middleware"
@@ -155,3 +158,114 @@ func (h *DeleteManifestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
w.WriteHeader(http.StatusOK)
}
// UploadAvatarHandler handles uploading/updating a repository avatar
type UploadAvatarHandler struct {
DB *sql.DB
Refresher *oauth.Refresher
}
// validImageTypes are the allowed MIME types for avatars (matches lexicon)
var validImageTypes = map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/webp": true,
}
func (h *UploadAvatarHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r)
if user == nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
repo := chi.URLParam(r, "repository")
// Parse multipart form (max 3MB to match lexicon maxSize)
if err := r.ParseMultipartForm(3 << 20); err != nil {
http.Error(w, "File too large (max 3MB)", http.StatusBadRequest)
return
}
file, header, err := r.FormFile("avatar")
if err != nil {
http.Error(w, "No file provided", http.StatusBadRequest)
return
}
defer file.Close()
// Validate MIME type
contentType := header.Header.Get("Content-Type")
if !validImageTypes[contentType] {
http.Error(w, "Invalid file type. Must be PNG, JPEG, or WebP", http.StatusBadRequest)
return
}
// Read file data
data, err := io.ReadAll(io.LimitReader(file, 3<<20+1)) // Read up to 3MB + 1 byte
if err != nil {
http.Error(w, "Failed to read file", http.StatusInternalServerError)
return
}
if len(data) > 3<<20 {
http.Error(w, "File too large (max 3MB)", http.StatusBadRequest)
return
}
// Create ATProto client with session provider (uses DoWithSession for DPoP nonce safety)
pdsClient := atproto.NewClientWithSessionProvider(user.PDSEndpoint, user.DID, h.Refresher)
// Upload blob to PDS
blobRef, err := pdsClient.UploadBlob(r.Context(), data, contentType)
if err != nil {
if handleOAuthError(r.Context(), h.Refresher, user.DID, err) {
http.Error(w, "Authentication failed, please log in again", http.StatusUnauthorized)
return
}
http.Error(w, fmt.Sprintf("Failed to upload image: %v", err), http.StatusInternalServerError)
return
}
// Fetch existing repo page record to preserve description
var existingDescription string
var existingCreatedAt time.Time
record, err := pdsClient.GetRecord(r.Context(), atproto.RepoPageCollection, repo)
if err == nil {
// Parse existing record to preserve description
var existingRecord atproto.RepoPageRecord
if jsonErr := json.Unmarshal(record.Value, &existingRecord); jsonErr == nil {
existingDescription = existingRecord.Description
existingCreatedAt = existingRecord.CreatedAt
}
} else if !errors.Is(err, atproto.ErrRecordNotFound) {
// Some other error - check if OAuth error
if handleOAuthError(r.Context(), h.Refresher, user.DID, err) {
http.Error(w, "Authentication failed, please log in again", http.StatusUnauthorized)
return
}
// Log but continue - we'll create a new record
}
// Create updated repo page record
repoPage := atproto.NewRepoPageRecord(repo, existingDescription, blobRef)
// Preserve original createdAt if record existed
if !existingCreatedAt.IsZero() {
repoPage.CreatedAt = existingCreatedAt
}
// Save record to PDS
_, err = pdsClient.PutRecord(r.Context(), atproto.RepoPageCollection, repo, repoPage)
if err != nil {
if handleOAuthError(r.Context(), h.Refresher, user.DID, err) {
http.Error(w, "Authentication failed, please log in again", http.StatusUnauthorized)
return
}
http.Error(w, fmt.Sprintf("Failed to update repository page: %v", err), http.StatusInternalServerError)
return
}
// Return new avatar URL
avatarURL := atproto.BlobCDNURL(user.DID, blobRef.Ref.Link)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"avatarURL": avatarURL})
}

View File

@@ -27,7 +27,7 @@ type RepositoryPageHandler struct {
Directory identity.Directory
Refresher *oauth.Refresher
HealthChecker *holdhealth.Checker
ReadmeCache *readme.Cache
ReadmeFetcher *readme.Fetcher // For rendering repo page descriptions
}
func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -37,7 +37,7 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
// Resolve identifier (handle or DID) to canonical DID and current handle
did, resolvedHandle, _, err := atproto.ResolveIdentity(r.Context(), identifier)
if err != nil {
http.Error(w, "User not found", http.StatusNotFound)
RenderNotFound(w, r, h.Templates, h.RegistryURL)
return
}
@@ -48,7 +48,7 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
return
}
if owner == nil {
http.Error(w, "User not found", http.StatusNotFound)
RenderNotFound(w, r, h.Templates, h.RegistryURL)
return
}
@@ -136,7 +136,7 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
}
if len(tagsWithPlatforms) == 0 && len(manifests) == 0 {
http.Error(w, "Repository not found", http.StatusNotFound)
RenderNotFound(w, r, h.Templates, h.RegistryURL)
return
}
@@ -190,19 +190,44 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
isOwner = (user.DID == owner.DID)
}
// Fetch README content if available
// Fetch README content from repo page record or annotations
var readmeHTML template.HTML
if repo.ReadmeURL != "" && h.ReadmeCache != nil {
// Fetch with timeout
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
html, err := h.ReadmeCache.Get(ctx, repo.ReadmeURL)
if err != nil {
slog.Warn("Failed to fetch README", "url", repo.ReadmeURL, "error", err)
// Continue without README on error
} else {
readmeHTML = template.HTML(html)
// Try repo page record from database (synced from PDS via Jetstream)
repoPage, err := db.GetRepoPage(h.DB, owner.DID, repository)
if err == nil && repoPage != nil {
// Use repo page avatar if present
if repoPage.AvatarCID != "" {
repo.IconURL = atproto.BlobCDNURL(owner.DID, repoPage.AvatarCID)
}
// Render description as markdown if present
if repoPage.Description != "" && h.ReadmeFetcher != nil {
html, err := h.ReadmeFetcher.RenderMarkdown([]byte(repoPage.Description))
if err != nil {
slog.Warn("Failed to render repo page description", "error", err)
} else {
readmeHTML = template.HTML(html)
}
}
}
// Fall back to fetching README from URL annotations if no description in repo page
if readmeHTML == "" && h.ReadmeFetcher != nil {
// Fall back to fetching from URL annotations
readmeURL := repo.ReadmeURL
if readmeURL == "" && repo.SourceURL != "" {
// Try to derive README URL from source URL
readmeURL = readme.DeriveReadmeURL(repo.SourceURL, "main")
if readmeURL == "" {
readmeURL = readme.DeriveReadmeURL(repo.SourceURL, "master")
}
}
if readmeURL != "" {
html, err := h.ReadmeFetcher.FetchAndRender(r.Context(), readmeURL)
if err != nil {
slog.Debug("Failed to fetch README from URL", "url", readmeURL, "error", err)
} else {
readmeHTML = template.HTML(html)
}
}
}

View File

@@ -62,9 +62,7 @@ func (h *SettingsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
data.Profile.Handle = user.Handle
data.Profile.DID = user.DID
data.Profile.PDSEndpoint = user.PDSEndpoint
if profile.DefaultHold != nil {
data.Profile.DefaultHold = *profile.DefaultHold
}
data.Profile.DefaultHold = profile.DefaultHold
if err := h.Templates.ExecuteTemplate(w, "settings", data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -96,9 +94,8 @@ func (h *UpdateDefaultHoldHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
profile = atproto.NewSailorProfileRecord(holdEndpoint)
} else {
// Update existing profile
profile.DefaultHold = &holdEndpoint
now := time.Now().Format(time.RFC3339)
profile.UpdatedAt = &now
profile.DefaultHold = holdEndpoint
profile.UpdatedAt = time.Now()
}
// Save profile

View File

@@ -23,7 +23,7 @@ func (h *UserPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Resolve identifier (handle or DID) to canonical DID and current handle
did, resolvedHandle, pdsEndpoint, err := atproto.ResolveIdentity(r.Context(), identifier)
if err != nil {
http.Error(w, "User not found", http.StatusNotFound)
RenderNotFound(w, r, h.Templates, h.RegistryURL)
return
}

View File

@@ -5,21 +5,26 @@ import (
"database/sql"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"atcr.io/pkg/appview/db"
"atcr.io/pkg/appview/readme"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth/oauth"
)
// BackfillWorker uses com.atproto.sync.listReposByCollection to backfill historical data
type BackfillWorker struct {
db *sql.DB
client *atproto.Client
processor *Processor // Shared processor for DB operations
defaultHoldDID string // Default hold DID from AppView config (e.g., "did:web:hold01.atcr.io")
testMode bool // If true, suppress warnings for external holds
processor *Processor // Shared processor for DB operations
defaultHoldDID string // Default hold DID from AppView config (e.g., "did:web:hold01.atcr.io")
testMode bool // If true, suppress warnings for external holds
refresher *oauth.Refresher // OAuth refresher for PDS writes (optional, can be nil)
}
// BackfillState tracks backfill progress
@@ -36,7 +41,8 @@ type BackfillState struct {
// NewBackfillWorker creates a backfill worker using sync API
// defaultHoldDID should be in format "did:web:hold01.atcr.io"
// To find a hold's DID, visit: https://hold-url/.well-known/did.json
func NewBackfillWorker(database *sql.DB, relayEndpoint, defaultHoldDID string, testMode bool) (*BackfillWorker, error) {
// refresher is optional - if provided, backfill will try to update PDS records when fetching README content
func NewBackfillWorker(database *sql.DB, relayEndpoint, defaultHoldDID string, testMode bool, refresher *oauth.Refresher) (*BackfillWorker, error) {
// Create client for relay - used only for listReposByCollection
client := atproto.NewClient(relayEndpoint, "", "")
@@ -46,6 +52,7 @@ func NewBackfillWorker(database *sql.DB, relayEndpoint, defaultHoldDID string, t
processor: NewProcessor(database, false), // No cache for batch processing
defaultHoldDID: defaultHoldDID,
testMode: testMode,
refresher: refresher,
}, nil
}
@@ -67,6 +74,7 @@ func (b *BackfillWorker) Start(ctx context.Context) error {
atproto.TagCollection, // io.atcr.tag
atproto.StarCollection, // io.atcr.sailor.star
atproto.SailorProfileCollection, // io.atcr.sailor.profile
atproto.RepoPageCollection, // io.atcr.repo.page
}
for _, collection := range collections {
@@ -164,12 +172,12 @@ func (b *BackfillWorker) backfillRepo(ctx context.Context, did, collection strin
// Track what we found for deletion reconciliation
switch collection {
case atproto.ManifestCollection:
var manifestRecord atproto.Manifest
var manifestRecord atproto.ManifestRecord
if err := json.Unmarshal(record.Value, &manifestRecord); err == nil {
foundManifestDigests = append(foundManifestDigests, manifestRecord.Digest)
}
case atproto.TagCollection:
var tagRecord atproto.Tag
var tagRecord atproto.TagRecord
if err := json.Unmarshal(record.Value, &tagRecord); err == nil {
foundTags = append(foundTags, struct{ Repository, Tag string }{
Repository: tagRecord.Repository,
@@ -177,15 +185,10 @@ func (b *BackfillWorker) backfillRepo(ctx context.Context, did, collection strin
})
}
case atproto.StarCollection:
var starRecord atproto.SailorStar
var starRecord atproto.StarRecord
if err := json.Unmarshal(record.Value, &starRecord); err == nil {
key := fmt.Sprintf("%s/%s", starRecord.Subject.Did, starRecord.Subject.Repository)
// Parse CreatedAt string to time.Time
createdAt, parseErr := time.Parse(time.RFC3339, starRecord.CreatedAt)
if parseErr != nil {
createdAt = time.Now()
}
foundStars[key] = createdAt
key := fmt.Sprintf("%s/%s", starRecord.Subject.DID, starRecord.Subject.Repository)
foundStars[key] = starRecord.CreatedAt
}
}
@@ -222,6 +225,13 @@ func (b *BackfillWorker) backfillRepo(ctx context.Context, did, collection strin
}
}
// After processing repo pages, fetch descriptions from external sources if empty
if collection == atproto.RepoPageCollection {
if err := b.reconcileRepoPageDescriptions(ctx, did, pdsEndpoint); err != nil {
slog.Warn("Backfill failed to reconcile repo page descriptions", "did", did, "error", err)
}
}
return recordCount, nil
}
@@ -287,6 +297,9 @@ func (b *BackfillWorker) processRecord(ctx context.Context, did, collection stri
return b.processor.ProcessStar(context.Background(), did, record.Value)
case atproto.SailorProfileCollection:
return b.processor.ProcessSailorProfile(ctx, did, record.Value, b.queryCaptainRecordWrapper)
case atproto.RepoPageCollection:
// rkey is extracted from the record URI, but for repo pages we use Repository field
return b.processor.ProcessRepoPage(ctx, did, record.URI, record.Value, false)
default:
return fmt.Errorf("unsupported collection: %s", collection)
}
@@ -364,12 +377,240 @@ func (b *BackfillWorker) queryCaptainRecord(ctx context.Context, holdDID string)
// reconcileAnnotations ensures annotations come from the newest manifest in each repository
// This fixes the out-of-order backfill issue where older manifests can overwrite newer annotations
// NOTE: Currently disabled because the generated Manifest_Annotations type doesn't support
// arbitrary key-value pairs. Would need to update lexicon schema with "unknown" type.
func (b *BackfillWorker) reconcileAnnotations(ctx context.Context, did string, pdsClient *atproto.Client) error {
// TODO: Re-enable once lexicon supports annotations as map[string]string
// For now, skip annotation reconciliation as the generated type is an empty struct
_ = did
_ = pdsClient
// Get all repositories for this DID
repositories, err := db.GetRepositoriesForDID(b.db, did)
if err != nil {
return fmt.Errorf("failed to get repositories: %w", err)
}
for _, repo := range repositories {
// Find newest manifest for this repository
newestManifest, err := db.GetNewestManifestForRepo(b.db, did, repo)
if err != nil {
slog.Warn("Backfill failed to get newest manifest for repo", "did", did, "repository", repo, "error", err)
continue // Skip on error
}
// Fetch the full manifest record from PDS using the digest as rkey
rkey := strings.TrimPrefix(newestManifest.Digest, "sha256:")
record, err := pdsClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
if err != nil {
slog.Warn("Backfill failed to fetch manifest record for repo", "did", did, "repository", repo, "error", err)
continue // Skip on error
}
// Parse manifest record
var manifestRecord atproto.ManifestRecord
if err := json.Unmarshal(record.Value, &manifestRecord); err != nil {
slog.Warn("Backfill failed to parse manifest record for repo", "did", did, "repository", repo, "error", err)
continue
}
// Update annotations from newest manifest only
if len(manifestRecord.Annotations) > 0 {
// Filter out empty annotations
hasData := false
for _, value := range manifestRecord.Annotations {
if value != "" {
hasData = true
break
}
}
if hasData {
err = db.UpsertRepositoryAnnotations(b.db, did, repo, manifestRecord.Annotations)
if err != nil {
slog.Warn("Backfill failed to reconcile annotations for repo", "did", did, "repository", repo, "error", err)
} else {
slog.Info("Backfill reconciled annotations for repo from newest manifest", "did", did, "repository", repo, "digest", newestManifest.Digest)
}
}
}
}
return nil
}
// reconcileRepoPageDescriptions fetches README content from external sources for repo pages with empty descriptions
// If the user has an OAuth session, it updates the PDS record (source of truth)
// Otherwise, it just stores the fetched content in the database
func (b *BackfillWorker) reconcileRepoPageDescriptions(ctx context.Context, did, pdsEndpoint string) error {
// Get all repo pages for this DID
repoPages, err := db.GetRepoPagesByDID(b.db, did)
if err != nil {
return fmt.Errorf("failed to get repo pages: %w", err)
}
for _, page := range repoPages {
// Skip pages that already have a description
if page.Description != "" {
continue
}
// Get annotations from the repository's manifest
annotations, err := db.GetRepositoryAnnotations(b.db, did, page.Repository)
if err != nil {
slog.Debug("Failed to get annotations for repo page", "did", did, "repository", page.Repository, "error", err)
continue
}
// Try to fetch README content from external sources
description := b.fetchReadmeContent(ctx, annotations)
if description == "" {
// No README content available, skip
continue
}
slog.Info("Fetched README for repo page", "did", did, "repository", page.Repository, "descriptionLength", len(description))
// Try to update PDS if we have OAuth session
pdsUpdated := false
if b.refresher != nil {
if err := b.updateRepoPageInPDS(ctx, did, pdsEndpoint, page.Repository, description, page.AvatarCID); err != nil {
slog.Debug("Could not update repo page in PDS, falling back to DB-only", "did", did, "repository", page.Repository, "error", err)
} else {
pdsUpdated = true
slog.Info("Updated repo page in PDS with fetched description", "did", did, "repository", page.Repository)
}
}
// Always update database with the fetched content
if err := db.UpsertRepoPage(b.db, did, page.Repository, description, page.AvatarCID, page.CreatedAt, time.Now()); err != nil {
slog.Warn("Failed to update repo page in database", "did", did, "repository", page.Repository, "error", err)
} else if !pdsUpdated {
slog.Info("Updated repo page in database (PDS not updated)", "did", did, "repository", page.Repository)
}
}
return nil
}
// fetchReadmeContent attempts to fetch README content from external sources based on annotations
// Priority: io.atcr.readme annotation > derived from org.opencontainers.image.source
func (b *BackfillWorker) fetchReadmeContent(ctx context.Context, annotations map[string]string) string {
// Create a context with timeout for README fetching
fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Priority 1: Direct README URL from io.atcr.readme annotation
if readmeURL := annotations["io.atcr.readme"]; readmeURL != "" {
content, err := b.fetchRawReadme(fetchCtx, readmeURL)
if err != nil {
slog.Debug("Failed to fetch README from io.atcr.readme annotation", "url", readmeURL, "error", err)
} else if content != "" {
return content
}
}
// Priority 2: Derive README URL from org.opencontainers.image.source
if sourceURL := annotations["org.opencontainers.image.source"]; sourceURL != "" {
// Try main branch first, then master
for _, branch := range []string{"main", "master"} {
readmeURL := readme.DeriveReadmeURL(sourceURL, branch)
if readmeURL == "" {
continue
}
content, err := b.fetchRawReadme(fetchCtx, readmeURL)
if err != nil {
// Only log non-404 errors (404 is expected when trying main vs master)
if !readme.Is404(err) {
slog.Debug("Failed to fetch README from source URL", "url", readmeURL, "branch", branch, "error", err)
}
continue
}
if content != "" {
return content
}
}
}
return ""
}
// fetchRawReadme fetches raw markdown content from a URL
func (b *BackfillWorker) fetchRawReadme(ctx context.Context, readmeURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", readmeURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", "ATCR-Backfill-README-Fetcher/1.0")
client := &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("too many redirects")
}
return nil
},
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to fetch URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("status %d", resp.StatusCode)
}
// Limit content size to 100KB
limitedReader := io.LimitReader(resp.Body, 100*1024)
content, err := io.ReadAll(limitedReader)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
return string(content), nil
}
// updateRepoPageInPDS updates the repo page record in the user's PDS using OAuth
func (b *BackfillWorker) updateRepoPageInPDS(ctx context.Context, did, pdsEndpoint, repository, description, avatarCID string) error {
if b.refresher == nil {
return fmt.Errorf("no OAuth refresher available")
}
// Create ATProto client with session provider
pdsClient := atproto.NewClientWithSessionProvider(pdsEndpoint, did, b.refresher)
// Get existing repo page record to preserve other fields
existingRecord, err := pdsClient.GetRecord(ctx, atproto.RepoPageCollection, repository)
var createdAt time.Time
var avatarRef *atproto.ATProtoBlobRef
if err == nil && existingRecord != nil {
// Parse existing record
var existingPage atproto.RepoPageRecord
if err := json.Unmarshal(existingRecord.Value, &existingPage); err == nil {
createdAt = existingPage.CreatedAt
avatarRef = existingPage.Avatar
}
}
if createdAt.IsZero() {
createdAt = time.Now()
}
// Create updated repo page record
repoPage := &atproto.RepoPageRecord{
Type: atproto.RepoPageCollection,
Repository: repository,
Description: description,
Avatar: avatarRef,
CreatedAt: createdAt,
UpdatedAt: time.Now(),
}
// Write to PDS - this will use DoWithSession internally
_, err = pdsClient.PutRecord(ctx, atproto.RepoPageCollection, repository, repoPage)
if err != nil {
return fmt.Errorf("failed to write to PDS: %w", err)
}
return nil
}

View File

@@ -100,7 +100,7 @@ func (p *Processor) EnsureUser(ctx context.Context, did string) error {
// Returns the manifest ID for further processing (layers/references)
func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData []byte) (int64, error) {
// Unmarshal manifest record
var manifestRecord atproto.Manifest
var manifestRecord atproto.ManifestRecord
if err := json.Unmarshal(recordData, &manifestRecord); err != nil {
return 0, fmt.Errorf("failed to unmarshal manifest: %w", err)
}
@@ -110,19 +110,10 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
// Extract hold DID from manifest (with fallback for legacy manifests)
// New manifests use holdDid field (DID format)
// Old manifests use holdEndpoint field (URL format) - convert to DID
var holdDID string
if manifestRecord.HoldDid != nil && *manifestRecord.HoldDid != "" {
holdDID = *manifestRecord.HoldDid
} else if manifestRecord.HoldEndpoint != nil && *manifestRecord.HoldEndpoint != "" {
holdDID := manifestRecord.HoldDID
if holdDID == "" && manifestRecord.HoldEndpoint != "" {
// Legacy manifest - convert URL to DID
holdDID = atproto.ResolveHoldDIDFromURL(*manifestRecord.HoldEndpoint)
}
// Parse CreatedAt string to time.Time
createdAt, err := time.Parse(time.RFC3339, manifestRecord.CreatedAt)
if err != nil {
// Fall back to current time if parsing fails
createdAt = time.Now()
holdDID = atproto.ResolveHoldDIDFromURL(manifestRecord.HoldEndpoint)
}
// Prepare manifest for insertion (WITHOUT annotation fields)
@@ -131,9 +122,9 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
Repository: manifestRecord.Repository,
Digest: manifestRecord.Digest,
MediaType: manifestRecord.MediaType,
SchemaVersion: int(manifestRecord.SchemaVersion),
SchemaVersion: manifestRecord.SchemaVersion,
HoldEndpoint: holdDID,
CreatedAt: createdAt,
CreatedAt: manifestRecord.CreatedAt,
// Annotations removed - stored separately in repository_annotations table
}
@@ -163,11 +154,24 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
}
}
// Note: Repository annotations are currently disabled because the generated
// Manifest_Annotations type doesn't support arbitrary key-value pairs.
// The lexicon would need to use "unknown" type for annotations to support this.
// TODO: Re-enable once lexicon supports annotations as map[string]string
_ = manifestRecord.Annotations
// Update repository annotations ONLY if manifest has at least one non-empty annotation
if manifestRecord.Annotations != nil {
hasData := false
for _, value := range manifestRecord.Annotations {
if value != "" {
hasData = true
break
}
}
if hasData {
// Replace all annotations for this repository
err = db.UpsertRepositoryAnnotations(p.db, did, manifestRecord.Repository, manifestRecord.Annotations)
if err != nil {
return 0, fmt.Errorf("failed to upsert annotations: %w", err)
}
}
}
// Insert manifest references or layers
if isManifestList {
@@ -180,19 +184,18 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
if ref.Platform != nil {
platformArch = ref.Platform.Architecture
platformOS = ref.Platform.Os
if ref.Platform.Variant != nil {
platformVariant = *ref.Platform.Variant
}
if ref.Platform.OsVersion != nil {
platformOSVersion = *ref.Platform.OsVersion
}
platformOS = ref.Platform.OS
platformVariant = ref.Platform.Variant
platformOSVersion = ref.Platform.OSVersion
}
// Note: Attestation detection via annotations is currently disabled
// because the generated Manifest_ManifestReference_Annotations type
// doesn't support arbitrary key-value pairs.
// Detect attestation manifests from annotations
isAttestation := false
if ref.Annotations != nil {
if refType, ok := ref.Annotations["vnd.docker.reference.type"]; ok {
isAttestation = refType == "attestation-manifest"
}
}
if err := db.InsertManifestReference(p.db, &db.ManifestReference{
ManifestID: manifestID,
@@ -232,7 +235,7 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
// ProcessTag processes a tag record and stores it in the database
func (p *Processor) ProcessTag(ctx context.Context, did string, recordData []byte) error {
// Unmarshal tag record
var tagRecord atproto.Tag
var tagRecord atproto.TagRecord
if err := json.Unmarshal(recordData, &tagRecord); err != nil {
return fmt.Errorf("failed to unmarshal tag: %w", err)
}
@@ -242,27 +245,20 @@ func (p *Processor) ProcessTag(ctx context.Context, did string, recordData []byt
return fmt.Errorf("failed to get manifest digest from tag record: %w", err)
}
// Parse CreatedAt string to time.Time
tagCreatedAt, err := time.Parse(time.RFC3339, tagRecord.CreatedAt)
if err != nil {
// Fall back to current time if parsing fails
tagCreatedAt = time.Now()
}
// Insert or update tag
return db.UpsertTag(p.db, &db.Tag{
DID: did,
Repository: tagRecord.Repository,
Tag: tagRecord.Tag,
Digest: manifestDigest,
CreatedAt: tagCreatedAt,
CreatedAt: tagRecord.UpdatedAt,
})
}
// ProcessStar processes a star record and stores it in the database
func (p *Processor) ProcessStar(ctx context.Context, did string, recordData []byte) error {
// Unmarshal star record
var starRecord atproto.SailorStar
var starRecord atproto.StarRecord
if err := json.Unmarshal(recordData, &starRecord); err != nil {
return fmt.Errorf("failed to unmarshal star: %w", err)
}
@@ -270,33 +266,27 @@ func (p *Processor) ProcessStar(ctx context.Context, did string, recordData []by
// The DID here is the starrer (user who starred)
// The subject contains the owner DID and repository
// Star count will be calculated on demand from the stars table
// Parse the CreatedAt string to time.Time
createdAt, err := time.Parse(time.RFC3339, starRecord.CreatedAt)
if err != nil {
// Fall back to current time if parsing fails
createdAt = time.Now()
}
return db.UpsertStar(p.db, did, starRecord.Subject.Did, starRecord.Subject.Repository, createdAt)
return db.UpsertStar(p.db, did, starRecord.Subject.DID, starRecord.Subject.Repository, starRecord.CreatedAt)
}
// ProcessSailorProfile processes a sailor profile record
// This is primarily used by backfill to cache captain records for holds
func (p *Processor) ProcessSailorProfile(ctx context.Context, did string, recordData []byte, queryCaptainFn func(context.Context, string) error) error {
// Unmarshal sailor profile record
var profileRecord atproto.SailorProfile
var profileRecord atproto.SailorProfileRecord
if err := json.Unmarshal(recordData, &profileRecord); err != nil {
return fmt.Errorf("failed to unmarshal sailor profile: %w", err)
}
// Skip if no default hold set
if profileRecord.DefaultHold == nil || *profileRecord.DefaultHold == "" {
if profileRecord.DefaultHold == "" {
return nil
}
// Convert hold URL/DID to canonical DID
holdDID := atproto.ResolveHoldDIDFromURL(*profileRecord.DefaultHold)
holdDID := atproto.ResolveHoldDIDFromURL(profileRecord.DefaultHold)
if holdDID == "" {
slog.Warn("Invalid hold reference in profile", "component", "processor", "did", did, "default_hold", *profileRecord.DefaultHold)
slog.Warn("Invalid hold reference in profile", "component", "processor", "did", did, "default_hold", profileRecord.DefaultHold)
return nil
}
@@ -309,6 +299,30 @@ func (p *Processor) ProcessSailorProfile(ctx context.Context, did string, record
return nil
}
// ProcessRepoPage processes a repository page record
// This is called when Jetstream receives a repo page create/update event
func (p *Processor) ProcessRepoPage(ctx context.Context, did string, rkey string, recordData []byte, isDelete bool) error {
if isDelete {
// Delete the repo page from our cache
return db.DeleteRepoPage(p.db, did, rkey)
}
// Unmarshal repo page record
var pageRecord atproto.RepoPageRecord
if err := json.Unmarshal(recordData, &pageRecord); err != nil {
return fmt.Errorf("failed to unmarshal repo page: %w", err)
}
// Extract avatar CID if present
avatarCID := ""
if pageRecord.Avatar != nil && pageRecord.Avatar.Ref.Link != "" {
avatarCID = pageRecord.Avatar.Ref.Link
}
// Upsert to database
return db.UpsertRepoPage(p.db, did, pageRecord.Repository, pageRecord.Description, avatarCID, pageRecord.CreatedAt, pageRecord.UpdatedAt)
}
// ProcessIdentity handles identity change events (handle updates)
// This is called when Jetstream receives an identity event indicating a handle change.
// The identity cache is invalidated to ensure the next lookup uses the new handle,

View File

@@ -11,11 +11,6 @@ import (
_ "github.com/mattn/go-sqlite3"
)
// ptrString returns a pointer to the given string
func ptrString(s string) *string {
return &s
}
// setupTestDB creates an in-memory SQLite database for testing
func setupTestDB(t *testing.T) *sql.DB {
database, err := sql.Open("sqlite3", ":memory:")
@@ -148,22 +143,28 @@ func TestProcessManifest_ImageManifest(t *testing.T) {
ctx := context.Background()
// Create test manifest record
manifestRecord := &atproto.Manifest{
manifestRecord := &atproto.ManifestRecord{
Repository: "test-app",
Digest: "sha256:abc123",
MediaType: "application/vnd.oci.image.manifest.v1+json",
SchemaVersion: 2,
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
CreatedAt: time.Now().Format(time.RFC3339),
Config: &atproto.Manifest_BlobReference{
HoldEndpoint: "did:web:hold01.atcr.io",
CreatedAt: time.Now(),
Config: &atproto.BlobReference{
Digest: "sha256:config123",
Size: 1234,
},
Layers: []atproto.Manifest_BlobReference{
Layers: []atproto.BlobReference{
{Digest: "sha256:layer1", Size: 5000, MediaType: "application/vnd.oci.image.layer.v1.tar+gzip"},
{Digest: "sha256:layer2", Size: 3000, MediaType: "application/vnd.oci.image.layer.v1.tar+gzip"},
},
// Annotations disabled - generated Manifest_Annotations is empty struct
Annotations: map[string]string{
"org.opencontainers.image.title": "Test App",
"org.opencontainers.image.description": "A test application",
"org.opencontainers.image.source": "https://github.com/test/app",
"org.opencontainers.image.licenses": "MIT",
"io.atcr.icon": "https://example.com/icon.png",
},
}
// Marshal to bytes for ProcessManifest
@@ -192,8 +193,25 @@ func TestProcessManifest_ImageManifest(t *testing.T) {
t.Errorf("Expected 1 manifest, got %d", count)
}
// Note: Annotations verification disabled - generated Manifest_Annotations is empty struct
// TODO: Re-enable when lexicon uses "unknown" type for annotations
// Verify annotations were stored in repository_annotations table
var title, source string
err = database.QueryRow("SELECT value FROM repository_annotations WHERE did = ? AND repository = ? AND key = ?",
"did:plc:test123", "test-app", "org.opencontainers.image.title").Scan(&title)
if err != nil {
t.Fatalf("Failed to query title annotation: %v", err)
}
if title != "Test App" {
t.Errorf("title = %q, want %q", title, "Test App")
}
err = database.QueryRow("SELECT value FROM repository_annotations WHERE did = ? AND repository = ? AND key = ?",
"did:plc:test123", "test-app", "org.opencontainers.image.source").Scan(&source)
if err != nil {
t.Fatalf("Failed to query source annotation: %v", err)
}
if source != "https://github.com/test/app" {
t.Errorf("source = %q, want %q", source, "https://github.com/test/app")
}
// Verify layers were inserted
var layerCount int
@@ -224,31 +242,31 @@ func TestProcessManifest_ManifestList(t *testing.T) {
ctx := context.Background()
// Create test manifest list record
manifestRecord := &atproto.Manifest{
manifestRecord := &atproto.ManifestRecord{
Repository: "test-app",
Digest: "sha256:list123",
MediaType: "application/vnd.oci.image.index.v1+json",
SchemaVersion: 2,
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
CreatedAt: time.Now().Format(time.RFC3339),
Manifests: []atproto.Manifest_ManifestReference{
HoldEndpoint: "did:web:hold01.atcr.io",
CreatedAt: time.Now(),
Manifests: []atproto.ManifestReference{
{
Digest: "sha256:amd64manifest",
MediaType: "application/vnd.oci.image.manifest.v1+json",
Size: 1000,
Platform: &atproto.Manifest_Platform{
Platform: &atproto.Platform{
Architecture: "amd64",
Os: "linux",
OS: "linux",
},
},
{
Digest: "sha256:arm64manifest",
MediaType: "application/vnd.oci.image.manifest.v1+json",
Size: 1100,
Platform: &atproto.Manifest_Platform{
Platform: &atproto.Platform{
Architecture: "arm64",
Os: "linux",
Variant: ptrString("v8"),
OS: "linux",
Variant: "v8",
},
},
},
@@ -308,11 +326,11 @@ func TestProcessTag(t *testing.T) {
ctx := context.Background()
// Create test tag record (using ManifestDigest field for simplicity)
tagRecord := &atproto.Tag{
tagRecord := &atproto.TagRecord{
Repository: "test-app",
Tag: "latest",
ManifestDigest: ptrString("sha256:abc123"),
CreatedAt: time.Now().Format(time.RFC3339),
ManifestDigest: "sha256:abc123",
UpdatedAt: time.Now(),
}
// Marshal to bytes for ProcessTag
@@ -350,7 +368,7 @@ func TestProcessTag(t *testing.T) {
}
// Test upserting same tag with new digest
tagRecord.ManifestDigest = ptrString("sha256:newdigest")
tagRecord.ManifestDigest = "sha256:newdigest"
recordBytes, err = json.Marshal(tagRecord)
if err != nil {
t.Fatalf("Failed to marshal tag: %v", err)
@@ -389,12 +407,12 @@ func TestProcessStar(t *testing.T) {
ctx := context.Background()
// Create test star record
starRecord := &atproto.SailorStar{
Subject: atproto.SailorStar_Subject{
Did: "did:plc:owner123",
starRecord := &atproto.StarRecord{
Subject: atproto.StarSubject{
DID: "did:plc:owner123",
Repository: "test-app",
},
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAt: time.Now(),
}
// Marshal to bytes for ProcessStar
@@ -448,13 +466,13 @@ func TestProcessManifest_Duplicate(t *testing.T) {
p := NewProcessor(database, false)
ctx := context.Background()
manifestRecord := &atproto.Manifest{
manifestRecord := &atproto.ManifestRecord{
Repository: "test-app",
Digest: "sha256:abc123",
MediaType: "application/vnd.oci.image.manifest.v1+json",
SchemaVersion: 2,
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
CreatedAt: time.Now().Format(time.RFC3339),
HoldEndpoint: "did:web:hold01.atcr.io",
CreatedAt: time.Now(),
}
// Marshal to bytes for ProcessManifest
@@ -500,13 +518,13 @@ func TestProcessManifest_EmptyAnnotations(t *testing.T) {
ctx := context.Background()
// Manifest with nil annotations
manifestRecord := &atproto.Manifest{
manifestRecord := &atproto.ManifestRecord{
Repository: "test-app",
Digest: "sha256:abc123",
MediaType: "application/vnd.oci.image.manifest.v1+json",
SchemaVersion: 2,
HoldEndpoint: ptrString("did:web:hold01.atcr.io"),
CreatedAt: time.Now().Format(time.RFC3339),
HoldEndpoint: "did:web:hold01.atcr.io",
CreatedAt: time.Now(),
Annotations: nil,
}

View File

@@ -61,9 +61,7 @@ func NewWorker(database *sql.DB, jetstreamURL string, startCursor int64) *Worker
jetstreamURL: jetstreamURL,
startCursor: startCursor,
wantedCollections: []string{
atproto.ManifestCollection, // io.atcr.manifest
atproto.TagCollection, // io.atcr.tag
atproto.StarCollection, // io.atcr.sailor.star
"io.atcr.*", // Subscribe to all ATCR collections
},
processor: NewProcessor(database, true), // Use cache for live streaming
}
@@ -312,6 +310,9 @@ func (w *Worker) processMessage(message []byte) error {
case atproto.StarCollection:
slog.Info("Jetstream processing star event", "did", commit.DID, "operation", commit.Operation, "rkey", commit.RKey)
return w.processStar(commit)
case atproto.RepoPageCollection:
slog.Info("Jetstream processing repo page event", "did", commit.DID, "operation", commit.Operation, "rkey", commit.RKey)
return w.processRepoPage(commit)
default:
// Ignore other collections
return nil
@@ -436,6 +437,41 @@ func (w *Worker) processStar(commit *CommitEvent) error {
return w.processor.ProcessStar(context.Background(), commit.DID, recordBytes)
}
// processRepoPage processes a repo page commit event
func (w *Worker) processRepoPage(commit *CommitEvent) error {
// Resolve and upsert user with handle/PDS endpoint
if err := w.processor.EnsureUser(context.Background(), commit.DID); err != nil {
return fmt.Errorf("failed to ensure user: %w", err)
}
isDelete := commit.Operation == "delete"
if isDelete {
// Delete - rkey is the repository name
slog.Info("Jetstream deleting repo page", "did", commit.DID, "repository", commit.RKey)
if err := w.processor.ProcessRepoPage(context.Background(), commit.DID, commit.RKey, nil, true); err != nil {
slog.Error("Jetstream ERROR deleting repo page", "error", err)
return err
}
slog.Info("Jetstream successfully deleted repo page", "did", commit.DID, "repository", commit.RKey)
return nil
}
// Parse repo page record
if commit.Record == nil {
return nil
}
// Marshal map to bytes for processing
recordBytes, err := json.Marshal(commit.Record)
if err != nil {
return fmt.Errorf("failed to marshal record: %w", err)
}
// Use shared processor for DB operations
return w.processor.ProcessRepoPage(context.Background(), commit.DID, commit.RKey, recordBytes, false)
}
// processIdentity processes an identity event (handle change)
func (w *Worker) processIdentity(event *JetstreamEvent) error {
if event.Identity == nil {

View File

@@ -11,14 +11,32 @@ import (
"net/url"
"atcr.io/pkg/appview/db"
"atcr.io/pkg/auth"
"atcr.io/pkg/auth/oauth"
)
type contextKey string
const userKey contextKey = "user"
// WebAuthDeps contains dependencies for web auth middleware
type WebAuthDeps struct {
SessionStore *db.SessionStore
Database *sql.DB
Refresher *oauth.Refresher
DefaultHoldDID string
}
// RequireAuth is middleware that requires authentication
func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
return RequireAuthWithDeps(WebAuthDeps{
SessionStore: store,
Database: database,
})
}
// RequireAuthWithDeps is middleware that requires authentication and creates UserContext
func RequireAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionID, ok := getSessionID(r)
@@ -32,7 +50,7 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
return
}
sess, ok := store.Get(sessionID)
sess, ok := deps.SessionStore.Get(sessionID)
if !ok {
// Build return URL with query parameters preserved
returnTo := r.URL.Path
@@ -44,7 +62,7 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
}
// Look up full user from database to get avatar
user, err := db.GetUserByDID(database, sess.DID)
user, err := db.GetUserByDID(deps.Database, sess.DID)
if err != nil || user == nil {
// Fallback to session data if DB lookup fails
user = &db.User{
@@ -54,7 +72,20 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
}
}
ctx := context.WithValue(r.Context(), userKey, user)
ctx := r.Context()
ctx = context.WithValue(ctx, userKey, user)
// Create UserContext for authenticated users (enables EnsureUserSetup)
if deps.Refresher != nil {
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
Refresher: deps.Refresher,
DefaultHoldDID: deps.DefaultHoldDID,
})
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
userCtx.EnsureUserSetup()
ctx = auth.WithUserContext(ctx, userCtx)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
@@ -62,13 +93,21 @@ func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) ht
// OptionalAuth is middleware that optionally includes user if authenticated
func OptionalAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
return OptionalAuthWithDeps(WebAuthDeps{
SessionStore: store,
Database: database,
})
}
// OptionalAuthWithDeps is middleware that optionally includes user and UserContext if authenticated
func OptionalAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionID, ok := getSessionID(r)
if ok {
if sess, ok := store.Get(sessionID); ok {
if sess, ok := deps.SessionStore.Get(sessionID); ok {
// Look up full user from database to get avatar
user, err := db.GetUserByDID(database, sess.DID)
user, err := db.GetUserByDID(deps.Database, sess.DID)
if err != nil || user == nil {
// Fallback to session data if DB lookup fails
user = &db.User{
@@ -77,7 +116,21 @@ func OptionalAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) h
PDSEndpoint: sess.PDSEndpoint,
}
}
ctx := context.WithValue(r.Context(), userKey, user)
ctx := r.Context()
ctx = context.WithValue(ctx, userKey, user)
// Create UserContext for authenticated users (enables EnsureUserSetup)
if deps.Refresher != nil {
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
Refresher: deps.Refresher,
DefaultHoldDID: deps.DefaultHoldDID,
})
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
userCtx.EnsureUserSetup()
ctx = auth.WithUserContext(ctx, userCtx)
}
r = r.WithContext(ctx)
}
}

View File

@@ -2,15 +2,13 @@ package middleware
import (
"context"
"database/sql"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/registry/api/errcode"
registrymw "github.com/distribution/distribution/v3/registry/middleware/registry"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference"
@@ -28,151 +26,16 @@ const holdDIDKey contextKey = "hold.did"
// authMethodKey is the context key for storing auth method from JWT
const authMethodKey contextKey = "auth.method"
// validationCacheEntry stores a validated service token with expiration
type validationCacheEntry struct {
serviceToken string
validUntil time.Time
err error // Cached error for fast-fail
mu sync.Mutex // Per-entry lock to serialize cache population
inFlight bool // True if another goroutine is fetching the token
done chan struct{} // Closed when fetch completes
}
// validationCache provides request-level caching for service tokens
// This prevents concurrent layer uploads from racing on OAuth/DPoP requests
type validationCache struct {
mu sync.RWMutex
entries map[string]*validationCacheEntry // key: "did:holdDID"
}
// newValidationCache creates a new validation cache
func newValidationCache() *validationCache {
return &validationCache{
entries: make(map[string]*validationCacheEntry),
}
}
// getOrFetch retrieves a service token from cache or fetches it
// Multiple concurrent requests for the same DID:holdDID will share the fetch operation
func (vc *validationCache) getOrFetch(ctx context.Context, cacheKey string, fetchFunc func() (string, error)) (string, error) {
// Fast path: check cache with read lock
vc.mu.RLock()
entry, exists := vc.entries[cacheKey]
vc.mu.RUnlock()
if exists {
// Entry exists, check if it's still valid
entry.mu.Lock()
// If another goroutine is fetching, wait for it
if entry.inFlight {
done := entry.done
entry.mu.Unlock()
select {
case <-done:
// Fetch completed, check result
entry.mu.Lock()
defer entry.mu.Unlock()
if entry.err != nil {
return "", entry.err
}
if time.Now().Before(entry.validUntil) {
return entry.serviceToken, nil
}
// Fall through to refetch
case <-ctx.Done():
return "", ctx.Err()
}
} else {
// Check if cached token is still valid
if entry.err != nil && time.Now().Before(entry.validUntil) {
// Return cached error (fast-fail)
entry.mu.Unlock()
return "", entry.err
}
if entry.err == nil && time.Now().Before(entry.validUntil) {
// Return cached token
token := entry.serviceToken
entry.mu.Unlock()
return token, nil
}
entry.mu.Unlock()
}
}
// Slow path: need to fetch token
vc.mu.Lock()
entry, exists = vc.entries[cacheKey]
if !exists {
// Create new entry
entry = &validationCacheEntry{
inFlight: true,
done: make(chan struct{}),
}
vc.entries[cacheKey] = entry
}
vc.mu.Unlock()
// Lock the entry to perform fetch
entry.mu.Lock()
// Double-check: another goroutine may have fetched while we waited
if !entry.inFlight {
if entry.err != nil && time.Now().Before(entry.validUntil) {
err := entry.err
entry.mu.Unlock()
return "", err
}
if entry.err == nil && time.Now().Before(entry.validUntil) {
token := entry.serviceToken
entry.mu.Unlock()
return token, nil
}
}
// Mark as in-flight and create fresh done channel for this fetch
// IMPORTANT: Always create a new channel - a closed channel is not nil
entry.done = make(chan struct{})
entry.inFlight = true
done := entry.done
entry.mu.Unlock()
// Perform the fetch (outside the lock to allow other operations)
serviceToken, err := fetchFunc()
// Update the entry with result
entry.mu.Lock()
entry.inFlight = false
if err != nil {
// Cache errors for 5 seconds (fast-fail for subsequent requests)
entry.err = err
entry.validUntil = time.Now().Add(5 * time.Second)
entry.serviceToken = ""
} else {
// Cache token for 45 seconds (covers typical Docker push operation)
entry.err = nil
entry.serviceToken = serviceToken
entry.validUntil = time.Now().Add(45 * time.Second)
}
// Signal completion to waiting goroutines
close(done)
entry.mu.Unlock()
return serviceToken, err
}
// pullerDIDKey is the context key for storing the authenticated user's DID from JWT
const pullerDIDKey contextKey = "puller.did"
// Global variables for initialization only
// These are set by main.go during startup and copied into NamespaceResolver instances.
// After initialization, request handling uses the NamespaceResolver's instance fields.
var (
globalRefresher *oauth.Refresher
globalDatabase storage.DatabaseMetrics
globalAuthorizer auth.HoldAuthorizer
globalReadmeCache storage.ReadmeCache
globalRefresher *oauth.Refresher
globalDatabase *sql.DB
globalAuthorizer auth.HoldAuthorizer
)
// SetGlobalRefresher sets the OAuth refresher instance during initialization
@@ -183,7 +46,7 @@ func SetGlobalRefresher(refresher *oauth.Refresher) {
// SetGlobalDatabase sets the database instance during initialization
// Must be called before the registry starts serving requests
func SetGlobalDatabase(database storage.DatabaseMetrics) {
func SetGlobalDatabase(database *sql.DB) {
globalDatabase = database
}
@@ -193,12 +56,6 @@ func SetGlobalAuthorizer(authorizer auth.HoldAuthorizer) {
globalAuthorizer = authorizer
}
// SetGlobalReadmeCache sets the readme cache instance during initialization
// Must be called before the registry starts serving requests
func SetGlobalReadmeCache(readmeCache storage.ReadmeCache) {
globalReadmeCache = readmeCache
}
func init() {
// Register the name resolution middleware
registrymw.Register("atproto-resolver", initATProtoResolver)
@@ -207,14 +64,12 @@ func init() {
// NamespaceResolver wraps a namespace and resolves names
type NamespaceResolver struct {
distribution.Namespace
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
testMode bool // If true, fallback to default hold when user's hold is unreachable
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
database storage.DatabaseMetrics // Metrics database (copied from global on init)
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
readmeCache storage.ReadmeCache // README cache (copied from global on init)
validationCache *validationCache // Request-level service token cache
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
testMode bool // If true, fallback to default hold when user's hold is unreachable
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
sqlDB *sql.DB // Database for hold DID lookup and metrics (copied from global on init)
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
}
// initATProtoResolver initializes the name resolution middleware
@@ -241,25 +96,16 @@ func initATProtoResolver(ctx context.Context, ns distribution.Namespace, _ drive
// Copy shared services from globals into the instance
// This avoids accessing globals during request handling
return &NamespaceResolver{
Namespace: ns,
defaultHoldDID: defaultHoldDID,
baseURL: baseURL,
testMode: testMode,
refresher: globalRefresher,
database: globalDatabase,
authorizer: globalAuthorizer,
readmeCache: globalReadmeCache,
validationCache: newValidationCache(),
Namespace: ns,
defaultHoldDID: defaultHoldDID,
baseURL: baseURL,
testMode: testMode,
refresher: globalRefresher,
sqlDB: globalDatabase,
authorizer: globalAuthorizer,
}, nil
}
// authErrorMessage creates a user-friendly auth error with login URL
func (nr *NamespaceResolver) authErrorMessage(message string) error {
loginURL := fmt.Sprintf("%s/auth/oauth/login", nr.baseURL)
fullMessage := fmt.Sprintf("%s - please re-authenticate at %s", message, loginURL)
return errcode.ErrorCodeUnauthorized.WithMessage(fullMessage)
}
// Repository resolves the repository name and delegates to underlying namespace
// Handles names like:
// - atcr.io/alice/myimage → resolve alice to DID
@@ -293,99 +139,8 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
}
ctx = context.WithValue(ctx, holdDIDKey, holdDID)
// Auto-reconcile crew membership on first push/pull
// This ensures users can push immediately after docker login without web sign-in
// EnsureCrewMembership is best-effort and logs errors without failing the request
// Run in background to avoid blocking registry operations if hold is offline
if holdDID != "" && nr.refresher != nil {
slog.Debug("Auto-reconciling crew membership", "component", "registry/middleware", "did", did, "hold_did", holdDID)
client := atproto.NewClient(pdsEndpoint, did, "")
go func(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
}(ctx, client, nr.refresher, holdDID)
}
// Get service token for hold authentication (only if authenticated)
// Use validation cache to prevent concurrent requests from racing on OAuth/DPoP
// Route based on auth method from JWT token
var serviceToken string
authMethod, _ := ctx.Value(authMethodKey).(string)
// Only fetch service token if user is authenticated
// Unauthenticated requests (like /v2/ ping) should not trigger token fetching
if authMethod != "" {
// Create cache key: "did:holdDID"
cacheKey := fmt.Sprintf("%s:%s", did, holdDID)
// Fetch service token through validation cache
// This ensures only ONE request per DID:holdDID pair fetches the token
// Concurrent requests will wait for the first request to complete
var fetchErr error
serviceToken, fetchErr = nr.validationCache.getOrFetch(ctx, cacheKey, func() (string, error) {
if authMethod == token.AuthMethodAppPassword {
// App-password flow: use Bearer token authentication
slog.Debug("Using app-password flow for service token",
"component", "registry/middleware",
"did", did,
"cacheKey", cacheKey)
token, err := token.GetOrFetchServiceTokenWithAppPassword(ctx, did, holdDID, pdsEndpoint)
if err != nil {
slog.Error("Failed to get service token with app-password",
"component", "registry/middleware",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"error", err)
return "", err
}
return token, nil
} else if nr.refresher != nil {
// OAuth flow: use DPoP authentication
slog.Debug("Using OAuth flow for service token",
"component", "registry/middleware",
"did", did,
"cacheKey", cacheKey)
token, err := token.GetOrFetchServiceToken(ctx, nr.refresher, did, holdDID, pdsEndpoint)
if err != nil {
slog.Error("Failed to get service token with OAuth",
"component", "registry/middleware",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"error", err)
return "", err
}
return token, nil
}
return "", fmt.Errorf("no authentication method available")
})
// Handle errors from cached fetch
if fetchErr != nil {
errMsg := fetchErr.Error()
// Check for app-password specific errors
if authMethod == token.AuthMethodAppPassword {
if strings.Contains(errMsg, "expired or invalid") || strings.Contains(errMsg, "no app-password") {
return nil, nr.authErrorMessage("App-password authentication failed. Please re-authenticate with: docker login")
}
}
// Check for OAuth specific errors
if strings.Contains(errMsg, "OAuth session") || strings.Contains(errMsg, "OAuth validation") {
return nil, nr.authErrorMessage("OAuth session expired or invalidated by PDS. Your session has been cleared")
}
// Generic service token error
return nil, nr.authErrorMessage(fmt.Sprintf("Failed to obtain storage credentials: %v", fetchErr))
}
} else {
slog.Debug("Skipping service token fetch for unauthenticated request",
"component", "registry/middleware",
"did", did)
}
// Note: Profile and crew membership are now ensured in UserContextMiddleware
// via EnsureUserSetup() - no need to call here
// Create a new reference with identity/image format
// Use the identity (or DID) as the namespace to ensure canonical format
@@ -402,74 +157,30 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
return nil, err
}
// Get access token for PDS operations
// Use auth method from JWT to determine client type:
// - OAuth users: use session provider (DPoP-enabled)
// - App-password users: use Basic Auth token cache
var atprotoClient *atproto.Client
if authMethod == token.AuthMethodOAuth && nr.refresher != nil {
// OAuth flow: use session provider for locked OAuth sessions
// This prevents DPoP nonce race conditions during concurrent layer uploads
slog.Debug("Creating ATProto client with OAuth session provider",
"component", "registry/middleware",
"did", did,
"authMethod", authMethod)
atprotoClient = atproto.NewClientWithSessionProvider(pdsEndpoint, did, nr.refresher)
} else {
// App-password flow (or fallback): use Basic Auth token cache
accessToken, ok := auth.GetGlobalTokenCache().Get(did)
if !ok {
slog.Debug("No cached access token found for app-password auth",
"component", "registry/middleware",
"did", did,
"authMethod", authMethod)
accessToken = "" // Will fail on manifest push, but let it try
} else {
slog.Debug("Creating ATProto client with app-password",
"component", "registry/middleware",
"did", did,
"authMethod", authMethod,
"token_length", len(accessToken))
}
atprotoClient = atproto.NewClient(pdsEndpoint, did, accessToken)
}
// IMPORTANT: Use only the image name (not identity/image) for ATProto storage
// ATProto records are scoped to the user's DID, so we don't need the identity prefix
// Example: "evan.jarrett.net/debian" -> store as "debian"
repositoryName := imageName
// Default auth method to OAuth if not already set (backward compatibility with old tokens)
if authMethod == "" {
authMethod = token.AuthMethodOAuth
// Get UserContext from request context (set by UserContextMiddleware)
userCtx := auth.FromContext(ctx)
if userCtx == nil {
return nil, fmt.Errorf("UserContext not set in request context - ensure UserContextMiddleware is configured")
}
// Set target repository info on UserContext
// ATProtoClient is cached lazily via userCtx.GetATProtoClient()
userCtx.SetTarget(did, handle, pdsEndpoint, repositoryName, holdDID)
// Create routing repository - routes manifests to ATProto, blobs to hold service
// The registry is stateless - no local storage is used
// Bundle all context into a single RegistryContext struct
//
// NOTE: We create a fresh RoutingRepository on every request (no caching) because:
// 1. Each layer upload is a separate HTTP request (possibly different process)
// 2. OAuth sessions can be refreshed/invalidated between requests
// 3. The refresher already caches sessions efficiently (in-memory + DB)
// 4. Caching the repository with a stale ATProtoClient causes refresh token errors
registryCtx := &storage.RegistryContext{
DID: did,
Handle: handle,
HoldDID: holdDID,
PDSEndpoint: pdsEndpoint,
Repository: repositoryName,
ServiceToken: serviceToken, // Cached service token from middleware validation
ATProtoClient: atprotoClient,
AuthMethod: authMethod, // Auth method from JWT token
Database: nr.database,
Authorizer: nr.authorizer,
Refresher: nr.refresher,
ReadmeCache: nr.readmeCache,
}
return storage.NewRoutingRepository(repo, registryCtx), nil
// 4. ATProtoClient is now cached in UserContext via GetATProtoClient()
return storage.NewRoutingRepository(repo, userCtx, nr.sqlDB), nil
}
// Repositories delegates to underlying namespace
@@ -490,8 +201,7 @@ func (nr *NamespaceResolver) BlobStatter() distribution.BlobStatter {
// findHoldDID determines which hold DID to use for blob storage
// Priority order:
// 1. User's sailor profile defaultHold (if set)
// 2. User's own hold record (io.atcr.hold)
// 3. AppView's default hold DID
// 2. AppView's default hold DID
// Returns a hold DID (e.g., "did:web:hold01.atcr.io"), or empty string if none configured
func (nr *NamespaceResolver) findHoldDID(ctx context.Context, did, pdsEndpoint string) string {
// Create ATProto client (without auth - reading public records)
@@ -504,22 +214,20 @@ func (nr *NamespaceResolver) findHoldDID(ctx context.Context, did, pdsEndpoint s
slog.Warn("Failed to read profile", "did", did, "error", err)
}
if profile != nil && profile.DefaultHold != nil && *profile.DefaultHold != "" {
defaultHold := *profile.DefaultHold
// Profile exists with defaultHold set
// In test mode, verify it's reachable before using it
if profile != nil && profile.DefaultHold != "" {
// In test mode, verify the hold is reachable (fall back to default if not)
// In production, trust the user's profile and return their hold
if nr.testMode {
if nr.isHoldReachable(ctx, defaultHold) {
return defaultHold
if nr.isHoldReachable(ctx, profile.DefaultHold) {
return profile.DefaultHold
}
slog.Debug("User's defaultHold unreachable, falling back to default", "component", "registry/middleware/testmode", "default_hold", defaultHold)
slog.Debug("User's defaultHold unreachable, falling back to default", "component", "registry/middleware/testmode", "default_hold", profile.DefaultHold)
return nr.defaultHoldDID
}
return defaultHold
return profile.DefaultHold
}
// Profile doesn't exist or defaultHold is null/empty
// Legacy io.atcr.hold records are no longer supported - use AppView default
// No profile defaultHold - use AppView default
return nr.defaultHoldDID
}
@@ -542,10 +250,17 @@ func (nr *NamespaceResolver) isHoldReachable(ctx context.Context, holdDID string
return false
}
// ExtractAuthMethod is an HTTP middleware that extracts the auth method from the JWT Authorization header
// and stores it in the request context for later use by the registry middleware
// ExtractAuthMethod is an HTTP middleware that extracts the auth method and puller DID from the JWT Authorization header
// and stores them in the request context for later use by the registry middleware.
// Also stores the HTTP method for routing decisions (GET/HEAD = pull, PUT/POST = push).
func ExtractAuthMethod(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Store HTTP method in context for routing decisions
// This is used by routing_repository.go to distinguish pull (GET/HEAD) from push (PUT/POST)
ctx = context.WithValue(ctx, "http.request.method", r.Method)
// Extract Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
@@ -558,15 +273,71 @@ func ExtractAuthMethod(next http.Handler) http.Handler {
authMethod := token.ExtractAuthMethod(tokenString)
if authMethod != "" {
// Store in context for registry middleware
ctx := context.WithValue(r.Context(), authMethodKey, authMethod)
r = r.WithContext(ctx)
slog.Debug("Extracted auth method from JWT",
"component", "registry/middleware",
"authMethod", authMethod)
ctx = context.WithValue(ctx, authMethodKey, authMethod)
}
// Extract puller DID (Subject) from JWT
// This is the authenticated user's DID, used for service token requests
pullerDID := token.ExtractSubject(tokenString)
if pullerDID != "" {
ctx = context.WithValue(ctx, pullerDIDKey, pullerDID)
}
slog.Debug("Extracted auth info from JWT",
"component", "registry/middleware",
"authMethod", authMethod,
"pullerDID", pullerDID,
"httpMethod", r.Method)
}
}
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
// UserContextMiddleware creates a UserContext from the extracted JWT claims
// and stores it in the request context for use throughout request processing.
// This middleware should be chained AFTER ExtractAuthMethod.
func UserContextMiddleware(deps *auth.Dependencies) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Get values set by ExtractAuthMethod
authMethod, _ := ctx.Value(authMethodKey).(string)
pullerDID, _ := ctx.Value(pullerDIDKey).(string)
// Build UserContext with all dependencies
userCtx := auth.NewUserContext(pullerDID, authMethod, r.Method, deps)
// Eagerly resolve user's PDS for authenticated users
// This is a fast path that avoids lazy loading in most cases
if userCtx.IsAuthenticated {
if err := userCtx.ResolvePDS(ctx); err != nil {
slog.Warn("Failed to resolve puller's PDS",
"component", "registry/middleware",
"did", pullerDID,
"error", err)
// Continue without PDS - will fail on service token request
}
// Ensure user has profile and crew membership (runs in background, cached)
userCtx.EnsureUserSetup()
}
// Store UserContext in request context
ctx = auth.WithUserContext(ctx, userCtx)
r = r.WithContext(ctx)
slog.Debug("Created UserContext",
"component", "registry/middleware",
"isAuthenticated", userCtx.IsAuthenticated,
"authMethod", userCtx.AuthMethod,
"action", userCtx.Action.String(),
"pullerDID", pullerDID)
next.ServeHTTP(w, r)
})
}
}

View File

@@ -67,11 +67,6 @@ func TestSetGlobalAuthorizer(t *testing.T) {
// If we get here without panic, test passes
}
func TestSetGlobalReadmeCache(t *testing.T) {
SetGlobalReadmeCache(nil)
// If we get here without panic, test passes
}
// TestInitATProtoResolver tests the initialization function
func TestInitATProtoResolver(t *testing.T) {
ctx := context.Background()
@@ -134,17 +129,6 @@ func TestInitATProtoResolver(t *testing.T) {
}
}
// TestAuthErrorMessage tests the error message formatting
func TestAuthErrorMessage(t *testing.T) {
resolver := &NamespaceResolver{
baseURL: "https://atcr.io",
}
err := resolver.authErrorMessage("OAuth session expired")
assert.Contains(t, err.Error(), "OAuth session expired")
assert.Contains(t, err.Error(), "https://atcr.io/auth/oauth/login")
}
// TestFindHoldDID_DefaultFallback tests default hold DID fallback
func TestFindHoldDID_DefaultFallback(t *testing.T) {
// Start a mock PDS server that returns 404 for profile and empty list for holds
@@ -204,34 +188,9 @@ func TestFindHoldDID_SailorProfile(t *testing.T) {
assert.Equal(t, "did:web:user.hold.io", holdDID, "should use sailor profile's defaultHold")
}
// TestFindHoldDID_NoProfile tests fallback to default hold when no profile exists
func TestFindHoldDID_NoProfile(t *testing.T) {
// Start a mock PDS server that returns 404 for profile
mockPDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/xrpc/com.atproto.repo.getRecord" {
// Profile not found
w.WriteHeader(http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer mockPDS.Close()
resolver := &NamespaceResolver{
defaultHoldDID: "did:web:default.atcr.io",
}
ctx := context.Background()
holdDID := resolver.findHoldDID(ctx, "did:plc:test123", mockPDS.URL)
// Should fall back to default hold DID when no profile exists
// Note: Legacy io.atcr.hold records are no longer supported
assert.Equal(t, "did:web:default.atcr.io", holdDID, "should fall back to default hold DID")
}
// TestFindHoldDID_Priority tests that profile takes priority over default
// TestFindHoldDID_Priority tests the priority order
func TestFindHoldDID_Priority(t *testing.T) {
// Start a mock PDS server that returns profile
// Start a mock PDS server that returns both profile and hold records
mockPDS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/xrpc/com.atproto.repo.getRecord" {
// Return sailor profile with defaultHold (highest priority)

View File

@@ -1,111 +0,0 @@
// Package readme provides README fetching, rendering, and caching functionality
// for container repositories. It fetches markdown content from URLs, renders it
// to sanitized HTML using GitHub-flavored markdown, and caches the results in
// a database with configurable TTL.
package readme
import (
"context"
"database/sql"
"log/slog"
"time"
)
// Cache stores rendered README HTML in the database
type Cache struct {
db *sql.DB
fetcher *Fetcher
ttl time.Duration
}
// NewCache creates a new README cache
func NewCache(db *sql.DB, ttl time.Duration) *Cache {
if ttl == 0 {
ttl = 1 * time.Hour // Default TTL
}
return &Cache{
db: db,
fetcher: NewFetcher(),
ttl: ttl,
}
}
// Get retrieves a README from cache or fetches it
func (c *Cache) Get(ctx context.Context, readmeURL string) (string, error) {
// Try to get from cache
html, fetchedAt, err := c.getFromDB(readmeURL)
if err == nil {
// Check if cache is still valid
if time.Since(fetchedAt) < c.ttl {
return html, nil
}
}
// Cache miss or expired, fetch fresh content
html, err = c.fetcher.FetchAndRender(ctx, readmeURL)
if err != nil {
// If fetch fails but we have stale cache, return it
if html != "" {
return html, nil
}
return "", err
}
// Store in cache
if err := c.storeInDB(readmeURL, html); err != nil {
// Log error but don't fail - we have the content
slog.Warn("Failed to cache README", "error", err)
}
return html, nil
}
// getFromDB retrieves cached README from database
func (c *Cache) getFromDB(readmeURL string) (string, time.Time, error) {
var html string
var fetchedAt time.Time
err := c.db.QueryRow(`
SELECT html, fetched_at
FROM readme_cache
WHERE url = ?
`, readmeURL).Scan(&html, &fetchedAt)
if err != nil {
return "", time.Time{}, err
}
return html, fetchedAt, nil
}
// storeInDB stores rendered README in database
func (c *Cache) storeInDB(readmeURL, html string) error {
_, err := c.db.Exec(`
INSERT INTO readme_cache (url, html, fetched_at)
VALUES (?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
html = excluded.html,
fetched_at = excluded.fetched_at
`, readmeURL, html, time.Now())
return err
}
// Invalidate removes a README from the cache
func (c *Cache) Invalidate(readmeURL string) error {
_, err := c.db.Exec(`
DELETE FROM readme_cache
WHERE url = ?
`, readmeURL)
return err
}
// Cleanup removes expired entries from the cache
func (c *Cache) Cleanup() error {
cutoff := time.Now().Add(-c.ttl * 2) // Keep for 2x TTL
_, err := c.db.Exec(`
DELETE FROM readme_cache
WHERE fetched_at < ?
`, cutoff)
return err
}

View File

@@ -1,13 +0,0 @@
package readme
import "testing"
func TestCache_Struct(t *testing.T) {
// Simple struct test
cache := &Cache{}
if cache == nil {
t.Error("Expected non-nil cache")
}
}
// TODO: Add cache operation tests

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
@@ -180,6 +181,27 @@ func getBaseURL(u *url.URL) string {
return fmt.Sprintf("%s://%s%s", u.Scheme, u.Host, path)
}
// Is404 returns true if the error indicates a 404 Not Found response
func Is404(err error) bool {
return err != nil && strings.Contains(err.Error(), "unexpected status code: 404")
}
// RenderMarkdown renders a markdown string to sanitized HTML
// This is used for rendering repo page descriptions stored in the database
func (f *Fetcher) RenderMarkdown(content []byte) (string, error) {
// Render markdown to HTML (no base URL for repo page descriptions)
return f.renderMarkdown(content, "")
}
// Regex patterns for matching relative URLs that need rewriting
// These match src="..." or href="..." where the URL is relative (not absolute, not data:, not #anchor)
var (
// Match src="filename" where filename doesn't start with http://, https://, //, /, #, data:, or mailto:
relativeSrcPattern = regexp.MustCompile(`src="([^"/:][^"]*)"`)
// Match href="filename" where filename doesn't start with http://, https://, //, /, #, data:, or mailto:
relativeHrefPattern = regexp.MustCompile(`href="([^"/:][^"]*)"`)
)
// rewriteRelativeURLs converts relative URLs to absolute URLs
func rewriteRelativeURLs(html, baseURL string) string {
if baseURL == "" {
@@ -191,20 +213,51 @@ func rewriteRelativeURLs(html, baseURL string) string {
return html
}
// Simple string replacement for common patterns
// This is a basic implementation - for production, consider using an HTML parser
// Handle root-relative URLs (starting with /) first
// Must be done before bare relative URLs to avoid double-processing
if base.Scheme != "" && base.Host != "" {
root := fmt.Sprintf("%s://%s/", base.Scheme, base.Host)
// Replace src="/" and href="/" but not src="//" (protocol-relative URLs)
html = strings.ReplaceAll(html, `src="/`, fmt.Sprintf(`src="%s`, root))
html = strings.ReplaceAll(html, `href="/`, fmt.Sprintf(`href="%s`, root))
}
// Handle explicit relative paths (./something and ../something)
html = strings.ReplaceAll(html, `src="./`, fmt.Sprintf(`src="%s`, baseURL))
html = strings.ReplaceAll(html, `href="./`, fmt.Sprintf(`href="%s`, baseURL))
html = strings.ReplaceAll(html, `src="../`, fmt.Sprintf(`src="%s../`, baseURL))
html = strings.ReplaceAll(html, `href="../`, fmt.Sprintf(`href="%s../`, baseURL))
// Handle root-relative URLs (starting with /)
if base.Scheme != "" && base.Host != "" {
root := fmt.Sprintf("%s://%s/", base.Scheme, base.Host)
// Replace src="/" and href="/" but not src="//" (absolute URLs)
html = strings.ReplaceAll(html, `src="/`, fmt.Sprintf(`src="%s`, root))
html = strings.ReplaceAll(html, `href="/`, fmt.Sprintf(`href="%s`, root))
}
// Handle bare relative URLs (e.g., src="image.png" without ./ prefix)
// Skip URLs that are already absolute (start with http://, https://, or //)
// Skip anchors (#), data URLs (data:), and mailto links
html = relativeSrcPattern.ReplaceAllStringFunc(html, func(match string) string {
// Extract the URL from src="..."
url := match[5 : len(match)-1] // Remove 'src="' and '"'
// Skip if already processed or is a special URL type
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") ||
strings.HasPrefix(url, "//") || strings.HasPrefix(url, "#") ||
strings.HasPrefix(url, "data:") || strings.HasPrefix(url, "mailto:") {
return match
}
return fmt.Sprintf(`src="%s%s"`, baseURL, url)
})
html = relativeHrefPattern.ReplaceAllStringFunc(html, func(match string) string {
// Extract the URL from href="..."
url := match[6 : len(match)-1] // Remove 'href="' and '"'
// Skip if already processed or is a special URL type
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") ||
strings.HasPrefix(url, "//") || strings.HasPrefix(url, "#") ||
strings.HasPrefix(url, "data:") || strings.HasPrefix(url, "mailto:") {
return match
}
return fmt.Sprintf(`href="%s%s"`, baseURL, url)
})
return html
}

View File

@@ -145,6 +145,48 @@ func TestRewriteRelativeURLs(t *testing.T) {
baseURL: "https://example.com/docs/",
expected: `<img src="https://example.com//cdn.example.com/image.png">`,
},
{
name: "bare relative src (no ./ prefix)",
html: `<img src="image.png">`,
baseURL: "https://example.com/docs/",
expected: `<img src="https://example.com/docs/image.png">`,
},
{
name: "bare relative href (no ./ prefix)",
html: `<a href="page.html">link</a>`,
baseURL: "https://example.com/docs/",
expected: `<a href="https://example.com/docs/page.html">link</a>`,
},
{
name: "bare relative with path",
html: `<img src="images/logo.png">`,
baseURL: "https://example.com/docs/",
expected: `<img src="https://example.com/docs/images/logo.png">`,
},
{
name: "anchor links unchanged",
html: `<a href="#section">link</a>`,
baseURL: "https://example.com/docs/",
expected: `<a href="#section">link</a>`,
},
{
name: "data URLs unchanged",
html: `<img src="data:image/png;base64,abc123">`,
baseURL: "https://example.com/docs/",
expected: `<img src="data:image/png;base64,abc123">`,
},
{
name: "mailto links unchanged",
html: `<a href="mailto:test@example.com">email</a>`,
baseURL: "https://example.com/docs/",
expected: `<a href="mailto:test@example.com">email</a>`,
},
{
name: "mixed bare and prefixed relative URLs",
html: `<img src="slices_and_lucy.png"><a href="./other.md">link</a>`,
baseURL: "https://github.com/user/repo/blob/main/",
expected: `<img src="https://github.com/user/repo/blob/main/slices_and_lucy.png"><a href="https://github.com/user/repo/blob/main/other.md">link</a>`,
},
}
for _, tt := range tests {
@@ -157,4 +199,110 @@ func TestRewriteRelativeURLs(t *testing.T) {
}
}
func TestFetcher_RenderMarkdown(t *testing.T) {
fetcher := NewFetcher()
tests := []struct {
name string
content string
wantContain string
wantErr bool
}{
{
name: "simple paragraph",
content: "Hello, world!",
wantContain: "<p>Hello, world!</p>",
wantErr: false,
},
{
name: "heading",
content: "# My App",
wantContain: "<h1",
wantErr: false,
},
{
name: "bold text",
content: "This is **bold** text.",
wantContain: "<strong>bold</strong>",
wantErr: false,
},
{
name: "italic text",
content: "This is *italic* text.",
wantContain: "<em>italic</em>",
wantErr: false,
},
{
name: "code block",
content: "```\ncode here\n```",
wantContain: "<pre>",
wantErr: false,
},
{
name: "link",
content: "[Link text](https://example.com)",
wantContain: `href="https://example.com"`,
wantErr: false,
},
{
name: "image",
content: "![Alt text](https://example.com/image.png)",
wantContain: `src="https://example.com/image.png"`,
wantErr: false,
},
{
name: "unordered list",
content: "- Item 1\n- Item 2",
wantContain: "<ul>",
wantErr: false,
},
{
name: "ordered list",
content: "1. Item 1\n2. Item 2",
wantContain: "<ol>",
wantErr: false,
},
{
name: "empty content",
content: "",
wantContain: "",
wantErr: false,
},
{
name: "complex markdown",
content: "# Title\n\nA paragraph with **bold** and *italic* text.\n\n- List item 1\n- List item 2\n\n```go\nfunc main() {}\n```",
wantContain: "<h1",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
html, err := fetcher.RenderMarkdown([]byte(tt.content))
if (err != nil) != tt.wantErr {
t.Errorf("RenderMarkdown() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.wantContain != "" {
if !containsSubstring(html, tt.wantContain) {
t.Errorf("RenderMarkdown() = %q, want to contain %q", html, tt.wantContain)
}
}
})
}
}
func containsSubstring(s, substr string) bool {
return len(substr) == 0 || (len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstringHelper(s, substr)))
}
func containsSubstringHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// TODO: Add README fetching and caching tests

View File

@@ -0,0 +1,103 @@
package readme
import (
"fmt"
"net/url"
"strings"
)
// Platform represents a supported Git hosting platform
type Platform string
const (
PlatformGitHub Platform = "github"
PlatformGitLab Platform = "gitlab"
PlatformTangled Platform = "tangled"
)
// ParseSourceURL extracts platform, user, and repo from a source repository URL.
// Returns ok=false if the URL is not a recognized pattern.
func ParseSourceURL(sourceURL string) (platform Platform, user, repo string, ok bool) {
if sourceURL == "" {
return "", "", "", false
}
parsed, err := url.Parse(sourceURL)
if err != nil {
return "", "", "", false
}
// Normalize: remove trailing slash and .git suffix
path := strings.TrimSuffix(parsed.Path, "/")
path = strings.TrimSuffix(path, ".git")
path = strings.TrimPrefix(path, "/")
if path == "" {
return "", "", "", false
}
host := strings.ToLower(parsed.Host)
switch {
case host == "github.com":
// GitHub: github.com/{user}/{repo}
parts := strings.SplitN(path, "/", 3)
if len(parts) < 2 || parts[0] == "" || parts[1] == "" {
return "", "", "", false
}
return PlatformGitHub, parts[0], parts[1], true
case host == "gitlab.com":
// GitLab: gitlab.com/{user}/{repo} or gitlab.com/{group}/{subgroup}/{repo}
// For nested groups, user = everything except last part, repo = last part
lastSlash := strings.LastIndex(path, "/")
if lastSlash == -1 || lastSlash == 0 {
return "", "", "", false
}
user = path[:lastSlash]
repo = path[lastSlash+1:]
if user == "" || repo == "" {
return "", "", "", false
}
return PlatformGitLab, user, repo, true
case host == "tangled.org" || host == "tangled.sh":
// Tangled: tangled.org/{user}/{repo} or tangled.sh/@{user}/{repo} (legacy)
// Strip leading @ from user if present
path = strings.TrimPrefix(path, "@")
parts := strings.SplitN(path, "/", 3)
if len(parts) < 2 || parts[0] == "" || parts[1] == "" {
return "", "", "", false
}
return PlatformTangled, parts[0], parts[1], true
default:
return "", "", "", false
}
}
// DeriveReadmeURL converts a source repository URL to a raw README URL.
// Returns empty string if platform is not supported.
func DeriveReadmeURL(sourceURL, branch string) string {
platform, user, repo, ok := ParseSourceURL(sourceURL)
if !ok {
return ""
}
switch platform {
case PlatformGitHub:
// https://raw.githubusercontent.com/{user}/{repo}/refs/heads/{branch}/README.md
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/refs/heads/%s/README.md", user, repo, branch)
case PlatformGitLab:
// https://gitlab.com/{user}/{repo}/-/raw/{branch}/README.md
return fmt.Sprintf("https://gitlab.com/%s/%s/-/raw/%s/README.md", user, repo, branch)
case PlatformTangled:
// https://tangled.org/{user}/{repo}/raw/{branch}/README.md
return fmt.Sprintf("https://tangled.org/%s/%s/raw/%s/README.md", user, repo, branch)
default:
return ""
}
}

View File

@@ -0,0 +1,241 @@
package readme
import (
"testing"
)
func TestParseSourceURL(t *testing.T) {
tests := []struct {
name string
sourceURL string
wantPlatform Platform
wantUser string
wantRepo string
wantOK bool
}{
// GitHub
{
name: "github standard",
sourceURL: "https://github.com/bigmoves/quickslice",
wantPlatform: PlatformGitHub,
wantUser: "bigmoves",
wantRepo: "quickslice",
wantOK: true,
},
{
name: "github with .git suffix",
sourceURL: "https://github.com/user/repo.git",
wantPlatform: PlatformGitHub,
wantUser: "user",
wantRepo: "repo",
wantOK: true,
},
{
name: "github with trailing slash",
sourceURL: "https://github.com/user/repo/",
wantPlatform: PlatformGitHub,
wantUser: "user",
wantRepo: "repo",
wantOK: true,
},
{
name: "github with subpath (ignored)",
sourceURL: "https://github.com/user/repo/tree/main",
wantPlatform: PlatformGitHub,
wantUser: "user",
wantRepo: "repo",
wantOK: true,
},
{
name: "github user only",
sourceURL: "https://github.com/user",
wantOK: false,
},
// GitLab
{
name: "gitlab standard",
sourceURL: "https://gitlab.com/user/repo",
wantPlatform: PlatformGitLab,
wantUser: "user",
wantRepo: "repo",
wantOK: true,
},
{
name: "gitlab nested groups",
sourceURL: "https://gitlab.com/group/subgroup/repo",
wantPlatform: PlatformGitLab,
wantUser: "group/subgroup",
wantRepo: "repo",
wantOK: true,
},
{
name: "gitlab deep nested groups",
sourceURL: "https://gitlab.com/a/b/c/d/repo",
wantPlatform: PlatformGitLab,
wantUser: "a/b/c/d",
wantRepo: "repo",
wantOK: true,
},
{
name: "gitlab with .git suffix",
sourceURL: "https://gitlab.com/user/repo.git",
wantPlatform: PlatformGitLab,
wantUser: "user",
wantRepo: "repo",
wantOK: true,
},
// Tangled
{
name: "tangled standard",
sourceURL: "https://tangled.org/evan.jarrett.net/at-container-registry",
wantPlatform: PlatformTangled,
wantUser: "evan.jarrett.net",
wantRepo: "at-container-registry",
wantOK: true,
},
{
name: "tangled with legacy @ prefix",
sourceURL: "https://tangled.org/@evan.jarrett.net/at-container-registry",
wantPlatform: PlatformTangled,
wantUser: "evan.jarrett.net",
wantRepo: "at-container-registry",
wantOK: true,
},
{
name: "tangled.sh domain",
sourceURL: "https://tangled.sh/user/repo",
wantPlatform: PlatformTangled,
wantUser: "user",
wantRepo: "repo",
wantOK: true,
},
{
name: "tangled with trailing slash",
sourceURL: "https://tangled.org/user/repo/",
wantPlatform: PlatformTangled,
wantUser: "user",
wantRepo: "repo",
wantOK: true,
},
// Unsupported / Invalid
{
name: "unsupported platform",
sourceURL: "https://bitbucket.org/user/repo",
wantOK: false,
},
{
name: "empty url",
sourceURL: "",
wantOK: false,
},
{
name: "invalid url",
sourceURL: "not-a-url",
wantOK: false,
},
{
name: "just host",
sourceURL: "https://github.com",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
platform, user, repo, ok := ParseSourceURL(tt.sourceURL)
if ok != tt.wantOK {
t.Errorf("ParseSourceURL(%q) ok = %v, want %v", tt.sourceURL, ok, tt.wantOK)
return
}
if !tt.wantOK {
return
}
if platform != tt.wantPlatform {
t.Errorf("ParseSourceURL(%q) platform = %v, want %v", tt.sourceURL, platform, tt.wantPlatform)
}
if user != tt.wantUser {
t.Errorf("ParseSourceURL(%q) user = %q, want %q", tt.sourceURL, user, tt.wantUser)
}
if repo != tt.wantRepo {
t.Errorf("ParseSourceURL(%q) repo = %q, want %q", tt.sourceURL, repo, tt.wantRepo)
}
})
}
}
func TestDeriveReadmeURL(t *testing.T) {
tests := []struct {
name string
sourceURL string
branch string
want string
}{
// GitHub
{
name: "github main",
sourceURL: "https://github.com/bigmoves/quickslice",
branch: "main",
want: "https://raw.githubusercontent.com/bigmoves/quickslice/refs/heads/main/README.md",
},
{
name: "github master",
sourceURL: "https://github.com/user/repo",
branch: "master",
want: "https://raw.githubusercontent.com/user/repo/refs/heads/master/README.md",
},
// GitLab
{
name: "gitlab main",
sourceURL: "https://gitlab.com/user/repo",
branch: "main",
want: "https://gitlab.com/user/repo/-/raw/main/README.md",
},
{
name: "gitlab nested groups",
sourceURL: "https://gitlab.com/group/subgroup/repo",
branch: "main",
want: "https://gitlab.com/group/subgroup/repo/-/raw/main/README.md",
},
// Tangled
{
name: "tangled main",
sourceURL: "https://tangled.org/evan.jarrett.net/at-container-registry",
branch: "main",
want: "https://tangled.org/evan.jarrett.net/at-container-registry/raw/main/README.md",
},
{
name: "tangled legacy @ prefix",
sourceURL: "https://tangled.org/@user/repo",
branch: "main",
want: "https://tangled.org/user/repo/raw/main/README.md",
},
// Unsupported
{
name: "unsupported platform",
sourceURL: "https://bitbucket.org/user/repo",
branch: "main",
want: "",
},
{
name: "empty url",
sourceURL: "",
branch: "main",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DeriveReadmeURL(tt.sourceURL, tt.branch)
if got != tt.want {
t.Errorf("DeriveReadmeURL(%q, %q) = %q, want %q", tt.sourceURL, tt.branch, got, tt.want)
}
})
}
}

View File

@@ -27,8 +27,9 @@ type UIDependencies struct {
BaseURL string
DeviceStore *db.DeviceStore
HealthChecker *holdhealth.Checker
ReadmeCache *readme.Cache
ReadmeFetcher *readme.Fetcher
Templates *template.Template
DefaultHoldDID string // For UserContext creation
}
// RegisterUIRoutes registers all web UI and API routes on the provided router
@@ -36,6 +37,14 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
// Extract trimmed registry URL for templates
registryURL := trimRegistryURL(deps.BaseURL)
// Create web auth dependencies for middleware (enables UserContext in web routes)
webAuthDeps := middleware.WebAuthDeps{
SessionStore: deps.SessionStore,
Database: deps.Database,
Refresher: deps.Refresher,
DefaultHoldDID: deps.DefaultHoldDID,
}
// OAuth login routes (public)
router.Get("/auth/oauth/login", (&uihandlers.LoginHandler{
Templates: deps.Templates,
@@ -45,7 +54,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
// Public routes (with optional auth for navbar)
// SECURITY: Public pages use read-only DB
router.Get("/", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.HomeHandler{
DB: deps.ReadOnlyDB,
Templates: deps.Templates,
@@ -53,7 +62,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
},
).ServeHTTP)
router.Get("/api/recent-pushes", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/api/recent-pushes", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.RecentPushesHandler{
DB: deps.ReadOnlyDB,
Templates: deps.Templates,
@@ -63,7 +72,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
).ServeHTTP)
// SECURITY: Search uses read-only DB to prevent writes and limit access to sensitive tables
router.Get("/search", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/search", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.SearchHandler{
DB: deps.ReadOnlyDB,
Templates: deps.Templates,
@@ -71,7 +80,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
},
).ServeHTTP)
router.Get("/api/search-results", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/api/search-results", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.SearchResultsHandler{
DB: deps.ReadOnlyDB,
Templates: deps.Templates,
@@ -80,7 +89,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
).ServeHTTP)
// Install page (public)
router.Get("/install", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/install", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.InstallHandler{
Templates: deps.Templates,
RegistryURL: registryURL,
@@ -88,7 +97,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
).ServeHTTP)
// API route for repository stats (public, read-only)
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.GetStatsHandler{
DB: deps.ReadOnlyDB,
Directory: deps.OAuthClientApp.Dir,
@@ -96,7 +105,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
).ServeHTTP)
// API routes for stars (require authentication)
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuthWithDeps(webAuthDeps)(
&uihandlers.StarRepositoryHandler{
DB: deps.Database, // Needs write access
Directory: deps.OAuthClientApp.Dir,
@@ -104,7 +113,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
},
).ServeHTTP)
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuthWithDeps(webAuthDeps)(
&uihandlers.UnstarRepositoryHandler{
DB: deps.Database, // Needs write access
Directory: deps.OAuthClientApp.Dir,
@@ -112,7 +121,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
},
).ServeHTTP)
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.CheckStarHandler{
DB: deps.ReadOnlyDB, // Read-only check
Directory: deps.OAuthClientApp.Dir,
@@ -121,7 +130,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
).ServeHTTP)
// Manifest detail API endpoint
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.ManifestDetailHandler{
DB: deps.ReadOnlyDB,
Directory: deps.OAuthClientApp.Dir,
@@ -133,7 +142,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
HealthChecker: deps.HealthChecker,
}).ServeHTTP)
router.Get("/u/{handle}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/u/{handle}", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.UserPageHandler{
DB: deps.ReadOnlyDB,
Templates: deps.Templates,
@@ -152,7 +161,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
DB: deps.ReadOnlyDB,
}).ServeHTTP)
router.Get("/r/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
router.Get("/r/{handle}/{repository}", middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.RepositoryPageHandler{
DB: deps.ReadOnlyDB,
Templates: deps.Templates,
@@ -160,13 +169,13 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
Directory: deps.OAuthClientApp.Dir,
Refresher: deps.Refresher,
HealthChecker: deps.HealthChecker,
ReadmeCache: deps.ReadmeCache,
ReadmeFetcher: deps.ReadmeFetcher,
},
).ServeHTTP)
// Authenticated routes
router.Group(func(r chi.Router) {
r.Use(middleware.RequireAuth(deps.SessionStore, deps.Database))
r.Use(middleware.RequireAuthWithDeps(webAuthDeps))
r.Get("/settings", (&uihandlers.SettingsHandler{
Templates: deps.Templates,
@@ -188,6 +197,11 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
Refresher: deps.Refresher,
}).ServeHTTP)
r.Post("/api/images/{repository}/avatar", (&uihandlers.UploadAvatarHandler{
DB: deps.Database,
Refresher: deps.Refresher,
}).ServeHTTP)
// Device approval page (authenticated)
r.Get("/device", (&uihandlers.DeviceApprovalPageHandler{
Store: deps.DeviceStore,
@@ -219,6 +233,14 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
}
router.Get("/auth/logout", logoutHandler.ServeHTTP)
router.Post("/auth/logout", logoutHandler.ServeHTTP)
// Custom 404 handler
router.NotFound(middleware.OptionalAuthWithDeps(webAuthDeps)(
&uihandlers.NotFoundHandler{
Templates: deps.Templates,
RegistryURL: registryURL,
},
).ServeHTTP)
}
// CORSMiddleware returns a middleware that sets CORS headers for API endpoints

View File

@@ -38,6 +38,10 @@
--version-badge-text: #7b1fa2;
--version-badge-border: #ba68c8;
/* Attestation badge */
--attestation-badge-bg: #d1fae5;
--attestation-badge-text: #065f46;
/* Hero section colors */
--hero-bg-start: #f8f9fa;
--hero-bg-end: #e9ecef;
@@ -90,6 +94,10 @@
--version-badge-text: #ffffff;
--version-badge-border: #ba68c8;
/* Attestation badge */
--attestation-badge-bg: #065f46;
--attestation-badge-text: #6ee7b7;
/* Hero section colors */
--hero-bg-start: #2d2d2d;
--hero-bg-end: #1a1a1a;
@@ -109,7 +117,9 @@
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
font-family:
-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue",
Arial, sans-serif;
background: var(--bg);
color: var(--fg);
line-height: 1.6;
@@ -170,7 +180,7 @@ body {
}
.nav-links a:hover {
background:var(--secondary);
background: var(--secondary);
border-radius: 4px;
}
@@ -193,7 +203,7 @@ body {
}
.user-menu-btn:hover {
background:var(--secondary);
background: var(--secondary);
}
.user-avatar {
@@ -266,7 +276,7 @@ body {
position: absolute;
top: calc(100% + 0.5rem);
right: 0;
background:var(--bg);
background: var(--bg);
border: 1px solid var(--border);
border-radius: 8px;
box-shadow: var(--shadow-lg);
@@ -287,7 +297,7 @@ body {
color: var(--fg);
text-decoration: none;
border: none;
background:var(--bg);
background: var(--bg);
cursor: pointer;
transition: background 0.2s;
font-size: 0.95rem;
@@ -309,7 +319,10 @@ body {
}
/* Buttons */
button, .btn, .btn-primary, .btn-secondary {
button,
.btn,
.btn-primary,
.btn-secondary {
padding: 0.5rem 1rem;
background: var(--button-primary);
color: var(--btn-text);
@@ -322,7 +335,10 @@ button, .btn, .btn-primary, .btn-secondary {
transition: opacity 0.2s;
}
button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
button:hover,
.btn:hover,
.btn-primary:hover,
.btn-secondary:hover {
opacity: 0.9;
}
@@ -393,12 +409,13 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
}
/* Cards */
.push-card, .repository-card {
.push-card,
.repository-card {
border: 1px solid var(--border);
border-radius: 8px;
padding: 1rem;
margin-bottom: 1rem;
background:var(--bg);
background: var(--bg);
box-shadow: var(--shadow-sm);
}
@@ -449,7 +466,7 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
}
.digest {
font-family: 'Monaco', 'Courier New', monospace;
font-family: "Monaco", "Courier New", monospace;
font-size: 0.85rem;
background: var(--code-bg);
padding: 0.1rem 0.3rem;
@@ -492,7 +509,7 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
}
.docker-command-text {
font-family: 'Monaco', 'Courier New', monospace;
font-family: "Monaco", "Courier New", monospace;
font-size: 0.85rem;
color: var(--fg);
flex: 0 1 auto;
@@ -510,7 +527,9 @@ button:hover, .btn:hover, .btn-primary:hover, .btn-secondary:hover {
border-radius: 4px;
opacity: 0;
visibility: hidden;
transition: opacity 0.2s, visibility 0.2s;
transition:
opacity 0.2s,
visibility 0.2s;
}
.docker-command:hover .copy-btn {
@@ -752,7 +771,7 @@ a.license-badge:hover {
}
.repo-stats {
color:var(--border-dark);
color: var(--border-dark);
font-size: 0.9rem;
display: flex;
gap: 0.5rem;
@@ -781,17 +800,20 @@ a.license-badge:hover {
padding-top: 1rem;
}
.tags-section, .manifests-section {
.tags-section,
.manifests-section {
margin-bottom: 1.5rem;
}
.tags-section h3, .manifests-section h3 {
.tags-section h3,
.manifests-section h3 {
font-size: 1.1rem;
margin-bottom: 0.5rem;
color: var(--secondary);
}
.tag-row, .manifest-row {
.tag-row,
.manifest-row {
display: flex;
gap: 1rem;
align-items: center;
@@ -799,7 +821,8 @@ a.license-badge:hover {
border-bottom: 1px solid var(--border);
}
.tag-row:last-child, .manifest-row:last-child {
.tag-row:last-child,
.manifest-row:last-child {
border-bottom: none;
}
@@ -821,7 +844,7 @@ a.license-badge:hover {
}
.settings-section {
background:var(--bg);
background: var(--bg);
border: 1px solid var(--border);
border-radius: 8px;
padding: 1.5rem;
@@ -918,7 +941,7 @@ a.license-badge:hover {
padding: 1rem;
border-radius: 4px;
overflow-x: auto;
font-family: 'Monaco', 'Courier New', monospace;
font-family: "Monaco", "Courier New", monospace;
font-size: 0.85rem;
border: 1px solid var(--border);
}
@@ -1004,13 +1027,6 @@ a.license-badge:hover {
margin: 1rem 0;
}
/* Load More Button */
.load-more {
width: 100%;
margin-top: 1rem;
background: var(--secondary);
}
/* Login Page */
.login-page {
max-width: 450px;
@@ -1031,7 +1047,7 @@ a.license-badge:hover {
}
.login-form {
background:var(--bg);
background: var(--bg);
padding: 2rem;
border-radius: 8px;
border: 1px solid var(--border);
@@ -1182,7 +1198,7 @@ a.license-badge:hover {
}
.repository-header {
background:var(--bg);
background: var(--bg);
border: 1px solid var(--border);
border-radius: 8px;
padding: 2rem;
@@ -1220,6 +1236,35 @@ a.license-badge:hover {
flex-shrink: 0;
}
.repo-hero-icon-wrapper {
position: relative;
display: inline-block;
flex-shrink: 0;
}
.avatar-upload-overlay {
position: absolute;
inset: 0;
display: flex;
align-items: center;
justify-content: center;
background: rgba(0, 0, 0, 0.5);
border-radius: 12px;
opacity: 0;
cursor: pointer;
transition: opacity 0.2s ease;
}
.avatar-upload-overlay i {
color: white;
width: 24px;
height: 24px;
}
.repo-hero-icon-wrapper:hover .avatar-upload-overlay {
opacity: 1;
}
.repo-hero-info {
flex: 1;
}
@@ -1290,7 +1335,7 @@ a.license-badge:hover {
}
.star-btn.starred {
border-color:var(--star);
border-color: var(--star);
background: var(--code-bg);
}
@@ -1374,7 +1419,7 @@ a.license-badge:hover {
}
.repo-section {
background:var(--bg);
background: var(--bg);
border: 1px solid var(--border);
border-radius: 8px;
padding: 1.5rem;
@@ -1389,20 +1434,23 @@ a.license-badge:hover {
border-bottom: 2px solid var(--border);
}
.tags-list, .manifests-list {
.tags-list,
.manifests-list {
display: flex;
flex-direction: column;
gap: 1rem;
}
.tag-item, .manifest-item {
.tag-item,
.manifest-item {
border: 1px solid var(--border);
border-radius: 6px;
padding: 1rem;
background: var(--hover-bg);
}
.tag-item-header, .manifest-item-header {
.tag-item-header,
.manifest-item-header {
display: flex;
justify-content: space-between;
align-items: center;
@@ -1532,7 +1580,7 @@ a.license-badge:hover {
color: var(--fg);
border: 1px solid var(--border);
white-space: nowrap;
font-family: 'Monaco', 'Courier New', monospace;
font-family: "Monaco", "Courier New", monospace;
}
.platforms-inline {
@@ -1570,20 +1618,21 @@ a.license-badge:hover {
.badge-attestation {
display: inline-flex;
align-items: center;
gap: 0.35rem;
padding: 0.25rem 0.5rem;
background: #f3e8ff;
color: #7c3aed;
border: 1px solid #c4b5fd;
border-radius: 4px;
font-size: 0.85rem;
gap: 0.3rem;
padding: 0.25rem 0.6rem;
background: var(--attestation-badge-bg);
color: var(--attestation-badge-text);
border-radius: 12px;
font-size: 0.75rem;
font-weight: 600;
margin-left: 0.5rem;
vertical-align: middle;
white-space: nowrap;
}
.badge-attestation .lucide {
width: 0.9rem;
height: 0.9rem;
width: 0.75rem;
height: 0.75rem;
}
/* Featured Repositories Section */
@@ -1736,7 +1785,11 @@ a.license-badge:hover {
/* Hero Section */
.hero-section {
background: linear-gradient(135deg, var(--hero-bg-start) 0%, var(--hero-bg-end) 100%);
background: linear-gradient(
135deg,
var(--hero-bg-start) 0%,
var(--hero-bg-end) 100%
);
padding: 4rem 2rem;
border-bottom: 1px solid var(--border);
}
@@ -1801,7 +1854,7 @@ a.license-badge:hover {
.terminal-content {
padding: 1.5rem;
margin: 0;
font-family: 'Monaco', 'Courier New', monospace;
font-family: "Monaco", "Courier New", monospace;
font-size: 0.95rem;
line-height: 1.8;
color: var(--terminal-text);
@@ -1957,7 +2010,7 @@ a.license-badge:hover {
}
.code-block code {
font-family: 'Monaco', 'Menlo', monospace;
font-family: "Monaco", "Menlo", monospace;
font-size: 0.9rem;
line-height: 1.5;
white-space: pre-wrap;
@@ -2014,7 +2067,8 @@ a.license-badge:hover {
flex-wrap: wrap;
}
.tag-row, .manifest-row {
.tag-row,
.manifest-row {
flex-wrap: wrap;
}
@@ -2103,7 +2157,7 @@ a.license-badge:hover {
/* README and Repository Layout */
.repo-content-layout {
display: grid;
grid-template-columns: 7fr 3fr;
grid-template-columns: 6fr 4fr;
gap: 2rem;
margin-top: 2rem;
}
@@ -2214,7 +2268,8 @@ a.license-badge:hover {
background: var(--code-bg);
padding: 0.2rem 0.4rem;
border-radius: 3px;
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace;
font-family:
"SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
font-size: 0.9em;
}
@@ -2318,3 +2373,59 @@ a.license-badge:hover {
padding: 0.75rem;
}
}
/* 404 Error Page */
.error-page {
display: flex;
align-items: center;
justify-content: center;
min-height: calc(100vh - 60px);
text-align: center;
padding: 2rem;
}
.error-content {
max-width: 480px;
}
.error-icon {
width: 80px;
height: 80px;
color: var(--secondary);
margin-bottom: 1.5rem;
}
.error-code {
font-size: 8rem;
font-weight: 700;
color: var(--primary);
line-height: 1;
margin-bottom: 0.5rem;
}
.error-content h1 {
font-size: 2rem;
margin-bottom: 0.75rem;
color: var(--fg);
}
.error-content p {
font-size: 1.125rem;
color: var(--secondary);
margin-bottom: 2rem;
}
@media (max-width: 768px) {
.error-code {
font-size: 5rem;
}
.error-icon {
width: 60px;
height: 60px;
}
.error-content h1 {
font-size: 1.5rem;
}
}

View File

@@ -434,6 +434,69 @@ function removeManifestElement(sanitizedId) {
}
}
// Upload repository avatar
async function uploadAvatar(input, repository) {
const file = input.files[0];
if (!file) return;
// Client-side validation
const validTypes = ['image/png', 'image/jpeg', 'image/webp'];
if (!validTypes.includes(file.type)) {
alert('Please select a PNG, JPEG, or WebP image');
return;
}
if (file.size > 3 * 1024 * 1024) {
alert('Image must be less than 3MB');
return;
}
const formData = new FormData();
formData.append('avatar', file);
try {
const response = await fetch(`/api/images/${repository}/avatar`, {
method: 'POST',
credentials: 'include',
body: formData
});
if (response.status === 401) {
window.location.href = '/auth/oauth/login';
return;
}
if (!response.ok) {
const error = await response.text();
throw new Error(error);
}
const data = await response.json();
// Update the avatar image on the page
const wrapper = document.querySelector('.repo-hero-icon-wrapper');
if (!wrapper) return;
const existingImg = wrapper.querySelector('.repo-hero-icon');
const placeholder = wrapper.querySelector('.repo-hero-icon-placeholder');
if (existingImg) {
existingImg.src = data.avatarURL;
} else if (placeholder) {
const newImg = document.createElement('img');
newImg.src = data.avatarURL;
newImg.alt = repository;
newImg.className = 'repo-hero-icon';
placeholder.replaceWith(newImg);
}
} catch (err) {
console.error('Error uploading avatar:', err);
alert('Failed to upload avatar: ' + err.message);
}
// Clear input so same file can be selected again
input.value = '';
}
// Close modal when clicking outside
document.addEventListener('DOMContentLoaded', () => {
const modal = document.getElementById('manifest-delete-modal');

View File

@@ -1,42 +0,0 @@
package storage
import (
"context"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth"
"atcr.io/pkg/auth/oauth"
)
// DatabaseMetrics interface for tracking pull/push counts and querying hold DIDs
type DatabaseMetrics interface {
IncrementPullCount(did, repository string) error
IncrementPushCount(did, repository string) error
GetLatestHoldDIDForRepo(did, repository string) (string, error)
}
// ReadmeCache interface for README content caching
type ReadmeCache interface {
Get(ctx context.Context, url string) (string, error)
Invalidate(url string) error
}
// RegistryContext bundles all the context needed for registry operations
// This includes both per-request data (DID, hold) and shared services
type RegistryContext struct {
// Per-request identity and routing information
DID string // User's DID (e.g., "did:plc:abc123")
Handle string // User's handle (e.g., "alice.bsky.social")
HoldDID string // Hold service DID (e.g., "did:web:hold01.atcr.io")
PDSEndpoint string // User's PDS endpoint URL
Repository string // Image repository name (e.g., "debian")
ServiceToken string // Service token for hold authentication (cached by middleware)
ATProtoClient *atproto.Client // Authenticated ATProto client for this user
AuthMethod string // Auth method used ("oauth" or "app_password")
// Shared services (same for all requests)
Database DatabaseMetrics // Metrics tracking database
Authorizer auth.HoldAuthorizer // Hold access authorization
Refresher *oauth.Refresher // OAuth session manager
ReadmeCache ReadmeCache // README content cache
}

View File

@@ -1,146 +0,0 @@
package storage
import (
"context"
"sync"
"testing"
"atcr.io/pkg/atproto"
)
// Mock implementations for testing
type mockDatabaseMetrics struct {
mu sync.Mutex
pullCount int
pushCount int
}
func (m *mockDatabaseMetrics) IncrementPullCount(did, repository string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.pullCount++
return nil
}
func (m *mockDatabaseMetrics) IncrementPushCount(did, repository string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.pushCount++
return nil
}
func (m *mockDatabaseMetrics) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
// Return empty string for mock - tests can override if needed
return "", nil
}
func (m *mockDatabaseMetrics) getPullCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.pullCount
}
func (m *mockDatabaseMetrics) getPushCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.pushCount
}
type mockReadmeCache struct{}
func (m *mockReadmeCache) Get(ctx context.Context, url string) (string, error) {
return "# Test README", nil
}
func (m *mockReadmeCache) Invalidate(url string) error {
return nil
}
type mockHoldAuthorizer struct{}
func (m *mockHoldAuthorizer) Authorize(holdDID, userDID, permission string) (bool, error) {
return true, nil
}
func TestRegistryContext_Fields(t *testing.T) {
// Create a sample RegistryContext
ctx := &RegistryContext{
DID: "did:plc:test123",
Handle: "alice.bsky.social",
HoldDID: "did:web:hold01.atcr.io",
PDSEndpoint: "https://bsky.social",
Repository: "debian",
ServiceToken: "test-token",
ATProtoClient: &atproto.Client{
// Mock client - would need proper initialization in real tests
},
Database: &mockDatabaseMetrics{},
ReadmeCache: &mockReadmeCache{},
}
// Verify fields are accessible
if ctx.DID != "did:plc:test123" {
t.Errorf("Expected DID %q, got %q", "did:plc:test123", ctx.DID)
}
if ctx.Handle != "alice.bsky.social" {
t.Errorf("Expected Handle %q, got %q", "alice.bsky.social", ctx.Handle)
}
if ctx.HoldDID != "did:web:hold01.atcr.io" {
t.Errorf("Expected HoldDID %q, got %q", "did:web:hold01.atcr.io", ctx.HoldDID)
}
if ctx.PDSEndpoint != "https://bsky.social" {
t.Errorf("Expected PDSEndpoint %q, got %q", "https://bsky.social", ctx.PDSEndpoint)
}
if ctx.Repository != "debian" {
t.Errorf("Expected Repository %q, got %q", "debian", ctx.Repository)
}
if ctx.ServiceToken != "test-token" {
t.Errorf("Expected ServiceToken %q, got %q", "test-token", ctx.ServiceToken)
}
}
func TestRegistryContext_DatabaseInterface(t *testing.T) {
db := &mockDatabaseMetrics{}
ctx := &RegistryContext{
Database: db,
}
// Test that interface methods are callable
err := ctx.Database.IncrementPullCount("did:plc:test", "repo")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
err = ctx.Database.IncrementPushCount("did:plc:test", "repo")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
func TestRegistryContext_ReadmeCacheInterface(t *testing.T) {
cache := &mockReadmeCache{}
ctx := &RegistryContext{
ReadmeCache: cache,
}
// Test that interface methods are callable
content, err := ctx.ReadmeCache.Get(context.Background(), "https://example.com/README.md")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if content != "# Test README" {
t.Errorf("Expected content %q, got %q", "# Test README", content)
}
err = ctx.ReadmeCache.Invalidate("https://example.com/README.md")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
// TODO: Add more comprehensive tests:
// - Test ATProtoClient integration
// - Test OAuth Refresher integration
// - Test HoldAuthorizer integration
// - Test nil handling for optional fields
// - Integration tests with real components

View File

@@ -1,93 +0,0 @@
package storage
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"time"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth/oauth"
"atcr.io/pkg/auth/token"
)
// EnsureCrewMembership attempts to register the user as a crew member on their default hold.
// The hold's requestCrew endpoint handles all authorization logic (checking allowAllCrew, existing membership, etc).
// This is best-effort and does not fail on errors.
func EnsureCrewMembership(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, defaultHoldDID string) {
if defaultHoldDID == "" {
return
}
// Normalize URL to DID if needed
holdDID := atproto.ResolveHoldDIDFromURL(defaultHoldDID)
if holdDID == "" {
slog.Warn("failed to resolve hold DID", "defaultHold", defaultHoldDID)
return
}
// Resolve hold DID to HTTP endpoint
holdEndpoint := atproto.ResolveHoldURL(holdDID)
// Get service token for the hold
// Only works with OAuth (refresher required) - app passwords can't get service tokens
if refresher == nil {
slog.Debug("skipping crew registration - no OAuth refresher (app password flow)", "holdDID", holdDID)
return
}
// Wrap the refresher to match OAuthSessionRefresher interface
serviceToken, err := token.GetOrFetchServiceToken(ctx, refresher, client.DID(), holdDID, client.PDSEndpoint())
if err != nil {
slog.Warn("failed to get service token", "holdDID", holdDID, "error", err)
return
}
// Call requestCrew endpoint - it handles all the logic:
// - Checks allowAllCrew flag
// - Checks if already a crew member (returns success if so)
// - Creates crew record if authorized
if err := requestCrewMembership(ctx, holdEndpoint, serviceToken); err != nil {
slog.Warn("failed to request crew membership", "holdDID", holdDID, "error", err)
return
}
slog.Info("successfully registered as crew member", "holdDID", holdDID, "userDID", client.DID())
}
// requestCrewMembership calls the hold's requestCrew endpoint
// The endpoint handles all authorization and duplicate checking internally
func requestCrewMembership(ctx context.Context, holdEndpoint, serviceToken string) error {
// Add 5 second timeout to prevent hanging on offline holds
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
url := fmt.Sprintf("%s%s", holdEndpoint, atproto.HoldRequestCrew)
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+serviceToken)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
// Read response body to capture actual error message from hold
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return fmt.Errorf("requestCrew failed with status %d (failed to read error body: %w)", resp.StatusCode, readErr)
}
return fmt.Errorf("requestCrew failed with status %d: %s", resp.StatusCode, string(body))
}
return nil
}

View File

@@ -1,14 +0,0 @@
package storage
import (
"context"
"testing"
)
func TestEnsureCrewMembership_EmptyHoldDID(t *testing.T) {
// Test that empty hold DID returns early without error (best-effort function)
EnsureCrewMembership(context.Background(), nil, nil, "")
// If we get here without panic, test passes
}
// TODO: Add comprehensive tests with HTTP client mocking

View File

@@ -3,6 +3,7 @@ package storage
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
@@ -10,9 +11,12 @@ import (
"log/slog"
"net/http"
"strings"
"sync"
"time"
"atcr.io/pkg/appview/db"
"atcr.io/pkg/appview/readme"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth"
"github.com/distribution/distribution/v3"
"github.com/opencontainers/go-digest"
)
@@ -20,24 +24,24 @@ import (
// ManifestStore implements distribution.ManifestService
// It stores manifests in ATProto as records
type ManifestStore struct {
ctx *RegistryContext // Context with user/hold info
mu sync.RWMutex // Protects lastFetchedHoldDID
lastFetchedHoldDID string // Hold DID from most recently fetched manifest (for pull)
ctx *auth.UserContext // User context with identity, target, permissions
blobStore distribution.BlobStore // Blob store for fetching config during push
sqlDB *sql.DB // Database for pull/push counts
}
// NewManifestStore creates a new ATProto-backed manifest store
func NewManifestStore(ctx *RegistryContext, blobStore distribution.BlobStore) *ManifestStore {
func NewManifestStore(userCtx *auth.UserContext, blobStore distribution.BlobStore, sqlDB *sql.DB) *ManifestStore {
return &ManifestStore{
ctx: ctx,
ctx: userCtx,
blobStore: blobStore,
sqlDB: sqlDB,
}
}
// Exists checks if a manifest exists by digest
func (s *ManifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
rkey := digestToRKey(dgst)
_, err := s.ctx.ATProtoClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
_, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.ManifestCollection, rkey)
if err != nil {
// If not found, return false without error
if errors.Is(err, atproto.ErrRecordNotFound) {
@@ -51,37 +55,24 @@ func (s *ManifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, e
// Get retrieves a manifest by digest
func (s *ManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error) {
rkey := digestToRKey(dgst)
record, err := s.ctx.ATProtoClient.GetRecord(ctx, atproto.ManifestCollection, rkey)
record, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.ManifestCollection, rkey)
if err != nil {
return nil, distribution.ErrManifestUnknownRevision{
Name: s.ctx.Repository,
Name: s.ctx.TargetRepo,
Revision: dgst,
}
}
var manifestRecord atproto.Manifest
var manifestRecord atproto.ManifestRecord
if err := json.Unmarshal(record.Value, &manifestRecord); err != nil {
return nil, fmt.Errorf("failed to unmarshal manifest record: %w", err)
}
// Store the hold DID for subsequent blob requests during pull
// Prefer HoldDid (new format) with fallback to HoldEndpoint (legacy URL format)
// The routing repository will cache this for concurrent blob fetches
s.mu.Lock()
if manifestRecord.HoldDid != nil && *manifestRecord.HoldDid != "" {
// New format: DID reference (preferred)
s.lastFetchedHoldDID = *manifestRecord.HoldDid
} else if manifestRecord.HoldEndpoint != nil && *manifestRecord.HoldEndpoint != "" {
// Legacy format: URL reference - convert to DID
s.lastFetchedHoldDID = atproto.ResolveHoldDIDFromURL(*manifestRecord.HoldEndpoint)
}
s.mu.Unlock()
var ociManifest []byte
// New records: Download blob from ATProto blob storage
if manifestRecord.ManifestBlob != nil && manifestRecord.ManifestBlob.Ref.Defined() {
ociManifest, err = s.ctx.ATProtoClient.GetBlob(ctx, manifestRecord.ManifestBlob.Ref.String())
if manifestRecord.ManifestBlob != nil && manifestRecord.ManifestBlob.Ref.Link != "" {
ociManifest, err = s.ctx.GetATProtoClient().GetBlob(ctx, manifestRecord.ManifestBlob.Ref.Link)
if err != nil {
return nil, fmt.Errorf("failed to download manifest blob: %w", err)
}
@@ -89,12 +80,12 @@ func (s *ManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...
// Track pull count (increment asynchronously to avoid blocking the response)
// Only count GET requests (actual downloads), not HEAD requests (existence checks)
if s.ctx.Database != nil {
if s.sqlDB != nil {
// Check HTTP method from context (distribution library stores it as "http.request.method")
if method, ok := ctx.Value("http.request.method").(string); ok && method == "GET" {
go func() {
if err := s.ctx.Database.IncrementPullCount(s.ctx.DID, s.ctx.Repository); err != nil {
slog.Warn("Failed to increment pull count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
if err := db.IncrementPullCount(s.sqlDB, s.ctx.TargetOwnerDID, s.ctx.TargetRepo); err != nil {
slog.Warn("Failed to increment pull count", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
}
}()
}
@@ -121,22 +112,20 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
dgst := digest.FromBytes(payload)
// Upload manifest as blob to PDS
blobRef, err := s.ctx.ATProtoClient.UploadBlob(ctx, payload, mediaType)
blobRef, err := s.ctx.GetATProtoClient().UploadBlob(ctx, payload, mediaType)
if err != nil {
return "", fmt.Errorf("failed to upload manifest blob: %w", err)
}
// Create manifest record with structured metadata
manifestRecord, err := atproto.NewManifestRecord(s.ctx.Repository, dgst.String(), payload)
manifestRecord, err := atproto.NewManifestRecord(s.ctx.TargetRepo, dgst.String(), payload)
if err != nil {
return "", fmt.Errorf("failed to create manifest record: %w", err)
}
// Set the blob reference, hold DID, and hold endpoint
manifestRecord.ManifestBlob = blobRef
if s.ctx.HoldDID != "" {
manifestRecord.HoldDid = &s.ctx.HoldDID // Primary reference (DID)
}
manifestRecord.HoldDID = s.ctx.TargetHoldDID // Primary reference (DID)
// Extract Dockerfile labels from config blob and add to annotations
// Only for image manifests (not manifest lists which don't have config blobs)
@@ -163,10 +152,10 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
if !exists {
platform := "unknown"
if ref.Platform != nil {
platform = fmt.Sprintf("%s/%s", ref.Platform.Os, ref.Platform.Architecture)
platform = fmt.Sprintf("%s/%s", ref.Platform.OS, ref.Platform.Architecture)
}
slog.Warn("Manifest list references non-existent child manifest",
"repository", s.ctx.Repository,
"repository", s.ctx.TargetRepo,
"missingDigest", ref.Digest,
"platform", platform)
return "", distribution.ErrManifestBlobUnknown{Digest: refDigest}
@@ -174,24 +163,43 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
}
}
// Note: Label extraction from config blob is currently disabled because the generated
// Manifest_Annotations type doesn't support arbitrary keys. The lexicon schema would
// need to use "unknown" type for annotations to support dynamic key-value pairs.
// TODO: Update lexicon schema if label extraction is needed.
_ = isManifestList // silence unused variable warning for now
if !isManifestList && s.blobStore != nil && manifestRecord.Config != nil && manifestRecord.Config.Digest != "" {
labels, err := s.extractConfigLabels(ctx, manifestRecord.Config.Digest)
if err != nil {
// Log error but don't fail the push - labels are optional
slog.Warn("Failed to extract config labels", "error", err)
} else if len(labels) > 0 {
// Initialize annotations map if needed
if manifestRecord.Annotations == nil {
manifestRecord.Annotations = make(map[string]string)
}
// Copy labels to annotations as fallback
// Only set label values for keys NOT already in manifest annotations
// This ensures explicit annotations take precedence over Dockerfile LABELs
// (which may be inherited from base images)
for key, value := range labels {
if _, exists := manifestRecord.Annotations[key]; !exists {
manifestRecord.Annotations[key] = value
}
}
slog.Debug("Merged labels from config blob", "labelsCount", len(labels), "annotationsCount", len(manifestRecord.Annotations))
}
}
// Store manifest record in ATProto
rkey := digestToRKey(dgst)
_, err = s.ctx.ATProtoClient.PutRecord(ctx, atproto.ManifestCollection, rkey, manifestRecord)
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.ManifestCollection, rkey, manifestRecord)
if err != nil {
return "", fmt.Errorf("failed to store manifest record in ATProto: %w", err)
}
// Track push count (increment asynchronously to avoid blocking the response)
if s.ctx.Database != nil {
if s.sqlDB != nil {
go func() {
if err := s.ctx.Database.IncrementPushCount(s.ctx.DID, s.ctx.Repository); err != nil {
slog.Warn("Failed to increment push count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err)
if err := db.IncrementPushCount(s.sqlDB, s.ctx.TargetOwnerDID, s.ctx.TargetRepo); err != nil {
slog.Warn("Failed to increment push count", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
}
}()
}
@@ -201,9 +209,9 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
for _, option := range options {
if tagOpt, ok := option.(distribution.WithTagOption); ok {
tag = tagOpt.Tag
tagRecord := atproto.NewTagRecord(s.ctx.ATProtoClient.DID(), s.ctx.Repository, tag, dgst.String())
tagRKey := atproto.RepositoryTagToRKey(s.ctx.Repository, tag)
_, err = s.ctx.ATProtoClient.PutRecord(ctx, atproto.TagCollection, tagRKey, tagRecord)
tagRecord := atproto.NewTagRecord(s.ctx.GetATProtoClient().DID(), s.ctx.TargetRepo, tag, dgst.String())
tagRKey := atproto.RepositoryTagToRKey(s.ctx.TargetRepo, tag)
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.TagCollection, tagRKey, tagRecord)
if err != nil {
return "", fmt.Errorf("failed to store tag in ATProto: %w", err)
}
@@ -212,28 +220,30 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
// Notify hold about manifest upload (for layer tracking and Bluesky posts)
// Do this asynchronously to avoid blocking the push
if tag != "" && s.ctx.ServiceToken != "" && s.ctx.Handle != "" {
go func() {
// Get service token before goroutine (requires context)
serviceToken, _ := s.ctx.GetServiceToken(ctx)
if tag != "" && serviceToken != "" && s.ctx.TargetOwnerHandle != "" {
go func(serviceToken string) {
defer func() {
if r := recover(); r != nil {
slog.Error("Panic in notifyHoldAboutManifest", "panic", r)
}
}()
if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String()); err != nil {
if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String(), serviceToken); err != nil {
slog.Warn("Failed to notify hold about manifest", "error", err)
}
}()
}(serviceToken)
}
// Refresh README cache asynchronously if manifest has io.atcr.readme annotation
// This ensures fresh README content is available on repository pages
// Create or update repo page asynchronously if manifest has relevant annotations
// This ensures repository metadata is synced to user's PDS
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("Panic in refreshReadmeCache", "panic", r)
slog.Error("Panic in ensureRepoPage", "panic", r)
}
}()
s.refreshReadmeCache(context.Background(), manifestRecord)
s.ensureRepoPage(context.Background(), manifestRecord)
}()
return dgst, nil
@@ -242,7 +252,7 @@ func (s *ManifestStore) Put(ctx context.Context, manifest distribution.Manifest,
// Delete removes a manifest
func (s *ManifestStore) Delete(ctx context.Context, dgst digest.Digest) error {
rkey := digestToRKey(dgst)
return s.ctx.ATProtoClient.DeleteRecord(ctx, atproto.ManifestCollection, rkey)
return s.ctx.GetATProtoClient().DeleteRecord(ctx, atproto.ManifestCollection, rkey)
}
// digestToRKey converts a digest to an ATProto record key
@@ -252,14 +262,6 @@ func digestToRKey(dgst digest.Digest) string {
return dgst.Encoded()
}
// GetLastFetchedHoldDID returns the hold DID from the most recently fetched manifest
// This is used by the routing repository to cache the hold for blob requests
func (s *ManifestStore) GetLastFetchedHoldDID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.lastFetchedHoldDID
}
// rawManifest is a simple implementation of distribution.Manifest
type rawManifest struct {
mediaType string
@@ -305,18 +307,17 @@ func (s *ManifestStore) extractConfigLabels(ctx context.Context, configDigestStr
// notifyHoldAboutManifest notifies the hold service about a manifest upload
// This enables the hold to create layer records and Bluesky posts
func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.Manifest, tag, manifestDigest string) error {
// Skip if no service token configured (e.g., anonymous pulls)
if s.ctx.ServiceToken == "" {
func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.ManifestRecord, tag, manifestDigest, serviceToken string) error {
// Skip if no service token provided
if serviceToken == "" {
return nil
}
// Resolve hold DID to HTTP endpoint
// For did:web, this is straightforward (e.g., did:web:hold01.atcr.io → https://hold01.atcr.io)
holdEndpoint := atproto.ResolveHoldURL(s.ctx.HoldDID)
holdEndpoint := atproto.ResolveHoldURL(s.ctx.TargetHoldDID)
// Use service token from middleware (already cached and validated)
serviceToken := s.ctx.ServiceToken
// Service token is passed in (already cached and validated)
// Build notification request
manifestData := map[string]any{
@@ -355,7 +356,7 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
}
if m.Platform != nil {
mData["platform"] = map[string]any{
"os": m.Platform.Os,
"os": m.Platform.OS,
"architecture": m.Platform.Architecture,
}
}
@@ -365,10 +366,10 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
}
notifyReq := map[string]any{
"repository": s.ctx.Repository,
"repository": s.ctx.TargetRepo,
"tag": tag,
"userDid": s.ctx.DID,
"userHandle": s.ctx.Handle,
"userDid": s.ctx.TargetOwnerDID,
"userHandle": s.ctx.TargetOwnerHandle,
"manifest": manifestData,
}
@@ -406,24 +407,251 @@ func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRec
// Parse response (optional logging)
var notifyResp map[string]any
if err := json.NewDecoder(resp.Body).Decode(&notifyResp); err == nil {
slog.Info("Hold notification successful", "repository", s.ctx.Repository, "tag", tag, "response", notifyResp)
slog.Info("Hold notification successful", "repository", s.ctx.TargetRepo, "tag", tag, "response", notifyResp)
}
return nil
}
// refreshReadmeCache refreshes the README cache for this manifest if it has io.atcr.readme annotation
// This should be called asynchronously after manifest push to keep README content fresh
// NOTE: Currently disabled because the generated Manifest_Annotations type doesn't support
// arbitrary key-value pairs. Would need to update lexicon schema with "unknown" type.
func (s *ManifestStore) refreshReadmeCache(ctx context.Context, manifestRecord *atproto.Manifest) {
// Skip if no README cache configured
if s.ctx.ReadmeCache == nil {
// ensureRepoPage creates or updates a repo page record in the user's PDS if needed
// This syncs repository metadata from manifest annotations to the io.atcr.repo.page collection
// Only creates a new record if one doesn't exist (doesn't overwrite user's custom content)
func (s *ManifestStore) ensureRepoPage(ctx context.Context, manifestRecord *atproto.ManifestRecord) {
// Check if repo page already exists (don't overwrite user's custom content)
rkey := s.ctx.TargetRepo
_, err := s.ctx.GetATProtoClient().GetRecord(ctx, atproto.RepoPageCollection, rkey)
if err == nil {
// Record already exists - don't overwrite
slog.Debug("Repo page already exists, skipping creation", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo)
return
}
// TODO: Re-enable once lexicon supports annotations as map[string]string
// The generated Manifest_Annotations is an empty struct that doesn't support map access.
// For now, README cache refresh on push is disabled.
_ = manifestRecord // silence unused variable warning
// Only continue if it's a "not found" error - other errors mean we should skip
if !errors.Is(err, atproto.ErrRecordNotFound) {
slog.Warn("Failed to check for existing repo page", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
return
}
// Get annotations (may be nil if image has no OCI labels)
annotations := manifestRecord.Annotations
if annotations == nil {
annotations = make(map[string]string)
}
// Try to fetch README content from external sources
// Priority: io.atcr.readme annotation > derived from org.opencontainers.image.source > org.opencontainers.image.description
description := s.fetchReadmeContent(ctx, annotations)
// If no README content could be fetched, fall back to description annotation
if description == "" {
description = annotations["org.opencontainers.image.description"]
}
// Try to fetch and upload icon from io.atcr.icon annotation
var avatarRef *atproto.ATProtoBlobRef
if iconURL := annotations["io.atcr.icon"]; iconURL != "" {
avatarRef = s.fetchAndUploadIcon(ctx, iconURL)
}
// Create new repo page record with description and optional avatar
repoPage := atproto.NewRepoPageRecord(s.ctx.TargetRepo, description, avatarRef)
slog.Info("Creating repo page from manifest annotations", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "descriptionLength", len(description), "hasAvatar", avatarRef != nil)
_, err = s.ctx.GetATProtoClient().PutRecord(ctx, atproto.RepoPageCollection, rkey, repoPage)
if err != nil {
slog.Warn("Failed to create repo page", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo, "error", err)
return
}
slog.Info("Repo page created successfully", "did", s.ctx.TargetOwnerDID, "repository", s.ctx.TargetRepo)
}
// fetchReadmeContent attempts to fetch README content from external sources
// Priority: io.atcr.readme annotation > derived from org.opencontainers.image.source
// Returns the raw markdown content, or empty string if not available
func (s *ManifestStore) fetchReadmeContent(ctx context.Context, annotations map[string]string) string {
// Create a context with timeout for README fetching (don't block push too long)
fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Priority 1: Direct README URL from io.atcr.readme annotation
if readmeURL := annotations["io.atcr.readme"]; readmeURL != "" {
content, err := s.fetchRawReadme(fetchCtx, readmeURL)
if err != nil {
slog.Debug("Failed to fetch README from io.atcr.readme annotation", "url", readmeURL, "error", err)
} else if content != "" {
slog.Info("Fetched README from io.atcr.readme annotation", "url", readmeURL, "length", len(content))
return content
}
}
// Priority 2: Derive README URL from org.opencontainers.image.source
if sourceURL := annotations["org.opencontainers.image.source"]; sourceURL != "" {
// Try main branch first, then master
for _, branch := range []string{"main", "master"} {
readmeURL := readme.DeriveReadmeURL(sourceURL, branch)
if readmeURL == "" {
continue
}
content, err := s.fetchRawReadme(fetchCtx, readmeURL)
if err != nil {
// Only log non-404 errors (404 is expected when trying main vs master)
if !readme.Is404(err) {
slog.Debug("Failed to fetch README from source URL", "url", readmeURL, "branch", branch, "error", err)
}
continue
}
if content != "" {
slog.Info("Fetched README from source URL", "sourceURL", sourceURL, "branch", branch, "length", len(content))
return content
}
}
}
return ""
}
// fetchRawReadme fetches raw markdown content from a URL
// Returns the raw markdown (not rendered HTML) for storage in the repo page record
func (s *ManifestStore) fetchRawReadme(ctx context.Context, readmeURL string) (string, error) {
// Use a simple HTTP client to fetch raw content
// We want raw markdown, not rendered HTML (the Fetcher renders to HTML)
req, err := http.NewRequestWithContext(ctx, "GET", readmeURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", "ATCR-README-Fetcher/1.0")
client := &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("too many redirects")
}
return nil
},
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to fetch URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
// Limit content size to 100KB (repo page description has 100KB limit in lexicon)
limitedReader := io.LimitReader(resp.Body, 100*1024)
content, err := io.ReadAll(limitedReader)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
return string(content), nil
}
// fetchAndUploadIcon fetches an image from a URL and uploads it as a blob to the user's PDS
// Returns the blob reference for use in the repo page record, or nil on error
func (s *ManifestStore) fetchAndUploadIcon(ctx context.Context, iconURL string) *atproto.ATProtoBlobRef {
// Create a context with timeout for icon fetching
fetchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Fetch the icon
req, err := http.NewRequestWithContext(fetchCtx, "GET", iconURL, nil)
if err != nil {
slog.Debug("Failed to create icon request", "url", iconURL, "error", err)
return nil
}
req.Header.Set("User-Agent", "ATCR-Icon-Fetcher/1.0")
client := &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("too many redirects")
}
return nil
},
}
resp, err := client.Do(req)
if err != nil {
slog.Debug("Failed to fetch icon", "url", iconURL, "error", err)
return nil
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
slog.Debug("Icon fetch returned non-OK status", "url", iconURL, "status", resp.StatusCode)
return nil
}
// Validate content type - only allow images
contentType := resp.Header.Get("Content-Type")
mimeType := detectImageMimeType(contentType, iconURL)
if mimeType == "" {
slog.Debug("Icon has unsupported content type", "url", iconURL, "contentType", contentType)
return nil
}
// Limit icon size to 3MB (matching lexicon maxSize)
limitedReader := io.LimitReader(resp.Body, 3*1024*1024)
iconData, err := io.ReadAll(limitedReader)
if err != nil {
slog.Debug("Failed to read icon data", "url", iconURL, "error", err)
return nil
}
if len(iconData) == 0 {
slog.Debug("Icon data is empty", "url", iconURL)
return nil
}
// Upload the icon as a blob to the user's PDS
blobRef, err := s.ctx.GetATProtoClient().UploadBlob(ctx, iconData, mimeType)
if err != nil {
slog.Warn("Failed to upload icon blob", "url", iconURL, "error", err)
return nil
}
slog.Info("Uploaded icon blob", "url", iconURL, "size", len(iconData), "mimeType", mimeType, "cid", blobRef.Ref.Link)
return blobRef
}
// detectImageMimeType determines the MIME type for an image
// Uses Content-Type header first, then falls back to extension-based detection
// Only allows types accepted by the lexicon: image/png, image/jpeg, image/webp
func detectImageMimeType(contentType, url string) string {
// Check Content-Type header first
switch {
case strings.HasPrefix(contentType, "image/png"):
return "image/png"
case strings.HasPrefix(contentType, "image/jpeg"):
return "image/jpeg"
case strings.HasPrefix(contentType, "image/webp"):
return "image/webp"
}
// Fall back to URL extension detection
lowerURL := strings.ToLower(url)
switch {
case strings.HasSuffix(lowerURL, ".png"):
return "image/png"
case strings.HasSuffix(lowerURL, ".jpg"), strings.HasSuffix(lowerURL, ".jpeg"):
return "image/jpeg"
case strings.HasSuffix(lowerURL, ".webp"):
return "image/webp"
}
// Unknown or unsupported type - reject
return ""
}

View File

@@ -8,15 +8,13 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth"
"github.com/distribution/distribution/v3"
"github.com/opencontainers/go-digest"
)
// mockDatabaseMetrics removed - using the one from context_test.go
// mockBlobStore is a minimal mock of distribution.BlobStore for testing
type mockBlobStore struct {
blobs map[digest.Digest][]byte
@@ -72,16 +70,11 @@ func (m *mockBlobStore) Open(ctx context.Context, dgst digest.Digest) (io.ReadSe
return nil, nil // Not needed for current tests
}
// mockRegistryContext creates a mock RegistryContext for testing
func mockRegistryContext(client *atproto.Client, repository, holdDID, did, handle string, database DatabaseMetrics) *RegistryContext {
return &RegistryContext{
ATProtoClient: client,
Repository: repository,
HoldDID: holdDID,
DID: did,
Handle: handle,
Database: database,
}
// mockUserContextForManifest creates a mock auth.UserContext for manifest store testing
func mockUserContextForManifest(pdsEndpoint, repository, holdDID, ownerDID, ownerHandle string) *auth.UserContext {
userCtx := auth.NewUserContext(ownerDID, "oauth", "PUT", nil)
userCtx.SetTarget(ownerDID, ownerHandle, pdsEndpoint, repository, holdDID)
return userCtx
}
// TestDigestToRKey tests digest to record key conversion
@@ -115,82 +108,27 @@ func TestDigestToRKey(t *testing.T) {
// TestNewManifestStore tests creating a new manifest store
func TestNewManifestStore(t *testing.T) {
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
blobStore := newMockBlobStore()
db := &mockDatabaseMetrics{}
userCtx := mockUserContextForManifest(
"https://pds.example.com",
"myapp",
"did:web:hold.example.com",
"did:plc:alice123",
"alice.test",
)
store := NewManifestStore(userCtx, blobStore, nil)
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db)
store := NewManifestStore(ctx, blobStore)
if store.ctx.Repository != "myapp" {
t.Errorf("repository = %v, want myapp", store.ctx.Repository)
if store.ctx.TargetRepo != "myapp" {
t.Errorf("repository = %v, want myapp", store.ctx.TargetRepo)
}
if store.ctx.HoldDID != "did:web:hold.example.com" {
t.Errorf("holdDID = %v, want did:web:hold.example.com", store.ctx.HoldDID)
if store.ctx.TargetHoldDID != "did:web:hold.example.com" {
t.Errorf("holdDID = %v, want did:web:hold.example.com", store.ctx.TargetHoldDID)
}
if store.ctx.DID != "did:plc:alice123" {
t.Errorf("did = %v, want did:plc:alice123", store.ctx.DID)
if store.ctx.TargetOwnerDID != "did:plc:alice123" {
t.Errorf("did = %v, want did:plc:alice123", store.ctx.TargetOwnerDID)
}
if store.ctx.Handle != "alice.test" {
t.Errorf("handle = %v, want alice.test", store.ctx.Handle)
}
}
// TestManifestStore_GetLastFetchedHoldDID tests tracking last fetched hold DID
func TestManifestStore_GetLastFetchedHoldDID(t *testing.T) {
tests := []struct {
name string
manifestHoldDID string
manifestHoldURL string
expectedLastFetched string
}{
{
name: "prefers HoldDID",
manifestHoldDID: "did:web:hold01.atcr.io",
manifestHoldURL: "https://hold01.atcr.io",
expectedLastFetched: "did:web:hold01.atcr.io",
},
{
name: "falls back to HoldEndpoint URL conversion",
manifestHoldDID: "",
manifestHoldURL: "https://hold02.atcr.io",
expectedLastFetched: "did:web:hold02.atcr.io",
},
{
name: "empty hold references",
manifestHoldDID: "",
manifestHoldURL: "",
expectedLastFetched: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, nil)
// Simulate what happens in Get() when parsing a manifest record
var manifestRecord atproto.Manifest
if tt.manifestHoldDID != "" {
manifestRecord.HoldDid = &tt.manifestHoldDID
}
if tt.manifestHoldURL != "" {
manifestRecord.HoldEndpoint = &tt.manifestHoldURL
}
// Mimic the hold DID extraction logic from Get()
if manifestRecord.HoldDid != nil && *manifestRecord.HoldDid != "" {
store.lastFetchedHoldDID = *manifestRecord.HoldDid
} else if manifestRecord.HoldEndpoint != nil && *manifestRecord.HoldEndpoint != "" {
store.lastFetchedHoldDID = atproto.ResolveHoldDIDFromURL(*manifestRecord.HoldEndpoint)
}
got := store.GetLastFetchedHoldDID()
if got != tt.expectedLastFetched {
t.Errorf("GetLastFetchedHoldDID() = %v, want %v", got, tt.expectedLastFetched)
}
})
if store.ctx.TargetOwnerHandle != "alice.test" {
t.Errorf("handle = %v, want alice.test", store.ctx.TargetOwnerHandle)
}
}
@@ -245,9 +183,14 @@ func TestExtractConfigLabels(t *testing.T) {
blobStore.blobs[configDigest] = configData
// Create manifest store
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, blobStore)
userCtx := mockUserContextForManifest(
"https://pds.example.com",
"myapp",
"",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, blobStore, nil)
// Extract labels
labels, err := store.extractConfigLabels(context.Background(), configDigest.String())
@@ -285,9 +228,14 @@ func TestExtractConfigLabels_NoLabels(t *testing.T) {
configDigest := digest.FromBytes(configData)
blobStore.blobs[configDigest] = configData
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, blobStore)
userCtx := mockUserContextForManifest(
"https://pds.example.com",
"myapp",
"",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, blobStore, nil)
labels, err := store.extractConfigLabels(context.Background(), configDigest.String())
if err != nil {
@@ -303,9 +251,14 @@ func TestExtractConfigLabels_NoLabels(t *testing.T) {
// TestExtractConfigLabels_InvalidDigest tests error handling for invalid digest
func TestExtractConfigLabels_InvalidDigest(t *testing.T) {
blobStore := newMockBlobStore()
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, blobStore)
userCtx := mockUserContextForManifest(
"https://pds.example.com",
"myapp",
"",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, blobStore, nil)
_, err := store.extractConfigLabels(context.Background(), "invalid-digest")
if err == nil {
@@ -322,9 +275,14 @@ func TestExtractConfigLabels_InvalidJSON(t *testing.T) {
configDigest := digest.FromBytes(configData)
blobStore.blobs[configDigest] = configData
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, blobStore)
userCtx := mockUserContextForManifest(
"https://pds.example.com",
"myapp",
"",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, blobStore, nil)
_, err := store.extractConfigLabels(context.Background(), configDigest.String())
if err == nil {
@@ -332,28 +290,18 @@ func TestExtractConfigLabels_InvalidJSON(t *testing.T) {
}
}
// TestManifestStore_WithMetrics tests that metrics are tracked
func TestManifestStore_WithMetrics(t *testing.T) {
db := &mockDatabaseMetrics{}
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db)
store := NewManifestStore(ctx, nil)
// TestManifestStore_WithoutDatabase tests that nil database is acceptable
func TestManifestStore_WithoutDatabase(t *testing.T) {
userCtx := mockUserContextForManifest(
"https://pds.example.com",
"myapp",
"did:web:hold.example.com",
"did:plc:alice123",
"alice.test",
)
store := NewManifestStore(userCtx, nil, nil)
if store.ctx.Database != db {
t.Error("ManifestStore should store database reference")
}
// Note: Actual metrics tracking happens in Put() and Get() which require
// full mock setup. The important thing is that the database is wired up.
}
// TestManifestStore_WithoutMetrics tests that nil database is acceptable
func TestManifestStore_WithoutMetrics(t *testing.T) {
client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", nil)
store := NewManifestStore(ctx, nil)
if store.ctx.Database != nil {
if store.sqlDB != nil {
t.Error("ManifestStore should accept nil database")
}
}
@@ -372,7 +320,7 @@ func TestManifestStore_Exists(t *testing.T) {
name: "manifest exists",
digest: "sha256:abc123",
serverStatus: http.StatusOK,
serverResp: `{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","value":{}}`,
serverResp: `{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafytest","value":{}}`,
wantExists: true,
wantErr: false,
},
@@ -403,9 +351,14 @@ func TestManifestStore_Exists(t *testing.T) {
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, nil)
userCtx := mockUserContextForManifest(
server.URL,
"myapp",
"did:web:hold.example.com",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, nil, nil)
exists, err := store.Exists(context.Background(), tt.digest)
if (err != nil) != tt.wantErr {
@@ -437,7 +390,7 @@ func TestManifestStore_Get(t *testing.T) {
digest: "sha256:abc123",
serverResp: `{
"uri":"at://did:plc:test123/io.atcr.manifest/abc123",
"cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku",
"cid":"bafytest",
"value":{
"$type":"io.atcr.manifest",
"repository":"myapp",
@@ -447,7 +400,7 @@ func TestManifestStore_Get(t *testing.T) {
"mediaType":"application/vnd.oci.image.manifest.v1+json",
"manifestBlob":{
"$type":"blob",
"ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},
"ref":{"$link":"bafytest"},
"mimeType":"application/vnd.oci.image.manifest.v1+json",
"size":100
}
@@ -481,9 +434,7 @@ func TestManifestStore_Get(t *testing.T) {
"holdEndpoint":"https://hold02.atcr.io",
"mediaType":"application/vnd.oci.image.manifest.v1+json",
"manifestBlob":{
"$type":"blob",
"ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},
"mimeType":"application/json",
"ref":{"$link":"bafylegacy"},
"size":100
}
}
@@ -523,10 +474,14 @@ func TestManifestStore_Get(t *testing.T) {
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
db := &mockDatabaseMetrics{}
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
store := NewManifestStore(ctx, nil)
userCtx := mockUserContextForManifest(
server.URL,
"myapp",
"did:web:hold.example.com",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, nil, nil)
manifest, err := store.Get(context.Background(), tt.digest)
if (err != nil) != tt.wantErr {
@@ -547,148 +502,6 @@ func TestManifestStore_Get(t *testing.T) {
}
}
// TestManifestStore_Get_HoldDIDTracking tests that Get() stores the holdDID
func TestManifestStore_Get_HoldDIDTracking(t *testing.T) {
ociManifest := []byte(`{"schemaVersion":2}`)
tests := []struct {
name string
manifestResp string
expectedHoldDID string
}{
{
name: "tracks HoldDID from new format",
manifestResp: `{
"uri":"at://did:plc:test123/io.atcr.manifest/abc123",
"value":{
"$type":"io.atcr.manifest",
"holdDid":"did:web:hold01.atcr.io",
"holdEndpoint":"https://hold01.atcr.io",
"mediaType":"application/vnd.oci.image.manifest.v1+json",
"manifestBlob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}
}
}`,
expectedHoldDID: "did:web:hold01.atcr.io",
},
{
name: "tracks HoldDID from legacy HoldEndpoint",
manifestResp: `{
"uri":"at://did:plc:test123/io.atcr.manifest/abc123",
"value":{
"$type":"io.atcr.manifest",
"holdEndpoint":"https://hold02.atcr.io",
"mediaType":"application/vnd.oci.image.manifest.v1+json",
"manifestBlob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}
}
}`,
expectedHoldDID: "did:web:hold02.atcr.io",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == atproto.SyncGetBlob {
w.Write(ociManifest)
return
}
w.Write([]byte(tt.manifestResp))
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, nil)
_, err := store.Get(context.Background(), "sha256:abc123")
if err != nil {
t.Fatalf("Get() error = %v", err)
}
gotHoldDID := store.GetLastFetchedHoldDID()
if gotHoldDID != tt.expectedHoldDID {
t.Errorf("GetLastFetchedHoldDID() = %v, want %v", gotHoldDID, tt.expectedHoldDID)
}
})
}
}
// TestManifestStore_Get_OnlyCountsGETRequests verifies that HEAD requests don't increment pull count
func TestManifestStore_Get_OnlyCountsGETRequests(t *testing.T) {
ociManifest := []byte(`{"schemaVersion":2}`)
tests := []struct {
name string
httpMethod string
expectPullIncrement bool
}{
{
name: "GET request increments pull count",
httpMethod: "GET",
expectPullIncrement: true,
},
{
name: "HEAD request does not increment pull count",
httpMethod: "HEAD",
expectPullIncrement: false,
},
{
name: "POST request does not increment pull count",
httpMethod: "POST",
expectPullIncrement: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == atproto.SyncGetBlob {
w.Write(ociManifest)
return
}
w.Write([]byte(`{
"uri": "at://did:plc:test123/io.atcr.manifest/abc123",
"value": {
"$type":"io.atcr.manifest",
"holdDid":"did:web:hold01.atcr.io",
"mediaType":"application/vnd.oci.image.manifest.v1+json",
"manifestBlob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}
}
}`))
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
mockDB := &mockDatabaseMetrics{}
ctx := mockRegistryContext(client, "myapp", "did:web:hold01.atcr.io", "did:plc:test123", "test.handle", mockDB)
store := NewManifestStore(ctx, nil)
// Create a context with the HTTP method stored (as distribution library does)
testCtx := context.WithValue(context.Background(), "http.request.method", tt.httpMethod)
_, err := store.Get(testCtx, "sha256:abc123")
if err != nil {
t.Fatalf("Get() error = %v", err)
}
// Wait for async goroutine to complete (metrics are incremented asynchronously)
time.Sleep(50 * time.Millisecond)
if tt.expectPullIncrement {
// Check that IncrementPullCount was called
if mockDB.getPullCount() == 0 {
t.Error("Expected pull count to be incremented for GET request, but it wasn't")
}
} else {
// Check that IncrementPullCount was NOT called
if mockDB.getPullCount() > 0 {
t.Errorf("Expected pull count NOT to be incremented for %s request, but it was (count=%d)", tt.httpMethod, mockDB.getPullCount())
}
}
})
}
}
// TestManifestStore_Put tests storing manifests
func TestManifestStore_Put(t *testing.T) {
ociManifest := []byte(`{
@@ -760,7 +573,7 @@ func TestManifestStore_Put(t *testing.T) {
// Handle uploadBlob
if r.URL.Path == atproto.RepoUploadBlob {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}}`))
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"mimeType":"application/json","size":100}}`))
return
}
@@ -769,7 +582,7 @@ func TestManifestStore_Put(t *testing.T) {
json.NewDecoder(r.Body).Decode(&lastBody)
w.WriteHeader(tt.serverStatus)
if tt.serverStatus == http.StatusOK {
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafytest"}`))
} else {
w.Write([]byte(`{"error":"ServerError"}`))
}
@@ -780,10 +593,14 @@ func TestManifestStore_Put(t *testing.T) {
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
db := &mockDatabaseMetrics{}
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
store := NewManifestStore(ctx, nil)
userCtx := mockUserContextForManifest(
server.URL,
"myapp",
"did:web:hold.example.com",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, nil, nil)
dgst, err := store.Put(context.Background(), tt.manifest, tt.options...)
if (err != nil) != tt.wantErr {
@@ -821,19 +638,24 @@ func TestManifestStore_Put_WithConfigLabels(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == atproto.RepoUploadBlob {
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"size":100}}`))
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"size":100}}`))
return
}
if r.URL.Path == atproto.RepoPutRecord {
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/config123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/config123","cid":"bafytest"}`))
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
userCtx := mockUserContextForManifest(
server.URL,
"myapp",
"did:web:hold.example.com",
"did:plc:test123",
"test.handle",
)
// Use config digest in manifest
ociManifestWithConfig := []byte(`{
@@ -848,7 +670,7 @@ func TestManifestStore_Put_WithConfigLabels(t *testing.T) {
payload: ociManifestWithConfig,
}
store := NewManifestStore(ctx, blobStore)
store := NewManifestStore(userCtx, blobStore, nil)
_, err := store.Put(context.Background(), manifest)
if err != nil {
@@ -876,7 +698,7 @@ func TestManifestStore_Delete(t *testing.T) {
name: "successful delete",
digest: "sha256:abc123",
serverStatus: http.StatusOK,
serverResp: `{"commit":{"cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","rev":"12345"}}`,
serverResp: `{"commit":{"cid":"bafytest","rev":"12345"}}`,
wantErr: false,
},
{
@@ -908,9 +730,14 @@ func TestManifestStore_Delete(t *testing.T) {
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, nil)
userCtx := mockUserContextForManifest(
server.URL,
"myapp",
"did:web:hold.example.com",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, nil, nil)
err := store.Delete(context.Background(), tt.digest)
if (err != nil) != tt.wantErr {
@@ -1033,7 +860,7 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
// Handle uploadBlob
if r.URL.Path == atproto.RepoUploadBlob {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"mimeType":"application/json","size":100}}`))
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"mimeType":"application/json","size":100}}`))
return
}
@@ -1045,7 +872,7 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
// If child should exist, return it; otherwise return RecordNotFound
if tt.childExists || rkey == childDigest.Encoded() {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","value":{}}`))
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafytest","value":{}}`))
} else {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"RecordNotFound","message":"Record not found"}`))
@@ -1056,7 +883,7 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
// Handle putRecord
if r.URL.Path == atproto.RepoPutRecord {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafytest"}`))
return
}
@@ -1064,10 +891,14 @@ func TestManifestStore_Put_ManifestListValidation(t *testing.T) {
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
db := &mockDatabaseMetrics{}
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db)
store := NewManifestStore(ctx, nil)
userCtx := mockUserContextForManifest(
server.URL,
"myapp",
"did:web:hold.example.com",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, nil, nil)
manifest := &rawManifest{
mediaType: "application/vnd.oci.image.index.v1+json",
@@ -1117,14 +948,14 @@ func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == atproto.RepoUploadBlob {
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},"size":100}}`))
w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"size":100}}`))
return
}
if r.URL.Path == atproto.RepoGetRecord {
rkey := r.URL.Query().Get("rkey")
if existingManifests[rkey] {
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku","value":{}}`))
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafytest","value":{}}`))
} else {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"RecordNotFound"}`))
@@ -1133,7 +964,7 @@ func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T)
}
if r.URL.Path == atproto.RepoPutRecord {
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"}`))
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafytest"}`))
return
}
@@ -1141,9 +972,14 @@ func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T)
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "token")
ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil)
store := NewManifestStore(ctx, nil)
userCtx := mockUserContextForManifest(
server.URL,
"myapp",
"did:web:hold.example.com",
"did:plc:test123",
"test.handle",
)
store := NewManifestStore(userCtx, nil, nil)
// Create manifest list with both children
manifestList := []byte(`{

View File

@@ -54,7 +54,7 @@ func EnsureProfile(ctx context.Context, client *atproto.Client, defaultHoldDID s
// GetProfile retrieves the user's profile from their PDS
// Returns nil if profile doesn't exist
// Automatically migrates old URL-based defaultHold values to DIDs
func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorProfile, error) {
func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorProfileRecord, error) {
record, err := client.GetRecord(ctx, atproto.SailorProfileCollection, ProfileRKey)
if err != nil {
// Check if it's a 404 (profile doesn't exist)
@@ -65,17 +65,17 @@ func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorPro
}
// Parse the profile record
var profile atproto.SailorProfile
var profile atproto.SailorProfileRecord
if err := json.Unmarshal(record.Value, &profile); err != nil {
return nil, fmt.Errorf("failed to parse profile: %w", err)
}
// Migrate old URL-based defaultHold to DID format
// This ensures backward compatibility with profiles created before DID migration
if profile.DefaultHold != nil && *profile.DefaultHold != "" && !atproto.IsDID(*profile.DefaultHold) {
if profile.DefaultHold != "" && !atproto.IsDID(profile.DefaultHold) {
// Convert URL to DID transparently
migratedDID := atproto.ResolveHoldDIDFromURL(*profile.DefaultHold)
profile.DefaultHold = &migratedDID
migratedDID := atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
profile.DefaultHold = migratedDID
// Persist the migration to PDS in a background goroutine
// Use a lock to ensure only one goroutine migrates this DID
@@ -94,8 +94,7 @@ func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorPro
defer cancel()
// Update the profile on the PDS
now := time.Now().Format(time.RFC3339)
profile.UpdatedAt = &now
profile.UpdatedAt = time.Now()
if err := UpdateProfile(ctx, client, &profile); err != nil {
slog.Warn("Failed to persist URL-to-DID migration", "component", "profile", "did", did, "error", err)
} else {
@@ -110,13 +109,12 @@ func GetProfile(ctx context.Context, client *atproto.Client) (*atproto.SailorPro
// UpdateProfile updates the user's profile
// Normalizes defaultHold to DID format before saving
func UpdateProfile(ctx context.Context, client *atproto.Client, profile *atproto.SailorProfile) error {
func UpdateProfile(ctx context.Context, client *atproto.Client, profile *atproto.SailorProfileRecord) error {
// Normalize defaultHold to DID if it's a URL
// This ensures we always store DIDs, even if user provides a URL
if profile.DefaultHold != nil && *profile.DefaultHold != "" && !atproto.IsDID(*profile.DefaultHold) {
normalized := atproto.ResolveHoldDIDFromURL(*profile.DefaultHold)
profile.DefaultHold = &normalized
slog.Debug("Normalized defaultHold to DID", "component", "profile", "default_hold", normalized)
if profile.DefaultHold != "" && !atproto.IsDID(profile.DefaultHold) {
profile.DefaultHold = atproto.ResolveHoldDIDFromURL(profile.DefaultHold)
slog.Debug("Normalized defaultHold to DID", "component", "profile", "default_hold", profile.DefaultHold)
}
_, err := client.PutRecord(ctx, atproto.SailorProfileCollection, ProfileRKey, profile)

View File

@@ -39,7 +39,7 @@ func TestEnsureProfile_Create(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var createdProfile *atproto.SailorProfile
var createdProfile *atproto.SailorProfileRecord
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// First request: GetRecord (should 404)
@@ -95,16 +95,12 @@ func TestEnsureProfile_Create(t *testing.T) {
t.Fatal("Profile was not created")
}
if createdProfile.LexiconTypeID != atproto.SailorProfileCollection {
t.Errorf("LexiconTypeID = %v, want %v", createdProfile.LexiconTypeID, atproto.SailorProfileCollection)
if createdProfile.Type != atproto.SailorProfileCollection {
t.Errorf("Type = %v, want %v", createdProfile.Type, atproto.SailorProfileCollection)
}
gotDefaultHold := ""
if createdProfile.DefaultHold != nil {
gotDefaultHold = *createdProfile.DefaultHold
}
if gotDefaultHold != tt.wantNormalized {
t.Errorf("DefaultHold = %v, want %v", gotDefaultHold, tt.wantNormalized)
if createdProfile.DefaultHold != tt.wantNormalized {
t.Errorf("DefaultHold = %v, want %v", createdProfile.DefaultHold, tt.wantNormalized)
}
})
}
@@ -158,7 +154,7 @@ func TestGetProfile(t *testing.T) {
name string
serverResponse string
serverStatus int
wantProfile *atproto.SailorProfile
wantProfile *atproto.SailorProfileRecord
wantNil bool
wantErr bool
expectMigration bool // Whether URL-to-DID migration should happen
@@ -269,12 +265,8 @@ func TestGetProfile(t *testing.T) {
}
// Check that defaultHold is migrated to DID in returned profile
gotDefaultHold := ""
if profile.DefaultHold != nil {
gotDefaultHold = *profile.DefaultHold
}
if gotDefaultHold != tt.expectedHoldDID {
t.Errorf("DefaultHold = %v, want %v", gotDefaultHold, tt.expectedHoldDID)
if profile.DefaultHold != tt.expectedHoldDID {
t.Errorf("DefaultHold = %v, want %v", profile.DefaultHold, tt.expectedHoldDID)
}
if tt.expectMigration {
@@ -374,43 +366,44 @@ func TestGetProfile_MigrationLocking(t *testing.T) {
}
}
// testSailorProfile creates a test profile with the given default hold
func testSailorProfile(defaultHold string) *atproto.SailorProfile {
now := time.Now().Format(time.RFC3339)
profile := &atproto.SailorProfile{
LexiconTypeID: atproto.SailorProfileCollection,
CreatedAt: now,
UpdatedAt: &now,
}
if defaultHold != "" {
profile.DefaultHold = &defaultHold
}
return profile
}
// TestUpdateProfile tests updating a user's profile
func TestUpdateProfile(t *testing.T) {
tests := []struct {
name string
profile *atproto.SailorProfile
profile *atproto.SailorProfileRecord
wantNormalized string // Expected defaultHold after normalization
wantErr bool
}{
{
name: "update with DID",
profile: testSailorProfile("did:web:hold02.atcr.io"),
name: "update with DID",
profile: &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: "did:web:hold02.atcr.io",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
wantNormalized: "did:web:hold02.atcr.io",
wantErr: false,
},
{
name: "update with URL - should normalize",
profile: testSailorProfile("https://hold02.atcr.io"),
name: "update with URL - should normalize",
profile: &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: "https://hold02.atcr.io",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
wantNormalized: "did:web:hold02.atcr.io",
wantErr: false,
},
{
name: "clear default hold",
profile: testSailorProfile(""),
name: "clear default hold",
profile: &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
wantNormalized: "",
wantErr: false,
},
@@ -461,12 +454,8 @@ func TestUpdateProfile(t *testing.T) {
}
// Verify normalization also updated the profile object
gotProfileHold := ""
if tt.profile.DefaultHold != nil {
gotProfileHold = *tt.profile.DefaultHold
}
if gotProfileHold != tt.wantNormalized {
t.Errorf("profile.DefaultHold = %v, want %v (should be updated in-place)", gotProfileHold, tt.wantNormalized)
if tt.profile.DefaultHold != tt.wantNormalized {
t.Errorf("profile.DefaultHold = %v, want %v (should be updated in-place)", tt.profile.DefaultHold, tt.wantNormalized)
}
}
})
@@ -550,8 +539,8 @@ func TestGetProfile_EmptyDefaultHold(t *testing.T) {
t.Fatalf("GetProfile() error = %v", err)
}
if profile.DefaultHold != nil && *profile.DefaultHold != "" {
t.Errorf("DefaultHold = %v, want empty or nil", profile.DefaultHold)
if profile.DefaultHold != "" {
t.Errorf("DefaultHold = %v, want empty string", profile.DefaultHold)
}
}
@@ -564,7 +553,12 @@ func TestUpdateProfile_ServerError(t *testing.T) {
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
profile := testSailorProfile("did:web:hold01.atcr.io")
profile := &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: "did:web:hold01.atcr.io",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err := UpdateProfile(context.Background(), client, profile)

View File

@@ -12,6 +12,7 @@ import (
"time"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/opencontainers/go-digest"
@@ -32,20 +33,20 @@ var (
// ProxyBlobStore proxies blob requests to an external storage service
type ProxyBlobStore struct {
ctx *RegistryContext // All context and services
holdURL string // Resolved HTTP URL for XRPC requests
ctx *auth.UserContext // User context with identity, target, permissions
holdURL string // Resolved HTTP URL for XRPC requests
httpClient *http.Client
}
// NewProxyBlobStore creates a new proxy blob store
func NewProxyBlobStore(ctx *RegistryContext) *ProxyBlobStore {
func NewProxyBlobStore(userCtx *auth.UserContext) *ProxyBlobStore {
// Resolve DID to URL once at construction time
holdURL := atproto.ResolveHoldURL(ctx.HoldDID)
holdURL := atproto.ResolveHoldURL(userCtx.TargetHoldDID)
slog.Debug("NewProxyBlobStore created", "component", "proxy_blob_store", "hold_did", ctx.HoldDID, "hold_url", holdURL, "user_did", ctx.DID, "repo", ctx.Repository)
slog.Debug("NewProxyBlobStore created", "component", "proxy_blob_store", "hold_did", userCtx.TargetHoldDID, "hold_url", holdURL, "user_did", userCtx.TargetOwnerDID, "repo", userCtx.TargetRepo)
return &ProxyBlobStore{
ctx: ctx,
ctx: userCtx,
holdURL: holdURL,
httpClient: &http.Client{
Timeout: 5 * time.Minute, // Timeout for presigned URL requests and uploads
@@ -61,32 +62,33 @@ func NewProxyBlobStore(ctx *RegistryContext) *ProxyBlobStore {
}
// doAuthenticatedRequest performs an HTTP request with service token authentication
// Uses the service token from middleware to authenticate requests to the hold service
// Uses the service token from UserContext to authenticate requests to the hold service
func (p *ProxyBlobStore) doAuthenticatedRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
// Use service token that middleware already validated and cached
// Middleware fails fast with HTTP 401 if OAuth session is invalid
if p.ctx.ServiceToken == "" {
// Get service token from UserContext (lazy-loaded and cached per holdDID)
serviceToken, err := p.ctx.GetServiceToken(ctx)
if err != nil {
slog.Error("Failed to get service token", "component", "proxy_blob_store", "did", p.ctx.DID, "error", err)
return nil, fmt.Errorf("failed to get service token: %w", err)
}
if serviceToken == "" {
// Should never happen - middleware validates OAuth before handlers run
slog.Error("No service token in context", "component", "proxy_blob_store", "did", p.ctx.DID)
return nil, fmt.Errorf("no service token available (middleware should have validated)")
}
// Add Bearer token to Authorization header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.ctx.ServiceToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", serviceToken))
return p.httpClient.Do(req)
}
// checkReadAccess validates that the user has read access to blobs in this hold
func (p *ProxyBlobStore) checkReadAccess(ctx context.Context) error {
if p.ctx.Authorizer == nil {
return nil // No authorization check if authorizer not configured
}
allowed, err := p.ctx.Authorizer.CheckReadAccess(ctx, p.ctx.HoldDID, p.ctx.DID)
canRead, err := p.ctx.CanRead(ctx)
if err != nil {
return fmt.Errorf("authorization check failed: %w", err)
}
if !allowed {
if !canRead {
// Return 403 Forbidden instead of masquerading as missing blob
return errcode.ErrorCodeDenied.WithMessage("read access denied")
}
@@ -95,21 +97,17 @@ func (p *ProxyBlobStore) checkReadAccess(ctx context.Context) error {
// checkWriteAccess validates that the user has write access to blobs in this hold
func (p *ProxyBlobStore) checkWriteAccess(ctx context.Context) error {
if p.ctx.Authorizer == nil {
return nil // No authorization check if authorizer not configured
}
slog.Debug("Checking write access", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
allowed, err := p.ctx.Authorizer.CheckWriteAccess(ctx, p.ctx.HoldDID, p.ctx.DID)
slog.Debug("Checking write access", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
canWrite, err := p.ctx.CanWrite(ctx)
if err != nil {
slog.Error("Authorization check error", "component", "proxy_blob_store", "error", err)
return fmt.Errorf("authorization check failed: %w", err)
}
if !allowed {
slog.Warn("Write access denied", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("write access denied to hold %s", p.ctx.HoldDID))
if !canWrite {
slog.Warn("Write access denied", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("write access denied to hold %s", p.ctx.TargetHoldDID))
}
slog.Debug("Write access allowed", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.HoldDID)
slog.Debug("Write access allowed", "component", "proxy_blob_store", "user_did", p.ctx.DID, "hold_did", p.ctx.TargetHoldDID)
return nil
}
@@ -356,10 +354,10 @@ func (p *ProxyBlobStore) Resume(ctx context.Context, id string) (distribution.Bl
// getPresignedURL returns the XRPC endpoint URL for blob operations
func (p *ProxyBlobStore) getPresignedURL(ctx context.Context, operation string, dgst digest.Digest) (string, error) {
// Use XRPC endpoint: /xrpc/com.atproto.sync.getBlob?did={userDID}&cid={digest}
// The 'did' parameter is the USER's DID (whose blob we're fetching), not the hold service DID
// The 'did' parameter is the TARGET OWNER's DID (whose blob we're fetching), not the hold service DID
// Per migration doc: hold accepts OCI digest directly as cid parameter (checks for sha256: prefix)
xrpcURL := fmt.Sprintf("%s%s?did=%s&cid=%s&method=%s",
p.holdURL, atproto.SyncGetBlob, p.ctx.DID, dgst.String(), operation)
p.holdURL, atproto.SyncGetBlob, p.ctx.TargetOwnerDID, dgst.String(), operation)
req, err := http.NewRequestWithContext(ctx, "GET", xrpcURL, nil)
if err != nil {

View File

@@ -1,46 +1,41 @@
package storage
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth/token"
"github.com/opencontainers/go-digest"
"atcr.io/pkg/auth"
)
// TestGetServiceToken_CachingLogic tests the token caching mechanism
// TestGetServiceToken_CachingLogic tests the global service token caching mechanism
// These tests use the global auth cache functions directly
func TestGetServiceToken_CachingLogic(t *testing.T) {
userDID := "did:plc:test"
userDID := "did:plc:cache-test"
holdDID := "did:web:hold.example.com"
// Test 1: Empty cache - invalidate any existing token
token.InvalidateServiceToken(userDID, holdDID)
cachedToken, _ := token.GetServiceToken(userDID, holdDID)
auth.InvalidateServiceToken(userDID, holdDID)
cachedToken, _ := auth.GetServiceToken(userDID, holdDID)
if cachedToken != "" {
t.Error("Expected empty cache at start")
}
// Test 2: Insert token into cache
// Create a JWT-like token with exp claim for testing
// Format: header.payload.signature where payload has exp claim
testPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(50*time.Second).Unix())
testToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
err := token.SetServiceToken(userDID, holdDID, testToken)
err := auth.SetServiceToken(userDID, holdDID, testToken)
if err != nil {
t.Fatalf("Failed to set service token: %v", err)
}
// Test 3: Retrieve from cache
cachedToken, expiresAt := token.GetServiceToken(userDID, holdDID)
cachedToken, expiresAt := auth.GetServiceToken(userDID, holdDID)
if cachedToken == "" {
t.Fatal("Expected token to be in cache")
}
@@ -56,10 +51,10 @@ func TestGetServiceToken_CachingLogic(t *testing.T) {
// Test 4: Expired token - GetServiceToken automatically removes it
expiredPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(-1*time.Hour).Unix())
expiredToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(expiredPayload) + ".signature"
token.SetServiceToken(userDID, holdDID, expiredToken)
auth.SetServiceToken(userDID, holdDID, expiredToken)
// GetServiceToken should return empty string for expired token
cachedToken, _ = token.GetServiceToken(userDID, holdDID)
cachedToken, _ = auth.GetServiceToken(userDID, holdDID)
if cachedToken != "" {
t.Error("Expected expired token to be removed from cache")
}
@@ -70,129 +65,33 @@ func base64URLEncode(data string) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString([]byte(data)), "=")
}
// TestServiceToken_EmptyInContext tests that operations fail when service token is missing
func TestServiceToken_EmptyInContext(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
ServiceToken: "", // No service token (middleware didn't set it)
Refresher: nil,
}
// mockUserContextForProxy creates a mock auth.UserContext for proxy blob store testing.
// It sets up both the user identity and target info, and configures test helpers
// to bypass network calls.
func mockUserContextForProxy(did, holdDID, pdsEndpoint, repository string) *auth.UserContext {
userCtx := auth.NewUserContext(did, "oauth", "PUT", nil)
userCtx.SetTarget(did, "test.handle", pdsEndpoint, repository, holdDID)
store := NewProxyBlobStore(ctx)
// Bypass PDS resolution (avoids network calls)
userCtx.SetPDSForTest("test.handle", pdsEndpoint)
// Try a write operation that requires authentication
testDigest := digest.FromString("test-content")
_, err := store.Stat(context.Background(), testDigest)
// Set up mock authorizer that allows access
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
// Should fail because no service token is available
if err == nil {
t.Error("Expected error when service token is empty")
}
// Set default hold DID for push resolution
userCtx.SetDefaultHoldDIDForTest(holdDID)
// Error should indicate authentication issue
if !strings.Contains(err.Error(), "UNAUTHORIZED") && !strings.Contains(err.Error(), "authentication") {
t.Logf("Got error (acceptable): %v", err)
}
return userCtx
}
// TestDoAuthenticatedRequest_BearerTokenInjection tests that Bearer tokens are added to requests
func TestDoAuthenticatedRequest_BearerTokenInjection(t *testing.T) {
// This test verifies the Bearer token injection logic
testToken := "test-bearer-token-xyz"
// Create a test server to verify the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
// Create ProxyBlobStore with service token in context (set by middleware)
ctx := &RegistryContext{
DID: "did:plc:bearer-test",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
ServiceToken: testToken, // Service token from middleware
Refresher: nil,
}
store := NewProxyBlobStore(ctx)
// Create request
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Do authenticated request
resp, err := store.doAuthenticatedRequest(context.Background(), req)
if err != nil {
t.Fatalf("doAuthenticatedRequest failed: %v", err)
}
defer resp.Body.Close()
// Verify Bearer token was added
expectedHeader := "Bearer " + testToken
if receivedAuthHeader != expectedHeader {
t.Errorf("Expected Authorization header %s, got %s", expectedHeader, receivedAuthHeader)
}
// mockUserContextForProxyWithToken creates a mock UserContext with a pre-populated service token.
func mockUserContextForProxyWithToken(did, holdDID, pdsEndpoint, repository, serviceToken string) *auth.UserContext {
userCtx := mockUserContextForProxy(did, holdDID, pdsEndpoint, repository)
userCtx.SetServiceTokenForTest(holdDID, serviceToken)
return userCtx
}
// TestDoAuthenticatedRequest_ErrorWhenTokenUnavailable tests that authentication failures return proper errors
func TestDoAuthenticatedRequest_ErrorWhenTokenUnavailable(t *testing.T) {
// Create test server (should not be called since auth fails first)
called := false
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
// Create ProxyBlobStore without service token (middleware didn't set it)
ctx := &RegistryContext{
DID: "did:plc:fallback",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
ServiceToken: "", // No service token
Refresher: nil,
}
store := NewProxyBlobStore(ctx)
// Create request
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Do authenticated request - should fail when no service token
resp, err := store.doAuthenticatedRequest(context.Background(), req)
if err == nil {
t.Fatal("Expected doAuthenticatedRequest to fail when no service token is available")
}
if resp != nil {
resp.Body.Close()
}
// Verify error indicates authentication/authorization issue
errStr := err.Error()
if !strings.Contains(errStr, "service token") && !strings.Contains(errStr, "UNAUTHORIZED") {
t.Errorf("Expected service token or unauthorized error, got: %v", err)
}
if called {
t.Error("Expected request to NOT be made when authentication fails")
}
}
// TestResolveHoldURL tests DID to URL conversion
// TestResolveHoldURL tests DID to URL conversion (pure function)
func TestResolveHoldURL(t *testing.T) {
tests := []struct {
name string
@@ -200,7 +99,7 @@ func TestResolveHoldURL(t *testing.T) {
expected string
}{
{
name: "did:web with http (TEST_MODE)",
name: "did:web with http (localhost)",
holdDID: "did:web:localhost:8080",
expected: "http://localhost:8080",
},
@@ -228,16 +127,16 @@ func TestResolveHoldURL(t *testing.T) {
// TestServiceTokenCacheExpiry tests that expired cached tokens are not used
func TestServiceTokenCacheExpiry(t *testing.T) {
userDID := "did:plc:expiry"
userDID := "did:plc:expiry-test"
holdDID := "did:web:hold.example.com"
// Insert expired token
expiredPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(-1*time.Hour).Unix())
expiredToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(expiredPayload) + ".signature"
token.SetServiceToken(userDID, holdDID, expiredToken)
auth.SetServiceToken(userDID, holdDID, expiredToken)
// GetServiceToken should automatically remove expired tokens
cachedToken, expiresAt := token.GetServiceToken(userDID, holdDID)
cachedToken, expiresAt := auth.GetServiceToken(userDID, holdDID)
// Should return empty string for expired token
if cachedToken != "" {
@@ -272,20 +171,20 @@ func TestServiceTokenCacheKeyFormat(t *testing.T) {
// TestNewProxyBlobStore tests ProxyBlobStore creation
func TestNewProxyBlobStore(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
}
userCtx := mockUserContextForProxy(
"did:plc:test",
"did:web:hold.example.com",
"https://pds.example.com",
"test-repo",
)
store := NewProxyBlobStore(ctx)
store := NewProxyBlobStore(userCtx)
if store == nil {
t.Fatal("Expected non-nil ProxyBlobStore")
}
if store.ctx != ctx {
if store.ctx != userCtx {
t.Error("Expected context to be set")
}
@@ -310,10 +209,10 @@ func BenchmarkServiceTokenCacheAccess(b *testing.B) {
testPayload := fmt.Sprintf(`{"exp":%d}`, time.Now().Add(50*time.Second).Unix())
testTokenStr := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
token.SetServiceToken(userDID, holdDID, testTokenStr)
auth.SetServiceToken(userDID, holdDID, testTokenStr)
for b.Loop() {
cachedToken, expiresAt := token.GetServiceToken(userDID, holdDID)
cachedToken, expiresAt := auth.GetServiceToken(userDID, holdDID)
if cachedToken == "" || time.Now().After(expiresAt) {
b.Error("Cache miss in benchmark")
@@ -321,296 +220,55 @@ func BenchmarkServiceTokenCacheAccess(b *testing.B) {
}
}
// TestCompleteMultipartUpload_JSONFormat verifies the JSON request format sent to hold service
// This test would have caught the "partNumber" vs "part_number" bug
func TestCompleteMultipartUpload_JSONFormat(t *testing.T) {
var capturedBody map[string]any
// TestParseJWTExpiry tests JWT expiry parsing
func TestParseJWTExpiry(t *testing.T) {
// Create a JWT with known expiry
futureTime := time.Now().Add(1 * time.Hour).Unix()
testPayload := fmt.Sprintf(`{"exp":%d}`, futureTime)
testToken := "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(testPayload) + ".signature"
// Mock hold service that captures the request body
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, atproto.HoldCompleteUpload) {
t.Errorf("Wrong endpoint called: %s", r.URL.Path)
}
// Capture request body
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("Failed to decode request body: %v", err)
}
capturedBody = body
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{}`))
}))
defer holdServer.Close()
// Create store with mocked hold URL
ctx := &RegistryContext{
DID: "did:plc:test",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
ServiceToken: "test-service-token", // Service token from middleware
}
store := NewProxyBlobStore(ctx)
store.holdURL = holdServer.URL
// Call completeMultipartUpload
parts := []CompletedPart{
{PartNumber: 1, ETag: "etag-1"},
{PartNumber: 2, ETag: "etag-2"},
}
err := store.completeMultipartUpload(context.Background(), "sha256:abc123", "upload-id-xyz", parts)
expiry, err := auth.ParseJWTExpiry(testToken)
if err != nil {
t.Fatalf("completeMultipartUpload failed: %v", err)
t.Fatalf("ParseJWTExpiry failed: %v", err)
}
// Verify JSON format
if capturedBody == nil {
t.Fatal("No request body was captured")
}
// Check top-level fields
if uploadID, ok := capturedBody["uploadId"].(string); !ok || uploadID != "upload-id-xyz" {
t.Errorf("Expected uploadId='upload-id-xyz', got %v", capturedBody["uploadId"])
}
if digest, ok := capturedBody["digest"].(string); !ok || digest != "sha256:abc123" {
t.Errorf("Expected digest='sha256:abc123', got %v", capturedBody["digest"])
}
// Check parts array
partsArray, ok := capturedBody["parts"].([]any)
if !ok {
t.Fatalf("Expected parts to be array, got %T", capturedBody["parts"])
}
if len(partsArray) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(partsArray))
}
// Verify first part has "part_number" (not "partNumber")
part0, ok := partsArray[0].(map[string]any)
if !ok {
t.Fatalf("Expected part to be object, got %T", partsArray[0])
}
// THIS IS THE KEY CHECK - would have caught the bug
if _, hasPartNumber := part0["partNumber"]; hasPartNumber {
t.Error("Found 'partNumber' (camelCase) - should be 'part_number' (snake_case)")
}
if partNum, ok := part0["part_number"].(float64); !ok || int(partNum) != 1 {
t.Errorf("Expected part_number=1, got %v", part0["part_number"])
}
if etag, ok := part0["etag"].(string); !ok || etag != "etag-1" {
t.Errorf("Expected etag='etag-1', got %v", part0["etag"])
// Verify expiry is close to what we set (within 1 second tolerance)
expectedExpiry := time.Unix(futureTime, 0)
diff := expiry.Sub(expectedExpiry)
if diff < -time.Second || diff > time.Second {
t.Errorf("Expiry mismatch: expected %v, got %v", expectedExpiry, expiry)
}
}
// TestGet_UsesPresignedURLDirectly verifies that Get() doesn't add auth headers to presigned URLs
// This test would have caught the presigned URL authentication bug
func TestGet_UsesPresignedURLDirectly(t *testing.T) {
blobData := []byte("test blob content")
var s3ReceivedAuthHeader string
// Mock S3 server that rejects requests with Authorization header
s3Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s3ReceivedAuthHeader = r.Header.Get("Authorization")
// Presigned URLs should NOT have Authorization header
if s3ReceivedAuthHeader != "" {
t.Errorf("S3 received Authorization header: %s (should be empty for presigned URLs)", s3ReceivedAuthHeader)
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`<?xml version="1.0"?><Error><Code>SignatureDoesNotMatch</Code></Error>`))
return
}
// Return blob data
w.WriteHeader(http.StatusOK)
w.Write(blobData)
}))
defer s3Server.Close()
// Mock hold service that returns presigned S3 URL
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return presigned URL pointing to S3 server
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
resp := map[string]string{
"url": s3Server.URL + "/blob?X-Amz-Signature=fake-signature",
}
json.NewEncoder(w).Encode(resp)
}))
defer holdServer.Close()
// Create store with service token in context
ctx := &RegistryContext{
DID: "did:plc:test",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
ServiceToken: "test-service-token", // Service token from middleware
}
store := NewProxyBlobStore(ctx)
store.holdURL = holdServer.URL
// Call Get()
dgst := digest.FromBytes(blobData)
retrieved, err := store.Get(context.Background(), dgst)
if err != nil {
t.Fatalf("Get() failed: %v", err)
}
// Verify correct data was retrieved
if string(retrieved) != string(blobData) {
t.Errorf("Expected data=%s, got %s", string(blobData), string(retrieved))
}
// Verify S3 received NO Authorization header
if s3ReceivedAuthHeader != "" {
t.Errorf("S3 should not receive Authorization header for presigned URLs, got: %s", s3ReceivedAuthHeader)
}
}
// TestOpen_UsesPresignedURLDirectly verifies that Open() doesn't add auth headers to presigned URLs
// This test would have caught the presigned URL authentication bug
func TestOpen_UsesPresignedURLDirectly(t *testing.T) {
blobData := []byte("test blob stream content")
var s3ReceivedAuthHeader string
// Mock S3 server
s3Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s3ReceivedAuthHeader = r.Header.Get("Authorization")
// Presigned URLs should NOT have Authorization header
if s3ReceivedAuthHeader != "" {
t.Errorf("S3 received Authorization header: %s (should be empty)", s3ReceivedAuthHeader)
w.WriteHeader(http.StatusForbidden)
return
}
w.WriteHeader(http.StatusOK)
w.Write(blobData)
}))
defer s3Server.Close()
// Mock hold service
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"url": s3Server.URL + "/blob?X-Amz-Signature=fake",
})
}))
defer holdServer.Close()
// Create store with service token in context
ctx := &RegistryContext{
DID: "did:plc:test",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
ServiceToken: "test-service-token", // Service token from middleware
}
store := NewProxyBlobStore(ctx)
store.holdURL = holdServer.URL
// Call Open()
dgst := digest.FromBytes(blobData)
reader, err := store.Open(context.Background(), dgst)
if err != nil {
t.Fatalf("Open() failed: %v", err)
}
defer reader.Close()
// Verify S3 received NO Authorization header
if s3ReceivedAuthHeader != "" {
t.Errorf("S3 should not receive Authorization header for presigned URLs, got: %s", s3ReceivedAuthHeader)
}
}
// TestMultipartEndpoints_CorrectURLs verifies all multipart XRPC endpoints use correct URLs
// This would have caught the old com.atproto.repo.uploadBlob vs new io.atcr.hold.* endpoints
func TestMultipartEndpoints_CorrectURLs(t *testing.T) {
// TestParseJWTExpiry_InvalidToken tests error handling for invalid tokens
func TestParseJWTExpiry_InvalidToken(t *testing.T) {
tests := []struct {
name string
testFunc func(*ProxyBlobStore) error
expectedPath string
name string
token string
}{
{
name: "startMultipartUpload",
testFunc: func(store *ProxyBlobStore) error {
_, err := store.startMultipartUpload(context.Background(), "sha256:test")
return err
},
expectedPath: atproto.HoldInitiateUpload,
},
{
name: "getPartUploadInfo",
testFunc: func(store *ProxyBlobStore) error {
_, err := store.getPartUploadInfo(context.Background(), "sha256:test", "upload-123", 1)
return err
},
expectedPath: atproto.HoldGetPartUploadURL,
},
{
name: "completeMultipartUpload",
testFunc: func(store *ProxyBlobStore) error {
parts := []CompletedPart{{PartNumber: 1, ETag: "etag1"}}
return store.completeMultipartUpload(context.Background(), "sha256:test", "upload-123", parts)
},
expectedPath: atproto.HoldCompleteUpload,
},
{
name: "abortMultipartUpload",
testFunc: func(store *ProxyBlobStore) error {
return store.abortMultipartUpload(context.Background(), "sha256:test", "upload-123")
},
expectedPath: atproto.HoldAbortUpload,
},
{"empty token", ""},
{"single part", "header"},
{"two parts", "header.payload"},
{"invalid base64 payload", "header.!!!.signature"},
{"missing exp claim", "eyJhbGciOiJIUzI1NiJ9." + base64URLEncode(`{"sub":"test"}`) + ".sig"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedPath string
// Mock hold service that captures request path
holdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
// Return success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
resp := map[string]string{
"uploadId": "test-upload-id",
"url": "https://s3.example.com/presigned",
}
json.NewEncoder(w).Encode(resp)
}))
defer holdServer.Close()
// Create store with service token in context
ctx := &RegistryContext{
DID: "did:plc:test",
HoldDID: "did:web:hold.example.com",
PDSEndpoint: "https://pds.example.com",
Repository: "test-repo",
ServiceToken: "test-service-token", // Service token from middleware
}
store := NewProxyBlobStore(ctx)
store.holdURL = holdServer.URL
// Call the function
_ = tt.testFunc(store) // Ignore error, we just care about the URL
// Verify correct endpoint was called
if capturedPath != tt.expectedPath {
t.Errorf("Expected endpoint %s, got %s", tt.expectedPath, capturedPath)
}
// Verify it's NOT the old endpoint
if strings.Contains(capturedPath, "com.atproto.repo.uploadBlob") {
t.Error("Still using old com.atproto.repo.uploadBlob endpoint!")
_, err := auth.ParseJWTExpiry(tt.token)
if err == nil {
t.Error("Expected error for invalid token")
}
})
}
}
// Note: Tests for doAuthenticatedRequest, Get, Open, completeMultipartUpload, etc.
// require complex dependency mocking (OAuth refresher, PDS resolution, HoldAuthorizer).
// These should be tested at the integration level with proper infrastructure.
//
// The current unit tests cover:
// - Global service token cache (auth.GetServiceToken, auth.SetServiceToken, etc.)
// - URL resolution (atproto.ResolveHoldURL)
// - JWT parsing (auth.ParseJWTExpiry)
// - Store construction (NewProxyBlobStore)

View File

@@ -6,110 +6,75 @@ package storage
import (
"context"
"database/sql"
"log/slog"
"sync"
"atcr.io/pkg/auth"
"github.com/distribution/distribution/v3"
"github.com/distribution/reference"
)
// RoutingRepository routes manifests to ATProto and blobs to external hold service
// The registry (AppView) is stateless and NEVER stores blobs locally
// RoutingRepository routes manifests to ATProto and blobs to external hold service.
// The registry (AppView) is stateless and NEVER stores blobs locally.
// A new instance is created per HTTP request - no caching or synchronization needed.
type RoutingRepository struct {
distribution.Repository
Ctx *RegistryContext // All context and services (exported for token updates)
mu sync.Mutex // Protects manifestStore and blobStore
manifestStore *ManifestStore // Cached manifest store instance
blobStore *ProxyBlobStore // Cached blob store instance
userCtx *auth.UserContext
sqlDB *sql.DB
}
// NewRoutingRepository creates a new routing repository
func NewRoutingRepository(baseRepo distribution.Repository, ctx *RegistryContext) *RoutingRepository {
func NewRoutingRepository(baseRepo distribution.Repository, userCtx *auth.UserContext, sqlDB *sql.DB) *RoutingRepository {
return &RoutingRepository{
Repository: baseRepo,
Ctx: ctx,
userCtx: userCtx,
sqlDB: sqlDB,
}
}
// Manifests returns the ATProto-backed manifest service
func (r *RoutingRepository) Manifests(ctx context.Context, options ...distribution.ManifestServiceOption) (distribution.ManifestService, error) {
r.mu.Lock()
// Create or return cached manifest store
if r.manifestStore == nil {
// Ensure blob store is created first (needed for label extraction during push)
// Release lock while calling Blobs to avoid deadlock
r.mu.Unlock()
blobStore := r.Blobs(ctx)
r.mu.Lock()
// Double-check after reacquiring lock (another goroutine might have set it)
if r.manifestStore == nil {
r.manifestStore = NewManifestStore(r.Ctx, blobStore)
}
}
manifestStore := r.manifestStore
r.mu.Unlock()
return manifestStore, nil
// blobStore used to fetch labels from th
blobStore := r.Blobs(ctx)
return NewManifestStore(r.userCtx, blobStore, r.sqlDB), nil
}
// Blobs returns a proxy blob store that routes to external hold service
// The registry (AppView) NEVER stores blobs locally - all blobs go through hold service
func (r *RoutingRepository) Blobs(ctx context.Context) distribution.BlobStore {
r.mu.Lock()
// Return cached blob store if available
if r.blobStore != nil {
blobStore := r.blobStore
r.mu.Unlock()
slog.Debug("Returning cached blob store", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository)
return blobStore
}
// Determine if this is a pull (GET) or push (PUT/POST/HEAD/etc) operation
// Pull operations use the historical hold DID from the database (blobs are where they were pushed)
// Push operations use the discovery-based hold DID from user's profile/default
// This allows users to change their default hold and have new pushes go there
isPull := false
if method, ok := ctx.Value("http.request.method").(string); ok {
isPull = method == "GET"
}
holdDID := r.Ctx.HoldDID // Default to discovery-based DID
holdSource := "discovery"
// Only query database for pull operations
if isPull && r.Ctx.Database != nil {
// Query database for the latest manifest's hold DID
if dbHoldDID, err := r.Ctx.Database.GetLatestHoldDIDForRepo(r.Ctx.DID, r.Ctx.Repository); err == nil && dbHoldDID != "" {
// Use hold DID from database (pull case - use historical reference)
holdDID = dbHoldDID
holdSource = "database"
slog.Debug("Using hold from database manifest (pull)", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", dbHoldDID)
} else if err != nil {
// Log error but don't fail - fall back to discovery-based DID
slog.Warn("Failed to query database for hold DID", "component", "storage/blobs", "error", err)
}
// If dbHoldDID is empty (no manifests yet), fall through to use discovery-based DID
// Resolve hold DID: pull uses DB lookup, push uses profile discovery
holdDID, err := r.userCtx.ResolveHoldDID(ctx, r.sqlDB)
if err != nil {
slog.Warn("Failed to resolve hold DID", "component", "storage/blobs", "error", err)
holdDID = r.userCtx.TargetHoldDID
}
if holdDID == "" {
// This should never happen if middleware is configured correctly
panic("hold DID not set in RegistryContext - ensure default_hold_did is configured in middleware")
panic("hold DID not set - ensure default_hold_did is configured in middleware")
}
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID, "source", holdSource)
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.userCtx.TargetOwnerDID, "repo", r.userCtx.TargetRepo, "hold", holdDID, "action", r.userCtx.Action.String())
// Update context with the correct hold DID (may be from database or discovered)
r.Ctx.HoldDID = holdDID
// Create and cache proxy blob store
r.blobStore = NewProxyBlobStore(r.Ctx)
blobStore := r.blobStore
r.mu.Unlock()
return blobStore
return NewProxyBlobStore(r.userCtx)
}
// Tags returns the tag service
// Tags are stored in ATProto as io.atcr.tag records
func (r *RoutingRepository) Tags(ctx context.Context) distribution.TagService {
return NewTagStore(r.Ctx.ATProtoClient, r.Ctx.Repository)
return NewTagStore(r.userCtx.GetATProtoClient(), r.userCtx.TargetRepo)
}
// Named returns a reference to the repository name.
// If the base repository is set, it delegates to the base.
// Otherwise, it constructs a name from the user context.
func (r *RoutingRepository) Named() reference.Named {
if r.Repository != nil {
return r.Repository.Named()
}
// Construct from user context
name, err := reference.WithName(r.userCtx.TargetRepo)
if err != nil {
// Fallback: return a simple reference
name, _ = reference.WithName("unknown")
}
return name
}

View File

@@ -2,273 +2,117 @@ package storage
import (
"context"
"sync"
"testing"
"github.com/distribution/distribution/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth"
)
// mockDatabase is a simple mock for testing
type mockDatabase struct {
holdDID string
err error
// mockUserContext creates a mock auth.UserContext for testing.
// It sets up both the user identity and target info, and configures
// test helpers to bypass network calls.
func mockUserContext(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID string) *auth.UserContext {
userCtx := auth.NewUserContext(did, authMethod, httpMethod, nil)
userCtx.SetTarget(targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID)
// Bypass PDS resolution (avoids network calls)
userCtx.SetPDSForTest(targetOwnerHandle, targetOwnerPDS)
// Set up mock authorizer that allows access
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
// Set default hold DID for push resolution
userCtx.SetDefaultHoldDIDForTest(targetHoldDID)
return userCtx
}
func (m *mockDatabase) IncrementPullCount(did, repository string) error {
return nil
}
func (m *mockDatabase) IncrementPushCount(did, repository string) error {
return nil
}
func (m *mockDatabase) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
if m.err != nil {
return "", m.err
}
return m.holdDID, nil
// mockUserContextWithToken creates a mock UserContext with a pre-populated service token.
func mockUserContextWithToken(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID, serviceToken string) *auth.UserContext {
userCtx := mockUserContext(did, authMethod, httpMethod, targetOwnerDID, targetOwnerHandle, targetOwnerPDS, targetRepo, targetHoldDID)
userCtx.SetServiceTokenForTest(targetHoldDID, serviceToken)
return userCtx
}
func TestNewRoutingRepository(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "debian",
HoldDID: "did:web:hold01.atcr.io",
ATProtoClient: &atproto.Client{},
userCtx := mockUserContext(
"did:plc:test123", // authenticated user
"oauth", // auth method
"GET", // HTTP method
"did:plc:test123", // target owner
"test.handle", // target owner handle
"https://pds.example.com", // target owner PDS
"debian", // repository
"did:web:hold01.atcr.io", // hold DID
)
repo := NewRoutingRepository(nil, userCtx, nil)
if repo.userCtx.TargetOwnerDID != "did:plc:test123" {
t.Errorf("Expected TargetOwnerDID %q, got %q", "did:plc:test123", repo.userCtx.TargetOwnerDID)
}
repo := NewRoutingRepository(nil, ctx)
if repo.Ctx.DID != "did:plc:test123" {
t.Errorf("Expected DID %q, got %q", "did:plc:test123", repo.Ctx.DID)
if repo.userCtx.TargetRepo != "debian" {
t.Errorf("Expected TargetRepo %q, got %q", "debian", repo.userCtx.TargetRepo)
}
if repo.Ctx.Repository != "debian" {
t.Errorf("Expected repository %q, got %q", "debian", repo.Ctx.Repository)
}
if repo.manifestStore != nil {
t.Error("Expected manifestStore to be nil initially")
}
if repo.blobStore != nil {
t.Error("Expected blobStore to be nil initially")
if repo.userCtx.TargetHoldDID != "did:web:hold01.atcr.io" {
t.Errorf("Expected TargetHoldDID %q, got %q", "did:web:hold01.atcr.io", repo.userCtx.TargetHoldDID)
}
}
// TestRoutingRepository_Manifests tests the Manifests() method
func TestRoutingRepository_Manifests(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp",
HoldDID: "did:web:hold01.atcr.io",
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
}
userCtx := mockUserContext(
"did:plc:test123",
"oauth",
"GET",
"did:plc:test123",
"test.handle",
"https://pds.example.com",
"myapp",
"did:web:hold01.atcr.io",
)
repo := NewRoutingRepository(nil, ctx)
repo := NewRoutingRepository(nil, userCtx, nil)
manifestService, err := repo.Manifests(context.Background())
require.NoError(t, err)
assert.NotNil(t, manifestService)
// Verify the manifest store is cached
assert.NotNil(t, repo.manifestStore, "manifest store should be cached")
// Call again and verify we get the same instance
manifestService2, err := repo.Manifests(context.Background())
require.NoError(t, err)
assert.Same(t, manifestService, manifestService2, "should return cached manifest store")
}
// TestRoutingRepository_ManifestStoreCaching tests that manifest store is cached
func TestRoutingRepository_ManifestStoreCaching(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp",
HoldDID: "did:web:hold01.atcr.io",
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
}
// TestRoutingRepository_Blobs tests the Blobs() method
func TestRoutingRepository_Blobs(t *testing.T) {
userCtx := mockUserContext(
"did:plc:test123",
"oauth",
"GET",
"did:plc:test123",
"test.handle",
"https://pds.example.com",
"myapp",
"did:web:hold01.atcr.io",
)
repo := NewRoutingRepository(nil, ctx)
// First call creates the store
store1, err := repo.Manifests(context.Background())
require.NoError(t, err)
assert.NotNil(t, store1)
// Second call returns cached store
store2, err := repo.Manifests(context.Background())
require.NoError(t, err)
assert.Same(t, store1, store2, "should return cached manifest store instance")
// Verify internal cache
assert.NotNil(t, repo.manifestStore)
}
// TestRoutingRepository_Blobs_PullUsesDatabase tests that GET (pull) uses database hold DID
func TestRoutingRepository_Blobs_PullUsesDatabase(t *testing.T) {
dbHoldDID := "did:web:database.hold.io"
discoveryHoldDID := "did:web:discovery.hold.io"
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp",
HoldDID: discoveryHoldDID, // Discovery-based hold (should be overridden for pull)
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: dbHoldDID},
}
repo := NewRoutingRepository(nil, ctx)
// Create context with GET method (pull operation)
pullCtx := context.WithValue(context.Background(), "http.request.method", "GET")
blobStore := repo.Blobs(pullCtx)
assert.NotNil(t, blobStore)
// Verify the hold DID was updated to use the database value for pull
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "pull (GET) should use database hold DID")
}
// TestRoutingRepository_Blobs_PushUsesDiscovery tests that push operations use discovery hold DID
func TestRoutingRepository_Blobs_PushUsesDiscovery(t *testing.T) {
dbHoldDID := "did:web:database.hold.io"
discoveryHoldDID := "did:web:discovery.hold.io"
testCases := []struct {
name string
method string
}{
{"PUT", "PUT"},
{"POST", "POST"},
{"HEAD", "HEAD"},
{"PATCH", "PATCH"},
{"DELETE", "DELETE"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp-" + tc.method, // Unique repo to avoid caching
HoldDID: discoveryHoldDID,
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: dbHoldDID},
}
repo := NewRoutingRepository(nil, ctx)
// Create context with push method
pushCtx := context.WithValue(context.Background(), "http.request.method", tc.method)
blobStore := repo.Blobs(pushCtx)
assert.NotNil(t, blobStore)
// Verify the hold DID remains the discovery-based one for push operations
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "%s should use discovery hold DID, not database", tc.method)
})
}
}
// TestRoutingRepository_Blobs_NoMethodUsesDiscovery tests that missing method defaults to discovery
func TestRoutingRepository_Blobs_NoMethodUsesDiscovery(t *testing.T) {
dbHoldDID := "did:web:database.hold.io"
discoveryHoldDID := "did:web:discovery.hold.io"
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp-nomethod",
HoldDID: discoveryHoldDID,
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: dbHoldDID},
}
repo := NewRoutingRepository(nil, ctx)
// Context without HTTP method (shouldn't happen in practice, but test defensive behavior)
repo := NewRoutingRepository(nil, userCtx, nil)
blobStore := repo.Blobs(context.Background())
assert.NotNil(t, blobStore)
// Without method, should default to discovery (safer for push scenarios)
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "missing method should use discovery hold DID")
}
// TestRoutingRepository_Blobs_WithoutDatabase tests blob store with discovery-based hold
func TestRoutingRepository_Blobs_WithoutDatabase(t *testing.T) {
discoveryHoldDID := "did:web:discovery.hold.io"
ctx := &RegistryContext{
DID: "did:plc:nocache456",
Repository: "uncached-app",
HoldDID: discoveryHoldDID,
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:nocache456", ""),
Database: nil, // No database
}
repo := NewRoutingRepository(nil, ctx)
blobStore := repo.Blobs(context.Background())
assert.NotNil(t, blobStore)
// Verify the hold DID remains the discovery-based one
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should use discovery-based hold DID")
}
// TestRoutingRepository_Blobs_DatabaseEmptyFallback tests fallback when database returns empty hold DID
func TestRoutingRepository_Blobs_DatabaseEmptyFallback(t *testing.T) {
discoveryHoldDID := "did:web:discovery.hold.io"
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "newapp",
HoldDID: discoveryHoldDID,
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: ""}, // Empty string (no manifests yet)
}
repo := NewRoutingRepository(nil, ctx)
blobStore := repo.Blobs(context.Background())
assert.NotNil(t, blobStore)
// Verify the hold DID falls back to discovery-based
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should fall back to discovery-based hold DID when database returns empty")
}
// TestRoutingRepository_BlobStoreCaching tests that blob store is cached
func TestRoutingRepository_BlobStoreCaching(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp",
HoldDID: "did:web:hold01.atcr.io",
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
}
repo := NewRoutingRepository(nil, ctx)
// First call creates the store
store1 := repo.Blobs(context.Background())
assert.NotNil(t, store1)
// Second call returns cached store
store2 := repo.Blobs(context.Background())
assert.Same(t, store1, store2, "should return cached blob store instance")
// Verify internal cache
assert.NotNil(t, repo.blobStore)
}
// TestRoutingRepository_Blobs_PanicOnEmptyHoldDID tests panic when hold DID is empty
func TestRoutingRepository_Blobs_PanicOnEmptyHoldDID(t *testing.T) {
// Use a unique DID/repo to ensure no cache entry exists
ctx := &RegistryContext{
DID: "did:plc:emptyholdtest999",
Repository: "empty-hold-app",
HoldDID: "", // Empty hold DID should panic
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:emptyholdtest999", ""),
}
// Create context without default hold and empty target hold
userCtx := auth.NewUserContext("did:plc:emptyholdtest999", "oauth", "GET", nil)
userCtx.SetTarget("did:plc:emptyholdtest999", "test.handle", "https://pds.example.com", "empty-hold-app", "")
userCtx.SetPDSForTest("test.handle", "https://pds.example.com")
userCtx.SetAuthorizerForTest(auth.NewMockHoldAuthorizer())
// Intentionally NOT setting default hold DID
repo := NewRoutingRepository(nil, ctx)
repo := NewRoutingRepository(nil, userCtx, nil)
// Should panic with empty hold DID
assert.Panics(t, func() {
@@ -278,106 +122,140 @@ func TestRoutingRepository_Blobs_PanicOnEmptyHoldDID(t *testing.T) {
// TestRoutingRepository_Tags tests the Tags() method
func TestRoutingRepository_Tags(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp",
HoldDID: "did:web:hold01.atcr.io",
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
}
userCtx := mockUserContext(
"did:plc:test123",
"oauth",
"GET",
"did:plc:test123",
"test.handle",
"https://pds.example.com",
"myapp",
"did:web:hold01.atcr.io",
)
repo := NewRoutingRepository(nil, ctx)
repo := NewRoutingRepository(nil, userCtx, nil)
tagService := repo.Tags(context.Background())
assert.NotNil(t, tagService)
// Call again and verify we get a new instance (Tags() doesn't cache)
// Call again and verify we get a fresh instance (no caching)
tagService2 := repo.Tags(context.Background())
assert.NotNil(t, tagService2)
// Tags service is not cached, so each call creates a new instance
}
// TestRoutingRepository_ConcurrentAccess tests concurrent access to cached stores
func TestRoutingRepository_ConcurrentAccess(t *testing.T) {
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp",
HoldDID: "did:web:hold01.atcr.io",
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
// TestRoutingRepository_UserContext tests that UserContext fields are properly set
func TestRoutingRepository_UserContext(t *testing.T) {
testCases := []struct {
name string
httpMethod string
expectedAction auth.RequestAction
}{
{"GET request is pull", "GET", auth.ActionPull},
{"HEAD request is pull", "HEAD", auth.ActionPull},
{"PUT request is push", "PUT", auth.ActionPush},
{"POST request is push", "POST", auth.ActionPush},
{"DELETE request is push", "DELETE", auth.ActionPush},
}
repo := NewRoutingRepository(nil, ctx)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
userCtx := mockUserContext(
"did:plc:test123",
"oauth",
tc.httpMethod,
"did:plc:test123",
"test.handle",
"https://pds.example.com",
"myapp",
"did:web:hold01.atcr.io",
)
var wg sync.WaitGroup
numGoroutines := 10
repo := NewRoutingRepository(nil, userCtx, nil)
// Track all manifest stores returned
manifestStores := make([]distribution.ManifestService, numGoroutines)
blobStores := make([]distribution.BlobStore, numGoroutines)
// Concurrent access to Manifests()
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
store, err := repo.Manifests(context.Background())
require.NoError(t, err)
manifestStores[index] = store
}(i)
assert.Equal(t, tc.expectedAction, repo.userCtx.Action, "action should match HTTP method")
})
}
wg.Wait()
// Verify all stores are non-nil (due to race conditions, they may not all be the same instance)
for i := 0; i < numGoroutines; i++ {
assert.NotNil(t, manifestStores[i], "manifest store should not be nil")
}
// After concurrent creation, subsequent calls should return the cached instance
cachedStore, err := repo.Manifests(context.Background())
require.NoError(t, err)
assert.NotNil(t, cachedStore)
// Concurrent access to Blobs()
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
blobStores[index] = repo.Blobs(context.Background())
}(i)
}
wg.Wait()
// Verify all stores are non-nil (due to race conditions, they may not all be the same instance)
for i := 0; i < numGoroutines; i++ {
assert.NotNil(t, blobStores[i], "blob store should not be nil")
}
// After concurrent creation, subsequent calls should return the cached instance
cachedBlobStore := repo.Blobs(context.Background())
assert.NotNil(t, cachedBlobStore)
}
// TestRoutingRepository_Blobs_PullPriority tests that database hold DID takes priority for pull (GET)
func TestRoutingRepository_Blobs_PullPriority(t *testing.T) {
dbHoldDID := "did:web:database.hold.io"
discoveryHoldDID := "did:web:discovery.hold.io"
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "myapp-priority",
HoldDID: discoveryHoldDID, // Discovery-based hold
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: dbHoldDID}, // Database has a different hold DID
// TestRoutingRepository_DifferentHoldDIDs tests routing with different hold DIDs
func TestRoutingRepository_DifferentHoldDIDs(t *testing.T) {
testCases := []struct {
name string
holdDID string
}{
{"did:web hold", "did:web:hold01.atcr.io"},
{"did:web with port", "did:web:localhost:8080"},
{"did:plc hold", "did:plc:xyz123"},
}
repo := NewRoutingRepository(nil, ctx)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
userCtx := mockUserContext(
"did:plc:test123",
"oauth",
"PUT",
"did:plc:test123",
"test.handle",
"https://pds.example.com",
"myapp",
tc.holdDID,
)
// For pull (GET), database should take priority
pullCtx := context.WithValue(context.Background(), "http.request.method", "GET")
blobStore := repo.Blobs(pullCtx)
repo := NewRoutingRepository(nil, userCtx, nil)
blobStore := repo.Blobs(context.Background())
assert.NotNil(t, blobStore)
// Database hold DID should take priority over discovery for pull operations
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "database hold DID should take priority over discovery for pull (GET)")
assert.NotNil(t, blobStore, "should create blob store for %s", tc.holdDID)
})
}
}
// TestRoutingRepository_Named tests the Named() method
func TestRoutingRepository_Named(t *testing.T) {
userCtx := mockUserContext(
"did:plc:test123",
"oauth",
"GET",
"did:plc:test123",
"test.handle",
"https://pds.example.com",
"myapp",
"did:web:hold01.atcr.io",
)
repo := NewRoutingRepository(nil, userCtx, nil)
// Named() returns a reference.Named from the base repository
// Since baseRepo is nil, this tests our implementation handles that case
named := repo.Named()
// With nil base, Named() should return a name constructed from context
assert.NotNil(t, named)
assert.Contains(t, named.Name(), "myapp")
}
// TestATProtoResolveHoldURL tests DID to URL resolution
func TestATProtoResolveHoldURL(t *testing.T) {
tests := []struct {
name string
holdDID string
expected string
}{
{
name: "did:web simple domain",
holdDID: "did:web:hold01.atcr.io",
expected: "https://hold01.atcr.io",
},
{
name: "did:web with port (localhost)",
holdDID: "did:web:localhost:8080",
expected: "http://localhost:8080",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := atproto.ResolveHoldURL(tt.holdDID)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -36,7 +36,7 @@ func (s *TagStore) Get(ctx context.Context, tag string) (distribution.Descriptor
return distribution.Descriptor{}, distribution.ErrTagUnknown{Tag: tag}
}
var tagRecord atproto.Tag
var tagRecord atproto.TagRecord
if err := json.Unmarshal(record.Value, &tagRecord); err != nil {
return distribution.Descriptor{}, fmt.Errorf("failed to unmarshal tag record: %w", err)
}
@@ -91,7 +91,7 @@ func (s *TagStore) All(ctx context.Context) ([]string, error) {
var tags []string
for _, record := range records {
var tagRecord atproto.Tag
var tagRecord atproto.TagRecord
if err := json.Unmarshal(record.Value, &tagRecord); err != nil {
// Skip invalid records
continue
@@ -116,7 +116,7 @@ func (s *TagStore) Lookup(ctx context.Context, desc distribution.Descriptor) ([]
var tags []string
for _, record := range records {
var tagRecord atproto.Tag
var tagRecord atproto.TagRecord
if err := json.Unmarshal(record.Value, &tagRecord); err != nil {
// Skip invalid records
continue

View File

@@ -229,7 +229,7 @@ func TestTagStore_Tag(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var sentTagRecord *atproto.Tag
var sentTagRecord *atproto.TagRecord
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
@@ -254,7 +254,7 @@ func TestTagStore_Tag(t *testing.T) {
// Parse and verify tag record
recordData := body["record"].(map[string]any)
recordBytes, _ := json.Marshal(recordData)
var tagRecord atproto.Tag
var tagRecord atproto.TagRecord
json.Unmarshal(recordBytes, &tagRecord)
sentTagRecord = &tagRecord
@@ -284,8 +284,8 @@ func TestTagStore_Tag(t *testing.T) {
if !tt.wantErr && sentTagRecord != nil {
// Verify the tag record
if sentTagRecord.LexiconTypeID != atproto.TagCollection {
t.Errorf("LexiconTypeID = %v, want %v", sentTagRecord.LexiconTypeID, atproto.TagCollection)
if sentTagRecord.Type != atproto.TagCollection {
t.Errorf("Type = %v, want %v", sentTagRecord.Type, atproto.TagCollection)
}
if sentTagRecord.Repository != "myapp" {
t.Errorf("Repository = %v, want myapp", sentTagRecord.Repository)
@@ -295,11 +295,11 @@ func TestTagStore_Tag(t *testing.T) {
}
// New records should have manifest field
expectedURI := atproto.BuildManifestURI("did:plc:test123", tt.digest.String())
if sentTagRecord.Manifest == nil || *sentTagRecord.Manifest != expectedURI {
if sentTagRecord.Manifest != expectedURI {
t.Errorf("Manifest = %v, want %v", sentTagRecord.Manifest, expectedURI)
}
// New records should NOT have manifestDigest field
if sentTagRecord.ManifestDigest != nil && *sentTagRecord.ManifestDigest != "" {
if sentTagRecord.ManifestDigest != "" {
t.Errorf("ManifestDigest should be empty for new records, got %v", sentTagRecord.ManifestDigest)
}
}

View File

@@ -0,0 +1,22 @@
{{ define "404" }}
<!DOCTYPE html>
<html lang="en">
<head>
<title>404 - Lost at Sea | ATCR</title>
{{ template "head" . }}
</head>
<body>
{{ template "nav-simple" . }}
<main class="error-page">
<div class="error-content">
<i data-lucide="anchor" class="error-icon"></i>
<div class="error-code">404</div>
<h1>Lost at Sea</h1>
<p>The page you're looking for has drifted into uncharted waters.</p>
<a href="/" class="btn btn-primary">Return to Port</a>
</div>
</main>
<script>lucide.createIcons();</script>
</body>
</html>
{{ end }}

View File

@@ -27,11 +27,20 @@
<!-- Repository Header -->
<div class="repository-header">
<div class="repo-hero">
{{ if .Repository.IconURL }}
<img src="{{ .Repository.IconURL }}" alt="{{ .Repository.Name }}" class="repo-hero-icon">
{{ else }}
<div class="repo-hero-icon-placeholder">{{ firstChar .Repository.Name }}</div>
{{ end }}
<div class="repo-hero-icon-wrapper">
{{ if .Repository.IconURL }}
<img src="{{ .Repository.IconURL }}" alt="{{ .Repository.Name }}" class="repo-hero-icon">
{{ else }}
<div class="repo-hero-icon-placeholder">{{ firstChar .Repository.Name }}</div>
{{ end }}
{{ if $.IsOwner }}
<label class="avatar-upload-overlay" for="avatar-upload">
<i data-lucide="plus"></i>
</label>
<input type="file" id="avatar-upload" accept="image/png,image/jpeg,image/webp"
onchange="uploadAvatar(this, '{{ .Repository.Name }}')" hidden>
{{ end }}
</div>
<div class="repo-hero-info">
<h1>
<a href="/u/{{ .Owner.Handle }}" class="owner-link">{{ .Owner.Handle }}</a>
@@ -130,6 +139,9 @@
{{ if .IsMultiArch }}
<span class="badge-multi">Multi-arch</span>
{{ end }}
{{ if .HasAttestations }}
<span class="badge-attestation"><i data-lucide="shield-check"></i> Attestations</span>
{{ end }}
</div>
<div style="display: flex; gap: 1rem; align-items: center;">
<time class="tag-timestamp" datetime="{{ .Tag.CreatedAt.Format "2006-01-02T15:04:05Z07:00" }}">

View File

@@ -44,15 +44,6 @@
</div>
{{ end }}
{{ if .HasMore }}
<button class="load-more"
hx-get="/api/recent-pushes?offset={{ .NextOffset }}"
hx-target="#push-list"
hx-swap="beforeend">
Load More
</button>
{{ end }}
{{ if eq (len .Pushes) 0 }}
<div class="empty-state">
<p>No pushes yet. Start using ATCR by pushing your first image!</p>

View File

@@ -1,65 +0,0 @@
package appview
import (
"testing"
"atcr.io/pkg/atproto"
)
func TestResolveHoldURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "DID with HTTPS domain",
input: "did:web:hold.example.com",
expected: "https://hold.example.com",
},
{
name: "DID with HTTP and port (IP)",
input: "did:web:172.28.0.3:8080",
expected: "http://172.28.0.3:8080",
},
{
name: "DID with HTTP and port (localhost)",
input: "did:web:127.0.0.1:8080",
expected: "http://127.0.0.1:8080",
},
{
name: "DID with localhost",
input: "did:web:localhost:8080",
expected: "http://localhost:8080",
},
{
name: "Already HTTPS URL (passthrough)",
input: "https://hold.example.com",
expected: "https://hold.example.com",
},
{
name: "Already HTTP URL (passthrough)",
input: "http://172.28.0.3:8080",
expected: "http://172.28.0.3:8080",
},
{
name: "Plain hostname (fallback to HTTPS)",
input: "hold.example.com",
expected: "https://hold.example.com",
},
{
name: "DID with subdomain",
input: "did:web:hold01.atcr.io",
expected: "https://hold01.atcr.io",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := atproto.ResolveHoldURL(tt.input)
if result != tt.expected {
t.Errorf("ResolveHoldURL(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -13,8 +13,6 @@ import (
"github.com/bluesky-social/indigo/atproto/atclient"
indigo_oauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
lexutil "github.com/bluesky-social/indigo/lex/util"
"github.com/ipfs/go-cid"
)
// Sentinel errors
@@ -303,7 +301,7 @@ type Link struct {
}
// UploadBlob uploads binary data to the PDS and returns a blob reference
func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (*lexutil.LexBlob, error) {
func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (*ATProtoBlobRef, error) {
// Use session provider (locked OAuth with DPoP) - prevents nonce races
if c.sessionProvider != nil {
var result struct {
@@ -312,12 +310,15 @@ func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (
err := c.sessionProvider.DoWithSession(ctx, c.did, func(session *indigo_oauth.ClientSession) error {
apiClient := session.APIClient()
// IMPORTANT: Use io.Reader for blob uploads
// LexDo JSON-encodes []byte (base64), but streams io.Reader as raw bytes
// Use the actual MIME type so PDS can validate against blob:image/* scope
return apiClient.LexDo(ctx,
"POST",
mimeType,
"com.atproto.repo.uploadBlob",
nil,
data,
bytes.NewReader(data),
&result,
)
})
@@ -325,7 +326,7 @@ func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (
return nil, fmt.Errorf("uploadBlob failed: %w", err)
}
return atProtoBlobRefToLexBlob(&result.Blob)
return &result.Blob, nil
}
// Basic Auth (app passwords)
@@ -356,22 +357,7 @@ func (c *Client) UploadBlob(ctx context.Context, data []byte, mimeType string) (
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return atProtoBlobRefToLexBlob(&result.Blob)
}
// atProtoBlobRefToLexBlob converts an ATProtoBlobRef to a lexutil.LexBlob
func atProtoBlobRefToLexBlob(ref *ATProtoBlobRef) (*lexutil.LexBlob, error) {
// Parse the CID string from the $link field
c, err := cid.Decode(ref.Ref.Link)
if err != nil {
return nil, fmt.Errorf("failed to parse blob CID %q: %w", ref.Ref.Link, err)
}
return &lexutil.LexBlob{
Ref: lexutil.LexLink(c),
MimeType: ref.MimeType,
Size: ref.Size,
}, nil
return &result.Blob, nil
}
// GetBlob downloads a blob by its CID from the PDS

View File

@@ -386,11 +386,11 @@ func TestUploadBlob(t *testing.T) {
t.Errorf("Content-Type = %v, want %v", r.Header.Get("Content-Type"), mimeType)
}
// Send response - use a valid CIDv1 in base32 format
// Send response
response := `{
"blob": {
"$type": "blob",
"ref": {"$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"},
"ref": {"$link": "bafytest123"},
"mimeType": "application/octet-stream",
"size": 17
}
@@ -406,14 +406,12 @@ func TestUploadBlob(t *testing.T) {
t.Fatalf("UploadBlob() error = %v", err)
}
if blobRef.MimeType != mimeType {
t.Errorf("MimeType = %v, want %v", blobRef.MimeType, mimeType)
if blobRef.Type != "blob" {
t.Errorf("Type = %v, want blob", blobRef.Type)
}
// LexBlob.Ref is a LexLink (cid.Cid alias), use .String() to get the CID string
expectedCID := "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
if blobRef.Ref.String() != expectedCID {
t.Errorf("Ref.String() = %v, want %v", blobRef.Ref.String(), expectedCID)
if blobRef.Ref.Link != "bafytest123" {
t.Errorf("Ref.Link = %v, want bafytest123", blobRef.Ref.Link)
}
if blobRef.Size != 17 {

View File

@@ -3,194 +3,17 @@
package main
// Lexicon and CBOR Code Generator
// CBOR Code Generator
//
// This generates:
// 1. Go types from lexicon JSON files (via lex/lexgen library)
// 2. CBOR marshaling code for ATProto records (via cbor-gen)
// 3. Type registration for lexutil (register.go)
// This generates optimized CBOR marshaling code for ATProto records.
//
// Usage:
// go generate ./pkg/atproto/...
//
// Key insight: We use RegisterLexiconTypeID: false to avoid generating init()
// blocks that require CBORMarshaler. This breaks the circular dependency between
// lexgen and cbor-gen. See: https://github.com/bluesky-social/indigo/issues/931
import (
"bytes"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/bluesky-social/indigo/atproto/lexicon"
"github.com/bluesky-social/indigo/lex/lexgen"
"golang.org/x/tools/imports"
)
func main() {
// Find repo root
repoRoot, err := findRepoRoot()
if err != nil {
fmt.Printf("failed to find repo root: %v\n", err)
os.Exit(1)
}
pkgDir := filepath.Join(repoRoot, "pkg/atproto")
lexDir := filepath.Join(repoRoot, "lexicons")
// Step 0: Clean up old register.go to avoid conflicts
// (It will be regenerated at the end)
os.Remove(filepath.Join(pkgDir, "register.go"))
// Step 1: Load all lexicon schemas into catalog (for cross-references)
fmt.Println("Loading lexicons...")
cat := lexicon.NewBaseCatalog()
if err := cat.LoadDirectory(lexDir); err != nil {
fmt.Printf("failed to load lexicons: %v\n", err)
os.Exit(1)
}
// Step 2: Generate Go code for each lexicon file
fmt.Println("Running lexgen...")
config := &lexgen.GenConfig{
RegisterLexiconTypeID: false, // KEY: no init() blocks generated
UnknownType: "map-string-any",
WarningText: "Code generated by generate.go; DO NOT EDIT.",
}
// Track generated types for register.go
var registeredTypes []typeInfo
// Walk lexicon directory and generate code for each file
err = filepath.Walk(lexDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || !strings.HasSuffix(path, ".json") {
return nil
}
// Load and parse the schema file
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read %s: %w", path, err)
}
var sf lexicon.SchemaFile
if err := json.Unmarshal(data, &sf); err != nil {
return fmt.Errorf("failed to parse %s: %w", path, err)
}
if err := sf.FinishParse(); err != nil {
return fmt.Errorf("failed to finish parse %s: %w", path, err)
}
// Flatten the schema
flat, err := lexgen.FlattenSchemaFile(&sf)
if err != nil {
return fmt.Errorf("failed to flatten schema %s: %w", path, err)
}
// Generate code
var buf bytes.Buffer
gen := &lexgen.CodeGenerator{
Config: config,
Lex: flat,
Cat: &cat,
Out: &buf,
}
if err := gen.WriteLexicon(); err != nil {
return fmt.Errorf("failed to generate code for %s: %w", path, err)
}
// Fix package name: lexgen generates "ioatcr" but we want "atproto"
code := bytes.Replace(buf.Bytes(), []byte("package ioatcr"), []byte("package atproto"), 1)
// Format with goimports
fileName := gen.FileName()
formatted, err := imports.Process(fileName, code, nil)
if err != nil {
// Write unformatted for debugging
outPath := filepath.Join(pkgDir, fileName)
os.WriteFile(outPath+".broken", code, 0644)
return fmt.Errorf("failed to format %s: %w (wrote to %s.broken)", fileName, err, outPath)
}
// Write output file
outPath := filepath.Join(pkgDir, fileName)
if err := os.WriteFile(outPath, formatted, 0644); err != nil {
return fmt.Errorf("failed to write %s: %w", outPath, err)
}
fmt.Printf(" Generated %s\n", fileName)
// Track type for registration - compute type name from NSID
typeName := nsidToTypeName(sf.ID)
registeredTypes = append(registeredTypes, typeInfo{
NSID: sf.ID,
TypeName: typeName,
})
return nil
})
if err != nil {
fmt.Printf("lexgen failed: %v\n", err)
os.Exit(1)
}
// Step 3: Run cbor-gen via exec.Command
// This must be a separate process so it can compile the freshly generated types
fmt.Println("Running cbor-gen...")
if err := runCborGen(repoRoot, pkgDir); err != nil {
fmt.Printf("cbor-gen failed: %v\n", err)
os.Exit(1)
}
// Step 4: Generate register.go
fmt.Println("Generating register.go...")
if err := generateRegisterFile(pkgDir, registeredTypes); err != nil {
fmt.Printf("failed to generate register.go: %v\n", err)
os.Exit(1)
}
fmt.Println("Code generation complete!")
}
type typeInfo struct {
NSID string
TypeName string
}
// nsidToTypeName converts an NSID to a Go type name
// io.atcr.manifest → Manifest
// io.atcr.hold.captain → HoldCaptain
// io.atcr.sailor.profile → SailorProfile
func nsidToTypeName(nsid string) string {
parts := strings.Split(nsid, ".")
if len(parts) < 3 {
return ""
}
// Skip the first two parts (authority, e.g., "io.atcr")
// and capitalize each remaining part
var result string
for _, part := range parts[2:] {
if len(part) > 0 {
result += strings.ToUpper(part[:1]) + part[1:]
}
}
return result
}
func runCborGen(repoRoot, pkgDir string) error {
// Create a temporary Go file that runs cbor-gen
cborGenCode := `//go:build ignore
package main
// This creates pkg/atproto/cbor_gen.go which should be committed to git.
// Only re-run when you modify types in pkg/atproto/types.go
//
// The //go:generate directive is in lexicon.go
import (
"fmt"
@@ -202,81 +25,14 @@ import (
)
func main() {
// Generate map-style encoders for CrewRecord, CaptainRecord, LayerRecord, and TangledProfileRecord
if err := cbg.WriteMapEncodersToFile("cbor_gen.go", "atproto",
// Manifest types
atproto.Manifest{},
atproto.Manifest_BlobReference{},
atproto.Manifest_ManifestReference{},
atproto.Manifest_Platform{},
atproto.Manifest_Annotations{},
atproto.Manifest_BlobReference_Annotations{},
atproto.Manifest_ManifestReference_Annotations{},
// Tag
atproto.Tag{},
// Sailor types
atproto.SailorProfile{},
atproto.SailorStar{},
atproto.SailorStar_Subject{},
// Hold types
atproto.HoldCaptain{},
atproto.HoldCrew{},
atproto.HoldLayer{},
// External types
atproto.CrewRecord{},
atproto.CaptainRecord{},
atproto.LayerRecord{},
atproto.TangledProfileRecord{},
); err != nil {
fmt.Printf("cbor-gen failed: %v\n", err)
fmt.Printf("Failed to generate CBOR encoders: %v\n", err)
os.Exit(1)
}
}
`
// Write temp file
tmpFile := filepath.Join(pkgDir, "cborgen_tmp.go")
if err := os.WriteFile(tmpFile, []byte(cborGenCode), 0644); err != nil {
return fmt.Errorf("failed to write temp cbor-gen file: %w", err)
}
defer os.Remove(tmpFile)
// Run it
cmd := exec.Command("go", "run", tmpFile)
cmd.Dir = pkgDir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func generateRegisterFile(pkgDir string, types []typeInfo) error {
var buf bytes.Buffer
buf.WriteString("// Code generated by generate.go; DO NOT EDIT.\n\n")
buf.WriteString("package atproto\n\n")
buf.WriteString("import lexutil \"github.com/bluesky-social/indigo/lex/util\"\n\n")
buf.WriteString("func init() {\n")
for _, t := range types {
fmt.Fprintf(&buf, "\tlexutil.RegisterType(%q, &%s{})\n", t.NSID, t.TypeName)
}
buf.WriteString("}\n")
outPath := filepath.Join(pkgDir, "register.go")
return os.WriteFile(outPath, buf.Bytes(), 0644)
}
func findRepoRoot() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return dir, nil
}
parent := filepath.Dir(dir)
if parent == dir {
return "", fmt.Errorf("go.mod not found")
}
dir = parent
}
}

View File

@@ -1,24 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
// Lexicon schema: io.atcr.hold.captain
package atproto
// Represents the hold's ownership and metadata. Stored as a singleton record at rkey 'self' in the hold's embedded PDS.
type HoldCaptain struct {
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.hold.captain"`
// allowAllCrew: Allow any authenticated user to register as crew
AllowAllCrew bool `json:"allowAllCrew" cborgen:"allowAllCrew"`
// deployedAt: RFC3339 timestamp of when the hold was deployed
DeployedAt string `json:"deployedAt" cborgen:"deployedAt"`
// enableBlueskyPosts: Enable Bluesky posts when manifests are pushed
EnableBlueskyPosts bool `json:"enableBlueskyPosts" cborgen:"enableBlueskyPosts"`
// owner: DID of the hold owner
Owner string `json:"owner" cborgen:"owner"`
// provider: Deployment provider (e.g., fly.io, aws, etc.)
Provider *string `json:"provider,omitempty" cborgen:"provider,omitempty"`
// public: Whether this hold allows public blob reads (pulls) without authentication
Public bool `json:"public" cborgen:"public"`
// region: S3 region where blobs are stored
Region *string `json:"region,omitempty" cborgen:"region,omitempty"`
}

View File

@@ -1,18 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
// Lexicon schema: io.atcr.hold.crew
package atproto
// Crew member in a hold's embedded PDS. Grants access permissions to push blobs to the hold. Stored in the hold's embedded PDS (one record per member).
type HoldCrew struct {
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.hold.crew"`
// addedAt: RFC3339 timestamp of when the member was added
AddedAt string `json:"addedAt" cborgen:"addedAt"`
// member: DID of the crew member
Member string `json:"member" cborgen:"member"`
// permissions: Specific permissions granted to this member
Permissions []string `json:"permissions" cborgen:"permissions"`
// role: Member's role in the hold
Role string `json:"role" cborgen:"role"`
}

View File

@@ -1,24 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
// Lexicon schema: io.atcr.hold.layer
package atproto
// Represents metadata about a container layer stored in the hold. Stored in the hold's embedded PDS for tracking and analytics.
type HoldLayer struct {
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.hold.layer"`
// createdAt: RFC3339 timestamp of when the layer was uploaded
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
// digest: Layer digest (e.g., sha256:abc123...)
Digest string `json:"digest" cborgen:"digest"`
// mediaType: Media type (e.g., application/vnd.oci.image.layer.v1.tar+gzip)
MediaType string `json:"mediaType" cborgen:"mediaType"`
// repository: Repository this layer belongs to
Repository string `json:"repository" cborgen:"repository"`
// size: Size in bytes
Size int64 `json:"size" cborgen:"size"`
// userDid: DID of user who uploaded this layer
UserDid string `json:"userDid" cborgen:"userDid"`
// userHandle: Handle of user (for display purposes)
UserHandle string `json:"userHandle" cborgen:"userHandle"`
}

View File

@@ -18,9 +18,6 @@ const (
// TagCollection is the collection name for image tags
TagCollection = "io.atcr.tag"
// HoldCollection is the collection name for storage holds (BYOS)
HoldCollection = "io.atcr.hold"
// HoldCrewCollection is the collection name for hold crew (membership) - LEGACY BYOS model
// Stored in owner's PDS for BYOS holds
HoldCrewCollection = "io.atcr.hold.crew"
@@ -41,9 +38,6 @@ const (
// TangledProfileCollection is the collection name for tangled profiles
// Stored in hold's embedded PDS (singleton record at rkey "self")
TangledProfileCollection = "sh.tangled.actor.profile"
// BskyPostCollection is the collection name for Bluesky posts
BskyPostCollection = "app.bsky.feed.post"
// BskyPostCollection is the collection name for Bluesky posts
BskyPostCollection = "app.bsky.feed.post"
@@ -53,6 +47,10 @@ const (
// StarCollection is the collection name for repository stars
StarCollection = "io.atcr.sailor.star"
// RepoPageCollection is the collection name for repository page metadata
// Stored in user's PDS with rkey = repository name
RepoPageCollection = "io.atcr.repo.page"
)
// ManifestRecord represents a container image manifest stored in ATProto
@@ -312,17 +310,6 @@ type HoldRecord struct {
CreatedAt time.Time `json:"createdAt"`
}
// NewHoldRecord creates a new hold record
func NewHoldRecord(endpoint, owner string, public bool) *HoldRecord {
return &HoldRecord{
Type: HoldCollection,
Endpoint: endpoint,
Owner: owner,
Public: public,
CreatedAt: time.Now(),
}
}
// SailorProfileRecord represents a user's profile with registry preferences
// Stored in the user's PDS to configure default hold and other settings
type SailorProfileRecord struct {
@@ -353,6 +340,42 @@ func NewSailorProfileRecord(defaultHold string) *SailorProfileRecord {
}
}
// RepoPageRecord represents repository page metadata (description + avatar)
// Stored in the user's PDS with rkey = repository name
// Users can edit this directly in their PDS to customize their repository page
type RepoPageRecord struct {
// Type should be "io.atcr.repo.page"
Type string `json:"$type"`
// Repository is the name of the repository (e.g., "myapp")
Repository string `json:"repository"`
// Description is the markdown README/description content
Description string `json:"description,omitempty"`
// Avatar is the repository avatar/icon blob reference
Avatar *ATProtoBlobRef `json:"avatar,omitempty"`
// CreatedAt timestamp
CreatedAt time.Time `json:"createdAt"`
// UpdatedAt timestamp
UpdatedAt time.Time `json:"updatedAt"`
}
// NewRepoPageRecord creates a new repo page record
func NewRepoPageRecord(repository, description string, avatar *ATProtoBlobRef) *RepoPageRecord {
now := time.Now()
return &RepoPageRecord{
Type: RepoPageCollection,
Repository: repository,
Description: description,
Avatar: avatar,
CreatedAt: now,
UpdatedAt: now,
}
}
// StarSubject represents the subject of a star (the repository being starred)
type StarSubject struct {
// DID is the DID of the repository owner

View File

@@ -1,18 +0,0 @@
package atproto
// This file contains ATProto record types that are NOT generated from our lexicons.
// These are either external schemas or special types that require manual definition.
// TangledProfileRecord represents a Tangled profile for the hold
// Collection: sh.tangled.actor.profile (external schema - not controlled by ATCR)
// Stored in hold's embedded PDS (singleton record at rkey "self")
// Uses CBOR encoding for efficient storage in hold's carstore
type TangledProfileRecord struct {
Type string `json:"$type" cborgen:"$type"`
Links []string `json:"links" cborgen:"links"`
Stats []string `json:"stats" cborgen:"stats"`
Bluesky bool `json:"bluesky" cborgen:"bluesky"`
Location string `json:"location" cborgen:"location"`
Description string `json:"description" cborgen:"description"`
PinnedRepositories []string `json:"pinnedRepositories" cborgen:"pinnedRepositories"`
}

View File

@@ -1,360 +0,0 @@
package atproto
//go:generate go run generate.go
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
)
// Collection names for ATProto records
const (
// ManifestCollection is the collection name for container manifests
ManifestCollection = "io.atcr.manifest"
// TagCollection is the collection name for image tags
TagCollection = "io.atcr.tag"
// HoldCollection is the collection name for storage holds (BYOS) - LEGACY
HoldCollection = "io.atcr.hold"
// HoldCrewCollection is the collection name for hold crew (membership) - LEGACY BYOS model
// Stored in owner's PDS for BYOS holds
HoldCrewCollection = "io.atcr.hold.crew"
// CaptainCollection is the collection name for captain records (hold ownership) - EMBEDDED PDS model
// Stored in hold's embedded PDS (singleton record at rkey "self")
CaptainCollection = "io.atcr.hold.captain"
// CrewCollection is the collection name for crew records (access control) - EMBEDDED PDS model
// Stored in hold's embedded PDS (one record per member)
// Note: Uses same collection name as HoldCrewCollection but stored in different PDS (hold's PDS vs owner's PDS)
CrewCollection = "io.atcr.hold.crew"
// LayerCollection is the collection name for container layer metadata
// Stored in hold's embedded PDS to track which layers are stored
LayerCollection = "io.atcr.hold.layer"
// TangledProfileCollection is the collection name for tangled profiles
// Stored in hold's embedded PDS (singleton record at rkey "self")
TangledProfileCollection = "sh.tangled.actor.profile"
// BskyPostCollection is the collection name for Bluesky posts
BskyPostCollection = "app.bsky.feed.post"
// SailorProfileCollection is the collection name for user profiles
SailorProfileCollection = "io.atcr.sailor.profile"
// StarCollection is the collection name for repository stars
StarCollection = "io.atcr.sailor.star"
)
// NewManifestRecord creates a new manifest record from OCI manifest JSON
func NewManifestRecord(repository, digest string, ociManifest []byte) (*Manifest, error) {
// Parse the OCI manifest
var ociData struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config json.RawMessage `json:"config,omitempty"`
Layers []json.RawMessage `json:"layers,omitempty"`
Manifests []json.RawMessage `json:"manifests,omitempty"`
Subject json.RawMessage `json:"subject,omitempty"`
Annotations map[string]string `json:"annotations,omitempty"`
}
if err := json.Unmarshal(ociManifest, &ociData); err != nil {
return nil, err
}
// Detect manifest type based on media type
isManifestList := strings.Contains(ociData.MediaType, "manifest.list") ||
strings.Contains(ociData.MediaType, "image.index")
// Validate: must have either (config+layers) OR (manifests), never both
hasImageFields := len(ociData.Config) > 0 || len(ociData.Layers) > 0
hasIndexFields := len(ociData.Manifests) > 0
if hasImageFields && hasIndexFields {
return nil, fmt.Errorf("manifest cannot have both image fields (config/layers) and index fields (manifests)")
}
if !hasImageFields && !hasIndexFields {
return nil, fmt.Errorf("manifest must have either image fields (config/layers) or index fields (manifests)")
}
record := &Manifest{
LexiconTypeID: ManifestCollection,
Repository: repository,
Digest: digest,
MediaType: ociData.MediaType,
SchemaVersion: int64(ociData.SchemaVersion),
// ManifestBlob will be set by the caller after uploading to blob storage
CreatedAt: time.Now().Format(time.RFC3339),
}
// Handle annotations - Manifest_Annotations is an empty struct in generated code
// We don't copy ociData.Annotations since the generated type doesn't support arbitrary keys
if isManifestList {
// Parse manifest list/index
record.Manifests = make([]Manifest_ManifestReference, len(ociData.Manifests))
for i, m := range ociData.Manifests {
var ref struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Platform *Manifest_Platform `json:"platform,omitempty"`
Annotations map[string]string `json:"annotations,omitempty"`
}
if err := json.Unmarshal(m, &ref); err != nil {
return nil, fmt.Errorf("failed to parse manifest reference %d: %w", i, err)
}
record.Manifests[i] = Manifest_ManifestReference{
MediaType: ref.MediaType,
Digest: ref.Digest,
Size: ref.Size,
Platform: ref.Platform,
}
}
} else {
// Parse image manifest
if len(ociData.Config) > 0 {
var config Manifest_BlobReference
if err := json.Unmarshal(ociData.Config, &config); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
record.Config = &config
}
// Parse layers
record.Layers = make([]Manifest_BlobReference, len(ociData.Layers))
for i, layer := range ociData.Layers {
if err := json.Unmarshal(layer, &record.Layers[i]); err != nil {
return nil, fmt.Errorf("failed to parse layer %d: %w", i, err)
}
}
}
// Parse subject if present (works for both types)
if len(ociData.Subject) > 0 {
var subject Manifest_BlobReference
if err := json.Unmarshal(ociData.Subject, &subject); err != nil {
return nil, err
}
record.Subject = &subject
}
return record, nil
}
// NewTagRecord creates a new tag record with manifest AT-URI
// did: The DID of the user (e.g., "did:plc:xyz123")
// repository: The repository name (e.g., "myapp")
// tag: The tag name (e.g., "latest", "v1.0.0")
// manifestDigest: The manifest digest (e.g., "sha256:abc123...")
func NewTagRecord(did, repository, tag, manifestDigest string) *Tag {
// Build AT-URI for the manifest
// Format: at://did:plc:xyz/io.atcr.manifest/<digest-without-sha256-prefix>
manifestURI := BuildManifestURI(did, manifestDigest)
return &Tag{
LexiconTypeID: TagCollection,
Repository: repository,
Tag: tag,
Manifest: &manifestURI,
// Note: ManifestDigest is not set for new records (only for backward compat with old records)
CreatedAt: time.Now().Format(time.RFC3339),
}
}
// NewSailorProfileRecord creates a new sailor profile record
func NewSailorProfileRecord(defaultHold string) *SailorProfile {
now := time.Now().Format(time.RFC3339)
var holdPtr *string
if defaultHold != "" {
holdPtr = &defaultHold
}
return &SailorProfile{
LexiconTypeID: SailorProfileCollection,
DefaultHold: holdPtr,
CreatedAt: now,
UpdatedAt: &now,
}
}
// NewStarRecord creates a new star record
func NewStarRecord(ownerDID, repository string) *SailorStar {
return &SailorStar{
LexiconTypeID: StarCollection,
Subject: SailorStar_Subject{
Did: ownerDID,
Repository: repository,
},
CreatedAt: time.Now().Format(time.RFC3339),
}
}
// NewLayerRecord creates a new layer record
func NewLayerRecord(digest string, size int64, mediaType, repository, userDID, userHandle string) *HoldLayer {
return &HoldLayer{
LexiconTypeID: LayerCollection,
Digest: digest,
Size: size,
MediaType: mediaType,
Repository: repository,
UserDid: userDID,
UserHandle: userHandle,
CreatedAt: time.Now().Format(time.RFC3339),
}
}
// StarRecordKey generates a record key for a star
// Uses a simple hash to ensure uniqueness and prevent duplicate stars
func StarRecordKey(ownerDID, repository string) string {
// Use base64 encoding of "ownerDID/repository" as the record key
// This is deterministic and prevents duplicate stars
combined := ownerDID + "/" + repository
return base64.RawURLEncoding.EncodeToString([]byte(combined))
}
// ParseStarRecordKey decodes a star record key back to ownerDID and repository
func ParseStarRecordKey(rkey string) (ownerDID, repository string, err error) {
decoded, err := base64.RawURLEncoding.DecodeString(rkey)
if err != nil {
return "", "", fmt.Errorf("failed to decode star rkey: %w", err)
}
parts := strings.SplitN(string(decoded), "/", 2)
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid star rkey format: %s", string(decoded))
}
return parts[0], parts[1], nil
}
// ResolveHoldDIDFromURL converts a hold endpoint URL to a did:web DID
// This ensures that different representations of the same hold are deduplicated:
// - http://172.28.0.3:8080 → did:web:172.28.0.3:8080
// - http://hold01.atcr.io → did:web:hold01.atcr.io
// - https://hold01.atcr.io → did:web:hold01.atcr.io
// - did:web:hold01.atcr.io → did:web:hold01.atcr.io (passthrough)
func ResolveHoldDIDFromURL(holdURL string) string {
// Handle empty URLs
if holdURL == "" {
return ""
}
// If already a DID, return as-is
if IsDID(holdURL) {
return holdURL
}
// Parse URL to get hostname
holdURL = strings.TrimPrefix(holdURL, "http://")
holdURL = strings.TrimPrefix(holdURL, "https://")
holdURL = strings.TrimSuffix(holdURL, "/")
// Extract hostname (remove path if present)
parts := strings.Split(holdURL, "/")
hostname := parts[0]
// Convert to did:web
// did:web uses hostname directly (port included if non-standard)
return "did:web:" + hostname
}
// IsDID checks if a string is a DID (starts with "did:")
func IsDID(s string) bool {
return len(s) > 4 && s[:4] == "did:"
}
// RepositoryTagToRKey converts a repository and tag to an ATProto record key
// ATProto record keys must match: ^[a-zA-Z0-9._~-]{1,512}$
func RepositoryTagToRKey(repository, tag string) string {
// Combine repository and tag to create a unique key
// Replace invalid characters: slashes become tildes (~)
// We use tilde instead of dash to avoid ambiguity with repository names that contain hyphens
key := fmt.Sprintf("%s_%s", repository, tag)
// Replace / with ~ (slash not allowed in rkeys, tilde is allowed and unlikely in repo names)
key = strings.ReplaceAll(key, "/", "~")
return key
}
// RKeyToRepositoryTag converts an ATProto record key back to repository and tag
// This is the inverse of RepositoryTagToRKey
// Note: If the tag contains underscores, this will split on the LAST underscore
func RKeyToRepositoryTag(rkey string) (repository, tag string) {
// Find the last underscore to split repository and tag
lastUnderscore := strings.LastIndex(rkey, "_")
if lastUnderscore == -1 {
// No underscore found - treat entire string as tag with empty repository
return "", rkey
}
repository = rkey[:lastUnderscore]
tag = rkey[lastUnderscore+1:]
// Convert tildes back to slashes in repository (tilde was used to encode slashes)
repository = strings.ReplaceAll(repository, "~", "/")
return repository, tag
}
// BuildManifestURI creates an AT-URI for a manifest record
// did: The DID of the user (e.g., "did:plc:xyz123")
// manifestDigest: The manifest digest (e.g., "sha256:abc123...")
// Returns: AT-URI in format "at://did:plc:xyz/io.atcr.manifest/<digest-without-sha256-prefix>"
func BuildManifestURI(did, manifestDigest string) string {
// Remove the "sha256:" prefix from the digest to get the rkey
rkey := strings.TrimPrefix(manifestDigest, "sha256:")
return fmt.Sprintf("at://%s/%s/%s", did, ManifestCollection, rkey)
}
// ParseManifestURI extracts the digest from a manifest AT-URI
// manifestURI: AT-URI in format "at://did:plc:xyz/io.atcr.manifest/<digest-without-sha256-prefix>"
// Returns: Full digest with "sha256:" prefix (e.g., "sha256:abc123...")
func ParseManifestURI(manifestURI string) (string, error) {
// Expected format: at://did:plc:xyz/io.atcr.manifest/<rkey>
if !strings.HasPrefix(manifestURI, "at://") {
return "", fmt.Errorf("invalid AT-URI format: must start with 'at://'")
}
// Remove "at://" prefix
remainder := strings.TrimPrefix(manifestURI, "at://")
// Split by "/"
parts := strings.Split(remainder, "/")
if len(parts) != 3 {
return "", fmt.Errorf("invalid AT-URI format: expected 3 parts (did/collection/rkey), got %d", len(parts))
}
// Validate collection
if parts[1] != ManifestCollection {
return "", fmt.Errorf("invalid AT-URI: expected collection %s, got %s", ManifestCollection, parts[1])
}
// The rkey is the digest without the "sha256:" prefix
// Add it back to get the full digest
rkey := parts[2]
return "sha256:" + rkey, nil
}
// GetManifestDigest extracts the digest from a Tag, preferring the manifest field
// Returns the digest with "sha256:" prefix (e.g., "sha256:abc123...")
func (t *Tag) GetManifestDigest() (string, error) {
// Prefer the new manifest field
if t.Manifest != nil && *t.Manifest != "" {
return ParseManifestURI(*t.Manifest)
}
// Fall back to the legacy manifestDigest field
if t.ManifestDigest != nil && *t.ManifestDigest != "" {
return *t.ManifestDigest, nil
}
return "", fmt.Errorf("tag record has neither manifest nor manifestDigest field")
}

View File

@@ -104,7 +104,7 @@ func TestNewManifestRecord(t *testing.T) {
digest string
ociManifest string
wantErr bool
checkFunc func(*testing.T, *Manifest)
checkFunc func(*testing.T, *ManifestRecord)
}{
{
name: "valid OCI manifest",
@@ -112,9 +112,9 @@ func TestNewManifestRecord(t *testing.T) {
digest: "sha256:abc123",
ociManifest: validOCIManifest,
wantErr: false,
checkFunc: func(t *testing.T, record *Manifest) {
if record.LexiconTypeID != ManifestCollection {
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, ManifestCollection)
checkFunc: func(t *testing.T, record *ManifestRecord) {
if record.Type != ManifestCollection {
t.Errorf("Type = %v, want %v", record.Type, ManifestCollection)
}
if record.Repository != "myapp" {
t.Errorf("Repository = %v, want myapp", record.Repository)
@@ -143,9 +143,11 @@ func TestNewManifestRecord(t *testing.T) {
if record.Layers[1].Digest != "sha256:layer2" {
t.Errorf("Layers[1].Digest = %v, want sha256:layer2", record.Layers[1].Digest)
}
// Note: Annotations are not copied to generated type (empty struct)
if record.CreatedAt == "" {
t.Error("CreatedAt should not be empty")
if record.Annotations["org.opencontainers.image.created"] != "2025-01-01T00:00:00Z" {
t.Errorf("Annotations missing expected key")
}
if record.CreatedAt.IsZero() {
t.Error("CreatedAt should not be zero")
}
if record.Subject != nil {
t.Error("Subject should be nil")
@@ -158,7 +160,7 @@ func TestNewManifestRecord(t *testing.T) {
digest: "sha256:abc123",
ociManifest: manifestWithSubject,
wantErr: false,
checkFunc: func(t *testing.T, record *Manifest) {
checkFunc: func(t *testing.T, record *ManifestRecord) {
if record.Subject == nil {
t.Fatal("Subject should not be nil")
}
@@ -190,7 +192,7 @@ func TestNewManifestRecord(t *testing.T) {
digest: "sha256:multiarch",
ociManifest: manifestList,
wantErr: false,
checkFunc: func(t *testing.T, record *Manifest) {
checkFunc: func(t *testing.T, record *ManifestRecord) {
if record.MediaType != "application/vnd.oci.image.index.v1+json" {
t.Errorf("MediaType = %v, want application/vnd.oci.image.index.v1+json", record.MediaType)
}
@@ -217,8 +219,8 @@ func TestNewManifestRecord(t *testing.T) {
if record.Manifests[0].Platform.Architecture != "amd64" {
t.Errorf("Platform.Architecture = %v, want amd64", record.Manifests[0].Platform.Architecture)
}
if record.Manifests[0].Platform.Os != "linux" {
t.Errorf("Platform.Os = %v, want linux", record.Manifests[0].Platform.Os)
if record.Manifests[0].Platform.OS != "linux" {
t.Errorf("Platform.OS = %v, want linux", record.Manifests[0].Platform.OS)
}
// Check second manifest (arm64)
@@ -228,7 +230,7 @@ func TestNewManifestRecord(t *testing.T) {
if record.Manifests[1].Platform.Architecture != "arm64" {
t.Errorf("Platform.Architecture = %v, want arm64", record.Manifests[1].Platform.Architecture)
}
if record.Manifests[1].Platform.Variant == nil || *record.Manifests[1].Platform.Variant != "v8" {
if record.Manifests[1].Platform.Variant != "v8" {
t.Errorf("Platform.Variant = %v, want v8", record.Manifests[1].Platform.Variant)
}
},
@@ -266,13 +268,12 @@ func TestNewManifestRecord(t *testing.T) {
func TestNewTagRecord(t *testing.T) {
did := "did:plc:test123"
// Truncate to second precision since RFC3339 doesn't have sub-second precision
before := time.Now().Truncate(time.Second)
before := time.Now()
record := NewTagRecord(did, "myapp", "latest", "sha256:abc123")
after := time.Now().Truncate(time.Second).Add(time.Second)
after := time.Now()
if record.LexiconTypeID != TagCollection {
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, TagCollection)
if record.Type != TagCollection {
t.Errorf("Type = %v, want %v", record.Type, TagCollection)
}
if record.Repository != "myapp" {
@@ -285,21 +286,17 @@ func TestNewTagRecord(t *testing.T) {
// New records should have manifest field (AT-URI)
expectedURI := "at://did:plc:test123/io.atcr.manifest/abc123"
if record.Manifest == nil || *record.Manifest != expectedURI {
if record.Manifest != expectedURI {
t.Errorf("Manifest = %v, want %v", record.Manifest, expectedURI)
}
// New records should NOT have manifestDigest field
if record.ManifestDigest != nil && *record.ManifestDigest != "" {
t.Errorf("ManifestDigest should be nil for new records, got %v", record.ManifestDigest)
if record.ManifestDigest != "" {
t.Errorf("ManifestDigest should be empty for new records, got %v", record.ManifestDigest)
}
createdAt, err := time.Parse(time.RFC3339, record.CreatedAt)
if err != nil {
t.Errorf("CreatedAt is not valid RFC3339: %v", err)
}
if createdAt.Before(before) || createdAt.After(after) {
t.Errorf("CreatedAt = %v, want between %v and %v", createdAt, before, after)
if record.UpdatedAt.Before(before) || record.UpdatedAt.After(after) {
t.Errorf("UpdatedAt = %v, want between %v and %v", record.UpdatedAt, before, after)
}
}
@@ -394,50 +391,47 @@ func TestParseManifestURI(t *testing.T) {
}
func TestTagRecord_GetManifestDigest(t *testing.T) {
manifestURI := "at://did:plc:test123/io.atcr.manifest/abc123"
digestValue := "sha256:def456"
tests := []struct {
name string
record Tag
record TagRecord
want string
wantErr bool
}{
{
name: "new record with manifest field",
record: Tag{
Manifest: &manifestURI,
record: TagRecord{
Manifest: "at://did:plc:test123/io.atcr.manifest/abc123",
},
want: "sha256:abc123",
wantErr: false,
},
{
name: "old record with manifestDigest field",
record: Tag{
ManifestDigest: &digestValue,
record: TagRecord{
ManifestDigest: "sha256:def456",
},
want: "sha256:def456",
wantErr: false,
},
{
name: "prefers manifest over manifestDigest",
record: Tag{
Manifest: &manifestURI,
ManifestDigest: &digestValue,
record: TagRecord{
Manifest: "at://did:plc:test123/io.atcr.manifest/abc123",
ManifestDigest: "sha256:def456",
},
want: "sha256:abc123",
wantErr: false,
},
{
name: "no fields set",
record: Tag{},
record: TagRecord{},
want: "",
wantErr: true,
},
{
name: "invalid manifest URI",
record: Tag{
Manifest: func() *string { s := "invalid-uri"; return &s }(),
record: TagRecord{
Manifest: "invalid-uri",
},
want: "",
wantErr: true,
@@ -458,8 +452,6 @@ func TestTagRecord_GetManifestDigest(t *testing.T) {
}
}
// TestNewHoldRecord is removed - HoldRecord is no longer supported (legacy BYOS)
func TestNewSailorProfileRecord(t *testing.T) {
tests := []struct {
name string
@@ -481,72 +473,53 @@ func TestNewSailorProfileRecord(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Truncate to second precision since RFC3339 doesn't have sub-second precision
before := time.Now().Truncate(time.Second)
before := time.Now()
record := NewSailorProfileRecord(tt.defaultHold)
after := time.Now().Truncate(time.Second).Add(time.Second)
after := time.Now()
if record.LexiconTypeID != SailorProfileCollection {
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, SailorProfileCollection)
if record.Type != SailorProfileCollection {
t.Errorf("Type = %v, want %v", record.Type, SailorProfileCollection)
}
if tt.defaultHold == "" {
if record.DefaultHold != nil {
t.Errorf("DefaultHold = %v, want nil", record.DefaultHold)
}
} else {
if record.DefaultHold == nil || *record.DefaultHold != tt.defaultHold {
t.Errorf("DefaultHold = %v, want %v", record.DefaultHold, tt.defaultHold)
}
if record.DefaultHold != tt.defaultHold {
t.Errorf("DefaultHold = %v, want %v", record.DefaultHold, tt.defaultHold)
}
createdAt, err := time.Parse(time.RFC3339, record.CreatedAt)
if err != nil {
t.Errorf("CreatedAt is not valid RFC3339: %v", err)
}
if createdAt.Before(before) || createdAt.After(after) {
t.Errorf("CreatedAt = %v, want between %v and %v", createdAt, before, after)
if record.CreatedAt.Before(before) || record.CreatedAt.After(after) {
t.Errorf("CreatedAt = %v, want between %v and %v", record.CreatedAt, before, after)
}
if record.UpdatedAt == nil {
t.Error("UpdatedAt should not be nil")
} else {
updatedAt, err := time.Parse(time.RFC3339, *record.UpdatedAt)
if err != nil {
t.Errorf("UpdatedAt is not valid RFC3339: %v", err)
}
if updatedAt.Before(before) || updatedAt.After(after) {
t.Errorf("UpdatedAt = %v, want between %v and %v", updatedAt, before, after)
}
if record.UpdatedAt.Before(before) || record.UpdatedAt.After(after) {
t.Errorf("UpdatedAt = %v, want between %v and %v", record.UpdatedAt, before, after)
}
// CreatedAt and UpdatedAt should be equal for new records
if !record.CreatedAt.Equal(record.UpdatedAt) {
t.Errorf("CreatedAt (%v) != UpdatedAt (%v)", record.CreatedAt, record.UpdatedAt)
}
})
}
}
func TestNewStarRecord(t *testing.T) {
// Truncate to second precision since RFC3339 doesn't have sub-second precision
before := time.Now().Truncate(time.Second)
before := time.Now()
record := NewStarRecord("did:plc:alice123", "myapp")
after := time.Now().Truncate(time.Second).Add(time.Second)
after := time.Now()
if record.LexiconTypeID != StarCollection {
t.Errorf("LexiconTypeID = %v, want %v", record.LexiconTypeID, StarCollection)
if record.Type != StarCollection {
t.Errorf("Type = %v, want %v", record.Type, StarCollection)
}
if record.Subject.Did != "did:plc:alice123" {
t.Errorf("Subject.Did = %v, want did:plc:alice123", record.Subject.Did)
if record.Subject.DID != "did:plc:alice123" {
t.Errorf("Subject.DID = %v, want did:plc:alice123", record.Subject.DID)
}
if record.Subject.Repository != "myapp" {
t.Errorf("Subject.Repository = %v, want myapp", record.Subject.Repository)
}
createdAt, err := time.Parse(time.RFC3339, record.CreatedAt)
if err != nil {
t.Errorf("CreatedAt is not valid RFC3339: %v", err)
}
if createdAt.Before(before) || createdAt.After(after) {
t.Errorf("CreatedAt = %v, want between %v and %v", createdAt, before, after)
if record.CreatedAt.Before(before) || record.CreatedAt.After(after) {
t.Errorf("CreatedAt = %v, want between %v and %v", record.CreatedAt, before, after)
}
}
@@ -834,8 +807,7 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
}
// Add hold DID
holdDID := "did:web:hold01.atcr.io"
record.HoldDid = &holdDID
record.HoldDID = "did:web:hold01.atcr.io"
// Serialize to JSON
jsonData, err := json.Marshal(record)
@@ -844,14 +816,14 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
}
// Deserialize from JSON
var decoded Manifest
var decoded ManifestRecord
if err := json.Unmarshal(jsonData, &decoded); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Verify fields
if decoded.LexiconTypeID != record.LexiconTypeID {
t.Errorf("LexiconTypeID = %v, want %v", decoded.LexiconTypeID, record.LexiconTypeID)
if decoded.Type != record.Type {
t.Errorf("Type = %v, want %v", decoded.Type, record.Type)
}
if decoded.Repository != record.Repository {
t.Errorf("Repository = %v, want %v", decoded.Repository, record.Repository)
@@ -859,8 +831,8 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
if decoded.Digest != record.Digest {
t.Errorf("Digest = %v, want %v", decoded.Digest, record.Digest)
}
if decoded.HoldDid == nil || *decoded.HoldDid != *record.HoldDid {
t.Errorf("HoldDid = %v, want %v", decoded.HoldDid, record.HoldDid)
if decoded.HoldDID != record.HoldDID {
t.Errorf("HoldDID = %v, want %v", decoded.HoldDID, record.HoldDID)
}
if decoded.Config.Digest != record.Config.Digest {
t.Errorf("Config.Digest = %v, want %v", decoded.Config.Digest, record.Config.Digest)
@@ -871,12 +843,14 @@ func TestManifestRecord_JSONSerialization(t *testing.T) {
}
func TestBlobReference_JSONSerialization(t *testing.T) {
blob := Manifest_BlobReference{
blob := BlobReference{
MediaType: "application/vnd.oci.image.layer.v1.tar+gzip",
Digest: "sha256:abc123",
Size: 12345,
Urls: []string{"https://s3.example.com/blob"},
// Note: Annotations is now an empty struct, not a map
URLs: []string{"https://s3.example.com/blob"},
Annotations: map[string]string{
"key": "value",
},
}
// Serialize
@@ -886,7 +860,7 @@ func TestBlobReference_JSONSerialization(t *testing.T) {
}
// Deserialize
var decoded Manifest_BlobReference
var decoded BlobReference
if err := json.Unmarshal(jsonData, &decoded); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
@@ -904,8 +878,8 @@ func TestBlobReference_JSONSerialization(t *testing.T) {
}
func TestStarSubject_JSONSerialization(t *testing.T) {
subject := SailorStar_Subject{
Did: "did:plc:alice123",
subject := StarSubject{
DID: "did:plc:alice123",
Repository: "myapp",
}
@@ -916,14 +890,14 @@ func TestStarSubject_JSONSerialization(t *testing.T) {
}
// Deserialize
var decoded SailorStar_Subject
var decoded StarSubject
if err := json.Unmarshal(jsonData, &decoded); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Verify
if decoded.Did != subject.Did {
t.Errorf("Did = %v, want %v", decoded.Did, subject.Did)
if decoded.DID != subject.DID {
t.Errorf("DID = %v, want %v", decoded.DID, subject.DID)
}
if decoded.Repository != subject.Repository {
t.Errorf("Repository = %v, want %v", decoded.Repository, subject.Repository)
@@ -1170,8 +1144,8 @@ func TestNewLayerRecord(t *testing.T) {
t.Fatal("NewLayerRecord() returned nil")
}
if record.LexiconTypeID != LayerCollection {
t.Errorf("LexiconTypeID = %q, want %q", record.LexiconTypeID, LayerCollection)
if record.Type != LayerCollection {
t.Errorf("Type = %q, want %q", record.Type, LayerCollection)
}
if record.Digest != tt.digest {
@@ -1190,8 +1164,8 @@ func TestNewLayerRecord(t *testing.T) {
t.Errorf("Repository = %q, want %q", record.Repository, tt.repository)
}
if record.UserDid != tt.userDID {
t.Errorf("UserDid = %q, want %q", record.UserDid, tt.userDID)
if record.UserDID != tt.userDID {
t.Errorf("UserDID = %q, want %q", record.UserDID, tt.userDID)
}
if record.UserHandle != tt.userHandle {
@@ -1213,7 +1187,7 @@ func TestNewLayerRecord(t *testing.T) {
}
func TestNewLayerRecordJSON(t *testing.T) {
// Test that HoldLayer can be marshaled/unmarshaled to/from JSON
// Test that LayerRecord can be marshaled/unmarshaled to/from JSON
record := NewLayerRecord(
"sha256:abc123",
1024,
@@ -1230,14 +1204,14 @@ func TestNewLayerRecordJSON(t *testing.T) {
}
// Unmarshal back
var decoded HoldLayer
var decoded LayerRecord
if err := json.Unmarshal(jsonData, &decoded); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Verify fields match
if decoded.LexiconTypeID != record.LexiconTypeID {
t.Errorf("LexiconTypeID = %q, want %q", decoded.LexiconTypeID, record.LexiconTypeID)
if decoded.Type != record.Type {
t.Errorf("Type = %q, want %q", decoded.Type, record.Type)
}
if decoded.Digest != record.Digest {
t.Errorf("Digest = %q, want %q", decoded.Digest, record.Digest)
@@ -1251,8 +1225,8 @@ func TestNewLayerRecordJSON(t *testing.T) {
if decoded.Repository != record.Repository {
t.Errorf("Repository = %q, want %q", decoded.Repository, record.Repository)
}
if decoded.UserDid != record.UserDid {
t.Errorf("UserDid = %q, want %q", decoded.UserDid, record.UserDid)
if decoded.UserDID != record.UserDID {
t.Errorf("UserDID = %q, want %q", decoded.UserDID, record.UserDID)
}
if decoded.UserHandle != record.UserHandle {
t.Errorf("UserHandle = %q, want %q", decoded.UserHandle, record.UserHandle)
@@ -1261,3 +1235,135 @@ func TestNewLayerRecordJSON(t *testing.T) {
t.Errorf("CreatedAt = %q, want %q", decoded.CreatedAt, record.CreatedAt)
}
}
func TestNewRepoPageRecord(t *testing.T) {
tests := []struct {
name string
repository string
description string
avatar *ATProtoBlobRef
}{
{
name: "with description only",
repository: "myapp",
description: "# My App\n\nA cool container image.",
avatar: nil,
},
{
name: "with avatar only",
repository: "another-app",
description: "",
avatar: &ATProtoBlobRef{
Type: "blob",
Ref: Link{Link: "bafyreiabc123"},
MimeType: "image/png",
Size: 1024,
},
},
{
name: "with both description and avatar",
repository: "full-app",
description: "This is a full description.",
avatar: &ATProtoBlobRef{
Type: "blob",
Ref: Link{Link: "bafyreiabc456"},
MimeType: "image/jpeg",
Size: 2048,
},
},
{
name: "empty values",
repository: "",
description: "",
avatar: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
before := time.Now()
record := NewRepoPageRecord(tt.repository, tt.description, tt.avatar)
after := time.Now()
if record.Type != RepoPageCollection {
t.Errorf("Type = %v, want %v", record.Type, RepoPageCollection)
}
if record.Repository != tt.repository {
t.Errorf("Repository = %v, want %v", record.Repository, tt.repository)
}
if record.Description != tt.description {
t.Errorf("Description = %v, want %v", record.Description, tt.description)
}
if tt.avatar == nil && record.Avatar != nil {
t.Error("Avatar should be nil")
}
if tt.avatar != nil {
if record.Avatar == nil {
t.Fatal("Avatar should not be nil")
}
if record.Avatar.Ref.Link != tt.avatar.Ref.Link {
t.Errorf("Avatar.Ref.Link = %v, want %v", record.Avatar.Ref.Link, tt.avatar.Ref.Link)
}
}
if record.CreatedAt.Before(before) || record.CreatedAt.After(after) {
t.Errorf("CreatedAt = %v, want between %v and %v", record.CreatedAt, before, after)
}
if record.UpdatedAt.Before(before) || record.UpdatedAt.After(after) {
t.Errorf("UpdatedAt = %v, want between %v and %v", record.UpdatedAt, before, after)
}
// CreatedAt and UpdatedAt should be equal for new records
if !record.CreatedAt.Equal(record.UpdatedAt) {
t.Errorf("CreatedAt (%v) != UpdatedAt (%v)", record.CreatedAt, record.UpdatedAt)
}
})
}
}
func TestRepoPageRecord_JSONSerialization(t *testing.T) {
record := NewRepoPageRecord(
"myapp",
"# My App\n\nA description with **markdown**.",
&ATProtoBlobRef{
Type: "blob",
Ref: Link{Link: "bafyreiabc123"},
MimeType: "image/png",
Size: 1024,
},
)
// Serialize to JSON
jsonData, err := json.Marshal(record)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
// Deserialize from JSON
var decoded RepoPageRecord
if err := json.Unmarshal(jsonData, &decoded); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Verify fields
if decoded.Type != record.Type {
t.Errorf("Type = %v, want %v", decoded.Type, record.Type)
}
if decoded.Repository != record.Repository {
t.Errorf("Repository = %v, want %v", decoded.Repository, record.Repository)
}
if decoded.Description != record.Description {
t.Errorf("Description = %v, want %v", decoded.Description, record.Description)
}
if decoded.Avatar == nil {
t.Fatal("Avatar should not be nil")
}
if decoded.Avatar.Ref.Link != record.Avatar.Ref.Link {
t.Errorf("Avatar.Ref.Link = %v, want %v", decoded.Avatar.Ref.Link, record.Avatar.Ref.Link)
}
}

View File

@@ -1,103 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
// Lexicon schema: io.atcr.manifest
package atproto
import (
lexutil "github.com/bluesky-social/indigo/lex/util"
)
// A container image manifest following OCI specification, stored in ATProto
type Manifest struct {
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.manifest"`
// annotations: Optional metadata annotations
Annotations *Manifest_Annotations `json:"annotations,omitempty" cborgen:"annotations,omitempty"`
// config: Reference to image configuration blob
Config *Manifest_BlobReference `json:"config,omitempty" cborgen:"config,omitempty"`
// createdAt: Record creation timestamp
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
// digest: Content digest (e.g., 'sha256:abc123...')
Digest string `json:"digest" cborgen:"digest"`
// holdDid: DID of the hold service where blobs are stored (e.g., 'did:web:hold01.atcr.io'). Primary reference for hold resolution.
HoldDid *string `json:"holdDid,omitempty" cborgen:"holdDid,omitempty"`
// holdEndpoint: Hold service endpoint URL where blobs are stored. DEPRECATED: Use holdDid instead. Kept for backward compatibility.
HoldEndpoint *string `json:"holdEndpoint,omitempty" cborgen:"holdEndpoint,omitempty"`
// layers: Filesystem layers (for image manifests)
Layers []Manifest_BlobReference `json:"layers,omitempty" cborgen:"layers,omitempty"`
// manifestBlob: The full OCI manifest stored as a blob in ATProto.
ManifestBlob *lexutil.LexBlob `json:"manifestBlob,omitempty" cborgen:"manifestBlob,omitempty"`
// manifests: Referenced manifests (for manifest lists/indexes)
Manifests []Manifest_ManifestReference `json:"manifests,omitempty" cborgen:"manifests,omitempty"`
// mediaType: OCI media type
MediaType string `json:"mediaType" cborgen:"mediaType"`
// repository: Repository name (e.g., 'myapp'). Scoped to user's DID.
Repository string `json:"repository" cborgen:"repository"`
// schemaVersion: OCI schema version (typically 2)
SchemaVersion int64 `json:"schemaVersion" cborgen:"schemaVersion"`
// subject: Optional reference to another manifest (for attestations, signatures)
Subject *Manifest_BlobReference `json:"subject,omitempty" cborgen:"subject,omitempty"`
}
// Optional metadata annotations
type Manifest_Annotations struct {
}
// Manifest_BlobReference is a "blobReference" in the io.atcr.manifest schema.
//
// Reference to a blob stored in S3 or external storage
type Manifest_BlobReference struct {
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.manifest#blobReference,omitempty"`
// annotations: Optional metadata
Annotations *Manifest_BlobReference_Annotations `json:"annotations,omitempty" cborgen:"annotations,omitempty"`
// digest: Content digest (e.g., 'sha256:...')
Digest string `json:"digest" cborgen:"digest"`
// mediaType: MIME type of the blob
MediaType string `json:"mediaType" cborgen:"mediaType"`
// size: Size in bytes
Size int64 `json:"size" cborgen:"size"`
// urls: Optional direct URLs to blob (for BYOS)
Urls []string `json:"urls,omitempty" cborgen:"urls,omitempty"`
}
// Optional metadata
type Manifest_BlobReference_Annotations struct {
}
// Manifest_ManifestReference is a "manifestReference" in the io.atcr.manifest schema.
//
// Reference to a manifest in a manifest list/index
type Manifest_ManifestReference struct {
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.manifest#manifestReference,omitempty"`
// annotations: Optional metadata
Annotations *Manifest_ManifestReference_Annotations `json:"annotations,omitempty" cborgen:"annotations,omitempty"`
// digest: Content digest (e.g., 'sha256:...')
Digest string `json:"digest" cborgen:"digest"`
// mediaType: Media type of the referenced manifest
MediaType string `json:"mediaType" cborgen:"mediaType"`
// platform: Platform information for this manifest
Platform *Manifest_Platform `json:"platform,omitempty" cborgen:"platform,omitempty"`
// size: Size in bytes
Size int64 `json:"size" cborgen:"size"`
}
// Optional metadata
type Manifest_ManifestReference_Annotations struct {
}
// Manifest_Platform is a "platform" in the io.atcr.manifest schema.
//
// Platform information describing OS and architecture
type Manifest_Platform struct {
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.manifest#platform,omitempty"`
// architecture: CPU architecture (e.g., 'amd64', 'arm64', 'arm')
Architecture string `json:"architecture" cborgen:"architecture"`
// os: Operating system (e.g., 'linux', 'windows', 'darwin')
Os string `json:"os" cborgen:"os"`
// osFeatures: Optional OS features
OsFeatures []string `json:"osFeatures,omitempty" cborgen:"osFeatures,omitempty"`
// osVersion: Optional OS version
OsVersion *string `json:"osVersion,omitempty" cborgen:"osVersion,omitempty"`
// variant: Optional CPU variant (e.g., 'v7' for ARM)
Variant *string `json:"variant,omitempty" cborgen:"variant,omitempty"`
}

View File

@@ -1,15 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
package atproto
import lexutil "github.com/bluesky-social/indigo/lex/util"
func init() {
lexutil.RegisterType("io.atcr.hold.captain", &HoldCaptain{})
lexutil.RegisterType("io.atcr.hold.crew", &HoldCrew{})
lexutil.RegisterType("io.atcr.hold.layer", &HoldLayer{})
lexutil.RegisterType("io.atcr.manifest", &Manifest{})
lexutil.RegisterType("io.atcr.sailor.profile", &SailorProfile{})
lexutil.RegisterType("io.atcr.sailor.star", &SailorStar{})
lexutil.RegisterType("io.atcr.tag", &Tag{})
}

View File

@@ -1,16 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
// Lexicon schema: io.atcr.sailor.profile
package atproto
// User profile for ATCR registry. Stores preferences like default hold for blob storage.
type SailorProfile struct {
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.sailor.profile"`
// createdAt: Profile creation timestamp
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
// defaultHold: Default hold endpoint for blob storage. If null, user has opted out of defaults.
DefaultHold *string `json:"defaultHold,omitempty" cborgen:"defaultHold,omitempty"`
// updatedAt: Profile last updated timestamp
UpdatedAt *string `json:"updatedAt,omitempty" cborgen:"updatedAt,omitempty"`
}

View File

@@ -1,25 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
// Lexicon schema: io.atcr.sailor.star
package atproto
// A star (like) on a container image repository. Stored in the starrer's PDS, similar to Bluesky likes.
type SailorStar struct {
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.sailor.star"`
// createdAt: Star creation timestamp
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
// subject: The repository being starred
Subject SailorStar_Subject `json:"subject" cborgen:"subject"`
}
// SailorStar_Subject is a "subject" in the io.atcr.sailor.star schema.
//
// Reference to a repository owned by a user
type SailorStar_Subject struct {
LexiconTypeID string `json:"$type,omitempty" cborgen:"$type,const=io.atcr.sailor.star#subject,omitempty"`
// did: DID of the repository owner
Did string `json:"did" cborgen:"did"`
// repository: Repository name (e.g., 'myapp')
Repository string `json:"repository" cborgen:"repository"`
}

View File

@@ -1,20 +0,0 @@
// Code generated by generate.go; DO NOT EDIT.
// Lexicon schema: io.atcr.tag
package atproto
// A named tag pointing to a specific manifest digest
type Tag struct {
LexiconTypeID string `json:"$type" cborgen:"$type,const=io.atcr.tag"`
// createdAt: Tag creation timestamp
CreatedAt string `json:"createdAt" cborgen:"createdAt"`
// manifest: AT-URI of the manifest this tag points to (e.g., 'at://did:plc:xyz/io.atcr.manifest/abc123'). Preferred over manifestDigest for new records.
Manifest *string `json:"manifest,omitempty" cborgen:"manifest,omitempty"`
// manifestDigest: DEPRECATED: Digest of the manifest (e.g., 'sha256:...'). Kept for backward compatibility with old records. New records should use 'manifest' field instead.
ManifestDigest *string `json:"manifestDigest,omitempty" cborgen:"manifestDigest,omitempty"`
// repository: Repository name (e.g., 'myapp'). Scoped to user's DID.
Repository string `json:"repository" cborgen:"repository"`
// tag: Tag name (e.g., 'latest', 'v1.0.0', '12-slim')
Tag string `json:"tag" cborgen:"tag"`
}

View File

@@ -2,14 +2,10 @@
// Service tokens are JWTs issued by a user's PDS to authorize AppView to
// act on their behalf when communicating with hold services. Tokens are
// cached with automatic expiry parsing and 10-second safety margins.
package token
package auth
import (
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
)
@@ -18,6 +14,8 @@ import (
type serviceTokenEntry struct {
token string
expiresAt time.Time
err error
once sync.Once
}
// Global cache for service tokens (DID:HoldDID -> token)
@@ -61,7 +59,7 @@ func SetServiceToken(did, holdDID, token string) error {
cacheKey := did + ":" + holdDID
// Parse JWT to extract expiry (don't verify signature - we trust the PDS)
expiry, err := parseJWTExpiry(token)
expiry, err := ParseJWTExpiry(token)
if err != nil {
// If parsing fails, use default 50s TTL (conservative fallback)
slog.Warn("Failed to parse JWT expiry, using default 50s", "error", err, "cacheKey", cacheKey)
@@ -85,37 +83,6 @@ func SetServiceToken(did, holdDID, token string) error {
return nil
}
// parseJWTExpiry extracts the expiry time from a JWT without verifying the signature
// We trust tokens from the user's PDS, so signature verification isn't needed here
// Manually decodes the JWT payload to avoid algorithm compatibility issues
func parseJWTExpiry(tokenString string) (time.Time, error) {
// JWT format: header.payload.signature
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode the payload (second part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %w", err)
}
// Parse the JSON payload
var claims struct {
Exp int64 `json:"exp"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return time.Time{}, fmt.Errorf("failed to parse JWT claims: %w", err)
}
if claims.Exp == 0 {
return time.Time{}, fmt.Errorf("JWT missing exp claim")
}
return time.Unix(claims.Exp, 0), nil
}
// InvalidateServiceToken removes a service token from the cache
// Used when we detect that a token is invalid or the user's session has expired
func InvalidateServiceToken(did, holdDID string) {

View File

@@ -1,4 +1,4 @@
package token
package auth
import (
"testing"

View File

@@ -21,7 +21,7 @@ type HoldAuthorizer interface {
// GetCaptainRecord retrieves the captain record for a hold
// Used to check public flag and allowAllCrew settings
GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error)
GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error)
// IsCrewMember checks if userDID is a crew member of holdDID
IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error)
@@ -32,7 +32,7 @@ type HoldAuthorizer interface {
// Read access rules:
// - Public hold: allow anyone (even anonymous)
// - Private hold: require authentication (any authenticated user)
func CheckReadAccessWithCaptain(captain *atproto.HoldCaptain, userDID string) bool {
func CheckReadAccessWithCaptain(captain *atproto.CaptainRecord, userDID string) bool {
if captain.Public {
// Public hold - allow anyone (even anonymous)
return true
@@ -55,7 +55,7 @@ func CheckReadAccessWithCaptain(captain *atproto.HoldCaptain, userDID string) bo
// Write access rules:
// - Must be authenticated
// - Must be hold owner OR crew member
func CheckWriteAccessWithCaptain(captain *atproto.HoldCaptain, userDID string, isCrew bool) bool {
func CheckWriteAccessWithCaptain(captain *atproto.CaptainRecord, userDID string, isCrew bool) bool {
slog.Debug("Checking write access", "userDID", userDID, "owner", captain.Owner, "isCrew", isCrew)
if userDID == "" {

View File

@@ -7,7 +7,7 @@ import (
)
func TestCheckReadAccessWithCaptain_PublicHold(t *testing.T) {
captain := &atproto.HoldCaptain{
captain := &atproto.CaptainRecord{
Public: true,
Owner: "did:plc:owner123",
}
@@ -26,7 +26,7 @@ func TestCheckReadAccessWithCaptain_PublicHold(t *testing.T) {
}
func TestCheckReadAccessWithCaptain_PrivateHold(t *testing.T) {
captain := &atproto.HoldCaptain{
captain := &atproto.CaptainRecord{
Public: false,
Owner: "did:plc:owner123",
}
@@ -45,7 +45,7 @@ func TestCheckReadAccessWithCaptain_PrivateHold(t *testing.T) {
}
func TestCheckWriteAccessWithCaptain_Owner(t *testing.T) {
captain := &atproto.HoldCaptain{
captain := &atproto.CaptainRecord{
Public: false,
Owner: "did:plc:owner123",
}
@@ -58,7 +58,7 @@ func TestCheckWriteAccessWithCaptain_Owner(t *testing.T) {
}
func TestCheckWriteAccessWithCaptain_Crew(t *testing.T) {
captain := &atproto.HoldCaptain{
captain := &atproto.CaptainRecord{
Public: false,
Owner: "did:plc:owner123",
}
@@ -77,7 +77,7 @@ func TestCheckWriteAccessWithCaptain_Crew(t *testing.T) {
}
func TestCheckWriteAccessWithCaptain_Anonymous(t *testing.T) {
captain := &atproto.HoldCaptain{
captain := &atproto.CaptainRecord{
Public: false,
Owner: "did:plc:owner123",
}

View File

@@ -35,7 +35,7 @@ func NewLocalHoldAuthorizerFromInterface(holdPDS any) HoldAuthorizer {
}
// GetCaptainRecord retrieves the captain record from the hold's PDS
func (a *LocalHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error) {
func (a *LocalHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
// Verify that the requested holdDID matches this hold
if holdDID != a.pds.DID() {
return nil, fmt.Errorf("holdDID mismatch: requested %s, this hold is %s", holdDID, a.pds.DID())
@@ -47,7 +47,7 @@ func (a *LocalHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID stri
return nil, fmt.Errorf("failed to get captain record: %w", err)
}
// The PDS returns *atproto.HoldCaptain directly
// The PDS returns *atproto.CaptainRecord directly now (after we update pds to use atproto types)
return pdsCaptain, nil
}

View File

@@ -101,14 +101,14 @@ func (a *RemoteHoldAuthorizer) cleanupRecentDenials() {
// 1. Check database cache
// 2. If cache miss or expired, query hold's XRPC endpoint
// 3. Update cache
func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error) {
func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
// Try cache first
if a.db != nil {
cached, err := a.getCachedCaptainRecord(holdDID)
if err == nil && cached != nil {
// Cache hit - check if still valid
if time.Since(cached.UpdatedAt) < a.cacheTTL {
return cached.HoldCaptain, nil
return cached.CaptainRecord, nil
}
// Cache expired - continue to fetch fresh data
}
@@ -133,7 +133,7 @@ func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID str
// captainRecordWithMeta includes UpdatedAt for cache management
type captainRecordWithMeta struct {
*atproto.HoldCaptain
*atproto.CaptainRecord
UpdatedAt time.Time
}
@@ -145,7 +145,7 @@ func (a *RemoteHoldAuthorizer) getCachedCaptainRecord(holdDID string) (*captainR
WHERE hold_did = ?
`
var record atproto.HoldCaptain
var record atproto.CaptainRecord
var deployedAt, region, provider sql.NullString
var updatedAt time.Time
@@ -172,20 +172,20 @@ func (a *RemoteHoldAuthorizer) getCachedCaptainRecord(holdDID string) (*captainR
record.DeployedAt = deployedAt.String
}
if region.Valid {
record.Region = &region.String
record.Region = region.String
}
if provider.Valid {
record.Provider = &provider.String
record.Provider = provider.String
}
return &captainRecordWithMeta{
HoldCaptain: &record,
UpdatedAt: updatedAt,
CaptainRecord: &record,
UpdatedAt: updatedAt,
}, nil
}
// setCachedCaptainRecord stores a captain record in database cache
func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *atproto.HoldCaptain) error {
func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *atproto.CaptainRecord) error {
query := `
INSERT INTO hold_captain_records (
hold_did, owner_did, public, allow_all_crew,
@@ -207,8 +207,8 @@ func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *at
record.Public,
record.AllowAllCrew,
nullString(record.DeployedAt),
nullStringPtr(record.Region),
nullStringPtr(record.Provider),
nullString(record.Region),
nullString(record.Provider),
time.Now(),
)
@@ -216,7 +216,7 @@ func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *at
}
// fetchCaptainRecordFromXRPC queries the hold's XRPC endpoint for captain record
func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, holdDID string) (*atproto.HoldCaptain, error) {
func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
// Resolve DID to URL
holdURL := atproto.ResolveHoldURL(holdDID)
@@ -261,20 +261,14 @@ func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, h
}
// Convert to our type
record := &atproto.HoldCaptain{
LexiconTypeID: atproto.CaptainCollection,
Owner: xrpcResp.Value.Owner,
Public: xrpcResp.Value.Public,
AllowAllCrew: xrpcResp.Value.AllowAllCrew,
DeployedAt: xrpcResp.Value.DeployedAt,
}
// Handle optional pointer fields
if xrpcResp.Value.Region != "" {
record.Region = &xrpcResp.Value.Region
}
if xrpcResp.Value.Provider != "" {
record.Provider = &xrpcResp.Value.Provider
record := &atproto.CaptainRecord{
Type: atproto.CaptainCollection,
Owner: xrpcResp.Value.Owner,
Public: xrpcResp.Value.Public,
AllowAllCrew: xrpcResp.Value.AllowAllCrew,
DeployedAt: xrpcResp.Value.DeployedAt,
Region: xrpcResp.Value.Region,
Provider: xrpcResp.Value.Provider,
}
return record, nil
@@ -414,14 +408,6 @@ func nullString(s string) sql.NullString {
return sql.NullString{String: s, Valid: true}
}
// nullStringPtr converts a *string to sql.NullString
func nullStringPtr(s *string) sql.NullString {
if s == nil || *s == "" {
return sql.NullString{Valid: false}
}
return sql.NullString{String: *s, Valid: true}
}
// getCachedApproval checks if user has a cached crew approval
func (a *RemoteHoldAuthorizer) getCachedApproval(holdDID, userDID string) (bool, error) {
query := `

View File

@@ -14,11 +14,6 @@ import (
"atcr.io/pkg/atproto"
)
// ptrString returns a pointer to the given string
func ptrString(s string) *string {
return &s
}
func TestNewRemoteHoldAuthorizer(t *testing.T) {
// Test with nil database (should still work)
authorizer := NewRemoteHoldAuthorizer(nil, false)
@@ -138,14 +133,14 @@ func TestGetCaptainRecord_CacheHit(t *testing.T) {
holdDID := "did:web:hold01.atcr.io"
// Pre-populate cache with a captain record
captainRecord := &atproto.HoldCaptain{
LexiconTypeID: atproto.CaptainCollection,
Owner: "did:plc:owner123",
Public: true,
AllowAllCrew: false,
DeployedAt: "2025-10-28T00:00:00Z",
Region: ptrString("us-east-1"),
Provider: ptrString("fly.io"),
captainRecord := &atproto.CaptainRecord{
Type: atproto.CaptainCollection,
Owner: "did:plc:owner123",
Public: true,
AllowAllCrew: false,
DeployedAt: "2025-10-28T00:00:00Z",
Region: "us-east-1",
Provider: "fly.io",
}
err := remote.setCachedCaptainRecord(holdDID, captainRecord)

View File

@@ -0,0 +1,80 @@
package auth
import (
"context"
"atcr.io/pkg/atproto"
)
// MockHoldAuthorizer is a test double for HoldAuthorizer.
// It allows tests to control the return values of authorization checks
// without making network calls or querying a real PDS.
type MockHoldAuthorizer struct {
// Direct result control
CanReadResult bool
CanWriteResult bool
CanAdminResult bool
Error error
// Captain record to return (optional, for GetCaptainRecord)
CaptainRecord *atproto.CaptainRecord
// Crew membership (optional, for IsCrewMember)
IsCrewResult bool
}
// NewMockHoldAuthorizer creates a MockHoldAuthorizer with sensible defaults.
// By default, it allows all access (public hold, user is owner).
func NewMockHoldAuthorizer() *MockHoldAuthorizer {
return &MockHoldAuthorizer{
CanReadResult: true,
CanWriteResult: true,
CanAdminResult: false,
IsCrewResult: false,
CaptainRecord: &atproto.CaptainRecord{
Type: "io.atcr.hold.captain",
Owner: "did:plc:mock-owner",
Public: true,
},
}
}
// CheckReadAccess returns the configured CanReadResult.
func (m *MockHoldAuthorizer) CheckReadAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
if m.Error != nil {
return false, m.Error
}
return m.CanReadResult, nil
}
// CheckWriteAccess returns the configured CanWriteResult.
func (m *MockHoldAuthorizer) CheckWriteAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
if m.Error != nil {
return false, m.Error
}
return m.CanWriteResult, nil
}
// GetCaptainRecord returns the configured CaptainRecord or a default.
func (m *MockHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
if m.Error != nil {
return nil, m.Error
}
if m.CaptainRecord != nil {
return m.CaptainRecord, nil
}
// Return a default captain record
return &atproto.CaptainRecord{
Type: "io.atcr.hold.captain",
Owner: "did:plc:mock-owner",
Public: true,
}, nil
}
// IsCrewMember returns the configured IsCrewResult.
func (m *MockHoldAuthorizer) IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error) {
if m.Error != nil {
return false, m.Error
}
return m.IsCrewResult, nil
}

View File

@@ -72,11 +72,19 @@ func RedirectURI(baseURL string) string {
return baseURL + "/auth/oauth/callback"
}
// GetDefaultScopes returns the default OAuth scopes for ATCR registry operations
// testMode determines whether to use transition:generic (test) or rpc scopes (production)
// GetDefaultScopes returns the default OAuth scopes for ATCR registry operations.
// Includes io.atcr.authFullApp permission-set plus individual scopes for PDS compatibility.
// Blob scopes are listed explicitly (not supported in Lexicon permission-sets).
func GetDefaultScopes(did string) []string {
scopes := []string{
return []string{
"atproto",
// Permission-set (for future PDS support)
// See lexicons/io/atcr/authFullApp.json for definition
// Uses "include:" prefix per ATProto permission spec
"include:io.atcr.authFullApp",
// com.atproto scopes must be separate (permission-sets are namespace-limited)
"rpc:com.atproto.repo.getRecord?aud=*",
// Blob scopes (not supported in Lexicon permission-sets)
// Image manifest types (single-arch)
"blob:application/vnd.oci.image.manifest.v1+json",
"blob:application/vnd.docker.distribution.manifest.v2+json",
@@ -85,19 +93,9 @@ func GetDefaultScopes(did string) []string {
"blob:application/vnd.docker.distribution.manifest.list.v2+json",
// OCI artifact manifests (for cosign signatures, SBOMs, attestations)
"blob:application/vnd.cncf.oras.artifact.manifest.v1+json",
// Used for service token validation on holds
"rpc:com.atproto.repo.getRecord?aud=*",
// Image avatars
"blob:image/*",
}
// Add repo scopes
scopes = append(scopes,
fmt.Sprintf("repo:%s", atproto.ManifestCollection),
fmt.Sprintf("repo:%s", atproto.TagCollection),
fmt.Sprintf("repo:%s", atproto.StarCollection),
fmt.Sprintf("repo:%s", atproto.SailorProfileCollection),
)
return scopes
}
// ScopesMatch checks if two scope lists are equivalent (order-independent)
@@ -225,6 +223,18 @@ func (r *Refresher) DoWithSession(ctx context.Context, did string, fn func(sessi
// The session's PersistSessionCallback will save nonce updates to DB
err = fn(session)
// If request failed with auth error, delete session to force re-auth
if err != nil && isAuthError(err) {
slog.Warn("Auth error detected, deleting session to force re-auth",
"component", "oauth/refresher",
"did", did,
"error", err)
// Don't hold the lock while deleting - release first
mutex.Unlock()
_ = r.DeleteSession(ctx, did)
mutex.Lock() // Re-acquire for the deferred unlock
}
slog.Debug("Released session lock for DoWithSession",
"component", "oauth/refresher",
"did", did,
@@ -233,6 +243,19 @@ func (r *Refresher) DoWithSession(ctx context.Context, did string, fn func(sessi
return err
}
// isAuthError checks if an error looks like an OAuth/auth failure
func isAuthError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "unauthorized") ||
strings.Contains(errStr, "invalid_token") ||
strings.Contains(errStr, "insufficient_scope") ||
strings.Contains(errStr, "token expired") ||
strings.Contains(errStr, "401")
}
// resumeSession loads a session from storage
func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
// Parse DID
@@ -257,28 +280,15 @@ func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.Clien
return nil, fmt.Errorf("no session found for DID: %s", did)
}
// Validate that session scopes match current desired scopes
// Log scope differences for debugging, but don't delete session
// The PDS will reject requests if scopes are insufficient
// (Permission-sets get expanded by PDS, so exact matching doesn't work)
desiredScopes := r.clientApp.Config.Scopes
if !ScopesMatch(sessionData.Scopes, desiredScopes) {
slog.Debug("Scope mismatch, deleting session",
slog.Debug("Session scopes differ from desired (may be permission-set expansion)",
"did", did,
"storedScopes", sessionData.Scopes,
"desiredScopes", desiredScopes)
// Delete the session from database since scopes have changed
if err := r.clientApp.Store.DeleteSession(ctx, accountDID, sessionID); err != nil {
slog.Warn("Failed to delete session with mismatched scopes", "error", err, "did", did)
}
// Also invalidate UI sessions since OAuth is now invalid
if r.uiSessionStore != nil {
r.uiSessionStore.DeleteByDID(did)
slog.Info("Invalidated UI sessions due to scope mismatch",
"component", "oauth/refresher",
"did", did)
}
return nil, fmt.Errorf("OAuth scopes changed, re-authentication required")
}
// Resume session

View File

@@ -1,18 +1,13 @@
package oauth
import (
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"testing"
)
func TestNewClientApp(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
keyPath := tmpDir + "/oauth-key.bin"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
keyPath := t.TempDir() + "/oauth-key.bin"
store := oauth.NewMemStore()
baseURL := "http://localhost:5000"
scopes := GetDefaultScopes("*")
@@ -32,14 +27,8 @@ func TestNewClientApp(t *testing.T) {
}
func TestNewClientAppWithCustomScopes(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
keyPath := tmpDir + "/oauth-key.bin"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
keyPath := t.TempDir() + "/oauth-key.bin"
store := oauth.NewMemStore()
baseURL := "http://localhost:5000"
scopes := []string{"atproto", "custom:scope"}
@@ -128,13 +117,7 @@ func TestScopesMatch(t *testing.T) {
// ----------------------------------------------------------------------------
func TestNewRefresher(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -153,13 +136,7 @@ func TestNewRefresher(t *testing.T) {
}
func TestRefresher_SetUISessionStore(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")

View File

@@ -26,11 +26,7 @@ func InteractiveFlowWithCallback(
registerCallback func(handler http.HandlerFunc) error,
displayAuthURL func(string) error,
) (*InteractiveResult, error) {
// Create temporary file store for this flow
store, err := NewFileStore("/tmp/atcr-oauth-temp.json")
if err != nil {
return nil, fmt.Errorf("failed to create OAuth store: %w", err)
}
store := oauth.NewMemStore()
// Create OAuth client app with custom scopes (or defaults if nil)
// Interactive flows are typically for production use (credential helper, etc.)

View File

@@ -2,6 +2,7 @@ package oauth
import (
"context"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"net/http"
"net/http/httptest"
"strings"
@@ -11,13 +12,7 @@ import (
func TestNewServer(t *testing.T) {
// Create a basic OAuth app for testing
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -36,13 +31,7 @@ func TestNewServer(t *testing.T) {
}
func TestServer_SetRefresher(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -60,13 +49,7 @@ func TestServer_SetRefresher(t *testing.T) {
}
func TestServer_SetPostAuthCallback(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -87,13 +70,7 @@ func TestServer_SetPostAuthCallback(t *testing.T) {
}
func TestServer_SetUISessionStore(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -151,13 +128,7 @@ func (m *mockRefresher) InvalidateSession(did string) {
// ServeAuthorize tests
func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -179,13 +150,7 @@ func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
}
func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -209,13 +174,7 @@ func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
// ServeCallback tests
func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -237,13 +196,7 @@ func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
}
func TestServer_ServeCallback_OAuthError(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -270,13 +223,7 @@ func TestServer_ServeCallback_OAuthError(t *testing.T) {
}
func TestServer_ServeCallback_WithPostAuthCallback(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -315,13 +262,7 @@ func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
},
}
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -345,13 +286,7 @@ func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
}
func TestServer_RenderError(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
@@ -380,13 +315,7 @@ func TestServer_RenderError(t *testing.T) {
}
func TestServer_RenderRedirectToSettings(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")

View File

@@ -1,236 +0,0 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"sync"
"time"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/syntax"
)
// FileStore implements oauth.ClientAuthStore with file-based persistence
type FileStore struct {
path string
sessions map[string]*oauth.ClientSessionData // Key: "did:sessionID"
requests map[string]*oauth.AuthRequestData // Key: state
mu sync.RWMutex
}
// FileStoreData represents the JSON structure stored on disk
type FileStoreData struct {
Sessions map[string]*oauth.ClientSessionData `json:"sessions"`
Requests map[string]*oauth.AuthRequestData `json:"requests"`
}
// NewFileStore creates a new file-based OAuth store
func NewFileStore(path string) (*FileStore, error) {
store := &FileStore{
path: path,
sessions: make(map[string]*oauth.ClientSessionData),
requests: make(map[string]*oauth.AuthRequestData),
}
// Load existing data if file exists
if err := store.load(); err != nil {
if !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to load store: %w", err)
}
// File doesn't exist yet, that's ok
}
return store, nil
}
// GetDefaultStorePath returns the default storage path for OAuth data
func GetDefaultStorePath() (string, error) {
// For AppView: /var/lib/atcr/oauth-sessions.json
// For CLI tools: ~/.atcr/oauth-sessions.json
// Check if running as a service (has write access to /var/lib)
servicePath := "/var/lib/atcr/oauth-sessions.json"
if err := os.MkdirAll(filepath.Dir(servicePath), 0700); err == nil {
// Can write to /var/lib, use service path
return servicePath, nil
}
// Fall back to user home directory
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to get home directory: %w", err)
}
atcrDir := filepath.Join(homeDir, ".atcr")
if err := os.MkdirAll(atcrDir, 0700); err != nil {
return "", fmt.Errorf("failed to create .atcr directory: %w", err)
}
return filepath.Join(atcrDir, "oauth-sessions.json"), nil
}
// GetSession retrieves a session by DID and session ID
func (s *FileStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
s.mu.RLock()
defer s.mu.RUnlock()
key := makeSessionKey(did.String(), sessionID)
session, ok := s.sessions[key]
if !ok {
return nil, fmt.Errorf("session not found: %s/%s", did, sessionID)
}
return session, nil
}
// SaveSession saves or updates a session (upsert)
func (s *FileStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
s.mu.Lock()
defer s.mu.Unlock()
key := makeSessionKey(sess.AccountDID.String(), sess.SessionID)
s.sessions[key] = &sess
return s.save()
}
// DeleteSession removes a session
func (s *FileStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
s.mu.Lock()
defer s.mu.Unlock()
key := makeSessionKey(did.String(), sessionID)
delete(s.sessions, key)
return s.save()
}
// GetAuthRequestInfo retrieves authentication request data by state
func (s *FileStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
s.mu.RLock()
defer s.mu.RUnlock()
request, ok := s.requests[state]
if !ok {
return nil, fmt.Errorf("auth request not found: %s", state)
}
return request, nil
}
// SaveAuthRequestInfo saves authentication request data
func (s *FileStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
s.mu.Lock()
defer s.mu.Unlock()
s.requests[info.State] = &info
return s.save()
}
// DeleteAuthRequestInfo removes authentication request data
func (s *FileStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.requests, state)
return s.save()
}
// CleanupExpired removes expired sessions and auth requests
// Should be called periodically (e.g., every hour)
func (s *FileStore) CleanupExpired() error {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
modified := false
// Clean up auth requests older than 10 minutes
// (OAuth flows should complete quickly)
for state := range s.requests {
// Note: AuthRequestData doesn't have a timestamp in indigo's implementation
// For now, we'll rely on the OAuth server's cleanup routine
// or we could extend AuthRequestData with metadata
_ = state // Placeholder for future expiration logic
}
// Sessions don't have expiry in the data structure
// Cleanup would need to be token-based (check token expiry)
// For now, manual cleanup via DeleteSession
_ = now
if modified {
return s.save()
}
return nil
}
// ListSessions returns all stored sessions for debugging/management
func (s *FileStore) ListSessions() map[string]*oauth.ClientSessionData {
s.mu.RLock()
defer s.mu.RUnlock()
// Return a copy to prevent external modification
result := make(map[string]*oauth.ClientSessionData)
maps.Copy(result, s.sessions)
return result
}
// load reads data from disk
func (s *FileStore) load() error {
data, err := os.ReadFile(s.path)
if err != nil {
return err
}
var storeData FileStoreData
if err := json.Unmarshal(data, &storeData); err != nil {
return fmt.Errorf("failed to parse store: %w", err)
}
if storeData.Sessions != nil {
s.sessions = storeData.Sessions
}
if storeData.Requests != nil {
s.requests = storeData.Requests
}
return nil
}
// save writes data to disk
func (s *FileStore) save() error {
storeData := FileStoreData{
Sessions: s.sessions,
Requests: s.requests,
}
data, err := json.MarshalIndent(storeData, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal store: %w", err)
}
// Ensure directory exists
if err := os.MkdirAll(filepath.Dir(s.path), 0700); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Write with restrictive permissions
if err := os.WriteFile(s.path, data, 0600); err != nil {
return fmt.Errorf("failed to write store: %w", err)
}
return nil
}
// makeSessionKey creates a composite key for session storage
func makeSessionKey(did, sessionID string) string {
return fmt.Sprintf("%s:%s", did, sessionID)
}

View File

@@ -1,631 +0,0 @@
package oauth
import (
"context"
"encoding/json"
"os"
"testing"
"time"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/syntax"
)
func TestNewFileStore(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
if store == nil {
t.Fatal("Expected non-nil store")
}
if store.path != storePath {
t.Errorf("Expected path %q, got %q", storePath, store.path)
}
if store.sessions == nil {
t.Error("Expected sessions map to be initialized")
}
if store.requests == nil {
t.Error("Expected requests map to be initialized")
}
}
func TestFileStore_LoadNonExistent(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/nonexistent.json"
// Should succeed even if file doesn't exist
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() should succeed with non-existent file, got error: %v", err)
}
if store == nil {
t.Fatal("Expected non-nil store")
}
}
func TestFileStore_LoadCorruptedFile(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/corrupted.json"
// Create corrupted JSON file
if err := os.WriteFile(storePath, []byte("invalid json {{{"), 0600); err != nil {
t.Fatalf("Failed to create corrupted file: %v", err)
}
// Should fail to load corrupted file
_, err := NewFileStore(storePath)
if err == nil {
t.Error("Expected error when loading corrupted file")
}
}
func TestFileStore_GetSession_NotFound(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:test123")
sessionID := "session123"
// Should return error for non-existent session
session, err := store.GetSession(ctx, did, sessionID)
if err == nil {
t.Error("Expected error for non-existent session")
}
if session != nil {
t.Error("Expected nil session for non-existent entry")
}
}
func TestFileStore_SaveAndGetSession(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:alice123")
// Create test session
sessionData := oauth.ClientSessionData{
AccountDID: did,
SessionID: "test-session-123",
HostURL: "https://pds.example.com",
Scopes: []string{"atproto", "blob:read"},
}
// Save session
if err := store.SaveSession(ctx, sessionData); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
// Retrieve session
retrieved, err := store.GetSession(ctx, did, "test-session-123")
if err != nil {
t.Fatalf("GetSession() error = %v", err)
}
if retrieved == nil {
t.Fatal("Expected non-nil session")
}
if retrieved.SessionID != sessionData.SessionID {
t.Errorf("Expected sessionID %q, got %q", sessionData.SessionID, retrieved.SessionID)
}
if retrieved.AccountDID.String() != did.String() {
t.Errorf("Expected DID %q, got %q", did.String(), retrieved.AccountDID.String())
}
if retrieved.HostURL != sessionData.HostURL {
t.Errorf("Expected hostURL %q, got %q", sessionData.HostURL, retrieved.HostURL)
}
}
func TestFileStore_UpdateSession(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:alice123")
// Save initial session
sessionData := oauth.ClientSessionData{
AccountDID: did,
SessionID: "test-session-123",
HostURL: "https://pds.example.com",
Scopes: []string{"atproto"},
}
if err := store.SaveSession(ctx, sessionData); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
// Update session with new scopes
sessionData.Scopes = []string{"atproto", "blob:read", "blob:write"}
if err := store.SaveSession(ctx, sessionData); err != nil {
t.Fatalf("SaveSession() (update) error = %v", err)
}
// Retrieve updated session
retrieved, err := store.GetSession(ctx, did, "test-session-123")
if err != nil {
t.Fatalf("GetSession() error = %v", err)
}
if len(retrieved.Scopes) != 3 {
t.Errorf("Expected 3 scopes, got %d", len(retrieved.Scopes))
}
}
func TestFileStore_DeleteSession(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:alice123")
// Save session
sessionData := oauth.ClientSessionData{
AccountDID: did,
SessionID: "test-session-123",
HostURL: "https://pds.example.com",
}
if err := store.SaveSession(ctx, sessionData); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
// Verify it exists
if _, err := store.GetSession(ctx, did, "test-session-123"); err != nil {
t.Fatalf("GetSession() should succeed before delete, got error: %v", err)
}
// Delete session
if err := store.DeleteSession(ctx, did, "test-session-123"); err != nil {
t.Fatalf("DeleteSession() error = %v", err)
}
// Verify it's gone
_, err = store.GetSession(ctx, did, "test-session-123")
if err == nil {
t.Error("Expected error after deleting session")
}
}
func TestFileStore_DeleteNonExistentSession(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:alice123")
// Delete non-existent session should not error
if err := store.DeleteSession(ctx, did, "nonexistent"); err != nil {
t.Errorf("DeleteSession() on non-existent session should not error, got: %v", err)
}
}
func TestFileStore_SaveAndGetAuthRequestInfo(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
// Create test auth request
did, _ := syntax.ParseDID("did:plc:alice123")
authRequest := oauth.AuthRequestData{
State: "test-state-123",
AuthServerURL: "https://pds.example.com",
AccountDID: &did,
Scopes: []string{"atproto", "blob:read"},
RequestURI: "urn:ietf:params:oauth:request_uri:test123",
AuthServerTokenEndpoint: "https://pds.example.com/oauth/token",
}
// Save auth request
if err := store.SaveAuthRequestInfo(ctx, authRequest); err != nil {
t.Fatalf("SaveAuthRequestInfo() error = %v", err)
}
// Retrieve auth request
retrieved, err := store.GetAuthRequestInfo(ctx, "test-state-123")
if err != nil {
t.Fatalf("GetAuthRequestInfo() error = %v", err)
}
if retrieved == nil {
t.Fatal("Expected non-nil auth request")
}
if retrieved.State != authRequest.State {
t.Errorf("Expected state %q, got %q", authRequest.State, retrieved.State)
}
if retrieved.AuthServerURL != authRequest.AuthServerURL {
t.Errorf("Expected authServerURL %q, got %q", authRequest.AuthServerURL, retrieved.AuthServerURL)
}
}
func TestFileStore_GetAuthRequestInfo_NotFound(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
// Should return error for non-existent request
_, err = store.GetAuthRequestInfo(ctx, "nonexistent-state")
if err == nil {
t.Error("Expected error for non-existent auth request")
}
}
func TestFileStore_DeleteAuthRequestInfo(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
// Save auth request
authRequest := oauth.AuthRequestData{
State: "test-state-123",
AuthServerURL: "https://pds.example.com",
}
if err := store.SaveAuthRequestInfo(ctx, authRequest); err != nil {
t.Fatalf("SaveAuthRequestInfo() error = %v", err)
}
// Verify it exists
if _, err := store.GetAuthRequestInfo(ctx, "test-state-123"); err != nil {
t.Fatalf("GetAuthRequestInfo() should succeed before delete, got error: %v", err)
}
// Delete auth request
if err := store.DeleteAuthRequestInfo(ctx, "test-state-123"); err != nil {
t.Fatalf("DeleteAuthRequestInfo() error = %v", err)
}
// Verify it's gone
_, err = store.GetAuthRequestInfo(ctx, "test-state-123")
if err == nil {
t.Error("Expected error after deleting auth request")
}
}
func TestFileStore_ListSessions(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
// Initially empty
sessions := store.ListSessions()
if len(sessions) != 0 {
t.Errorf("Expected 0 sessions, got %d", len(sessions))
}
// Add multiple sessions
did1, _ := syntax.ParseDID("did:plc:alice123")
did2, _ := syntax.ParseDID("did:plc:bob456")
session1 := oauth.ClientSessionData{
AccountDID: did1,
SessionID: "session-1",
HostURL: "https://pds1.example.com",
}
session2 := oauth.ClientSessionData{
AccountDID: did2,
SessionID: "session-2",
HostURL: "https://pds2.example.com",
}
if err := store.SaveSession(ctx, session1); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
if err := store.SaveSession(ctx, session2); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
// List sessions
sessions = store.ListSessions()
if len(sessions) != 2 {
t.Errorf("Expected 2 sessions, got %d", len(sessions))
}
// Verify we got both sessions
key1 := makeSessionKey(did1.String(), "session-1")
key2 := makeSessionKey(did2.String(), "session-2")
if sessions[key1] == nil {
t.Error("Expected session1 in list")
}
if sessions[key2] == nil {
t.Error("Expected session2 in list")
}
}
func TestFileStore_Persistence_Across_Instances(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:alice123")
// Create first store and save data
store1, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
sessionData := oauth.ClientSessionData{
AccountDID: did,
SessionID: "persistent-session",
HostURL: "https://pds.example.com",
}
if err := store1.SaveSession(ctx, sessionData); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
authRequest := oauth.AuthRequestData{
State: "persistent-state",
AuthServerURL: "https://pds.example.com",
}
if err := store1.SaveAuthRequestInfo(ctx, authRequest); err != nil {
t.Fatalf("SaveAuthRequestInfo() error = %v", err)
}
// Create second store from same file
store2, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("Second NewFileStore() error = %v", err)
}
// Verify session persisted
retrievedSession, err := store2.GetSession(ctx, did, "persistent-session")
if err != nil {
t.Fatalf("GetSession() from second store error = %v", err)
}
if retrievedSession.SessionID != "persistent-session" {
t.Errorf("Expected persistent session ID, got %q", retrievedSession.SessionID)
}
// Verify auth request persisted
retrievedAuth, err := store2.GetAuthRequestInfo(ctx, "persistent-state")
if err != nil {
t.Fatalf("GetAuthRequestInfo() from second store error = %v", err)
}
if retrievedAuth.State != "persistent-state" {
t.Errorf("Expected persistent state, got %q", retrievedAuth.State)
}
}
func TestFileStore_FileSecurity(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:alice123")
// Save some data to trigger file creation
sessionData := oauth.ClientSessionData{
AccountDID: did,
SessionID: "test-session",
HostURL: "https://pds.example.com",
}
if err := store.SaveSession(ctx, sessionData); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
// Check file permissions (should be 0600)
info, err := os.Stat(storePath)
if err != nil {
t.Fatalf("Failed to stat file: %v", err)
}
mode := info.Mode()
if mode.Perm() != 0600 {
t.Errorf("Expected file permissions 0600, got %o", mode.Perm())
}
}
func TestFileStore_JSONFormat(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
did, _ := syntax.ParseDID("did:plc:alice123")
// Save data
sessionData := oauth.ClientSessionData{
AccountDID: did,
SessionID: "test-session",
HostURL: "https://pds.example.com",
}
if err := store.SaveSession(ctx, sessionData); err != nil {
t.Fatalf("SaveSession() error = %v", err)
}
// Read and verify JSON format
data, err := os.ReadFile(storePath)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
var storeData FileStoreData
if err := json.Unmarshal(data, &storeData); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
if storeData.Sessions == nil {
t.Error("Expected sessions in JSON")
}
if storeData.Requests == nil {
t.Error("Expected requests in JSON")
}
}
func TestFileStore_CleanupExpired(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
// CleanupExpired should not error even with no data
if err := store.CleanupExpired(); err != nil {
t.Errorf("CleanupExpired() error = %v", err)
}
// Note: Current implementation doesn't actually clean anything
// since AuthRequestData and ClientSessionData don't have expiry timestamps
// This test verifies the method doesn't panic
}
func TestGetDefaultStorePath(t *testing.T) {
path, err := GetDefaultStorePath()
if err != nil {
t.Fatalf("GetDefaultStorePath() error = %v", err)
}
if path == "" {
t.Fatal("Expected non-empty path")
}
// Path should either be /var/lib/atcr or ~/.atcr
// We can't assert exact path since it depends on permissions
t.Logf("Default store path: %s", path)
}
func TestMakeSessionKey(t *testing.T) {
did := "did:plc:alice123"
sessionID := "session-456"
key := makeSessionKey(did, sessionID)
expected := "did:plc:alice123:session-456"
if key != expected {
t.Errorf("Expected key %q, got %q", expected, key)
}
}
func TestFileStore_ConcurrentAccess(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
ctx := context.Background()
// Run concurrent operations
done := make(chan bool)
// Writer goroutine
go func() {
for i := 0; i < 10; i++ {
did, _ := syntax.ParseDID("did:plc:alice123")
sessionData := oauth.ClientSessionData{
AccountDID: did,
SessionID: "session-1",
HostURL: "https://pds.example.com",
}
store.SaveSession(ctx, sessionData)
time.Sleep(1 * time.Millisecond)
}
done <- true
}()
// Reader goroutine
go func() {
for i := 0; i < 10; i++ {
did, _ := syntax.ParseDID("did:plc:alice123")
store.GetSession(ctx, did, "session-1")
time.Sleep(1 * time.Millisecond)
}
done <- true
}()
// Wait for both goroutines
<-done
<-done
// If we got here without panicking, the locking works
t.Log("Concurrent access test passed")
}

300
pkg/auth/servicetoken.go Normal file
View File

@@ -0,0 +1,300 @@
package auth
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth/oauth"
"github.com/bluesky-social/indigo/atproto/atclient"
indigo_oauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
)
// getErrorHint provides context-specific troubleshooting hints based on API error type
func getErrorHint(apiErr *atclient.APIError) string {
switch apiErr.Name {
case "use_dpop_nonce":
return "DPoP nonce mismatch - indigo library should automatically retry with new nonce. If this persists, check for concurrent request issues or PDS session corruption."
case "invalid_client":
if apiErr.Message != "" && apiErr.Message == "Validation of \"client_assertion\" failed: \"iat\" claim timestamp check failed (it should be in the past)" {
return "JWT timestamp validation failed - system clock on AppView may be ahead of PDS clock. Check NTP sync with: timedatectl status"
}
return "OAuth client authentication failed - check client key configuration and PDS OAuth server status"
case "invalid_token", "invalid_grant":
return "OAuth tokens expired or invalidated - user will need to re-authenticate via OAuth flow"
case "server_error":
if apiErr.StatusCode == 500 {
return "PDS returned internal server error - this may occur after repeated DPoP nonce failures or other PDS-side issues. Check PDS logs for root cause."
}
return "PDS server error - check PDS health and logs"
case "invalid_dpop_proof":
return "DPoP proof validation failed - check system clock sync and DPoP key configuration"
default:
if apiErr.StatusCode == 401 || apiErr.StatusCode == 403 {
return "Authentication/authorization failed - OAuth session may be expired or revoked"
}
return "PDS rejected the request - see errorName and errorMessage for details"
}
}
// ParseJWTExpiry extracts the expiry time from a JWT without verifying the signature
// We trust tokens from the user's PDS, so signature verification isn't needed here
// Manually decodes the JWT payload to avoid algorithm compatibility issues
func ParseJWTExpiry(tokenString string) (time.Time, error) {
// JWT format: header.payload.signature
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode the payload (second part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %w", err)
}
// Parse the JSON payload
var claims struct {
Exp int64 `json:"exp"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return time.Time{}, fmt.Errorf("failed to parse JWT claims: %w", err)
}
if claims.Exp == 0 {
return time.Time{}, fmt.Errorf("JWT missing exp claim")
}
return time.Unix(claims.Exp, 0), nil
}
// buildServiceAuthURL constructs the URL for com.atproto.server.getServiceAuth
func buildServiceAuthURL(pdsEndpoint, holdDID string) string {
// Request 5-minute expiry (PDS may grant less)
// exp must be absolute Unix timestamp, not relative duration
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
return fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
pdsEndpoint,
atproto.ServerGetServiceAuth,
url.QueryEscape(holdDID),
url.QueryEscape("com.atproto.repo.getRecord"),
expiryTime,
)
}
// parseServiceTokenResponse extracts the token from a service auth response
func parseServiceTokenResponse(resp *http.Response) (string, error) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Token string `json:"token"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode service auth response: %w", err)
}
if result.Token == "" {
return "", fmt.Errorf("empty token in service auth response")
}
return result.Token, nil
}
// GetOrFetchServiceToken gets a service token for hold authentication.
// Handles both OAuth/DPoP and app-password authentication based on authMethod.
// Checks cache first, then fetches from PDS if needed.
//
// For OAuth: Uses DoWithSession() to hold a per-DID lock through the entire PDS interaction.
// This prevents DPoP nonce race conditions when multiple Docker layers upload concurrently.
//
// For app-password: Uses Bearer token authentication without locking (no DPoP complexity).
func GetOrFetchServiceToken(
ctx context.Context,
authMethod string,
refresher *oauth.Refresher, // Required for OAuth, nil for app-password
did, holdDID, pdsEndpoint string,
) (string, error) {
// Check cache first to avoid unnecessary PDS calls on every request
cachedToken, expiresAt := GetServiceToken(did, holdDID)
// Use cached token if it exists and has > 10s remaining
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
slog.Debug("Using cached service token",
"did", did,
"authMethod", authMethod,
"expiresIn", time.Until(expiresAt).Round(time.Second))
return cachedToken, nil
}
// Cache miss or expiring soon - fetch new service token
if cachedToken == "" {
slog.Debug("Service token cache miss, fetching new token", "did", did, "authMethod", authMethod)
} else {
slog.Debug("Service token expiring soon, proactively renewing", "did", did, "authMethod", authMethod)
}
var serviceToken string
var err error
// Branch based on auth method
if authMethod == AuthMethodOAuth {
serviceToken, err = doOAuthFetch(ctx, refresher, did, holdDID, pdsEndpoint)
// OAuth-specific cleanup: delete stale session on error
if err != nil && refresher != nil {
if delErr := refresher.DeleteSession(ctx, did); delErr != nil {
slog.Warn("Failed to delete stale OAuth session",
"component", "auth/servicetoken",
"did", did,
"error", delErr)
}
}
} else {
serviceToken, err = doAppPasswordFetch(ctx, did, holdDID, pdsEndpoint)
}
// Unified error handling
if err != nil {
InvalidateServiceToken(did, holdDID)
var apiErr *atclient.APIError
if errors.As(err, &apiErr) {
slog.Error("Service token request failed",
"component", "auth/servicetoken",
"authMethod", authMethod,
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"error", err,
"httpStatus", apiErr.StatusCode,
"errorName", apiErr.Name,
"errorMessage", apiErr.Message,
"hint", getErrorHint(apiErr))
} else {
slog.Error("Service token request failed",
"component", "auth/servicetoken",
"authMethod", authMethod,
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"error", err)
}
return "", err
}
// Cache the token (parses JWT to extract actual expiry)
if cacheErr := SetServiceToken(did, holdDID, serviceToken); cacheErr != nil {
slog.Warn("Failed to cache service token", "error", cacheErr, "did", did, "holdDID", holdDID)
}
slog.Debug("Service token obtained", "did", did, "authMethod", authMethod)
return serviceToken, nil
}
// doOAuthFetch fetches a service token using OAuth/DPoP authentication.
// Uses DoWithSession() for per-DID locking to prevent DPoP nonce races.
// Returns (token, error) without logging - caller handles error logging.
func doOAuthFetch(
ctx context.Context,
refresher *oauth.Refresher,
did, holdDID, pdsEndpoint string,
) (string, error) {
if refresher == nil {
return "", fmt.Errorf("refresher is nil (OAuth session required)")
}
var serviceToken string
var fetchErr error
err := refresher.DoWithSession(ctx, did, func(session *indigo_oauth.ClientSession) error {
// Double-check cache after acquiring lock (double-checked locking pattern)
cachedToken, expiresAt := GetServiceToken(did, holdDID)
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
slog.Debug("Service token cache hit after lock acquisition",
"did", did,
"expiresIn", time.Until(expiresAt).Round(time.Second))
serviceToken = cachedToken
return nil
}
serviceAuthURL := buildServiceAuthURL(pdsEndpoint, holdDID)
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
if err != nil {
fetchErr = fmt.Errorf("failed to create request: %w", err)
return fetchErr
}
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
if err != nil {
fetchErr = fmt.Errorf("OAuth request failed: %w", err)
return fetchErr
}
token, parseErr := parseServiceTokenResponse(resp)
if parseErr != nil {
fetchErr = parseErr
return fetchErr
}
serviceToken = token
return nil
})
if err != nil {
if fetchErr != nil {
return "", fetchErr
}
return "", fmt.Errorf("failed to get OAuth session: %w", err)
}
return serviceToken, nil
}
// doAppPasswordFetch fetches a service token using Bearer token authentication.
// Returns (token, error) without logging - caller handles error logging.
func doAppPasswordFetch(
ctx context.Context,
did, holdDID, pdsEndpoint string,
) (string, error) {
accessToken, ok := GetGlobalTokenCache().Get(did)
if !ok {
return "", fmt.Errorf("no app-password access token available for DID %s", did)
}
serviceAuthURL := buildServiceAuthURL(pdsEndpoint, holdDID)
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
resp.Body.Close()
// Clear stale app-password token
GetGlobalTokenCache().Delete(did)
return "", fmt.Errorf("app-password authentication failed: token expired or invalid")
}
return parseServiceTokenResponse(resp)
}

View File

@@ -0,0 +1,27 @@
package auth
import (
"context"
"testing"
)
func TestGetOrFetchServiceToken_NilRefresher(t *testing.T) {
ctx := context.Background()
did := "did:plc:test123"
holdDID := "did:web:hold.example.com"
pdsEndpoint := "https://pds.example.com"
// Test with nil refresher and OAuth auth method - should return error
_, err := GetOrFetchServiceToken(ctx, AuthMethodOAuth, nil, did, holdDID, pdsEndpoint)
if err == nil {
t.Error("Expected error when refresher is nil for OAuth")
}
expectedErrMsg := "refresher is nil (OAuth session required)"
if err.Error() != expectedErrMsg {
t.Errorf("Expected error message %q, got %q", expectedErrMsg, err.Error())
}
}
// Note: Full tests with mocked OAuth refresher and HTTP client will be added
// in the comprehensive test implementation phase

View File

@@ -56,3 +56,22 @@ func ExtractAuthMethod(tokenString string) string {
return claims.AuthMethod
}
// ExtractSubject parses a JWT token string and extracts the Subject claim (the user's DID)
// Returns the subject or empty string if not found or token is invalid
// This does NOT validate the token - it only parses it to extract the claim
func ExtractSubject(tokenString string) string {
// Parse token without validation (we only need the claims, validation is done by distribution library)
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
token, _, err := parser.ParseUnverified(tokenString, &Claims{})
if err != nil {
return "" // Invalid token format
}
claims, ok := token.Claims.(*Claims)
if !ok {
return "" // Wrong claims type
}
return claims.Subject
}

View File

@@ -1,362 +0,0 @@
package token
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"time"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth"
"atcr.io/pkg/auth/oauth"
"github.com/bluesky-social/indigo/atproto/atclient"
indigo_oauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
)
// getErrorHint provides context-specific troubleshooting hints based on API error type
func getErrorHint(apiErr *atclient.APIError) string {
switch apiErr.Name {
case "use_dpop_nonce":
return "DPoP nonce mismatch - indigo library should automatically retry with new nonce. If this persists, check for concurrent request issues or PDS session corruption."
case "invalid_client":
if apiErr.Message != "" && apiErr.Message == "Validation of \"client_assertion\" failed: \"iat\" claim timestamp check failed (it should be in the past)" {
return "JWT timestamp validation failed - system clock on AppView may be ahead of PDS clock. Check NTP sync with: timedatectl status"
}
return "OAuth client authentication failed - check client key configuration and PDS OAuth server status"
case "invalid_token", "invalid_grant":
return "OAuth tokens expired or invalidated - user will need to re-authenticate via OAuth flow"
case "server_error":
if apiErr.StatusCode == 500 {
return "PDS returned internal server error - this may occur after repeated DPoP nonce failures or other PDS-side issues. Check PDS logs for root cause."
}
return "PDS server error - check PDS health and logs"
case "invalid_dpop_proof":
return "DPoP proof validation failed - check system clock sync and DPoP key configuration"
default:
if apiErr.StatusCode == 401 || apiErr.StatusCode == 403 {
return "Authentication/authorization failed - OAuth session may be expired or revoked"
}
return "PDS rejected the request - see errorName and errorMessage for details"
}
}
// GetOrFetchServiceToken gets a service token for hold authentication.
// Checks cache first, then fetches from PDS with OAuth/DPoP if needed.
// This is the canonical implementation used by both middleware and crew registration.
//
// IMPORTANT: Uses DoWithSession() to hold a per-DID lock through the entire PDS interaction.
// This prevents DPoP nonce race conditions when multiple Docker layers upload concurrently.
func GetOrFetchServiceToken(
ctx context.Context,
refresher *oauth.Refresher,
did, holdDID, pdsEndpoint string,
) (string, error) {
if refresher == nil {
return "", fmt.Errorf("refresher is nil (OAuth session required for service tokens)")
}
// Check cache first to avoid unnecessary PDS calls on every request
cachedToken, expiresAt := GetServiceToken(did, holdDID)
// Use cached token if it exists and has > 10s remaining
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
slog.Debug("Using cached service token",
"did", did,
"expiresIn", time.Until(expiresAt).Round(time.Second))
return cachedToken, nil
}
// Cache miss or expiring soon - validate OAuth and get new service token
if cachedToken == "" {
slog.Debug("Service token cache miss, fetching new token", "did", did)
} else {
slog.Debug("Service token expiring soon, proactively renewing", "did", did)
}
// Use DoWithSession to hold the lock through the entire PDS interaction.
// This prevents DPoP nonce races when multiple goroutines try to fetch service tokens.
var serviceToken string
var fetchErr error
err := refresher.DoWithSession(ctx, did, func(session *indigo_oauth.ClientSession) error {
// Double-check cache after acquiring lock - another goroutine may have
// populated it while we were waiting (classic double-checked locking pattern)
cachedToken, expiresAt := GetServiceToken(did, holdDID)
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
slog.Debug("Service token cache hit after lock acquisition",
"did", did,
"expiresIn", time.Until(expiresAt).Round(time.Second))
serviceToken = cachedToken
return nil
}
// Cache still empty/expired - proceed with PDS call
// Request 5-minute expiry (PDS may grant less)
// exp must be absolute Unix timestamp, not relative duration
// Note: OAuth scope includes #atcr_hold fragment, but service auth aud must be bare DID
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
pdsEndpoint,
atproto.ServerGetServiceAuth,
url.QueryEscape(holdDID),
url.QueryEscape("com.atproto.repo.getRecord"),
expiryTime,
)
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
if err != nil {
fetchErr = fmt.Errorf("failed to create service auth request: %w", err)
return fetchErr
}
// Use OAuth session to authenticate to PDS (with DPoP)
// The lock is held, so DPoP nonce negotiation is serialized per-DID
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
if err != nil {
// Auth error - may indicate expired tokens or corrupted session
InvalidateServiceToken(did, holdDID)
// Inspect the error to extract detailed information from indigo's APIError
var apiErr *atclient.APIError
if errors.As(err, &apiErr) {
// Log detailed API error information
slog.Error("OAuth authentication failed during service token request",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"url", serviceAuthURL,
"error", err,
"httpStatus", apiErr.StatusCode,
"errorName", apiErr.Name,
"errorMessage", apiErr.Message,
"hint", getErrorHint(apiErr))
} else {
// Fallback for non-API errors (network errors, etc.)
slog.Error("OAuth authentication failed during service token request",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"url", serviceAuthURL,
"error", err,
"errorType", fmt.Sprintf("%T", err),
"hint", "Network error or unexpected failure during OAuth request")
}
fetchErr = fmt.Errorf("OAuth validation failed: %w", err)
return fetchErr
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
// Service auth failed
bodyBytes, _ := io.ReadAll(resp.Body)
InvalidateServiceToken(did, holdDID)
slog.Error("Service token request returned non-200 status",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"statusCode", resp.StatusCode,
"responseBody", string(bodyBytes),
"hint", "PDS rejected the service token request - check PDS logs for details")
fetchErr = fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
return fetchErr
}
// Parse response to get service token
var result struct {
Token string `json:"token"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
fetchErr = fmt.Errorf("failed to decode service auth response: %w", err)
return fetchErr
}
if result.Token == "" {
fetchErr = fmt.Errorf("empty token in service auth response")
return fetchErr
}
serviceToken = result.Token
return nil
})
if err != nil {
// DoWithSession failed (session load or callback error)
InvalidateServiceToken(did, holdDID)
// Try to extract detailed error information
var apiErr *atclient.APIError
if errors.As(err, &apiErr) {
slog.Error("Failed to get OAuth session for service token",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"error", err,
"httpStatus", apiErr.StatusCode,
"errorName", apiErr.Name,
"errorMessage", apiErr.Message,
"hint", getErrorHint(apiErr))
} else if fetchErr == nil {
// Session load failed (not a fetch error)
slog.Error("Failed to get OAuth session for service token",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"error", err,
"errorType", fmt.Sprintf("%T", err),
"hint", "OAuth session not found in database or token refresh failed")
}
// Delete the stale OAuth session to force re-authentication
// This also invalidates the UI session automatically
if delErr := refresher.DeleteSession(ctx, did); delErr != nil {
slog.Warn("Failed to delete stale OAuth session",
"component", "token/servicetoken",
"did", did,
"error", delErr)
}
if fetchErr != nil {
return "", fetchErr
}
return "", fmt.Errorf("failed to get OAuth session: %w", err)
}
// Cache the token (parses JWT to extract actual expiry)
if err := SetServiceToken(did, holdDID, serviceToken); err != nil {
slog.Warn("Failed to cache service token", "error", err, "did", did, "holdDID", holdDID)
// Non-fatal - we have the token, just won't be cached
}
slog.Debug("OAuth validation succeeded, service token obtained", "did", did)
return serviceToken, nil
}
// GetOrFetchServiceTokenWithAppPassword gets a service token using app-password Bearer authentication.
// Used when auth method is app_password instead of OAuth.
func GetOrFetchServiceTokenWithAppPassword(
ctx context.Context,
did, holdDID, pdsEndpoint string,
) (string, error) {
// Check cache first to avoid unnecessary PDS calls on every request
cachedToken, expiresAt := GetServiceToken(did, holdDID)
// Use cached token if it exists and has > 10s remaining
if cachedToken != "" && time.Until(expiresAt) > 10*time.Second {
slog.Debug("Using cached service token (app-password)",
"did", did,
"expiresIn", time.Until(expiresAt).Round(time.Second))
return cachedToken, nil
}
// Cache miss or expiring soon - get app-password token and fetch new service token
if cachedToken == "" {
slog.Debug("Service token cache miss, fetching new token with app-password", "did", did)
} else {
slog.Debug("Service token expiring soon, proactively renewing with app-password", "did", did)
}
// Get app-password access token from cache
accessToken, ok := auth.GetGlobalTokenCache().Get(did)
if !ok {
InvalidateServiceToken(did, holdDID)
slog.Error("No app-password access token found in cache",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"hint", "User must re-authenticate with docker login")
return "", fmt.Errorf("no app-password access token available for DID %s", did)
}
// Call com.atproto.server.getServiceAuth on the user's PDS with Bearer token
// Request 5-minute expiry (PDS may grant less)
// exp must be absolute Unix timestamp, not relative duration
expiryTime := time.Now().Unix() + 300 // 5 minutes from now
serviceAuthURL := fmt.Sprintf("%s%s?aud=%s&lxm=%s&exp=%d",
pdsEndpoint,
atproto.ServerGetServiceAuth,
url.QueryEscape(holdDID),
url.QueryEscape("com.atproto.repo.getRecord"),
expiryTime,
)
req, err := http.NewRequestWithContext(ctx, "GET", serviceAuthURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create service auth request: %w", err)
}
// Set Bearer token authentication (app-password)
req.Header.Set("Authorization", "Bearer "+accessToken)
// Make request with standard HTTP client
resp, err := http.DefaultClient.Do(req)
if err != nil {
InvalidateServiceToken(did, holdDID)
slog.Error("App-password service token request failed",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"error", err)
return "", fmt.Errorf("failed to request service token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized {
// App-password token is invalid or expired - clear from cache
auth.GetGlobalTokenCache().Delete(did)
InvalidateServiceToken(did, holdDID)
slog.Error("App-password token rejected by PDS",
"component", "token/servicetoken",
"did", did,
"hint", "User must re-authenticate with docker login")
return "", fmt.Errorf("app-password authentication failed: token expired or invalid")
}
if resp.StatusCode != http.StatusOK {
// Service auth failed
bodyBytes, _ := io.ReadAll(resp.Body)
InvalidateServiceToken(did, holdDID)
slog.Error("Service token request returned non-200 status (app-password)",
"component", "token/servicetoken",
"did", did,
"holdDID", holdDID,
"pdsEndpoint", pdsEndpoint,
"statusCode", resp.StatusCode,
"responseBody", string(bodyBytes))
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
// Parse response to get service token
var result struct {
Token string `json:"token"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode service auth response: %w", err)
}
if result.Token == "" {
return "", fmt.Errorf("empty token in service auth response")
}
serviceToken := result.Token
// Cache the token (parses JWT to extract actual expiry)
if err := SetServiceToken(did, holdDID, serviceToken); err != nil {
slog.Warn("Failed to cache service token", "error", err, "did", did, "holdDID", holdDID)
// Non-fatal - we have the token, just won't be cached
}
slog.Debug("App-password validation succeeded, service token obtained", "did", did)
return serviceToken, nil
}

View File

@@ -1,27 +0,0 @@
package token
import (
"context"
"testing"
)
func TestGetOrFetchServiceToken_NilRefresher(t *testing.T) {
ctx := context.Background()
did := "did:plc:test123"
holdDID := "did:web:hold.example.com"
pdsEndpoint := "https://pds.example.com"
// Test with nil refresher - should return error
_, err := GetOrFetchServiceToken(ctx, nil, did, holdDID, pdsEndpoint)
if err == nil {
t.Error("Expected error when refresher is nil")
}
expectedErrMsg := "refresher is nil"
if err.Error() != "refresher is nil (OAuth session required for service tokens)" {
t.Errorf("Expected error message to contain %q, got %q", expectedErrMsg, err.Error())
}
}
// Note: Full tests with mocked OAuth refresher and HTTP client will be added
// in the comprehensive test implementation phase

784
pkg/auth/usercontext.go Normal file
View File

@@ -0,0 +1,784 @@
// Package auth provides UserContext for managing authenticated user state
// throughout request handling in the AppView.
package auth
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"sync"
"time"
"atcr.io/pkg/appview/db"
"atcr.io/pkg/atproto"
"atcr.io/pkg/auth/oauth"
)
// Auth method constants (duplicated from token package to avoid import cycle)
const (
AuthMethodOAuth = "oauth"
AuthMethodAppPassword = "app_password"
)
// RequestAction represents the type of registry operation
type RequestAction int
const (
ActionUnknown RequestAction = iota
ActionPull // GET/HEAD - reading from registry
ActionPush // PUT/POST/DELETE - writing to registry
ActionInspect // Metadata operations only
)
func (a RequestAction) String() string {
switch a {
case ActionPull:
return "pull"
case ActionPush:
return "push"
case ActionInspect:
return "inspect"
default:
return "unknown"
}
}
// HoldPermissions describes what the user can do on a specific hold
type HoldPermissions struct {
HoldDID string // Hold being checked
IsOwner bool // User is captain of this hold
IsCrew bool // User is a crew member
IsPublic bool // Hold allows public reads
CanRead bool // Computed: can user read blobs?
CanWrite bool // Computed: can user write blobs?
CanAdmin bool // Computed: can user manage crew?
Permissions []string // Raw permissions from crew record
}
// contextKey is unexported to prevent collisions
type contextKey struct{}
// userContextKey is the context key for UserContext
var userContextKey = contextKey{}
// userSetupCache tracks which users have had their profile/crew setup ensured
var userSetupCache sync.Map // did -> time.Time
// userSetupTTL is how long to cache user setup status (1 hour)
const userSetupTTL = 1 * time.Hour
// Dependencies bundles services needed by UserContext
type Dependencies struct {
Refresher *oauth.Refresher
Authorizer HoldAuthorizer
DefaultHoldDID string // AppView's default hold DID
}
// UserContext encapsulates authenticated user state for a request.
// Built early in the middleware chain and available throughout request processing.
//
// Two-phase initialization:
// 1. Middleware phase: Identity is set (DID, authMethod, action)
// 2. Repository() phase: Target is set via SetTarget() (owner, repo, holdDID)
type UserContext struct {
// === User Identity (set in middleware) ===
DID string // User's DID (empty if unauthenticated)
Handle string // User's handle (may be empty)
PDSEndpoint string // User's PDS endpoint
AuthMethod string // "oauth", "app_password", or ""
IsAuthenticated bool
// === Request Info ===
Action RequestAction
HTTPMethod string
// === Target Info (set by SetTarget) ===
TargetOwnerDID string // whose repo is being accessed
TargetOwnerHandle string
TargetOwnerPDS string
TargetRepo string // image name (e.g., "quickslice")
TargetHoldDID string // hold where blobs live/will live
// === Dependencies (injected) ===
refresher *oauth.Refresher
authorizer HoldAuthorizer
defaultHoldDID string
// === Cached State (lazy-loaded) ===
serviceTokens sync.Map // holdDID -> *serviceTokenEntry
permissions sync.Map // holdDID -> *HoldPermissions
pdsResolved bool
pdsResolveErr error
mu sync.Mutex // protects PDS resolution
atprotoClient *atproto.Client
atprotoClientOnce sync.Once
}
// FromContext retrieves UserContext from context.
// Returns nil if not present (unauthenticated or before middleware).
func FromContext(ctx context.Context) *UserContext {
uc, _ := ctx.Value(userContextKey).(*UserContext)
return uc
}
// WithUserContext adds UserContext to context
func WithUserContext(ctx context.Context, uc *UserContext) context.Context {
return context.WithValue(ctx, userContextKey, uc)
}
// NewUserContext creates a UserContext from extracted JWT claims.
// The deps parameter provides access to services needed for lazy operations.
func NewUserContext(did, authMethod, httpMethod string, deps *Dependencies) *UserContext {
action := ActionUnknown
switch httpMethod {
case "GET", "HEAD":
action = ActionPull
case "PUT", "POST", "PATCH", "DELETE":
action = ActionPush
}
var refresher *oauth.Refresher
var authorizer HoldAuthorizer
var defaultHoldDID string
if deps != nil {
refresher = deps.Refresher
authorizer = deps.Authorizer
defaultHoldDID = deps.DefaultHoldDID
}
return &UserContext{
DID: did,
AuthMethod: authMethod,
IsAuthenticated: did != "",
Action: action,
HTTPMethod: httpMethod,
refresher: refresher,
authorizer: authorizer,
defaultHoldDID: defaultHoldDID,
}
}
// SetPDS sets the user's PDS endpoint directly, bypassing network resolution.
// Use when PDS is already known (e.g., from previous resolution or client).
func (uc *UserContext) SetPDS(handle, pdsEndpoint string) {
uc.mu.Lock()
defer uc.mu.Unlock()
uc.Handle = handle
uc.PDSEndpoint = pdsEndpoint
uc.pdsResolved = true
uc.pdsResolveErr = nil
}
// SetTarget sets the target repository information.
// Called in Repository() after resolving the owner identity.
func (uc *UserContext) SetTarget(ownerDID, ownerHandle, ownerPDS, repo, holdDID string) {
uc.TargetOwnerDID = ownerDID
uc.TargetOwnerHandle = ownerHandle
uc.TargetOwnerPDS = ownerPDS
uc.TargetRepo = repo
uc.TargetHoldDID = holdDID
}
// ResolvePDS resolves the user's PDS endpoint (lazy, cached).
// Safe to call multiple times; resolution happens once.
func (uc *UserContext) ResolvePDS(ctx context.Context) error {
if !uc.IsAuthenticated {
return nil // Nothing to resolve for anonymous users
}
uc.mu.Lock()
defer uc.mu.Unlock()
if uc.pdsResolved {
return uc.pdsResolveErr
}
_, handle, pds, err := atproto.ResolveIdentity(ctx, uc.DID)
if err != nil {
uc.pdsResolveErr = err
uc.pdsResolved = true
return err
}
uc.Handle = handle
uc.PDSEndpoint = pds
uc.pdsResolved = true
return nil
}
// GetServiceToken returns a service token for the target hold.
// Uses internal caching with sync.Once per holdDID.
// Requires target to be set via SetTarget().
func (uc *UserContext) GetServiceToken(ctx context.Context) (string, error) {
if uc.TargetHoldDID == "" {
return "", fmt.Errorf("target hold not set (call SetTarget first)")
}
return uc.GetServiceTokenForHold(ctx, uc.TargetHoldDID)
}
// GetServiceTokenForHold returns a service token for an arbitrary hold.
// Uses internal caching with sync.Once per holdDID.
func (uc *UserContext) GetServiceTokenForHold(ctx context.Context, holdDID string) (string, error) {
if !uc.IsAuthenticated {
return "", fmt.Errorf("cannot get service token: user not authenticated")
}
// Ensure PDS is resolved
if err := uc.ResolvePDS(ctx); err != nil {
return "", fmt.Errorf("failed to resolve PDS: %w", err)
}
// Load or create cache entry
entryVal, _ := uc.serviceTokens.LoadOrStore(holdDID, &serviceTokenEntry{})
entry := entryVal.(*serviceTokenEntry)
entry.once.Do(func() {
slog.Debug("Fetching service token",
"component", "auth/context",
"userDID", uc.DID,
"holdDID", holdDID,
"authMethod", uc.AuthMethod)
// Use unified service token function (handles both OAuth and app-password)
serviceToken, err := GetOrFetchServiceToken(
ctx, uc.AuthMethod, uc.refresher, uc.DID, holdDID, uc.PDSEndpoint,
)
entry.token = serviceToken
entry.err = err
if err == nil {
// Parse JWT to get expiry
expiry, parseErr := ParseJWTExpiry(serviceToken)
if parseErr == nil {
entry.expiresAt = expiry.Add(-10 * time.Second) // Safety margin
} else {
entry.expiresAt = time.Now().Add(45 * time.Second) // Default fallback
}
}
})
return entry.token, entry.err
}
// CanRead checks if user can read blobs from target hold.
// - Public hold: any user (even anonymous)
// - Private hold: owner OR crew with blob:read/blob:write
func (uc *UserContext) CanRead(ctx context.Context) (bool, error) {
if uc.TargetHoldDID == "" {
return false, fmt.Errorf("target hold not set (call SetTarget first)")
}
if uc.authorizer == nil {
return false, fmt.Errorf("authorizer not configured")
}
return uc.authorizer.CheckReadAccess(ctx, uc.TargetHoldDID, uc.DID)
}
// CanWrite checks if user can write blobs to target hold.
// - Must be authenticated
// - Must be owner OR crew with blob:write
func (uc *UserContext) CanWrite(ctx context.Context) (bool, error) {
if uc.TargetHoldDID == "" {
return false, fmt.Errorf("target hold not set (call SetTarget first)")
}
if !uc.IsAuthenticated {
return false, nil // Anonymous writes never allowed
}
if uc.authorizer == nil {
return false, fmt.Errorf("authorizer not configured")
}
return uc.authorizer.CheckWriteAccess(ctx, uc.TargetHoldDID, uc.DID)
}
// GetPermissions returns detailed permissions for target hold.
// Lazy-loaded and cached per holdDID.
func (uc *UserContext) GetPermissions(ctx context.Context) (*HoldPermissions, error) {
if uc.TargetHoldDID == "" {
return nil, fmt.Errorf("target hold not set (call SetTarget first)")
}
return uc.GetPermissionsForHold(ctx, uc.TargetHoldDID)
}
// GetPermissionsForHold returns detailed permissions for an arbitrary hold.
// Lazy-loaded and cached per holdDID.
func (uc *UserContext) GetPermissionsForHold(ctx context.Context, holdDID string) (*HoldPermissions, error) {
// Check cache first
if cached, ok := uc.permissions.Load(holdDID); ok {
return cached.(*HoldPermissions), nil
}
if uc.authorizer == nil {
return nil, fmt.Errorf("authorizer not configured")
}
// Build permissions by querying authorizer
captain, err := uc.authorizer.GetCaptainRecord(ctx, holdDID)
if err != nil {
return nil, fmt.Errorf("failed to get captain record: %w", err)
}
perms := &HoldPermissions{
HoldDID: holdDID,
IsPublic: captain.Public,
IsOwner: uc.DID != "" && uc.DID == captain.Owner,
}
// Check crew membership if authenticated and not owner
if uc.IsAuthenticated && !perms.IsOwner {
isCrew, crewErr := uc.authorizer.IsCrewMember(ctx, holdDID, uc.DID)
if crewErr != nil {
slog.Warn("Failed to check crew membership",
"component", "auth/context",
"holdDID", holdDID,
"userDID", uc.DID,
"error", crewErr)
}
perms.IsCrew = isCrew
}
// Compute permissions based on role
if perms.IsOwner {
perms.CanRead = true
perms.CanWrite = true
perms.CanAdmin = true
} else if perms.IsCrew {
// Crew members can read and write (for now, all crew have blob:write)
// TODO: Check specific permissions from crew record
perms.CanRead = true
perms.CanWrite = true
perms.CanAdmin = false
} else if perms.IsPublic {
// Public hold - anyone can read
perms.CanRead = true
perms.CanWrite = false
perms.CanAdmin = false
} else if uc.IsAuthenticated {
// Private hold, authenticated non-crew
// Per permission matrix: cannot read private holds
perms.CanRead = false
perms.CanWrite = false
perms.CanAdmin = false
} else {
// Anonymous on private hold
perms.CanRead = false
perms.CanWrite = false
perms.CanAdmin = false
}
// Cache and return
uc.permissions.Store(holdDID, perms)
return perms, nil
}
// IsCrewMember checks if user is crew of target hold.
func (uc *UserContext) IsCrewMember(ctx context.Context) (bool, error) {
if uc.TargetHoldDID == "" {
return false, fmt.Errorf("target hold not set (call SetTarget first)")
}
if !uc.IsAuthenticated {
return false, nil
}
if uc.authorizer == nil {
return false, fmt.Errorf("authorizer not configured")
}
return uc.authorizer.IsCrewMember(ctx, uc.TargetHoldDID, uc.DID)
}
// EnsureCrewMembership is a standalone function to register as crew on a hold.
// Use this when you don't have a UserContext (e.g., OAuth callback).
// This is best-effort and logs errors without failing.
func EnsureCrewMembership(ctx context.Context, did, pdsEndpoint string, refresher *oauth.Refresher, holdDID string) {
if holdDID == "" {
return
}
// Only works with OAuth (refresher required) - app passwords can't get service tokens
if refresher == nil {
slog.Debug("skipping crew registration - no OAuth refresher (app password flow)", "holdDID", holdDID)
return
}
// Normalize URL to DID if needed
if !atproto.IsDID(holdDID) {
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
if holdDID == "" {
slog.Warn("failed to resolve hold DID", "defaultHold", holdDID)
return
}
}
// Get service token for the hold (OAuth only at this point)
serviceToken, err := GetOrFetchServiceToken(ctx, AuthMethodOAuth, refresher, did, holdDID, pdsEndpoint)
if err != nil {
slog.Warn("failed to get service token", "holdDID", holdDID, "error", err)
return
}
// Resolve hold DID to HTTP endpoint
holdEndpoint := atproto.ResolveHoldURL(holdDID)
if holdEndpoint == "" {
slog.Warn("failed to resolve hold endpoint", "holdDID", holdDID)
return
}
// Call requestCrew endpoint
if err := requestCrewMembership(ctx, holdEndpoint, serviceToken); err != nil {
slog.Warn("failed to request crew membership", "holdDID", holdDID, "error", err)
return
}
slog.Info("successfully registered as crew member", "holdDID", holdDID, "userDID", did)
}
// ensureCrewMembership attempts to register as crew on target hold (UserContext method).
// Called automatically during first push; idempotent.
// This is a best-effort operation and logs errors without failing.
// Requires SetTarget() to be called first.
func (uc *UserContext) ensureCrewMembership(ctx context.Context) error {
if uc.TargetHoldDID == "" {
return fmt.Errorf("target hold not set (call SetTarget first)")
}
return uc.EnsureCrewMembershipForHold(ctx, uc.TargetHoldDID)
}
// EnsureCrewMembershipForHold attempts to register as crew on the specified hold.
// This is the core implementation that can be called with any holdDID.
// Called automatically during first push; idempotent.
// This is a best-effort operation and logs errors without failing.
func (uc *UserContext) EnsureCrewMembershipForHold(ctx context.Context, holdDID string) error {
if holdDID == "" {
return nil // Nothing to do
}
// Normalize URL to DID if needed
if !atproto.IsDID(holdDID) {
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
if holdDID == "" {
return fmt.Errorf("failed to resolve hold DID from URL")
}
}
if !uc.IsAuthenticated {
return fmt.Errorf("cannot register as crew: user not authenticated")
}
if uc.refresher == nil {
return fmt.Errorf("cannot register as crew: OAuth session required")
}
// Get service token for the hold
serviceToken, err := uc.GetServiceTokenForHold(ctx, holdDID)
if err != nil {
return fmt.Errorf("failed to get service token: %w", err)
}
// Resolve hold DID to HTTP endpoint
holdEndpoint := atproto.ResolveHoldURL(holdDID)
if holdEndpoint == "" {
return fmt.Errorf("failed to resolve hold endpoint for %s", holdDID)
}
// Call requestCrew endpoint
return requestCrewMembership(ctx, holdEndpoint, serviceToken)
}
// requestCrewMembership calls the hold's requestCrew endpoint
// The endpoint handles all authorization and duplicate checking internally
func requestCrewMembership(ctx context.Context, holdEndpoint, serviceToken string) error {
// Add 5 second timeout to prevent hanging on offline holds
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
url := fmt.Sprintf("%s%s", holdEndpoint, atproto.HoldRequestCrew)
req, err := http.NewRequestWithContext(ctx, "POST", url, nil)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+serviceToken)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
// Read response body to capture actual error message from hold
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return fmt.Errorf("requestCrew failed with status %d (failed to read error body: %w)", resp.StatusCode, readErr)
}
return fmt.Errorf("requestCrew failed with status %d: %s", resp.StatusCode, string(body))
}
return nil
}
// GetUserClient returns an authenticated ATProto client for the user's own PDS.
// Used for profile operations (reading/writing to user's own repo).
// Returns nil if not authenticated or PDS not resolved.
func (uc *UserContext) GetUserClient() *atproto.Client {
if !uc.IsAuthenticated || uc.PDSEndpoint == "" {
return nil
}
if uc.AuthMethod == AuthMethodOAuth && uc.refresher != nil {
return atproto.NewClientWithSessionProvider(uc.PDSEndpoint, uc.DID, uc.refresher)
} else if uc.AuthMethod == AuthMethodAppPassword {
accessToken, _ := GetGlobalTokenCache().Get(uc.DID)
return atproto.NewClient(uc.PDSEndpoint, uc.DID, accessToken)
}
return nil
}
// EnsureUserSetup ensures the user has a profile and crew membership.
// Called once per user (cached for userSetupTTL). Runs in background - does not block.
// Safe to call on every request.
func (uc *UserContext) EnsureUserSetup() {
if !uc.IsAuthenticated || uc.DID == "" {
return
}
// Check cache - skip if recently set up
if lastSetup, ok := userSetupCache.Load(uc.DID); ok {
if time.Since(lastSetup.(time.Time)) < userSetupTTL {
return
}
}
// Run in background to avoid blocking requests
go func() {
bgCtx := context.Background()
// 1. Ensure profile exists
if client := uc.GetUserClient(); client != nil {
uc.ensureProfile(bgCtx, client)
}
// 2. Ensure crew membership on default hold
if uc.defaultHoldDID != "" {
EnsureCrewMembership(bgCtx, uc.DID, uc.PDSEndpoint, uc.refresher, uc.defaultHoldDID)
}
// Mark as set up
userSetupCache.Store(uc.DID, time.Now())
slog.Debug("User setup complete",
"component", "auth/usercontext",
"did", uc.DID,
"defaultHoldDID", uc.defaultHoldDID)
}()
}
// ensureProfile creates sailor profile if it doesn't exist.
// Inline implementation to avoid circular import with storage package.
func (uc *UserContext) ensureProfile(ctx context.Context, client *atproto.Client) {
// Check if profile already exists
profile, err := client.GetRecord(ctx, atproto.SailorProfileCollection, "self")
if err == nil && profile != nil {
return // Already exists
}
// Create profile with default hold
normalizedDID := ""
if uc.defaultHoldDID != "" {
normalizedDID = atproto.ResolveHoldDIDFromURL(uc.defaultHoldDID)
}
newProfile := atproto.NewSailorProfileRecord(normalizedDID)
if _, err := client.PutRecord(ctx, atproto.SailorProfileCollection, "self", newProfile); err != nil {
slog.Warn("Failed to create sailor profile",
"component", "auth/usercontext",
"did", uc.DID,
"error", err)
return
}
slog.Debug("Created sailor profile",
"component", "auth/usercontext",
"did", uc.DID,
"defaultHold", normalizedDID)
}
// GetATProtoClient returns a cached ATProto client for the target owner's PDS.
// Authenticated if user is owner, otherwise anonymous.
// Cached per-request (uses sync.Once).
func (uc *UserContext) GetATProtoClient() *atproto.Client {
uc.atprotoClientOnce.Do(func() {
if uc.TargetOwnerPDS == "" {
return
}
// If puller is owner and authenticated, use authenticated client
if uc.DID == uc.TargetOwnerDID && uc.IsAuthenticated {
if uc.AuthMethod == AuthMethodOAuth && uc.refresher != nil {
uc.atprotoClient = atproto.NewClientWithSessionProvider(uc.TargetOwnerPDS, uc.TargetOwnerDID, uc.refresher)
return
} else if uc.AuthMethod == AuthMethodAppPassword {
accessToken, _ := GetGlobalTokenCache().Get(uc.TargetOwnerDID)
uc.atprotoClient = atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, accessToken)
return
}
}
// Anonymous client for reads
uc.atprotoClient = atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, "")
})
return uc.atprotoClient
}
// ResolveHoldDID finds the hold for the target repository.
// - Pull: uses database lookup (historical from manifest)
// - Push: uses discovery (sailor profile → default)
//
// Must be called after SetTarget() is called with at least TargetOwnerDID and TargetRepo set.
// Updates TargetHoldDID on success.
func (uc *UserContext) ResolveHoldDID(ctx context.Context, sqlDB *sql.DB) (string, error) {
if uc.TargetOwnerDID == "" {
return "", fmt.Errorf("target owner not set")
}
var holdDID string
var err error
switch uc.Action {
case ActionPull:
// For pulls, look up historical hold from database
holdDID, err = uc.resolveHoldForPull(ctx, sqlDB)
case ActionPush:
// For pushes, discover hold from owner's profile
holdDID, err = uc.resolveHoldForPush(ctx)
default:
// Default to push discovery
holdDID, err = uc.resolveHoldForPush(ctx)
}
if err != nil {
return "", err
}
if holdDID == "" {
return "", fmt.Errorf("no hold DID found for %s/%s", uc.TargetOwnerDID, uc.TargetRepo)
}
uc.TargetHoldDID = holdDID
return holdDID, nil
}
// resolveHoldForPull looks up the hold from the database (historical reference)
func (uc *UserContext) resolveHoldForPull(ctx context.Context, sqlDB *sql.DB) (string, error) {
// If no database is available, fall back to discovery
if sqlDB == nil {
return uc.resolveHoldForPush(ctx)
}
// Try database lookup first
holdDID, err := db.GetLatestHoldDIDForRepo(sqlDB, uc.TargetOwnerDID, uc.TargetRepo)
if err != nil {
slog.Debug("Database lookup failed, falling back to discovery",
"component", "auth/context",
"ownerDID", uc.TargetOwnerDID,
"repo", uc.TargetRepo,
"error", err)
return uc.resolveHoldForPush(ctx)
}
if holdDID != "" {
return holdDID, nil
}
// No historical hold found, fall back to discovery
return uc.resolveHoldForPush(ctx)
}
// resolveHoldForPush discovers hold from owner's sailor profile or default
func (uc *UserContext) resolveHoldForPush(ctx context.Context) (string, error) {
// Create anonymous client to query owner's profile
client := atproto.NewClient(uc.TargetOwnerPDS, uc.TargetOwnerDID, "")
// Try to get owner's sailor profile
record, err := client.GetRecord(ctx, atproto.SailorProfileCollection, "self")
if err == nil && record != nil {
var profile atproto.SailorProfileRecord
if jsonErr := json.Unmarshal(record.Value, &profile); jsonErr == nil {
if profile.DefaultHold != "" {
// Normalize to DID if needed
holdDID := profile.DefaultHold
if !atproto.IsDID(holdDID) {
holdDID = atproto.ResolveHoldDIDFromURL(holdDID)
}
slog.Debug("Found hold from owner's profile",
"component", "auth/context",
"ownerDID", uc.TargetOwnerDID,
"holdDID", holdDID)
return holdDID, nil
}
}
}
// Fall back to default hold
if uc.defaultHoldDID != "" {
slog.Debug("Using default hold",
"component", "auth/context",
"ownerDID", uc.TargetOwnerDID,
"defaultHoldDID", uc.defaultHoldDID)
return uc.defaultHoldDID, nil
}
return "", fmt.Errorf("no hold configured for %s and no default hold set", uc.TargetOwnerDID)
}
// =============================================================================
// Test Helper Methods
// =============================================================================
// These methods are designed to make UserContext testable by allowing tests
// to bypass network-dependent code paths (PDS resolution, OAuth token fetching).
// Only use these in tests - they are not intended for production use.
// SetPDSForTest sets the PDS endpoint directly, bypassing ResolvePDS network calls.
// This allows tests to skip DID resolution which would make network requests.
// Deprecated: Use SetPDS instead.
func (uc *UserContext) SetPDSForTest(handle, pdsEndpoint string) {
uc.SetPDS(handle, pdsEndpoint)
}
// SetServiceTokenForTest pre-populates a service token for the given holdDID,
// bypassing the sync.Once and OAuth/app-password fetching logic.
// The token will appear as if it was already fetched and cached.
func (uc *UserContext) SetServiceTokenForTest(holdDID, token string) {
entry := &serviceTokenEntry{
token: token,
expiresAt: time.Now().Add(5 * time.Minute),
err: nil,
}
// Mark the sync.Once as done so real fetch won't happen
entry.once.Do(func() {})
uc.serviceTokens.Store(holdDID, entry)
}
// SetAuthorizerForTest sets the authorizer for permission checks.
// Use with MockHoldAuthorizer to control CanRead/CanWrite behavior in tests.
func (uc *UserContext) SetAuthorizerForTest(authorizer HoldAuthorizer) {
uc.authorizer = authorizer
}
// SetDefaultHoldDIDForTest sets the default hold DID for tests.
// This is used as fallback when resolving hold for push operations.
func (uc *UserContext) SetDefaultHoldDIDForTest(holdDID string) {
uc.defaultHoldDID = holdDID
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
@@ -18,6 +19,44 @@ import (
"github.com/golang-jwt/jwt/v5"
)
// Authentication errors
var (
ErrMissingAuthHeader = errors.New("missing Authorization header")
ErrInvalidAuthFormat = errors.New("invalid Authorization header format")
ErrInvalidAuthScheme = errors.New("invalid authorization scheme: expected 'Bearer' or 'DPoP'")
ErrMissingToken = errors.New("missing token")
ErrMissingDPoPHeader = errors.New("missing DPoP header")
)
// JWT validation errors
var (
ErrInvalidJWTFormat = errors.New("invalid JWT format: expected header.payload.signature")
ErrMissingISSClaim = errors.New("missing 'iss' claim in token")
ErrMissingSubClaim = errors.New("missing 'sub' claim in token")
ErrTokenExpired = errors.New("token has expired")
)
// AuthError provides structured authorization error information
type AuthError struct {
Action string // The action being attempted: "blob:read", "blob:write", "crew:admin"
Reason string // Why access was denied
Required []string // What permission(s) would grant access
}
func (e *AuthError) Error() string {
return fmt.Sprintf("access denied for %s: %s (required: %s)",
e.Action, e.Reason, strings.Join(e.Required, " or "))
}
// NewAuthError creates a new AuthError
func NewAuthError(action, reason string, required ...string) *AuthError {
return &AuthError{
Action: action,
Reason: reason,
Required: required,
}
}
// HTTPClient interface allows injecting a custom HTTP client for testing
type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
@@ -44,13 +83,13 @@ func ValidateDPoPRequest(r *http.Request, httpClient HTTPClient) (*ValidatedUser
// Extract Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return nil, fmt.Errorf("missing Authorization header")
return nil, ErrMissingAuthHeader
}
// Check for DPoP authorization scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid Authorization header format")
return nil, ErrInvalidAuthFormat
}
if parts[0] != "DPoP" {
@@ -59,13 +98,13 @@ func ValidateDPoPRequest(r *http.Request, httpClient HTTPClient) (*ValidatedUser
accessToken := parts[1]
if accessToken == "" {
return nil, fmt.Errorf("missing access token")
return nil, ErrMissingToken
}
// Extract DPoP header
dpopProof := r.Header.Get("DPoP")
if dpopProof == "" {
return nil, fmt.Errorf("missing DPoP header")
return nil, ErrMissingDPoPHeader
}
// TODO: We could verify the DPoP proof locally (signature, HTM, HTU, etc.)
@@ -109,7 +148,7 @@ func extractDIDFromToken(token string) (string, string, error) {
// JWT format: header.payload.signature
parts := strings.Split(token, ".")
if len(parts) != 3 {
return "", "", fmt.Errorf("invalid JWT format")
return "", "", ErrInvalidJWTFormat
}
// Decode payload (base64url)
@@ -129,11 +168,11 @@ func extractDIDFromToken(token string) (string, string, error) {
}
if claims.Sub == "" {
return "", "", fmt.Errorf("missing sub claim (DID)")
return "", "", ErrMissingSubClaim
}
if claims.Iss == "" {
return "", "", fmt.Errorf("missing iss claim (PDS)")
return "", "", ErrMissingISSClaim
}
return claims.Sub, claims.Iss, nil
@@ -216,7 +255,7 @@ func ValidateOwnerOrCrewAdmin(r *http.Request, pds *HoldPDS, httpClient HTTPClie
return nil, fmt.Errorf("DPoP authentication failed: %w", err)
}
} else {
return nil, fmt.Errorf("missing or invalid Authorization header (expected Bearer or DPoP)")
return nil, ErrInvalidAuthScheme
}
// Get captain record to check owner
@@ -243,12 +282,12 @@ func ValidateOwnerOrCrewAdmin(r *http.Request, pds *HoldPDS, httpClient HTTPClie
return user, nil
}
// User is crew but doesn't have admin permission
return nil, fmt.Errorf("crew member lacks required 'crew:admin' permission")
return nil, NewAuthError("crew:admin", "crew member lacks permission", "crew:admin")
}
}
// User is neither owner nor authorized crew
return nil, fmt.Errorf("user is not authorized (must be hold owner or crew admin)")
return nil, NewAuthError("crew:admin", "user is not a crew member", "crew:admin")
}
// ValidateBlobWriteAccess validates that the request has valid authentication
@@ -276,7 +315,7 @@ func ValidateBlobWriteAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClien
return nil, fmt.Errorf("DPoP authentication failed: %w", err)
}
} else {
return nil, fmt.Errorf("missing or invalid Authorization header (expected Bearer or DPoP)")
return nil, ErrInvalidAuthScheme
}
// Get captain record to check owner and public settings
@@ -303,17 +342,18 @@ func ValidateBlobWriteAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClien
return user, nil
}
// User is crew but doesn't have write permission
return nil, fmt.Errorf("crew member lacks required 'blob:write' permission")
return nil, NewAuthError("blob:write", "crew member lacks permission", "blob:write")
}
}
// User is neither owner nor authorized crew
return nil, fmt.Errorf("user is not authorized for blob write (must be hold owner or crew with blob:write permission)")
return nil, NewAuthError("blob:write", "user is not a crew member", "blob:write")
}
// ValidateBlobReadAccess validates that the request has read access to blobs
// If captain.public = true: No auth required (returns nil user to indicate public access)
// If captain.public = false: Requires valid DPoP + OAuth and (captain OR crew with blob:read permission).
// If captain.public = false: Requires valid DPoP + OAuth and (captain OR crew with blob:read or blob:write permission).
// Note: blob:write implicitly grants blob:read access.
// The httpClient parameter is optional and defaults to http.DefaultClient if nil.
func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient) (*ValidatedUser, error) {
// Get captain record to check public setting
@@ -344,7 +384,7 @@ func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient
return nil, fmt.Errorf("DPoP authentication failed: %w", err)
}
} else {
return nil, fmt.Errorf("missing or invalid Authorization header (expected Bearer or DPoP)")
return nil, ErrInvalidAuthScheme
}
// Check if user is the owner (always has read access)
@@ -352,7 +392,8 @@ func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient
return user, nil
}
// Check if user is crew with blob:read permission
// Check if user is crew with blob:read or blob:write permission
// Note: blob:write implicitly grants blob:read access
crew, err := pds.ListCrewMembers(r.Context())
if err != nil {
return nil, fmt.Errorf("failed to check crew membership: %w", err)
@@ -360,17 +401,19 @@ func ValidateBlobReadAccess(r *http.Request, pds *HoldPDS, httpClient HTTPClient
for _, member := range crew {
if member.Record.Member == user.DID {
// Check if this crew member has blob:read permission
if slices.Contains(member.Record.Permissions, "blob:read") {
// Check if this crew member has blob:read or blob:write permission
// blob:write implicitly grants read access (can't push without pulling)
if slices.Contains(member.Record.Permissions, "blob:read") ||
slices.Contains(member.Record.Permissions, "blob:write") {
return user, nil
}
// User is crew but doesn't have read permission
return nil, fmt.Errorf("crew member lacks required 'blob:read' permission")
// User is crew but doesn't have read or write permission
return nil, NewAuthError("blob:read", "crew member lacks permission", "blob:read", "blob:write")
}
}
// User is neither owner nor authorized crew
return nil, fmt.Errorf("user is not authorized for blob read (must be hold owner or crew with blob:read permission)")
return nil, NewAuthError("blob:read", "user is not a crew member", "blob:read", "blob:write")
}
// ServiceTokenClaims represents the claims in a service token JWT
@@ -385,13 +428,13 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
// Extract Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return nil, fmt.Errorf("missing Authorization header")
return nil, ErrMissingAuthHeader
}
// Check for Bearer authorization scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid Authorization header format")
return nil, ErrInvalidAuthFormat
}
if parts[0] != "Bearer" {
@@ -400,7 +443,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
tokenString := parts[1]
if tokenString == "" {
return nil, fmt.Errorf("missing token")
return nil, ErrMissingToken
}
slog.Debug("Validating service token", "holdDID", holdDID)
@@ -409,7 +452,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
// Split token: header.payload.signature
tokenParts := strings.Split(tokenString, ".")
if len(tokenParts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
return nil, ErrInvalidJWTFormat
}
// Decode payload (second part) to extract claims
@@ -427,7 +470,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
// Get issuer (user DID)
issuerDID := claims.Issuer
if issuerDID == "" {
return nil, fmt.Errorf("missing iss claim")
return nil, ErrMissingISSClaim
}
// Verify audience matches this hold service
@@ -445,7 +488,7 @@ func ValidateServiceToken(r *http.Request, holdDID string, httpClient HTTPClient
return nil, fmt.Errorf("failed to get expiration: %w", err)
}
if exp != nil && time.Now().After(exp.Time) {
return nil, fmt.Errorf("token has expired")
return nil, ErrTokenExpired
}
// Verify JWT signature using ATProto's secp256k1 crypto

View File

@@ -771,6 +771,116 @@ func TestValidateBlobReadAccess_PrivateHold(t *testing.T) {
}
}
// TestValidateBlobReadAccess_BlobWriteImpliesRead tests that blob:write grants read access
func TestValidateBlobReadAccess_BlobWriteImpliesRead(t *testing.T) {
ownerDID := "did:plc:owner123"
pds, ctx := setupTestPDSWithBootstrap(t, ownerDID, false, false)
// Verify captain record has public=false (private hold)
_, captain, err := pds.GetCaptainRecord(ctx)
if err != nil {
t.Fatalf("Failed to get captain record: %v", err)
}
if captain.Public {
t.Error("Expected public=false for captain record")
}
// Add crew member with ONLY blob:write permission (no blob:read)
writerDID := "did:plc:writer123"
_, err = pds.AddCrewMember(ctx, writerDID, "writer", []string{"blob:write"})
if err != nil {
t.Fatalf("Failed to add crew writer: %v", err)
}
mockClient := &mockPDSClient{}
// Test writer (has only blob:write permission) can read
t.Run("crew with blob:write can read", func(t *testing.T) {
dpopHelper, err := NewDPoPTestHelper(writerDID, "https://test-pds.example.com")
if err != nil {
t.Fatalf("Failed to create DPoP helper: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/test", nil)
if err := dpopHelper.AddDPoPToRequest(req); err != nil {
t.Fatalf("Failed to add DPoP to request: %v", err)
}
// This should SUCCEED because blob:write implies blob:read
user, err := ValidateBlobReadAccess(req, pds, mockClient)
if err != nil {
t.Errorf("Expected blob:write to grant read access, got error: %v", err)
}
if user == nil {
t.Error("Expected user to be returned for valid read access")
} else if user.DID != writerDID {
t.Errorf("Expected user DID %s, got %s", writerDID, user.DID)
}
})
// Also verify that crew with only blob:read still works
t.Run("crew with blob:read can read", func(t *testing.T) {
readerDID := "did:plc:reader123"
_, err = pds.AddCrewMember(ctx, readerDID, "reader", []string{"blob:read"})
if err != nil {
t.Fatalf("Failed to add crew reader: %v", err)
}
dpopHelper, err := NewDPoPTestHelper(readerDID, "https://test-pds.example.com")
if err != nil {
t.Fatalf("Failed to create DPoP helper: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/test", nil)
if err := dpopHelper.AddDPoPToRequest(req); err != nil {
t.Fatalf("Failed to add DPoP to request: %v", err)
}
user, err := ValidateBlobReadAccess(req, pds, mockClient)
if err != nil {
t.Errorf("Expected blob:read to grant read access, got error: %v", err)
}
if user == nil {
t.Error("Expected user to be returned for valid read access")
} else if user.DID != readerDID {
t.Errorf("Expected user DID %s, got %s", readerDID, user.DID)
}
})
// Verify crew with neither permission cannot read
t.Run("crew without read or write cannot read", func(t *testing.T) {
noPermDID := "did:plc:noperm123"
_, err = pds.AddCrewMember(ctx, noPermDID, "noperm", []string{"crew:admin"})
if err != nil {
t.Fatalf("Failed to add crew member: %v", err)
}
dpopHelper, err := NewDPoPTestHelper(noPermDID, "https://test-pds.example.com")
if err != nil {
t.Fatalf("Failed to create DPoP helper: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/test", nil)
if err := dpopHelper.AddDPoPToRequest(req); err != nil {
t.Fatalf("Failed to add DPoP to request: %v", err)
}
_, err = ValidateBlobReadAccess(req, pds, mockClient)
if err == nil {
t.Error("Expected error for crew without read or write permission")
}
// Verify error message format
if !strings.Contains(err.Error(), "access denied for blob:read") {
t.Errorf("Expected structured error message, got: %v", err)
}
})
}
// TestValidateOwnerOrCrewAdmin tests admin permission checking
func TestValidateOwnerOrCrewAdmin(t *testing.T) {
ownerDID := "did:plc:owner123"

View File

@@ -18,8 +18,8 @@ const (
// CreateCaptainRecord creates the captain record for the hold (first-time only).
// This will FAIL if the captain record already exists. Use UpdateCaptainRecord to modify.
func (p *HoldPDS) CreateCaptainRecord(ctx context.Context, ownerDID string, public bool, allowAllCrew bool, enableBlueskyPosts bool) (cid.Cid, error) {
captainRecord := &atproto.HoldCaptain{
LexiconTypeID: atproto.CaptainCollection,
captainRecord := &atproto.CaptainRecord{
Type: atproto.CaptainCollection,
Owner: ownerDID,
Public: public,
AllowAllCrew: allowAllCrew,
@@ -40,7 +40,7 @@ func (p *HoldPDS) CreateCaptainRecord(ctx context.Context, ownerDID string, publ
}
// GetCaptainRecord retrieves the captain record
func (p *HoldPDS) GetCaptainRecord(ctx context.Context) (cid.Cid, *atproto.HoldCaptain, error) {
func (p *HoldPDS) GetCaptainRecord(ctx context.Context) (cid.Cid, *atproto.CaptainRecord, error) {
// Use repomgr.GetRecord - our types are registered in init()
// so it will automatically unmarshal to the concrete type
recordCID, val, err := p.repomgr.GetRecord(ctx, p.uid, atproto.CaptainCollection, CaptainRkey, cid.Undef)
@@ -49,7 +49,7 @@ func (p *HoldPDS) GetCaptainRecord(ctx context.Context) (cid.Cid, *atproto.HoldC
}
// Type assert to our concrete type
captainRecord, ok := val.(*atproto.HoldCaptain)
captainRecord, ok := val.(*atproto.CaptainRecord)
if !ok {
return cid.Undef, nil, fmt.Errorf("unexpected type for captain record: %T", val)
}

View File

@@ -12,11 +12,6 @@ import (
"atcr.io/pkg/atproto"
)
// ptrString returns a pointer to the given string
func ptrString(s string) *string {
return &s
}
// setupTestPDS creates a test PDS instance in a temporary directory
// It initializes the repo but does NOT create captain/crew records
// Tests should call Bootstrap or create records as needed
@@ -151,8 +146,8 @@ func TestCreateCaptainRecord(t *testing.T) {
if captain.EnableBlueskyPosts != tt.enableBlueskyPosts {
t.Errorf("Expected enableBlueskyPosts=%v, got %v", tt.enableBlueskyPosts, captain.EnableBlueskyPosts)
}
if captain.LexiconTypeID != atproto.CaptainCollection {
t.Errorf("Expected type %s, got %s", atproto.CaptainCollection, captain.LexiconTypeID)
if captain.Type != atproto.CaptainCollection {
t.Errorf("Expected type %s, got %s", atproto.CaptainCollection, captain.Type)
}
if captain.DeployedAt == "" {
t.Error("Expected deployedAt to be set")
@@ -327,40 +322,40 @@ func TestUpdateCaptainRecord_NotFound(t *testing.T) {
func TestCaptainRecord_CBORRoundtrip(t *testing.T) {
tests := []struct {
name string
record *atproto.HoldCaptain
record *atproto.CaptainRecord
}{
{
name: "Basic captain",
record: &atproto.HoldCaptain{
LexiconTypeID: atproto.CaptainCollection,
Owner: "did:plc:alice123",
Public: true,
AllowAllCrew: false,
DeployedAt: "2025-10-16T12:00:00Z",
record: &atproto.CaptainRecord{
Type: atproto.CaptainCollection,
Owner: "did:plc:alice123",
Public: true,
AllowAllCrew: false,
DeployedAt: "2025-10-16T12:00:00Z",
},
},
{
name: "Captain with optional fields",
record: &atproto.HoldCaptain{
LexiconTypeID: atproto.CaptainCollection,
Owner: "did:plc:bob456",
Public: false,
AllowAllCrew: true,
DeployedAt: "2025-10-16T12:00:00Z",
Region: ptrString("us-west-2"),
Provider: ptrString("fly.io"),
record: &atproto.CaptainRecord{
Type: atproto.CaptainCollection,
Owner: "did:plc:bob456",
Public: false,
AllowAllCrew: true,
DeployedAt: "2025-10-16T12:00:00Z",
Region: "us-west-2",
Provider: "fly.io",
},
},
{
name: "Captain with empty optional fields",
record: &atproto.HoldCaptain{
LexiconTypeID: atproto.CaptainCollection,
Owner: "did:plc:charlie789",
Public: true,
AllowAllCrew: true,
DeployedAt: "2025-10-16T12:00:00Z",
Region: ptrString(""),
Provider: ptrString(""),
record: &atproto.CaptainRecord{
Type: atproto.CaptainCollection,
Owner: "did:plc:charlie789",
Public: true,
AllowAllCrew: true,
DeployedAt: "2025-10-16T12:00:00Z",
Region: "",
Provider: "",
},
},
}
@@ -380,15 +375,15 @@ func TestCaptainRecord_CBORRoundtrip(t *testing.T) {
}
// Unmarshal from CBOR
var decoded atproto.HoldCaptain
var decoded atproto.CaptainRecord
err = decoded.UnmarshalCBOR(bytes.NewReader(cborBytes))
if err != nil {
t.Fatalf("UnmarshalCBOR failed: %v", err)
}
// Verify all fields match
if decoded.LexiconTypeID != tt.record.LexiconTypeID {
t.Errorf("LexiconTypeID mismatch: expected %s, got %s", tt.record.LexiconTypeID, decoded.LexiconTypeID)
if decoded.Type != tt.record.Type {
t.Errorf("Type mismatch: expected %s, got %s", tt.record.Type, decoded.Type)
}
if decoded.Owner != tt.record.Owner {
t.Errorf("Owner mismatch: expected %s, got %s", tt.record.Owner, decoded.Owner)
@@ -402,17 +397,11 @@ func TestCaptainRecord_CBORRoundtrip(t *testing.T) {
if decoded.DeployedAt != tt.record.DeployedAt {
t.Errorf("DeployedAt mismatch: expected %s, got %s", tt.record.DeployedAt, decoded.DeployedAt)
}
// Compare Region pointers (may be nil)
if (decoded.Region == nil) != (tt.record.Region == nil) {
t.Errorf("Region nil mismatch: expected %v, got %v", tt.record.Region, decoded.Region)
} else if decoded.Region != nil && *decoded.Region != *tt.record.Region {
t.Errorf("Region mismatch: expected %q, got %q", *tt.record.Region, *decoded.Region)
if decoded.Region != tt.record.Region {
t.Errorf("Region mismatch: expected %s, got %s", tt.record.Region, decoded.Region)
}
// Compare Provider pointers (may be nil)
if (decoded.Provider == nil) != (tt.record.Provider == nil) {
t.Errorf("Provider nil mismatch: expected %v, got %v", tt.record.Provider, decoded.Provider)
} else if decoded.Provider != nil && *decoded.Provider != *tt.record.Provider {
t.Errorf("Provider mismatch: expected %q, got %q", *tt.record.Provider, *decoded.Provider)
if decoded.Provider != tt.record.Provider {
t.Errorf("Provider mismatch: expected %s, got %s", tt.record.Provider, decoded.Provider)
}
})
}

Some files were not shown because too many files have changed in this diff Show More